from pathlib import Path
###########
## SETUP ##
############

configfile: "config/config.yaml"

# This doesn't work with the sif files
# Need to bind mount the data directory on the containers
# with --singularity-args "-B path/to/data_dir:/data"
# It is required to get taxonomy for wish though
DATA_DIR = Path(config.get("vhmnet").get("data_dir"))

def parse_samplesheet(samples_tsv):
    samples_dic = {}
    with open(samples_tsv, 'r') as fin:
        header_line = fin.readline()
        header_fields = [f.strip() for f in header_line.split('\t')]
        assert header_fields == ['sample', 'fasta'], "Malformatted samplesheet"
        for line in fin:
            if line.startswith('#'):
                pass
            else:
                fields = [f.strip() for f in line.split('\t')]
                samples_dic[fields[0]] = fields[1]
    return samples_dic

samples_dic = parse_samplesheet(config.get('samplesheet', 'samples.tsv'))

SAMPLES = list(samples_dic.keys())
TOOLS = [
        "vhulk", 
        "rafah", 
        "vhmnet",
        "wish",
        "htp"
        ]

def get_sample_fasta(wc):
    return samples_dic[wc.sample]

def collect_prediction_tsvs(wc):
    tsvs = []
    for tool in TOOLS:
        tool_tsv = "results/{}/{}/predictions.tsv".format(wc.sample, tool)
        tsvs.append(tool_tsv)
    return tsvs


#############
## TARGETS ##
#############

rule all:
    input:
        expand([
            #"results/{sample}/tmp/filtered.fa.gz",
            "results/{sample}/tmp/reflist.txt",
            "results/{sample}/{tool}/predictions.tsv",
            "results/{sample}/all_predictions.tsv",
            "results/{sample}/lca.tsv"
            ],
            sample=SAMPLES, tool= TOOLS)

###########
## RULES ##
###########

#rule size_filter:
#    input:
#        multifasta_fp = get_sample_fasta
#    output:
#        filtered_fasta = "results/{sample}/tmp/filtered.fa.gz"
#    threads: 4
#    log:
#        "logs/{sample}/size_filter.log"
#    params:
#        min_size = 5000
#    shell:
#        "seqkit seq -g -j {threads} -m {params.min_size} "
#        "{input.multifasta_fp} | gzip -c >{output.filtered_fasta} 2>{log}" 

rule split_multifasta:
    input:
        get_sample_fasta
    output:
        reflist = "results/{sample}/tmp/reflist.txt"
    log: "logs/{sample}/split_multifasta.log"
    params:
        genomes_dir = "results/{sample}/tmp/genomes",
        scrpt = srcdir("scripts/split_multifasta.py")
    shell:
        "mkdir -p {params.genomes_dir} && "
        "python {params.scrpt} "
        "-i {input} "
        "-o {params.genomes_dir} "
        "--write-reflist &>{log}"

# vHULK
rule run_vhulk:
    input:
        rules.split_multifasta.output.reflist
    output:
        results_csv = "results/{sample}/vhulk/results.csv"
    params:
        fasta_dir = "results/{sample}/tmp/genomes",
        output_dir = "results/{sample}/vhulk"
    log: 
        "logs/{sample}/vhulk.log"
    container:
        "library://papanikos_182/default/vhulk:1.0.0"
    threads: 8
    shell:
        "vHULK.py -i {params.fasta_dir} "
        "-o {params.output_dir} "
        "-t {threads} --all &>{log}"

rule process_vhulk:
    input:
        vhulk_csv = rules.run_vhulk.output.results_csv
    output:
        vhulk_tsv = "results/{sample}/vhulk/predictions.tsv"
    log: 
        "logs/{sample}/process_vhulk.log"
    shell:
        "tail -n+2 {input.vhulk_csv} | cut -d ',' -f 1,10,11 "
        "| tr ',' '\t' | sort -k1 1>{output.vhulk_tsv} 2>{log}"

# RAFAH
rule run_rafah:
    input:
        reflist = rules.split_multifasta.output.reflist,
    output:
        seq_info = "results/{sample}/rafah/{sample}_Seq_Info.tsv"
    params:
        prefix = "results/{sample}/rafah/{sample}",
        fasta_dir = "results/{sample}/tmp/genomes"
    log:
        "logs/{sample}/rafah.log"
    container: 
        "library://papanikos_182/default/rafah:0.1"
    threads: 8
    shell:
        "RaFAH_v0.1.pl --genomes_dir {params.fasta_dir}/ "
        "--extension fasta --threads {threads} "
        "--file_prefix {params.prefix} "
        "&>{log}"

rule process_rafah:
    input:
        seq_info = rules.run_rafah.output.seq_info
    output:
        rafah_tsv = "results/{sample}/rafah/predictions.tsv"
    log:
        "logs/{sample}/process_rafah.log"
    shell:
        "tail -n+2 {input.seq_info} | cut -f1,6,7 | sort -k1 "
        "> {output.rafah_tsv}"


