from pathlib import Path
###########
## SETUP ##
############

configfile: "config/config.yaml"

# This doesn't work with the sif files
# Need to bind mount the data directory on the containers
# with --singularity-args "-B path/to/data_dir:/data"
# It is required to get taxonomy for wish though
DATA_DIR = Path(config.get("vhmnet").get("data_dir"))

def parse_samplesheet(samples_tsv):
	samples_dic = {}
	with open(samples_tsv, 'r') as fin:
		header_line = fin.readline()
		header_fields = [f.strip() for f in header_line.split('\t')]
		assert header_fields == ['sample', 'fasta'], "Malformatted samplesheet"
		for line in fin:
			fields = [f.strip() for f in line.split('\t')]
			samples_dic[fields[0]] = fields[1]
	return samples_dic

samples_dic = parse_samplesheet(config.get('samplesheet', 'samples.tsv'))

SAMPLES = list(samples_dic.keys())
TOOLS = [
		"vhulk", 
		"rafah", 
		"vhmnet",
		"wish",
		"htp"
		]

def get_sample_fasta(wc):
	return samples_dic[wc.sample]

def collect_prediction_tsvs(wc):
	tsvs = []
	for tool in TOOLS:
		tool_tsv = "results/{}/{}/predictions.tsv".format(wc.sample, tool)
		tsvs.append(tool_tsv)
	return tsvs


#############
## TARGETS ##
#############

rule all:
	input:
		expand([
			"results/{sample}/tmp/filtered.fa.gz",
			"results/{sample}/tmp/reflist.txt",
			"results/{sample}/{tool}/predictions.tsv",
			"results/{sample}/all_predictions.tsv",
			"results/{sample}/lca.tsv"
			],
			sample=SAMPLES, tool= TOOLS)

###########
## RULES ##
###########

rule size_filter:
	input:
		multifasta_fp = get_sample_fasta
	output:
		filtered_fasta = "results/{sample}/tmp/filtered.fa.gz"
	threads: 4
	log:
		"logs/{sample}/size_filter.log"
	params:
		min_size = 5000
	shell:
		"seqkit seq -g -j {threads} -m {params.min_size} "
		"{input.multifasta_fp} | gzip -c >{output.filtered_fasta} 2>{log}" 

rule split_multifasta:
	input:
		multifasta_fp = rules.size_filter.output.filtered_fasta
	output:
		reflist = "results/{sample}/tmp/reflist.txt"
	log: "logs/{sample}/split_multifasta.log"
	params:
		genomes_dir = "results/{sample}/tmp/genomes",
	shell:
		"mkdir -p {params.genomes_dir} && "
		"python workflow/scripts/split_multifasta.py "
		"-i {input.multifasta_fp} "
		"-o {params.genomes_dir} "
		"--write-reflist &>{log}"

# vHULK
rule run_vhulk:
	input:
		reflist = rules.split_multifasta.output.reflist
	output:
		done_txt = touch("results/{sample}/vhulk/.done.txt"),
		results_csv = "results/{sample}/vhulk/results/results.csv"
	params:
		fasta_dir = "results/{sample}/tmp/genomes",
		output_dir = "results/{sample}/vhulk"
	log: 
		"logs/{sample}/vhulk.log"
	container:
		"library://papanikos_182/default/vhulk:0.1"
	threads: 8
	shell:
		"vHULK-v0.1.py -i {params.fasta_dir} "
		"-t {threads} &>{log} && "
		"mv {params.fasta_dir}/results {params.output_dir}"

rule filter_vhulk:
	input:
		rules.run_vhulk.output.done_txt,
		vhulk_csv = rules.run_vhulk.output.results_csv
	output:
		vhulk_tsv = "results/{sample}/vhulk/predictions.tsv"
	shell:
		"tail -n+2 {input.vhulk_csv} | cut -d ',' -f 1,10,11 "
		"| tr ',' '\t' | sort -k1 > {output.vhulk_tsv}"

# RAFAH
rule run_rafah:
	input:
		reflist = rules.split_multifasta.output.reflist,
	output:
		seq_info = "results/{sample}/rafah/{sample}_Seq_Info.tsv"
	params:
		prefix = "results/{sample}/rafah/{sample}",
		fasta_dir = "results/{sample}/tmp/genomes"
	log:
		"logs/{sample}/rafah.log"
	container: 
		"library://papanikos_182/default/rafah:0.1"
	threads: 8
	shell:
		"RaFAH_v0.1.pl --genomes_dir {params.fasta_dir}/ "
		"--extension fasta --threads {threads} "
		"--file_prefix {params.prefix} "
		"&>{log}"

rule filter_rafah:
	input:
		seq_info = rules.run_rafah.output.seq_info
	output:
		rafah_tsv = "results/{sample}/rafah/predictions.tsv"
	shell:
		"tail -n+2 {input.seq_info} | cut -f1,6,7 | sort -k1 "
		"> {output.rafah_tsv}"