# VirHostMatcher-Net
rule run_vhmnet:
    input:
        reflist = rules.split_multifasta.output.reflist,
    output:
        done = touch("results/{sample}/vhmnet/.done.txt")
    params:
        # use it with
        # snakemake --singularity-args "-B /path/to/data/:/data" ...
        data_dir = "/data",
        tmp_dir = "results/{sample}/vhmnet/tmp",
        output_dir = "results/{sample}/vhmnet",
        fasta_dir = "results/{sample}/tmp/genomes"
    threads: 12
    container:
        "library://papanikos_182/default/vhmnet:0.1"
    log:
        "logs/{sample}/vhmnet.log"
    shell:
        "VirHostMatcher-Net.py -q {params.fasta_dir} "
        "-t {threads} "
        "--short-contig "
        "-i {params.tmp_dir} "
        "-d {params.data_dir} "
        "-q {params.fasta_dir} "
        "-o {params.output_dir} "
        "&>{log}"

rule process_vhmnet:
    input:
        rules.run_vhmnet.output.done
    output:
        vhmnet_tsv = "results/{sample}/vhmnet/predictions.tsv"
    params:
        predictions_dir = "./results/{sample}/vhmnet/predictions"
    shell:
        """
        for f in $(find -wholename "{params.predictions_dir}/*.csv" -type f);
        do 
            contig_id=$(basename ${{f}} | sed -e 's/_prediction.csv//')
            host_score=$(tail -n1 ${{f}} | cut -f8,10 -d',' | tr ',' '\t')
            echo -e "$contig_id\t$host_score" >> {output.vhmnet_tsv}.tmp;
        done
        sort -k1 {output.vhmnet_tsv}.tmp > {output.vhmnet_tsv}
        rm -f {output.vhmnet_tsv}.tmp
        """
# WIsH
rule run_wish:
    input:
        reflist = rules.split_multifasta.output.reflist,
    output:
        prediction_list = "results/{sample}/wish/prediction.list",
        ll_mat = "results/{sample}/wish/llikelihood.matrix"
    log:
        "logs/{sample}/wish.log"
    threads: 8
    container:
        "library://papanikos_182/default/wish:1.0"
    params:
        # snakemake --singularity-args binds the whole data dir to /data
        # see run_vhmnet rule
        models_dir = "/data/host_wish_model",
        output_dir = "results/{sample}/wish",
        fasta_dir = "results/{sample}/tmp/genomes"
    shell:
        "mkdir -p {params.output_dir} && "
        "WIsH -c predict -g {params.fasta_dir} "
        "-t {threads} -b "
        "-m {params.models_dir} -r {params.output_dir} "
        "&>{log}"

rule process_wish:
    input:
        prediction_list = rules.run_wish.output.prediction_list
    output:
        predictions_tsv = "results/{sample}/wish/predictions.tsv"
    params:
        hostTaxa_pkl = DATA_DIR.joinpath("tables/hostTaxa.pkl"),
        scrpt = srcdir("scripts/wish_add_taxonomy.py")
    log:
        "logs/{sample}/process_wish.log"
    shell:
        "python {params.scrpt} "
        "-i {input.prediction_list} -t {params.hostTaxa_pkl} "
        "-o {output.predictions_tsv} 2>{log}"

# HTP
rule run_htp:
    input:
        reflist = rules.split_multifasta.output.reflist
    output:
        htp_raw = "results/{sample}/htp/raw.txt"
    log:
        stderr = "logs/{sample}/htp.stderr"
    params:
        fasta_dir = "./results/{sample}/tmp/genomes"
    container:
        "library://papanikos_182/default/htp:1.0.2"
    shell:
        """
        printf "contig\traw_pred\n" > {output.htp_raw}
        for f in $(find {params.fasta_dir} -wholename "*.fasta");
        do
            contig_id=$(basename ${{f}} | sed -e 's/\.fasta//' )
            pred=$(viruses_classifier \
                    --classifier svc \
                    --nucleic_acid dna \
                    -p ${{f}} 2>{log.stderr})
            echo -e $contig_id\t$pred >> {output.htp_raw}
        done
        """

rule process_htp:
    input:
        htp_raw = rules.run_htp.output.htp_raw
    output:
        predictions_tsv = "results/{sample}/htp/predictions.tsv"
    params:
        ob = "{"
    shell:
        """
        tail -n +2 results/{wildcards.sample}/htp/raw.txt | cut -f1 -d','| \
                sed -r "s/ \{params.ob}'phage'\: /\t/" | sort -k1 \
                >{output.predictions_tsv}
        """


# Aggregate
rule collect_hosts:
    input:
        collect_prediction_tsvs
    output:
        sample_tsv = "results/{sample}/all_predictions.tsv"
    params:
        scrpt = srcdir("scripts/cat_predictions.py")
    log:
        "logs/{sample}/cat_predictions.log"
    shell:
        "python {params.scrpt} "
        "-i {input} -o {output.sample_tsv} "
        "2>{log}"


rule lca:
    input:
        predictions_tsv = rules.collect_hosts.output.sample_tsv
    output:
        lca_tsv = "results/{sample}/lca.tsv"
    params:
        scrpt = srcdir("scripts/get_lca.py")
    log:
        "logs/{sample}/get_lca.log"
    shell:
        "python {params.scrpt} "
        "-i {input.predictions_tsv} "
        "-o {output.lca_tsv} 2>{log}"