# VirHostMatcher-Net
rule run_vhmnet:
	input:
		reflist = rules.split_multifasta.output.reflist,
	output:
		done = touch("results/{sample}/vhmnet/.done.txt")
	params:
		# use it with
		# snakemake --singularity-args "-B /path/to/data/:/data" ...
		data_dir = "/data",
		tmp_dir = "results/{sample}/vhmnet/tmp",
		output_dir = "results/{sample}/vhmnet",
		fasta_dir = "results/{sample}/tmp/genomes"
	threads: 12
	container:
		"library://papanikos_182/default/vhmnet:0.1"
	log:
		"logs/{sample}/vhmnet.log"
	shell:
		"VirHostMatcher-Net.py -q {params.fasta_dir} "
		"-t {threads} "
		"--short-contig "
		"-i {params.tmp_dir} "
		"-d {params.data_dir} "
		"-q {params.fasta_dir} "
		"-o {params.output_dir} "
		"&>{log}"

rule filter_vhmnet:
	input:
		rules.run_vhmnet.output.done
	output:
		vhmnet_tsv = "results/{sample}/vhmnet/predictions.tsv"
	params:
		predictions_dir = "./results/{sample}/vhmnet/predictions"
	shell:
		"""
		for f in $(find -wholename "{params.predictions_dir}/*.csv" -type f);
		do 
			contig_id=$(basename ${{f}} | sed -e 's/_prediction.csv//')
			host_score=$(tail -n1 ${{f}} | cut -f8,10 -d',' | tr ',' '\t')
			echo -e "$contig_id\t$host_score" >> {output.vhmnet_tsv}.tmp;
		done
		sort -k1 {output.vhmnet_tsv}.tmp > {output.vhmnet_tsv}
		rm -f {output.vhmnet_tsv}.tmp
		"""
# WIsH
rule run_wish:
	input:
		# Run vhulk first since it writing in the input
		rules.run_vhulk.output.done_txt,
		reflist = rules.split_multifasta.output.reflist,
	output:
		prediction_list = "results/{sample}/wish/prediction.list",
		ll_mat = "results/{sample}/wish/llikelihood.matrix"
	log:
		"logs/{sample}/wish.log"
	threads: 8
	container:
		"library://papanikos_182/default/wish:1.0"
	params:
		# snakemake --singularity-args binds the whole data dir to /data
		# see run_vhmnet rule
		models_dir = "/data/host_wish_model",
		output_dir = "results/{sample}/wish",
		fasta_dir = "results/{sample}/tmp/genomes"
	shell:
		"mkdir -p {params.output_dir} && "
		"WIsH -c predict -g {params.fasta_dir} "
		"-t {threads} -b "
		"-m {params.models_dir} -r {params.output_dir} "
		"&>{log}"

rule process_wish:
	input:
		prediction_list = rules.run_wish.output.prediction_list
	output:
		predictions_tsv = "results/{sample}/wish/predictions.tsv"
	params:
		hostTaxa_pkl = DATA_DIR.joinpath("tables/hostTaxa.pkl")
	shell:
		"python workflow/scripts/wish_add_taxonomy.py "
		"-i {input.prediction_list} -t {params.hostTaxa_pkl} "
		"-o {output.predictions_tsv}"

# HTP
rule run_htp:
	input:
		reflist = rules.split_multifasta.output.reflist
	output:
		htp_raw = "results/{sample}/htp/raw.txt"
	log:
		stderr = "logs/{sample}/htp.stderr"
	params:
		fasta_dir = "./results/{sample}/tmp/genomes"
	container:
		"library://papanikos_182/default/htp:1.0.2"
	shell:
		"""
		printf "contig\traw_pred\n" > {output.htp_raw}
		for f in $(find {params.fasta_dir} -wholename "*.fasta");
		do
			contig_id=$(basename ${{f}} | sed -e 's/\.fasta//' )
			pred=$(viruses_classifier \
					--classifier svc \
					--nucleic_acid dna \
					-p ${{f}} 2>{log.stderr})
			echo -e $contig_id\t$pred >> {output.htp_raw}
		done
		"""

rule process_htp:
	input:
		htp_raw = rules.run_htp.output.htp_raw
	output:
		predictions_tsv = "results/{sample}/htp/predictions.tsv"
	params:
		ob = "{"
	shell:
		"""
		tail -n +2 results/{wildcards.sample}/htp/raw.txt | cut -f1 -d','| \
				sed -r "s/ \{params.ob}'phage'\: /\t/" | sort -k1 \
				>{output.predictions_tsv}
		"""


# Aggregate
rule collect_hosts:
	input:
		collect_prediction_tsvs
	output:
		sample_tsv = "results/{sample}/all_predictions.tsv"
	shell:
		"echo -e 'contig\thtp_proba\tvhulk_pred\tvhulk_score\trafah_pred\trafah_score\t"
		"vhmnet_pred\tvhmnet_score\twish_pred\twish_score'"
		">{output.sample_tsv} && "
		"paste <(cat {input[4]}) "
		"<(cut -f2,3 {input[0]}) "
		"<(cut -f2,3 {input[1]}) "
		"<(cut -f2,3 {input[2]}) "
		"<(cut -f2,3 {input[3]}) "
		">>{output.sample_tsv}"

rule lca:
	input:
		predictions_tsv = rules.collect_hosts.output.sample_tsv
	output:
		lca_tsv = "results/{sample}/lca.tsv"
	log:
		"logs/{sample}/lca.log"
	shell:
		"python workflow/scripts/get_lca.py "
		"-i {input.predictions_tsv} "
		"-o {output.lca_tsv} &>{log}"