diff --git a/workflow/Snakefile b/workflow/Snakefile
new file mode 100644
index 0000000000000000000000000000000000000000..3212bff6ce99abe8d19cc9e40e732b8e43eec4a6
--- /dev/null
+++ b/workflow/Snakefile
@@ -0,0 +1,102 @@
+configfile: "config/config.yaml"
+
+# This defines the number of negatives
+NEGATIVES = config["negatives"]
+
+# This defines the names of the negatives
+DATASETS = ['N{}'.format(i+1) for i in range(0,NEGATIVES)]
+# Append the 'positives'. Now a rule can expand on all datasets
+DATASETS.append('positives')
+
+
+include: "rules/get_data.smk"
+include: "rules/pre_process.smk"
+include: "rules/construct_datasets.smk"
+
+
+rule all:
+	input:
+		# Raw data required
+		# Produced by get_data.smk
+		"data/genomes/phages_refseq.fasta",
+		"data/interactions/intact.txt",
+		"data/pvogs/all.hmm",
+		"data/pvogs/VOGProteinTable.txt",
+		"data/taxonomy_db/taxa.sqlite",
+		"data/taxonomy_db/taxa.sqlite.traverse.pkl",
+
+		# Pre-processing for calculating matrices
+		# Produced by pre_process.smk
+		"results/scores.tsv",
+		"results/filtered_scores.tsv",
+		"results/annotations.tsv",
+
+		 # Getting to pvogs from proteins and calcualating features
+		 # Produced by pre_process.smk
+		 expand(["results/interaction_datasets/{dataset}/{dataset}.interactions.tsv",
+				"results/interaction_datasets/{dataset}/{dataset}.proteins.faa",
+				"results/interaction_datasets/{dataset}/{dataset}.pvogs_interactions.tsv",
+				"results/interaction_datasets/{dataset}/{dataset}.features.tsv"],
+				dataset=DATASETS),
+
+		 # Major results
+		 # Produced by this file.
+		 "results/RF/best_model.pkl",
+		 "results/RF/best_model_id.txt",
+		 "results/predictions.tsv",
+		 "results/final_training_set.tsv"
+
+
+
+checkpoint random_forest:
+	input:
+		expand("results/interaction_datasets/{dataset}/{dataset}.features.tsv", dataset=DATASETS),
+		filtered_master_tsv = rules.filter_scores_table.output.filtered_master_tsv
+	output:
+		best_model = "results/RF/best_model.pkl",
+		best_model_id = "results/RF/best_model_id.txt" # This only contains the name as a string...
+	log:
+		"results/logs/processed_notebook.py.ipynb"
+	conda:
+		"envs/pvogs_jupy.yml"
+	threads: 16
+	notebook:
+		"notebooks/data_processing.py.ipynb"
+
+
+def get_best_model_id(wildcards):
+	"""
+	Helper function to get the id of the dataset that gave the
+	best results from the notebook search.
+	"""
+	checkpoint_output = checkpoints.random_forest.get(**wildcards).output[1]
+	with open(checkpoint_output, 'r') as fin:
+		best_model_id = fin.read().strip()
+	return best_model_id
+
+rule predict:
+	input:
+		positives_features_tsv = "results/interaction_datasets/positives/positives.features.tsv",
+		model_fp = "results/RF/best_model.pkl",
+		filtered_scores_tsv = rules.filter_scores_table.output.filtered_master_tsv
+	output:
+		predictions_tsv = "results/predictions.tsv",
+		final_training_set_tsv = "results/final_training_set.tsv"
+	log:
+		"results/logs/predict.log"
+	conda:
+		"envs/pvogs_jupy.yml"
+	params:
+		# This is read from the notebook output
+		dataset_string = get_best_model_id,
+	threads: 16
+	shell:
+		"python workflow/scripts/predict.py "
+		"-j {threads} "
+		"-m {input.model_fp} "
+		"-p {input.positives_features_tsv} "
+		"-n results/interaction_datasets/{params.dataset_string}/{params.dataset_string}.features.tsv "
+		"-t {input.filtered_scores_tsv} "
+		"-o {output.predictions_tsv} "
+		"&>{log}"
+
diff --git a/workflow/envs/pvogs.yml b/workflow/envs/pvogs.yml
new file mode 100644
index 0000000000000000000000000000000000000000..04a6e5aef80c2aae205054dd08aa902ae23e5479
--- /dev/null
+++ b/workflow/envs/pvogs.yml
@@ -0,0 +1,12 @@
+channels:
+  - bioconda
+  - conda-forge
+dependencies:
+  - emboss==6.6.0.0
+  - biopython==1.77
+  - fastani==1.3
+  - comparem==0.1.1
+  - pandas==1.0.3
+  - ete3==3.1.1
+  - requests==2.23.0
+  - hmmer==3.2.1
diff --git a/workflow/envs/pvogs_jupy.yml b/workflow/envs/pvogs_jupy.yml
new file mode 100644
index 0000000000000000000000000000000000000000..805b9c75f5c715c5e6a555beb5112b0adad9633f
--- /dev/null
+++ b/workflow/envs/pvogs_jupy.yml
@@ -0,0 +1,11 @@
+channels:
+  - bioconda
+  - conda-forge
+dependencies:
+  - python==3.7
+  - jupyter
+  - scipy==1.4.1
+  - scikit-learn==0.21.3
+  - seaborn==0.10.1
+  - pandas>=1.0*
+
diff --git a/workflow/notebooks/data_processing.py.ipynb b/workflow/notebooks/data_processing.py.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..6cc3cf7311fac6232197d0d63f634a0035f06d60
--- /dev/null
+++ b/workflow/notebooks/data_processing.py.ipynb
@@ -0,0 +1,1228 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 26,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# ML\n",
+    "from sklearn.model_selection import train_test_split, RandomizedSearchCV\n",
+    "from sklearn.ensemble import RandomForestClassifier\n",
+    "from sklearn import metrics\n",
+    "\n",
+    "from scipy.cluster.hierarchy import dendrogram, linkage\n",
+    "from scipy.spatial.distance import squareform, pdist\n",
+    "\n",
+    "# Data handling\n",
+    "import pandas as pd\n",
+    "import numpy as np\n",
+    "\n",
+    "from collections import Counter\n",
+    "import operator\n",
+    "\n",
+    "# Plotting\n",
+    "%matplotlib inline\n",
+    "# import matplotlib.patches as mpatches\n",
+    "import matplotlib.pyplot as plt\n",
+    "import seaborn as sns\n",
+    "\n",
+    "# Read and write\n",
+    "from pathlib import Path # filesystem\n",
+    "\n",
+    "## sklearn is giving me\n",
+    "## sklearn/model_selection/_search.py:814: DeprecationWarning: \n",
+    "## The default of the `iid` parameter will change from True to False in version 0.22 \n",
+    "## and will be removed in 0.24. This will change numeric results when test-set sizes are unequal.\n",
+    "##  DeprecationWarning)\n",
+    "\n",
+    "# Suppress that\n",
+    "import warnings\n",
+    "warnings.filterwarnings('ignore', category=DeprecationWarning)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Initial setup"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "## GLOBAL VARIABLES\n",
+    "\n",
+    "These are variables that are used throughout the notebook"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Set this to False if you don't want to scale the feature values\n",
+    "SCALE_DATA = True"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 28,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Only change these if you know what you are doing\n",
+    "\n",
+    "# A dictionary that maps original feature names to final ones\n",
+    "FEATURES_MAP = dict(zip(['jaccard_score', 'same_score', \n",
+    "                            'inwards_score', 'outwards_score', \n",
+    "                            'avg_distance', 'mean_ani', \n",
+    "                            'mean_aai'],\n",
+    "                        ['Co-occurrence', 'Co-orientation',\n",
+    "                            'Convergent', 'Divergent',\n",
+    "                            'Average Distance', 'Mean ANI',\n",
+    "                            'Mean AAI']\n",
+    "                       )\n",
+    "                   )\n",
+    "\n",
+    "# A list to only get the features\n",
+    "# Useful for slicing data frames\n",
+    "FEATURES = ['jaccard_score', 'same_score', \n",
+    "            'inwards_score', 'outwards_score', \n",
+    "            'avg_distance', 'mean_ani', \n",
+    "            'mean_aai']\n",
+    "                    \n",
+    "# Features with label appended\n",
+    "# useful for plotting based on label\n",
+    "FEAT_LAB = FEATURES.copy()\n",
+    "FEAT_LAB.append('label')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "## HELPER FUNCTIONS"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "* Reading data\n",
+    "\n",
+    "Some of these functions were written **before** filtering for max distance, so they try to address this issue."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def read_scores_table(scores_fp, label=None):\n",
+    "    \"\"\"\n",
+    "    Read a table from a tsv file provided as a path.\n",
+    "    Return a dataframe.\n",
+    "    \n",
+    "    If label is set, append a column named `label` with\n",
+    "    the provided value\n",
+    "    \"\"\"\n",
+    "    scores_df = pd.read_csv(scores_fp, sep='\\t')\n",
+    "    scores_df['interaction'] = scores_df['pvog1'] + '-' + scores_df['pvog2']\n",
+    "    # Unpack features list because I want interaction prepended.\n",
+    "    scores_df = scores_df[['interaction', *FEATURES]]    \n",
+    "    if label is not None:\n",
+    "        scores_df['label'] = label\n",
+    "    return scores_df\n",
+    "\n",
+    "\n",
+    "def concat_data_frames(pos_df, \n",
+    "                       neg_df, \n",
+    "                       subsample=False,\n",
+    "                       clean=True,\n",
+    "                       is_scaled=True,\n",
+    "                      ):\n",
+    "    \"\"\"\n",
+    "    Concatenate two dataframes.\n",
+    "    \n",
+    "    subsample:bool \n",
+    "        Subsample the `neg_df` to a number of\n",
+    "        observations equal to the number of pos_df \n",
+    "        (balance datasets)\n",
+    "    \n",
+    "    clean:bool\n",
+    "        Remove observations with a value of `avg_distance` == 100000 or 1.\n",
+    "    \n",
+    "    is_scaled:bool\n",
+    "        The features have been scaled to a range 0-1\n",
+    "    \n",
+    "    Return:\n",
+    "    concat_df: pd.DataFrame\n",
+    "        The concatenated data frame\n",
+    "    \"\"\"\n",
+    "    \n",
+    "    if clean is True:\n",
+    "        pos_df = remove_ambiguous(pos_df)\n",
+    "        neg_df = remove_ambiguous(neg_df)\n",
+    "    \n",
+    "    n_positives = pos_df.shape[0]\n",
+    "    n_negatives = neg_df.shape[0]\n",
+    "    \n",
+    "    # Remove possible duplicate interactions from the negatives\n",
+    "    # This might happen because of the random selection when creating the set\n",
+    "    # Why I also select more negatives to begin with\n",
+    "    neg_df = neg_df.loc[~neg_df['interaction'].isin(pos_df['interaction'])]\n",
+    "    \n",
+    "    if (n_positives != n_negatives) and (subsample is True):\n",
+    "        neg_df = neg_df.sample(n=n_positives, random_state=1)\n",
+    "    concat_df = pd.concat([pos_df, neg_df])\n",
+    "    \n",
+    "    assert concat_df[concat_df.duplicated(subset=['interaction'])].empty == True, concat_df.loc[concat_df.duplicated(subset=['interaction'], keep=False)]\n",
+    "    \n",
+    "    return concat_df\n",
+    "\n",
+    "def scale_df(input_df):\n",
+    "    \"\"\"\n",
+    "    Scale all feature values in the data frame to [0-1].\n",
+    "    \"\"\"\n",
+    "    maxes = input_df[FEATURES].max(axis=0)\n",
+    "    scaled_data = input_df[FEATURES].divide(maxes)\n",
+    "    if 'label' in input_df.columns:\n",
+    "        scaled_df = pd.concat([input_df['interaction'], scaled_data, input_df['label']], axis=1)\n",
+    "    else:\n",
+    "        scaled_df = pd.concat([input_df['interaction'], scaled_data], axis=1)\n",
+    "    return scaled_df\n",
+    "\n",
+    "def remove_ambiguous(input_df):\n",
+    "    \"\"\"\n",
+    "    Select observations in the `input_df` that have feature values\n",
+    "    \"\"\"\n",
+    "    df_clean = input_df[(input_df.jaccard_score != 0)] # This is true if they don't co-occur\n",
+    "    return df_clean\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Analysis\n",
+    "\n",
+    "From this point on all the magic happens"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 30,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Get all the files in a list of Path objects\n",
+    "filepaths = list(map(Path, snakemake.input))\n",
+    "# Remove the last element\n",
+    "filtered_scores_tsv = filepaths.pop(-1)\n",
+    "negatives_fp_list = []\n",
+    "for fp in filepaths:\n",
+    "    # Grab the positives\n",
+    "    if 'positives' in fp.name:\n",
+    "        positives_tsv = fp\n",
+    "    else:\n",
+    "        # Append to the list of negatives\n",
+    "        negatives_fp_list.append(fp)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Positives are always the same\n",
+    "pos_df = read_scores_table(positives_tsv, label = 1)\n",
+    "\n",
+    "# Negatives are stored in the list above\n",
+    "# Select one for visualization\n",
+    "neg_df = read_scores_table(negatives_fp_list[3], label = 0)\n",
+    "if SCALE_DATA is True:\n",
+    "    pos_df = scale_df(pos_df)\n",
+    "    neg_df = scale_df(neg_df)\n",
+    "    \n",
+    "posneg = concat_data_frames(pos_df, neg_df,\n",
+    "                           clean=True,\n",
+    "                           subsample=True,\n",
+    "                           is_scaled=SCALE_DATA)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def plot_features_correlation(data_df):\n",
+    "    # Subset to features only\n",
+    "    cor_data = data_df[FEATURES]\n",
+    "    cor_df = cor_data.corr()\n",
+    "    \n",
+    "    plt.figure(figsize=(10, 8))\n",
+    "    sns.heatmap(cor_df, annot=True, \n",
+    "            cmap=sns.color_palette(\"RdBu_r\"), \n",
+    "            vmin=-1, \n",
+    "            vmax=1, \n",
+    "            cbar_kws = {\"shrink\" : 0.5, \n",
+    "                        \"ticks\" : [-1, -0.5, 0, 0.5, 1],\n",
+    "                       }\n",
+    "           )    \n",
+    "    plt.show()\n",
+    "\n",
+    "## THIS IS NO LONGER USED in favor of the pairplot\n",
+    "# you can call it with\n",
+    "#  plot_features_dist(posneg, FEATURES, is_scaled = SCALE_DATA)\n",
+    "def plot_features_dist(data, features, is_scaled=False):\n",
+    "    \"\"\"\n",
+    "    Make a gridplot of all features distributions\n",
+    "    \n",
+    "    data: A pd.DataFrame with features and a label column\n",
+    "    features: A list of features to use\n",
+    "    \"\"\"\n",
+    "    if is_scaled:\n",
+    "        distance_threshold = 1\n",
+    "    else:\n",
+    "        distance_threshold = 1000000\n",
+    "        \n",
+    "    df_list=[data.loc[:,[f, 'label']] for f in features]\n",
+    "    # Initialize a figure\n",
+    "    fig, axes = plt.subplots(3, 3, figsize = (8,4), dpi=300)\n",
+    "    # Flatten the axes for plotting in the correct ax object\n",
+    "    ax_list = axes.flatten()\n",
+    "\n",
+    "    # Workaround to add the legend\n",
+    "    red_patch = mpatches.Patch(color='red', label='positives', alpha=0.5)\n",
+    "    blue_patch = mpatches.Patch(color='b', label='negatives', alpha=0.5)\n",
+    "    plt.figlegend(handles=[red_patch, blue_patch], loc='lower center')\n",
+    "    # Make the plots\n",
+    "    for i in range(len(df_list)):\n",
+    "        df=df_list[i]\n",
+    "        if features[i] == \"avg_distance\":\n",
+    "            x0 = df.loc[((df.label==0) & (df.avg_distance != distance_threshold)), features[i]]\n",
+    "            x1 = df.loc[((df.label==1) & (df.avg_distance != distance_threshold)), features[i]]\n",
+    "        else:\n",
+    "            x0 = df.loc[df.label==0, features[i]] \n",
+    "            x1 = df.loc[df.label==1, features[i]]\n",
+    "#     a1 = sns.distplot(x0, ax=axes[coords[0], coords[1]], axlabel=features[i],color='b', label='negatives')\n",
+    "#     a2 =sns.distplot(x1, ax=axes[coords[0], coords[1]], axlabel=features[i], color='r', label='positives')\n",
+    "        ax = ax_list[i]\n",
+    "        ax.hist(x0, color = 'b', alpha=.5)\n",
+    "        ax.hist(x1, color = 'r', alpha= .5)\n",
+    "        ax.set_title(\"{} (N={}/{})\".format(features[i], \n",
+    "                                           x0.shape[0], \n",
+    "                                           x1.shape[0]), \n",
+    "                                           fontsize=8\n",
+    "                    )\n",
+    "\n",
+    "    plt.subplots_adjust(hspace=0.7)\n",
+    "\n",
+    "    axes[2,2].remove()\n",
+    "    axes[2,1].remove()    "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 33,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pair = sns.pairplot(posneg[FEAT_LAB],\n",
+    "             hue = 'label', \n",
+    "             vars=FEATURES,\n",
+    "             markers = [\"+\", \"x\"],\n",
+    "             palette = {\n",
+    "                 0 : 'b',\n",
+    "                 1 : 'r',\n",
+    "             },\n",
+    "#              diag_kind=\"hist\",\n",
+    "             plot_kws = {\n",
+    "                 \"alpha\": 0.65,\n",
+    "                 },\n",
+    "             diag_kws={\n",
+    "                 \"clip\" : [0. , 1.],\n",
+    "                 },\n",
+    "                )\n",
+    "pair.set(xlim=(-0.1, 1.1))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 34,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "plot_features_correlation(posneg)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "## Cluster interactions"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 35,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Make a copy of the original df for plotting\n",
+    "d_posneg = posneg\n",
+    "\n",
+    "D_posneg = pdist(d_posneg[FEATURES])\n",
+    "D_posneg = D_posneg / D_posneg.max() # Scale the distances to be [0-1]\n",
+    "Z_posneg = linkage(D_posneg, method=\"average\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 36,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Append a color column based on the label value\n",
+    "# https://stackoverflow.com/a/26887820\n",
+    "\n",
+    "def color_label_interaction(row):\n",
+    "    if row['label'] == 1:\n",
+    "        return 'red'\n",
+    "    else:\n",
+    "        return 'blue'\n",
+    "\n",
+    "d_posneg['color'] = d_posneg.apply(lambda row: color_label_interaction(row), axis=1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 37,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "plt.figure(figsize=(20, 30))\n",
+    "dn = dendrogram(Z_posneg,\n",
+    "                labels=d_posneg['interaction'].values, \n",
+    "                orientation='left',\n",
+    "                leaf_font_size=12,\n",
+    "                color_threshold=0.1,\n",
+    "                above_threshold_color='black',\n",
+    "               )\n",
+    "plt.axvline(x=0.1, linestyle='--')\n",
+    "plt.title(\"Average linkage based on Euclidean distances\", size=14)\n",
+    "# plt.tick_params(axis='y', labelsize=12, direction='out')\n",
+    "ax=plt.gca()\n",
+    "ylbls = ax.get_ymajorticklabels()\n",
+    "for lbl in ylbls:\n",
+    "    lbl.set_color(d_posneg.loc[d_posneg.interaction == lbl.get_text(), 'color'].values[0])\n",
+    "plt.show()\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "### Distribution of distances pos vs neg"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 38,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "D_pos = pdist(pos_df[FEATURES])\n",
+    "D_neg = pdist(neg_df[FEATURES])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 39,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fig, ax = plt.subplots(figsize=(10,8))\n",
+    "\n",
+    "ax = sns.distplot(D_pos ,\n",
+    "             hist=False,\n",
+    "            kde_kws = {\"shade\" : True,\n",
+    "                       \"alpha\" : 0.5,\n",
+    "                       \"clip\" : [0. , 1.],\n",
+    "                       \"label\": 'Positives'\n",
+    "                       \n",
+    "                  },\n",
+    "            color='r')\n",
+    "ax = sns.distplot(D_neg,\n",
+    "            color='b',\n",
+    "             hist=False,\n",
+    "             kde_kws = {\"shade\" : True,\n",
+    "                       \"alpha\": 0.5,\n",
+    "                       \"clip\" : [0. , 1.],\n",
+    "                       \"label\": \"Negatives\"}\n",
+    "            )\n",
+    "ax.set_title(\"Distances distribution\", size=14)\n",
+    "ax.set_xlabel(\"Distance\", size=14)\n",
+    "ax.set_ylabel(\"Density\", size=14)\n",
+    "ax.tick_params(axis='both', which='major', labelsize=12)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Random Forest"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "## Model selection and hyperparameter tuning"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "Define a search space for the parameters to be sampled from. \n",
+    "Less exhaustive than a full grid search, but quicker."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 40,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Number of trees\n",
+    "n_estimators = [int(x) for x in range(100,5000,200)]\n",
+    "\n",
+    "# Number of features to consider at every split\n",
+    "# 'auto' uses sqrt(n_features), a float uses int(float * n_features)\n",
+    "max_features = ['auto', .5, 1.]\n",
+    "\n",
+    "# Maximum number of levels in each tree\n",
+    "max_depth = [int(x) for x in range(10, 60, 10)]\n",
+    "max_depth.append(None)\n",
+    "\n",
+    "# Minimum number of samples required to split a node\n",
+    "min_samples_split = [2, 5, 10]\n",
+    "\n",
+    "# Minimum number of samples required at each leaf node\n",
+    "min_samples_leaf = [1, 2, 4]\n",
+    "\n",
+    "# Method for selecting samples for training\n",
+    "bootstrap = [True, False]\n",
+    "\n",
+    "# Put them all in a dictionary\n",
+    "params_space = {\n",
+    "    \"max_features\": max_features,\n",
+    "    \"n_estimators\" : n_estimators,\n",
+    "    \"max_depth\" : max_depth,\n",
+    "    \"min_samples_split\" : min_samples_split,\n",
+    "    \"min_samples_leaf\" : min_samples_leaf,\n",
+    "    \"bootstrap\": bootstrap,\n",
+    "    }"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 41,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Create a color mapping for the negative sets\n",
+    "\n",
+    "negative_sets = []\n",
+    "for negative in negatives_fp_list:\n",
+    "    set_name = negative.name.split('.')[0]\n",
+    "    negative_sets.append(set_name)\n",
+    "    \n",
+    "color_palette = sns.color_palette(\"muted\", len(negative_sets))\n",
+    "\n",
+    "color_dict = dict(zip(negative_sets, color_palette))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 42,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_metric_value(metric_string, \n",
+    "                     truth_array, \n",
+    "                     prediction_array, \n",
+    "                     prediction_proba_array):\n",
+    "    \"\"\"\n",
+    "    Helper function to get metric values based on a string representation.\n",
+    "    \"\"\"\n",
+    "    if metric_string == 'accuracy':\n",
+    "        return metrics.accuracy_score(truth_array, prediction_array)\n",
+    "    if metric_string == 'accuracy_abs':\n",
+    "        return metrics.accuracy_score(truth_array, prediction_array, normalize = False)\n",
+    "    if metric_string == 'precision':\n",
+    "        return metrics.precision_score(truth_array, prediction_array)\n",
+    "    if metric_string == 'recall':\n",
+    "        return metrics.recall_score(truth_array, prediction_array)\n",
+    "    if metric_string == 'f1_score':\n",
+    "        return metrics.f1_score(truth_array, prediction_array)\n",
+    "    if metric_string == 'roc_auc_score':\n",
+    "        return metrics.roc_auc_score(truth_array, prediction_proba_array[:,1])\n",
+    "    if metric_string == 'fpr':\n",
+    "        return metrics.roc_curve(truth_array, prediction_proba_array[:,1])[0]\n",
+    "    if metric_string == 'tpr':\n",
+    "        return metrics.roc_curve(truth_array, prediction_proba_array[:,1])[1]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 43,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# A list of all the metric we are storing\n",
+    "_metrics = [\n",
+    "    'accuracy', \n",
+    "    'accuracy_abs', \n",
+    "    'f1_score', \n",
+    "    'precision', \n",
+    "    'recall', \n",
+    "    'roc_auc_score', \n",
+    "    'fpr', \n",
+    "    'tpr'\n",
+    "    ]\n",
+    "# A dictionary that will hold all the values\n",
+    "all_metrics = {}\n",
+    "# A dictionary that holds the parameters for the best model\n",
+    "# after each iteration\n",
+    "best_models = {}\n",
+    "\n",
+    "# A dictionary that holds feature importances\n",
+    "# for each iteration\n",
+    "feat_importances = {}\n",
+    "\n",
+    "for i, negative in enumerate(negatives_fp_list):\n",
+    "    \n",
+    "    # Make a copy of the list with the negative_files\n",
+    "    copy_of_list = negatives_fp_list.copy()    \n",
+    "    # Get the name of the current negative\n",
+    "    training_set = negative.name.split('.')[0]\n",
+    "    # Read the current negative data in a data frame\n",
+    "    neg_df = read_scores_table(negative, label=0)\n",
+    "    # ... scale it\n",
+    "    neg = scale_df(neg_df)\n",
+    "    # Concatenate with the positive\n",
+    "    # A ground truth set is constructed\n",
+    "    posneg = concat_data_frames(pos_df, \n",
+    "                                neg, \n",
+    "                                subsample=True, \n",
+    "                                clean=True, \n",
+    "                                is_scaled=True)\n",
+    "    \n",
+    "    # Split the frame to features and the label\n",
+    "    X = posneg[FEATURES]\n",
+    "    y = posneg['label']\n",
+    "    \n",
+    "    # Train test split\n",
+    "    X_train, X_holdout, y_train, y_holdout = train_test_split(X,\n",
+    "                                                              y,\n",
+    "                                                              test_size=0.3,\n",
+    "                                                              random_state=1)\n",
+    "    \n",
+    "    # Instantiate the Randomized Search\n",
+    "    models = RandomizedSearchCV(estimator = RandomForestClassifier(random_state=1),\n",
+    "                            param_distributions = params_space,\n",
+    "                            n_iter = 500, # number of parameter settings to sample\n",
+    "                            cv = 5, # Use 5-folds for CV\n",
+    "                            random_state = 1, # Set random state for reproducibility\n",
+    "                            n_jobs = snakemake.threads,\n",
+    "#                             verbose=1 # Print some messages to keep track of progress\n",
+    "                               )\n",
+    "    # Execute the search\n",
+    "    search = models.fit(X_train, y_train)\n",
+    "    \n",
+    "    # Make a classifier out of the best model\n",
+    "    best_model = RandomForestClassifier(**search.best_params_, \n",
+    "                                        random_state=1)\n",
+    "    \n",
+    "    # Fit the best model on the training data\n",
+    "    best_model.fit(X_train, y_train)\n",
+    "    \n",
+    "    # Predict labels and probabilities for the test/holdout set\n",
+    "    holdout_pred = best_model.predict(X_holdout)\n",
+    "    holdout_proba = best_model.predict_proba(X_holdout)\n",
+    "    \n",
+    "    ###############\n",
+    "    # Store results\n",
+    "    ###############\n",
+    "    # Create a tuple of the strings of the training set\n",
+    "    # as a key for the all_metrics dict\n",
+    "    self_key = (training_set, training_set)\n",
+    "    \n",
+    "    # The get_metric_value is defined above\n",
+    "    all_metrics[self_key] = {metric : get_metric_value(metric,\n",
+    "                                                       y_holdout,\n",
+    "                                                       holdout_pred,\n",
+    "                                                       holdout_proba)\n",
+    "                            for metric in _metrics}\n",
+    "    \n",
+    "    best_models[training_set] = best_model\n",
+    "    \n",
+    "    importances = best_model.feature_importances_      \n",
+    "    feat_importances[self_key] = dict(zip(FEATURES, importances))\n",
+    "    \n",
+    "    # Remove the current negative set file from the copy of the list\n",
+    "    # The rest will be used for validation\n",
+    "    validation_sets = [i for i in copy_of_list if i != negative]\n",
+    "\n",
+    "    for validation_set in validation_sets:\n",
+    "        # Repeat the same process as above, but without optimizing\n",
+    "        validation_set_name = validation_set.name.split('.')[0]\n",
+    "        negg = read_scores_table(validation_set, label=0)\n",
+    "        scaled_negg = scale_df(negg)\n",
+    "        \n",
+    "        pn = concat_data_frames(pos_df, \n",
+    "                                scaled_negg, \n",
+    "                                subsample=True, \n",
+    "                                clean=True, \n",
+    "                                is_scaled=True)\n",
+    "        XX = pn[FEATURES]\n",
+    "        yy = pn['label']\n",
+    "        \n",
+    "        XX_train, XX_holdout, yy_train, yy_holdout = train_test_split(XX, yy, test_size=0.3, random_state=1)\n",
+    "        best_model.fit(XX_train, yy_train)\n",
+    "        holdout_pred = best_model.predict(XX_holdout)\n",
+    "        holdout_proba = best_model.predict_proba(XX_holdout)\n",
+    "        \n",
+    "        \n",
+    "        combo_key = (training_set, validation_set_name)\n",
+    "        all_metrics[combo_key] = {metric : get_metric_value(metric, yy_holdout, holdout_pred, holdout_proba) \n",
+    "                                 for metric in _metrics}\n",
+    "        \n",
+    "        importances = best_model.feature_importances_      \n",
+    "        feat_importances[combo_key] = dict(zip(FEATURES, importances))\n",
+    "    \n",
+    "    print(\"Finished set : {}\".format(i))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "## Save results"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 44,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Results are stored in the results/RF directory\n",
+    "# Create it, or do nothing\n",
+    "RF_dir = Path(\"results/RF\")\n",
+    "RF_dir.mkdir(exist_ok=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 45,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "## metrics\n",
+    "metrics_df = pd.DataFrame.from_dict(all_metrics, orient='index', columns=_metrics)\n",
+    "\n",
+    "# Append columns TS (Training Set) and VS (Validation Set)\n",
+    "# For slicing\n",
+    "metrics_df['TS'] = [i[0] for i in metrics_df.index.values]\n",
+    "metrics_df['VS'] = [i[1] for i in metrics_df.index.values]\n",
+    "\n",
+    "# Pickle is used to store the fpr, tpr values as arrays\n",
+    "# otherwise it messes up the column formatting\n",
+    "metrics_fp = RF_dir / Path('metrics.pkl')\n",
+    "metrics_tsv = RF_dir / Path('metrics.tsv')\n",
+    "metrics_df.to_pickle(metrics_fp)\n",
+    "## Drop the fpr and tpr to write to tsv\n",
+    "metrics_df.drop(columns=['fpr', \n",
+    "                         'tpr']).to_csv(metrics_tsv, \n",
+    "                                        sep='\\t', \n",
+    "                                        index=False)\n",
+    "\n",
+    "## features\n",
+    "features_df = pd.DataFrame.from_dict(feat_importances,\n",
+    "                                     orient = 'index',\n",
+    "                                     columns = FEATURES)\n",
+    "features_df['TS'] = [i[0] for i in features_df.index.values]\n",
+    "features_df['VS'] = [i[1] for i in features_df.index.values]\n",
+    "features_fp = RF_dir / Path(\"features.tsv\")\n",
+    "\n",
+    "features_df.to_csv(features_fp, sep='\\t', index=False)\n",
+    "\n",
+    "\n",
+    "## models\n",
+    "models_dir = RF_dir / Path(\"models\")\n",
+    "models_dir.mkdir(exist_ok=True)\n",
+    "for dataset in best_models:\n",
+    "    pkl_fname = Path('{}.RF.pkl'.format(dataset))\n",
+    "    pkl_fpath = models_dir.joinpath(pkl_fname)\n",
+    "    with open(pkl_fpath, 'wb') as fout:\n",
+    "        pickle.dump(best_models[dataset], fout)\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "## Best model across the board"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 46,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Calculate min, max, mean, std for all metrics across datasets\n",
+    "## this probably is not the most pythonic way of doing this\n",
+    "\n",
+    "metrics_stats = {}\n",
+    "for i, negative in enumerate(negatives_fp_list):\n",
+    "    dataset_name = negative.name.split('.')[0]\n",
+    "    # Get the stats for a particular dataset and trnaspose it\n",
+    "    set_df = metrics_df.loc[metrics_df[\"TS\"] == dataset_name, _metrics].describe().T\n",
+    "    \n",
+    "    metrics_stats[dataset_name] = {'accuracy_mean' : set_df.loc['accuracy', 'mean'],\n",
+    "                                 'accuracy_std' : set_df.loc['accuracy', 'std'],\n",
+    "                                'accuracy_min' : set_df.loc['accuracy', 'min'],\n",
+    "                                'accuracy_max' : set_df.loc['accuracy', 'max'],\n",
+    "                                 'precision_mean' : set_df.loc['precision', 'mean'],\n",
+    "                                 'precision_std' : set_df.loc['precision', 'std'],\n",
+    "                                'precision_min' : set_df.loc['precision', 'min'],\n",
+    "                                'precision_max' : set_df.loc['precision', 'max'],\n",
+    "                                 'recall_mean' : set_df.loc['recall', 'mean'],\n",
+    "                                 'recall_std' : set_df.loc['recall', 'std'],\n",
+    "                                'recall_min' : set_df.loc['recall', 'min'],\n",
+    "                                'recall_max' : set_df.loc['recall', 'max'],\n",
+    "                                 'f1_score_mean': set_df.loc['f1_score', 'mean'],\n",
+    "                                 'f1_score_std': set_df.loc['f1_score', 'std'],\n",
+    "                                'f1_score_min' : set_df.loc['f1_score', 'min'],\n",
+    "                                'f1_score_max' : set_df.loc['f1_score', 'max'],\n",
+    "                                'roc_auc_score_mean': set_df.loc['roc_auc_score', 'mean'],\n",
+    "                                 'roc_auc_score_std': set_df.loc['roc_auc_score', 'std'],\n",
+    "                                'roc_auc_score_min' : set_df.loc['roc_auc_score', 'min'],\n",
+    "                                'roc_auc_score_max' : set_df.loc['roc_auc_score', 'max'],\n",
+    "                                'accuracy_abs_mean' : set_df.loc['accuracy_abs', 'mean'],\n",
+    "                                'accuracy_abs_std' : set_df.loc['accuracy_abs', 'std'],\n",
+    "                                'accuracy_abs_min' : set_df.loc['accuracy_abs', 'min'],\n",
+    "                                'accuracy_abs_max' : set_df.loc['accuracy_abs', 'max']\n",
+    "                               }\n",
+    "\n",
+    "metrics_stats_df = pd.DataFrame.from_dict(metrics_stats, orient='index')\n",
+    "\n",
+    "# Save them\n",
+    "metrics_stats_fp = RF_dir / Path(\"metrics.stats.tsv\")\n",
+    "metrics_stats_df.to_csv(metrics_stats_fp, sep='\\t',)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 47,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Select the best model automatically\n",
+    "\n",
+    "# Initialize a dictionary {N1 : 0, N2 : 0, ...}\n",
+    "dataset_scores = dict(zip(metrics_stats_df.T.columns, \n",
+    "                          [0 for i in range(len(metrics_stats_df.T.columns))]))\n",
+    "# Iterate over the rows of the dataframe\n",
+    "for stat, row in metrics_stats_df.T.iterrows():\n",
+    "    if stat.endswith('std'):\n",
+    "        #row is a series. idxmax() returns the value of the index of the series\n",
+    "        dataset_scores[row.idxmin()] += 1 \n",
+    "    else:\n",
+    "        dataset_scores[row.idxmax()] += 1\n",
+    "        \n",
+    "# Select best scoring model\n",
+    "best_model = max(dataset_scores.items(), key=operator.itemgetter(1))[0]\n",
+    "print(\"AND THE BEST MODEL IS...: {}\".format(best_model))\n",
+    "\n",
+    "best_model_fp = RF_dir / Path(\"best_model.pkl\")\n",
+    "# Should be identical with the one in the models dir\n",
+    "with open(best_model_fp, 'wb') as fout:\n",
+    "    pickle.dump(best_models[best_model], fout)\n",
+    "\n",
+    "# Write the best model id in a file.\n",
+    "## This is used internally in the workflow as a checkpoint\n",
+    "# Should be a file containing just the id , e.g. N2\n",
+    "best_model_id_fp = RF_dir / Path(\"best_model_id.txt\")\n",
+    "with open(best_model_id_fp, 'w') as fout:\n",
+    "    fout.write(best_model)\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# PLOTS"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "## Metrics"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 48,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def plot_dataset_roc_curve(training_set,\n",
+    "                           data_df,\n",
+    "                           plot_all=True,\n",
+    "                           ax=None,\n",
+    "                           label_xy=True):\n",
+    "    \"\"\"\n",
+    "    Plot the roc curve for a dataset.\n",
+    "    training_set:str\n",
+    "        The name of the dataset\n",
+    "    data_df: pandas.Dataframe\n",
+    "        A data frame that holds all metrics\n",
+    "    plot_all:bool\n",
+    "        If roc curves for all other datasets used for validation\n",
+    "        should be plotted\n",
+    "    ax:pyplot.axes\n",
+    "        An axes object (used for subplotting)\n",
+    "    label_xy:bool\n",
+    "        If axes x,y should be annotated\n",
+    "    \n",
+    "    \"\"\"\n",
+    "    \n",
+    "    if ax is None:\n",
+    "        ax = plt.gca()\n",
+    "        \n",
+    "    rc = ax.plot([0,1], [0,1], linestyle='--', color='k', label='Baseline')\n",
+    "    \n",
+    "    # Select all rows for a particular dataset used for the optimization\n",
+    "    # and create a new dataframe\n",
+    "    train_df = data_df.loc[data_df['TS'] == training_set]\n",
+    "    # Grab the fpr and tpr arrays for that set\n",
+    "    \n",
+    "    fpr = train_df.loc[train_df['VS'] == training_set, 'fpr'].values\n",
+    "    tpr = train_df.loc[train_df['VS'] == training_set, 'tpr'].values\n",
+    "    # Get the auc_score value\n",
+    "    auc_score = data_df.loc[(training_set, training_set), 'roc_auc_score']\n",
+    "    # Plot the roc curve\n",
+    "    ## fpr and tpr are returned as an array of arrays\n",
+    "    ## with length == 1. Hence, we grab the first ellement which\n",
+    "    ## is the array itself\n",
+    "    rc = ax.plot(fpr[0], tpr[0], \n",
+    "                 label=\"{}, auc={:.2f}\".format(training_set, auc_score),\n",
+    "                 color = color_dict.get(training_set), \n",
+    "                 linewidth=3, \n",
+    "                 alpha=1)\n",
+    "    # For plotting all other datasets roc curves\n",
+    "    if plot_all is True:\n",
+    "        for validation_set in train_df['VS'].values:\n",
+    "            if validation_set == training_set:\n",
+    "                pass\n",
+    "            else:\n",
+    "                # This is done on the subset train_df\n",
+    "                fpr = train_df.loc[train_df['VS'] == validation_set, 'fpr'].values\n",
+    "                tpr = train_df.loc[train_df['VS'] == validation_set, 'tpr'].values\n",
+    "                auc_score = data_df.loc[(training_set, validation_set), 'roc_auc_score']\n",
+    "                rc = ax.plot(fpr[0], tpr[0], \n",
+    "                             label=\"{}, auc={:.2f}\".format(validation_set, auc_score),\n",
+    "                            color = color_dict.get(validation_set),\n",
+    "                             linewidth=1,\n",
+    "                            alpha=0.7)\n",
+    "            \n",
+    "    rc = ax.legend(loc='lower right', fontsize='small')\n",
+    "    \n",
+    "    # When plotting multiple datasets in a subplots\n",
+    "    if label_xy is True:\n",
+    "        rc = ax.set_ylabel('True Positive Rate')\n",
+    "        rc = ax.set_xlabel('False Positive Rate')\n",
+    "    # Set the master title    \n",
+    "    rc = ax.set_title('ROC curves (dataset={})'.format(training_set))\n",
+    "    return rc"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 49,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "figures_dir = RF_dir / Path(\"figures\")\n",
+    "figures_dir.mkdir(exist_ok=True)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "Figure 1a."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 50,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fig1_fp = figures_dir / Path(\"Figure_1a.svg\")\n",
+    "\n",
+    "fig1, ax = plt.subplots(figsize=(5.32, 5.11))\n",
+    "\n",
+    "plot_dataset_roc_curve(best_model, metrics_df, plot_all=True)\n",
+    "\n",
+    "fig1.savefig(fig1_fp, dpi=600)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "* Figure S1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 51,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fig_s1_fp = figures_dir / Path(\"Figure_S1.svg\")\n",
+    "fig, axs = plt.subplots(4,3, figsize=(18, 20))\n",
+    "\n",
+    "for i, ts in enumerate(negative_sets):\n",
+    "    plot_dataset_roc_curve(ts, metrics_df, plot_all=True, ax = axs.flat[i], label_xy=False)\n",
+    "# https://napsterinblue.github.io/notes/python/viz/subplots/\n",
+    "else:\n",
+    "    [ax.set_visible(False) for ax in axs.flatten()[i+1:]]\n",
+    "    \n",
+    "# Manually placing axes labels with trial and error\n",
+    "fig.text(0.5, 0.09, \n",
+    "         'False Positive Rate', \n",
+    "         ha='center', \n",
+    "         va='center', \n",
+    "         fontsize=\"x-large\")\n",
+    "fig.text(0.09, 0.5, \n",
+    "         'True Positive Rate', \n",
+    "         ha='center', \n",
+    "         va='center', \n",
+    "         rotation='vertical', \n",
+    "         fontsize=\"x-large\")\n",
+    "\n",
+    "# Super title\n",
+    "fig.suptitle(\"ROC Curves for all datasets\", \n",
+    "             x=0.5, \n",
+    "             y=0.92, \n",
+    "             fontsize = \"x-large\")\n",
+    "# Save it\n",
+    "fig.savefig(fig_s1_fp, dpi=600, bbox_inches='tight')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "* Figure S2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 52,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fig_s2_fp = figures_dir / Path(\"Figure_S2.svg\")\n",
+    "\n",
+    "boxplot_metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'roc_auc_score',]\n",
+    "fig, axs = plt.subplots(len(boxplot_metrics), 1,\n",
+    "                        figsize=(12, len(boxplot_metrics)*6),\n",
+    "                        )\n",
+    "for i, metric in enumerate(boxplot_metrics):\n",
+    "    title_string = ' '.join(metric.split('_'))\n",
+    "    \n",
+    "    axs.flat[i].set_title(title_string.title(), fontsize=\"x-large\", )\n",
+    "    \n",
+    "    a = sns.boxplot(data=metrics_df, \n",
+    "                    x='TS', \n",
+    "                    y=metric, \n",
+    "                    palette=color_dict, \n",
+    "                    ax = axs.flat[i]\n",
+    "                   )\n",
+    "    \n",
+    "    a.set_xticklabels([lbl.get_text() for lbl in a.get_xticklabels()],\n",
+    "#                       rotation=60,\n",
+    "#                       horizontalalignment='right',\n",
+    "                      fontsize=\"large\"\n",
+    "                     )\n",
+    "\n",
+    "    a.set_ylim([0.4, 1.])\n",
+    "        \n",
+    "    a.set_yticklabels(np.around(a.get_yticks(), 1),\n",
+    "                     fontsize=\"large\")\n",
+    "    a.set_xlabel('')\n",
+    "    a.set_ylabel('')\n",
+    "    \n",
+    "fig.savefig(fig_s2_fp, dpi=600, bbox_inches='tight')\n",
+    "# Gives a \n",
+    "# .../python3.7/site-packages/ipykernel_launcher.py:28: UserWarning: \n",
+    "# FixedFormatter should only be used together with FixedLocator"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "## Features"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 53,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Calculate basic stats for all features across all datasets\n",
+    "features_stats = {}\n",
+    "for i, negative in enumerate(negatives_fp_list):\n",
+    "    dataset_name = negative.name.split('.')[0]\n",
+    "    \n",
+    "    set_df = features_df.loc[features_df[\"TS\"] == dataset_name, FEATURES].describe().T\n",
+    "    features_stats[dataset_name] = {'jaccard_score_mean' : set_df.loc['jaccard_score', 'mean'],\n",
+    "                                 'jaccard_score_std' : set_df.loc['jaccard_score', 'std'],\n",
+    "                                'jaccard_score_min' : set_df.loc['jaccard_score', 'min'],\n",
+    "                                'jaccard_score_max' : set_df.loc['jaccard_score', 'max'],\n",
+    "                                 'same_score_mean' : set_df.loc['same_score', 'mean'],\n",
+    "                                 'same_score_std' : set_df.loc['same_score', 'std'],\n",
+    "                                'same_score_min' : set_df.loc['same_score', 'min'],\n",
+    "                                'same_score_max' : set_df.loc['same_score', 'max'],\n",
+    "                                 'inwards_score_mean' : set_df.loc['inwards_score', 'mean'],\n",
+    "                                 'inwards_score_std' : set_df.loc['inwards_score', 'std'],\n",
+    "                                'inwards_score_min' : set_df.loc['inwards_score', 'min'],\n",
+    "                                'inwards_score_max' : set_df.loc['inwards_score', 'max'],\n",
+    "                                 'outwards_score_mean': set_df.loc['outwards_score', 'mean'],\n",
+    "                                 'outwards_score_std': set_df.loc['outwards_score', 'std'],\n",
+    "                                'outwards_score_min' : set_df.loc['outwards_score', 'min'],\n",
+    "                                'outwards_score_max' : set_df.loc['outwards_score', 'max'],\n",
+    "                                'avg_distance_mean': set_df.loc['avg_distance', 'mean'],\n",
+    "                                 'avg_distance_std': set_df.loc['avg_distance', 'std'],\n",
+    "                                'avg_distance_min' : set_df.loc['avg_distance', 'min'],\n",
+    "                                'avg_distance_max' : set_df.loc['avg_distance', 'max'],\n",
+    "                                'mean_ani_mean' : set_df.loc['mean_ani', 'mean'],\n",
+    "                                'mean_ani_std' : set_df.loc['mean_ani', 'std'],\n",
+    "                                'mean_ani_min' : set_df.loc['mean_ani', 'min'],\n",
+    "                                'mean_ani_max' : set_df.loc['mean_ani', 'max'],\n",
+    "                                'mean_aai_mean' : set_df.loc['mean_aai', 'mean'],\n",
+    "                                'mean_aai_std' : set_df.loc['mean_aai', 'std'],\n",
+    "                                'mean_aai_min' : set_df.loc['mean_aai', 'min'],\n",
+    "                                'mean_aai_max' : set_df.loc['mean_aai', 'max'],    \n",
+    "                               }\n",
+    "    \n",
+    "features_stats_df = pd.DataFrame.from_dict(features_stats, orient='index')\n",
+    "\n",
+    "# Store them\n",
+    "features_stats_fp = RF_dir / Path(\"features_stats.tsv\")\n",
+    "features_stats_df.T.to_csv(features_stats_fp, sep='\\t')\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 54,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "feat_imp = features_df.loc[features_df['TS'] == best_model, FEATURES]\n",
+    "\n",
+    "fig, ax = plt.subplots(figsize=(5.32,5.11))\n",
+    "\n",
+    "features_color_palette = sns.color_palette(\"YlOrBr\", len(FEATURES))\n",
+    "\n",
+    "features_color_dict = dict(zip(FEATURES, features_color_palette))\n",
+    "\n",
+    "# fimp : For feature importances\n",
+    "fimp = sns.barplot(data=feat_imp,\n",
+    "           orient='h',\n",
+    "            ci=\"sd\",\n",
+    "           order=['avg_distance',\n",
+    "                  'jaccard_score', \n",
+    "                  'mean_aai', \n",
+    "                  'mean_ani', \n",
+    "                  'inwards_score', \n",
+    "                  'outwards_score', \n",
+    "                  'same_score'],\n",
+    "           palette = features_color_dict,\n",
+    "            capsize=0.1,\n",
+    "        )\n",
+    "# Trial and error\n",
+    "fimp.set_xlim([0., 0.45])\n",
+    "        \n",
+    "fimp.set_yticklabels([FEATURES_MAP[lbl.get_text()] for lbl in fimp.get_yticklabels()],\n",
+    "                      fontsize='medium')\n",
+    "fimp.set_xlabel('Gini Importance', fontsize='large')\n",
+    "fimp.set_ylabel('')\n",
+    "fimp.set_title('Feature Importances', fontsize=\"x-large\")\n",
+    "\n",
+    "fig.tight_layout()\n",
+    "\n",
+    "# Save the fig\n",
+    "feat_importances_fp = figures_dir / Path(\"Figure_1b.svg\")\n",
+    "fig.savefig(feat_importances_fp, dpi= 600)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.0"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/workflow/rules/construct_datasets.smk b/workflow/rules/construct_datasets.smk
new file mode 100644
index 0000000000000000000000000000000000000000..a71612091568a5ac45810d3420927e6a25b52576
--- /dev/null
+++ b/workflow/rules/construct_datasets.smk
@@ -0,0 +1,231 @@
+rule filter_intact:
+	input:
+		intact_raw = ancient("data/interactions/intact.txt"),
+		tax_db = ancient("data/taxonomy_db/taxa.sqlite")
+	output:
+		intact_phages = "results/interaction_datasets/01_filter_intact/intact_phages.txt"
+	conda:
+		"../envs/pvogs.yml"
+	log:
+		"results/logs/construct_datasets/filter_intact.log"
+	shell:
+		"python workflow/scripts/filter_intact.py "
+		"--tax-db {input.tax_db} "
+		"--phages-only "
+		"-i {input.intact_raw} "
+		"-o {output.intact_phages} &>{log}"
+
+rule summarize_intact:
+	input:
+		intact_phages = rules.filter_intact.output.intact_phages,
+		tax_db = ancient("data/taxonomy_db/taxa.sqlite")
+	output:
+		metadata_tsv = "results/interaction_datasets/02_summarize_intact/metadata.tsv"
+	params:
+		summary_dir = "results/interaction_datasets/02_summarize_intact"
+	conda:
+		"../envs/pvogs.yml"
+	log:
+		"results/logs/construct_datasets/summarize_intact.log"
+	shell:
+		"python workflow/scripts/summarize_intact.py "
+		"--tax-db {input.tax_db} "
+		"-i {input.intact_phages} "
+		"-o {params.summary_dir} &>{log}"
+
+rule get_uniprot_ids:
+	input:
+		metadata_tsv = rules.summarize_intact.output.metadata_tsv
+	output:
+		uniprot_ids_txt = "results/interaction_datasets/03_uniprot/uniprot_ids.txt",
+		uniprot_only = "results/interaction_datasets/03_uniprot/uniprot.phages_only.tsv"
+	conda:
+		"../envs/pvogs.yml"
+	log:
+		"results/logs/construct_datasets/get_uniprot_ids.log"
+	shell:
+		"python workflow/scripts/get_uniprot_ids.py "
+		"-i {input.metadata_tsv} "
+		"-l {output.uniprot_ids_txt} "
+		"-o {output.uniprot_only} &>{log}"
+	
+rule download_uniprots:
+	input:
+		uniprot_ids_txt = rules.get_uniprot_ids.output.uniprot_ids_txt
+	output:
+		swissprot_txt = "results/interaction_datasets/03_uniprot/uniprot.sprot.txt"
+	log:
+		"results/logs/construct_datasets/query_uniprot.log"
+	conda:
+		"../envs/pvogs.yml"
+	shell:
+		"python workflow/scripts/query_uniprot.py "
+		"--filter-list "
+		"-i {input.uniprot_ids_txt} "
+		"-o {output.swissprot_txt} &>{log}"
+
+rule process_uniprot:
+	input:
+		uniprot_ids_txt = rules.get_uniprot_ids.output.uniprot_ids_txt,
+		uniprot_only = rules.get_uniprot_ids.output.uniprot_only,
+		swissprot_txt = rules.download_uniprots.output.swissprot_txt
+	output:
+		proteins_fasta = "results/interaction_datasets/04_process_uniprot/proteins.faa",
+		skipped = "results/interaction_datasets/04_process_uniprot/skipped.txt",
+		duplicates = "results/interaction_datasets/04_process_uniprot/duplicates.tsv",
+		interactions_filtered = "results/interaction_datasets/04_process_uniprot/interactions_filtered.tsv",
+		ncbi_interactions = "results/interaction_datasets/04_process_uniprot/ncbi_interactions.tsv",
+		genomes_accessions = "results/interaction_datasets/04_process_uniprot/genome_accessions.txt",
+		ncbi_to_uniprot = "results/interaction_datasets/04_process_uniprot/ncbi2uniprot.mapping.txt",
+		uniprot_to_ncbi = "results/interaction_datasets/04_process_uniprot/uniprot2ncbi.mapping.txt"
+	conda:
+		"../envs/pvogs.yml"
+	log:
+		"results/logs/construct_datasets/process_uniprot.log"
+	params:
+		output_dir = "results/interaction_datasets/04_process_uniprot"
+	shell:
+		"python workflow/scripts/process_uniprot.py "
+		"-i {input.uniprot_only} "
+		"-l {input.uniprot_ids_txt} "
+		"-s {input.swissprot_txt} "
+		"-p {params.output_dir} &>{log}"
+
+rule download_genomes:
+	input:
+		genomes_accessions = rules.process_uniprot.output.genomes_accessions
+	output:
+		genomes_gb = "results/interaction_datasets/05_interaction_datasets/genomes.gb"
+	log:
+		"results/logs/construct_datasets/download_genomes.log"
+	params:
+		email = config.get('email')
+	conda:
+		"../envs/pvogs.yml"
+	shell:
+		"python workflow/scripts/download_genomes.py "
+		"-i {input.genomes_accessions} "
+		"-o {output.genomes_gb} "
+		"-e {params.email} &>{log}"
+
+rule extract_proteins_from_genomes:
+	input:
+		genomes_gb = rules.download_genomes.output.genomes_gb
+	output:
+		all_proteins_faa = "results/interaction_datasets/05_genomes/all_proteins.faa"
+	conda:
+		"../envs/pvogs.yml"
+	log:
+		"results/logs/interaction_datasets/05_genomes/.extract_proteins.log"
+	shell:
+		"python workflow/scripts/extract_proteins_from_gb.py "
+		"-i {input.genomes_gb} "
+		"-o {output.all_proteins_faa} &>{log}"
+
+## Copy the proteins file and create a 2-column tsv in a positives directory
+## The same naming structure helps with expansion for the rules.
+rule copy_positives_proteins:
+	input:
+		proteins_faa = rules.process_uniprot.output.proteins_fasta
+	output:
+		positives_faa = "results/interaction_datasets/positives/positives.proteins.faa"
+	log:
+		"results/logs/construct_datasets/copy_positives_scores.log"
+	shell:
+		"cp {input.proteins_faa} {output.positives_faa} 2>{log}"
+
+
+rule ncbi_positives:
+	input:
+		ncbi_interactions = rules.process_uniprot.output.ncbi_interactions,
+		genomes_gb = rules.download_genomes.output.genomes_gb
+	output:
+		ncbi_positives_tsv = "results/interaction_datasets/positives/positives.interactions.tsv"
+	log:
+		"results/logs/construct_datasets/ncbi_positives.log"
+	shell:
+		"cut -f3,4 {input.ncbi_interactions} > {output.ncbi_positives_tsv} 2>{log}"
+
+
+rule make_negatives:
+	input:
+		ncbi_positives_tsv = rules.ncbi_positives.output.ncbi_positives_tsv,
+		all_proteins_faa = rules.extract_proteins_from_genomes.output.all_proteins_faa,
+		genomes_gb = rules.download_genomes.output.genomes_gb
+
+	output:
+		expand(["results/interaction_datasets/N{I}/N{I}.interactions.tsv",
+				"results/interaction_datasets/N{I}/N{I}.proteins.faa"],
+				I = [i+1 for i in range(0, NEGATIVES)])
+	conda:
+		"../envs/pvogs.yml"
+	params:
+		no_negatives = NEGATIVES
+	log:
+		"results/logs/construct_datasets/make_negatives.log"
+	shell:
+		"for i in `seq 1 {params.no_negatives}`;do "
+		"	python workflow/scripts/make_protein_combos.py "
+		"	-i {input.genomes_gb} "
+		"	-a {input.all_proteins_faa} "
+		"	-o results/interaction_datasets/N${{i}}/N${{i}}.proteins.faa "
+		"	-x results/interaction_datasets/N${{i}}/N${{i}}.interactions.tsv "
+		"	--exclude {input.ncbi_positives_tsv} "
+		"	--sample-size 2 "
+		"	--random-seed ${{i}}; "
+		"done &>{log}"	
+
+rule hmmsearch:
+	input:
+		proteins_fasta = "results/interaction_datasets/{dataset}/{dataset}.proteins.faa",
+		all_pvogs_profiles = "data/pvogs/all.hmm"
+	output:
+		hmm_out_txt = "results/interaction_datasets/06_map_proteins_to_pvogs/{dataset}/{dataset}.hmmout.txt",
+		hmm_tblout_tsv = "results/interaction_datasets/06_map_proteins_to_pvogs/{dataset}/{dataset}.hmmtblout.tsv" 
+	log:
+		"results/logs/construct_datasets/{dataset}/hmmsearch.log"
+	threads: 8
+	conda:
+		"../envs/pvogs.yml"
+	shell:
+		"hmmsearch --cpu {threads} "
+		"-o {output.hmm_out_txt} "
+		"--tblout {output.hmm_tblout_tsv} "
+		"{input.all_pvogs_profiles} "
+		"{input.proteins_fasta}"
+
+rule refseqs_to_pvogs:
+	input:
+		interactions_tsv = "results/interaction_datasets/{dataset}/{dataset}.interactions.tsv",
+		proteins_faa = "results/interaction_datasets/{dataset}/{dataset}.proteins.faa",
+		hmm_tblout = "results/interaction_datasets/06_map_proteins_to_pvogs/{dataset}/{dataset}.hmmtblout.tsv"
+	output:
+		pvogs_interactions = "results/interaction_datasets/{dataset}/{dataset}.pvogs_interactions.tsv"
+	conda:
+		"../envs/pvogs.yml"
+	log:
+		"results/logs/construct_datasets/{dataset}/refseqs_to_pvogs.log"
+	shell:
+		"python workflow/scripts/refseqs_to_pvogs.py "
+		"-i {input.interactions_tsv} "
+		"-f {input.proteins_faa} "
+		"-hmm {input.hmm_tblout} "
+		"-o {output.pvogs_interactions} "
+		"&> {log}"
+
+rule create_features_tables:
+	input:
+		filtered_master_tsv = rules.filter_scores_table.output.filtered_master_tsv,
+		interactions_tsv = "results/interaction_datasets/{dataset}/{dataset}.pvogs_interactions.tsv"
+	output:
+		features_tsv = "results/interaction_datasets/{dataset}/{dataset}.features.tsv"
+	log:
+		"results/logs/construct_datasets/{dataset}/create_features_tables.log"
+	conda:
+		"../envs/pvogs.yml"
+	shell:
+		"python workflow/scripts/subset_scores.py "
+		"-s {input.filtered_master_tsv} "
+		"-i {input.interactions_tsv} "
+		"-o {output.features_tsv}"
+
diff --git a/workflow/rules/envs/zenodo.yml b/workflow/rules/envs/zenodo.yml
new file mode 100644
index 0000000000000000000000000000000000000000..57a1cad0cf9c614115bf996a6eb1cb3789e384b8
--- /dev/null
+++ b/workflow/rules/envs/zenodo.yml
@@ -0,0 +1,8 @@
+channels:
+  - conda-forge
+dependencies:
+  - python>=3.6
+  - pip
+  - pip:
+    - zenodo_get==1.3.2
+
diff --git a/workflow/rules/get_data.smk b/workflow/rules/get_data.smk
new file mode 100644
index 0000000000000000000000000000000000000000..bfc779dcf13956de80edc6804182fcbd6efcdb86
--- /dev/null
+++ b/workflow/rules/get_data.smk
@@ -0,0 +1,28 @@
+rule download_archive:
+	output:
+		tar_gz = "pvogs_function.data.tar.gz"
+	log:
+		"results/logs/get_data/download_archive.log"
+	conda:
+		"envs/zenodo.yml"
+	params:
+		sandbox = "--sandbox" if config["zenodo"]["use_sandbox"] is True else "",
+		zenodo_id = config["zenodo"]["doi"]
+	shell:
+		"zenodo_get {params.sandbox} --record={params.zenodo_id}"
+
+rule extract_data:
+	input:
+		tar_gz = rules.download_archive.output.tar_gz
+	output:
+		genomes_fasta = "data/genomes/phages_refseq.fasta",
+		intact_txt= "data/interactions/intact.txt",
+		pvogs_profiles = "data/pvogs/all.hmm",
+		pvogs_annotations = "data/pvogs/VOGProteinTable.txt",
+		taxonomy_db = "data/taxonomy_db/taxa.sqlite",
+		taxonomy_pkl = "data/taxonomy_db/taxa.sqlite.traverse.pkl"
+	log:
+		"results/logs/get_data/extract_data.log"
+	shell:
+		"tar -xzvf {input.tar_gz} &>{log}"
+
diff --git a/workflow/rules/pre_process.smk b/workflow/rules/pre_process.smk
new file mode 100644
index 0000000000000000000000000000000000000000..53cc8dc49d8933e4ddda4a3777ebfd267da4f5a7
--- /dev/null
+++ b/workflow/rules/pre_process.smk
@@ -0,0 +1,234 @@
+## pre_process
+## PREPARE NECESSARY FILES FOR FEATURE CALCULATIONS
+
+rule translate_genomes:
+	input:
+		refseq_genomes = ancient("data/genomes/phages_refseq.fasta")
+	output:
+		transeq_genomes = "results/pre_process/transeq/transeq.genomes.fasta"
+	log: "results/logs/pre_process/transeq/transeq_genomes.log"
+	conda:
+		"../envs/pvogs.yml"
+	shell:
+		"transeq -frame 6 "
+		"-table 11 -clean "
+		"-sequence {input.refseq_genomes} "
+		"-outseq {output.transeq_genomes} 2>{log}"
+
+rule hmmsearch_transeq:
+	input:
+		transeq_genomes = rules.translate_genomes.output.transeq_genomes,
+		pvogs_all_profiles = ancient("data/pvogs/all.hmm")
+	output:
+		hmmout_txt = "results/pre_process/hmmsearch/transeq.hmmout.txt",
+		hmmtblout_tsv = "results/pre_process/hmmsearch/transeq.hmmtblout.tsv"
+	log:
+		"results/logs/pre_process/hmmsearch/hmmsearch_transeq.log"
+	conda:
+		"../envs/pvogs.yml"
+	threads: 10
+	shell:
+		"hmmsearch --cpu {threads} "
+		"-o {output.hmmout_txt} "
+		"--tblout {output.hmmtblout_tsv} "
+		"{input.pvogs_all_profiles} "
+		"{input.transeq_genomes} "
+		"&>{log}"
+
+
+rule split_genomes:
+	input:
+		refseq_genomes = ancient("data/genomes/phages_refseq.fasta")
+	output:
+		reflist_txt = "results/pre_process/reflist.txt"
+	log:
+		"results/logs/pre_process/split_genomes.log"
+	conda:
+		"../envs/pvogs.yml"
+	params:
+		genomes_dir = "results/pre_process/all_genomes"
+	shell:
+		"mkdir -p {params.genomes_dir} && "
+		"python workflow/scripts/split_multifasta.py "
+		"--write-reflist "
+		"-i {input.refseq_genomes} "
+		"-o {params.genomes_dir} 2>{log}"
+
+
+## FASTANI
+rule fastani:
+	input:
+		reflist_txt = rules.split_genomes.output.reflist_txt
+	output:
+		fastani_raw = "results/pre_process/fastani/fastani.out",
+		fastani_matrix = "results/pre_process/fastani/fastani.out.matrix"
+	log:
+		"results/logs/pre_process/fastani/fastani.log"
+	conda:
+		"../envs/pvogs.yml"
+	threads: 8
+	params:
+		fragLen = 300,
+		minFraction = 0.1
+	shell:
+		"fastANI -t {threads} "
+		"--ql {input.reflist_txt} "
+		"--rl {input.reflist_txt} "
+		"--fragLen {params.fragLen} "
+		"--minFraction {params.minFraction} "
+		"-o {output.fastani_raw} "
+		"--matrix 2>{log}" # Output the matrix
+
+rule fastani_matrix_to_square:
+	input:
+		fastani_matrix = rules.fastani.output.fastani_matrix
+	output:
+		fastani_square_mat = "results/pre_process/fastani/fastani.square.matrix"
+	conda:
+		"../envs/pvogs.yml"
+	log:
+		"results/logs/pre_process/fastani_matrix_to_square.log"
+	shell:
+		"python workflow/scripts/fastani_mat_to_square.py "
+		"-i {input.fastani_matrix} "
+		"-o {output.fastani_square_mat} "
+		"--process-names &>{log}" # Chop full paths to last node, which is the accession number
+
+
+
+## COMPAREM
+rule comparem_call_genes:
+	input:
+		rules.split_genomes.output.reflist_txt
+	output:
+		done_file = touch("results/pre_process/comparem/aai_wf/genes/.done.txt")
+	conda:
+		"../envs/pvogs.yml"
+	log:
+		# This is produced by comparem by default
+		"results/pre_process/comparem/aai_wf/genes/comparem.log"
+	params:
+		genes_dir = "results/pre_process/comparem/aai_wf/genes",
+		genomes_dir = "results/pre_process/all_genomes"
+	threads: 10
+	shell:
+		"comparem call_genes -c {threads} --silent -x fasta "
+		"{params.genomes_dir} "
+		"{params.genes_dir}"
+
+rule remove_empty_files:
+	input:
+		rules.comparem_call_genes.output.done_file
+	output:
+		info_file = "results/pre_process/comparem/aai_wf/genes/genomes_skipped.txt"
+	params:
+		genes_dir = "results/pre_process/comparem/aai_wf/genes"
+	log:
+		"results/logs/pre_process/remove_empty_files.log"
+	shell:
+		"python workflow/scripts/remove_empty_files.py "
+		"-i {params.genes_dir} "
+		"-o {output.info_file}"
+
+rule comparem_similarity:
+	input:
+		info_file = rules.remove_empty_files.output.info_file
+	output:
+		hits_sorted_tsv = "results/pre_process/comparem/aai_wf/similarity/hits_sorted.tsv",
+		query_genes_dmnd = "results/pre_process/comparem/aai_wf/similarity/query_genes.dmnd",
+		query_genes_faa = "results/pre_process/comparem/aai_wf/similarity/query_genes.faa"
+	params:
+		similarity_dir = "results/pre_process/comparem/aai_wf/similarity",
+		genes_dir = "results/pre_process/comparem/aai_wf/genes"
+	log:
+		# Produced by comparem by default
+		"results/pre_process/comparem/aai_wf/similarity/comparem.log"
+	threads: 10
+	conda:
+		"../envs/pvogs.yml"
+	shell:
+		"comparem similarity "
+		"-c {threads} -x faa "
+		"--silent "
+		"{params.genes_dir} {params.genes_dir} {params.similarity_dir}"
+
+
+rule comparem_aai:
+	input:
+		query_genes_faa = rules.comparem_similarity.output.query_genes_faa,
+		hits_sorted_tsv = rules.comparem_similarity.output.hits_sorted_tsv
+	output:
+		aai_summary_tsv = "results/pre_process/comparem/aai_wf/aai/aai_summary.tsv"
+	conda:
+		"../envs/pvogs.yml"
+	log:
+		# Produced by comparem by default
+		"results/pre_process/comparem/aai_wf/aai/comparem.log"
+	params:
+		aai_dir = "results/pre_process/comparem/aai_wf/aai"
+	shell:
+		"comparem aai --silent "
+		"{input.query_genes_faa} {input.hits_sorted_tsv} "
+		"{params.aai_dir}"
+
+rule process_comparem_matrix:
+	input:
+		aai_summary_tsv = rules.comparem_aai.output.aai_summary_tsv
+	output:
+		aai_summary_square = "results/pre_process/comparem/aai_wf/aai/aai_summary.square.tsv"
+	conda:
+		"../envs/pvogs.yml"
+	log:
+		"results/logs/pre_process/proecess_comparem_matrix.log"
+	shell:
+		"python workflow/scripts/process_comparem.py "
+		"-i {input.aai_summary_tsv} "
+		"-o {output.aai_summary_square} &>{log}"
+
+rule calculate_all_scores:
+	input:
+		ani_square_matrix = rules.fastani_matrix_to_square.output.fastani_square_mat,
+		aai_square_matrix = rules.process_comparem_matrix.output.aai_summary_square,
+		phages_genomes_fasta = ancient("data/genomes/phages_refseq.fasta"),
+		pvogs_all_profiles = ancient("data/pvogs/all.hmm"),
+		hmmout_txt = rules.hmmsearch_transeq.output.hmmout_txt
+	output:
+		master_table = "results/scores.tsv"
+	conda: "../envs/pvogs.yml"
+	log: 
+		"results/logs/pre_process/calculate_scores.log"
+	shell:
+		"python workflow/scripts/calculate_all_scores.py "
+		"--profiles-file {input.pvogs_all_profiles} "
+		"--genomes {input.phages_genomes_fasta} "
+		"--input-hmm {input.hmmout_txt} "
+		"--ani-matrix {input.ani_square_matrix} "
+		"--aai-matrix {input.aai_square_matrix} "
+		"-o {output.master_table} &>{log}"
+
+rule filter_scores_table:
+	input:
+		master_table = rules.calculate_all_scores.output.master_table
+	output:
+		filtered_master_tsv = "results/filtered_scores.tsv"
+	log:
+		"results/pre_process/filter_scores_table.log"
+	shell:
+		"""awk '{{if ( $10 != "1000000") print $0}}' {input.master_table} > {output.filtered_master_tsv}"""
+
+rule parse_annotations:
+	input:
+		raw_annotations_fp = ancient("data/pvogs/VOGProteinTable.txt")
+	output:
+		processed_annotations_fp = "results/annotations.tsv"
+	conda:
+		"../envs/pvogs.yml"
+	log:
+		"results/pre_process/process_annotations.log"
+	shell:
+		"python workflow/scripts/process_annotations.py "
+		"-i {input.raw_annotations_fp} "
+		"-o {output.processed_annotations_fp} "
+		"&>{log}"
+
+
diff --git a/workflow/scripts/calculate_all_scores.py b/workflow/scripts/calculate_all_scores.py
new file mode 100755
index 0000000000000000000000000000000000000000..d40a85ba993162c3b41420f2a8dd3ac2051a34a0
--- /dev/null
+++ b/workflow/scripts/calculate_all_scores.py
@@ -0,0 +1,340 @@
+#!/usr/bin/env python
+
+import argparse
+from pathlib import Path
+from Bio import SeqIO, SearchIO
+import pandas as pd
+import numpy as np
+from itertools import combinations
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Calculate scores for a set of pVOG interactions, '
+                                                  'provided as a tsv file')
+
+    optionalArgs = parser._action_groups.pop()
+
+    requiredArgs = parser.add_argument_group("required arguments")
+    requiredArgs.add_argument('-p', '--profiles-file',
+                              dest='profiles_file',
+                              type=lambda p: Path(p).resolve(strict=True),
+                              required=True,
+                              help="The all.hmm file from pVOGs database"
+                              )
+    requiredArgs.add_argument('-g', '--genomes',
+                              dest='genomes_fasta',
+                              type=lambda p: Path(p).resolve(strict=True),
+                              required=True,
+                              help="A fasta file with the genomes")
+    requiredArgs.add_argument('-hmm', '--input-hmm',
+                              dest='hmmer_in',
+                              type=lambda p: Path(p).resolve(strict=True),
+                              required=True,
+                              help="The regular output file from hmmsearch all pvogs against"
+                                    "the translated genomes"
+                             )
+    requiredArgs.add_argument('-ani_f', '--ani-matrix',
+                              dest='ani_matrix',
+                              type=lambda p: Path(p).resolve(strict=True),
+                              required=True,
+                              help="The square matrix resulting from fastANI with all genomes"
+                             )
+
+    requiredArgs.add_argument('-aai_f', '--aai-matrix',
+                              dest='aai_matrix',
+                              type=lambda p: Path(p).resolve(strict=True),
+                              required=True,
+                              help="The aai square matrix from compareM on all genomes"
+                              )
+    requiredArgs.add_argument('-o', '--output-file',
+                              dest='outfile',
+                              required=True,
+                              type=lambda p: Path(p).resolve(),
+                              help="File path to write the results in")
+
+    parser._action_groups.append(optionalArgs)
+
+    return parser.parse_args()
+
+def get_seq_sizes(seq_fasta):
+    """
+    Create a dict that holds length sizes for all records
+    in a fasta file.
+    """
+    seq_sizes = {}
+    with open(seq_fasta, 'r') as fin:
+        for record in SeqIO.parse(fin, "fasta"):
+            seq_sizes[record.id] = len(record.seq)
+    return seq_sizes
+
+def get_maximum_index(values_list):
+    """
+    Get the index of the maximum value in a list
+    
+    param: list: values_list A list of values
+    return: int: The index of the maximum value
+    """
+    if len(values_list) == 1:
+        max_index = 0
+    else:
+        max_index = values_list.index(max(values_list))
+    return max_index
+
+def translate_to_genomic_coords(start, end, frame, genome_size):
+    """
+    Translate the coordinates of the protein from transeq to genomic
+    coordinates.
+    
+    Strand is used here as orientation and not really [non-]conding.
+    If the frame is 1,2 or 3 (-->) I call this (+) strand.
+    Else, it is the (-) strand
+    
+    param: int: start The starting coordinate on the protein
+    param: int: end The ending coordinate on the protein
+    param: int: frame The frame on which it is found (1-6)
+    param: int: genome_size The size of the genomes
+    
+    return: tuple: (genomic start, genomic end, strand)
+    """
+    nucleic_start = start * 3
+    nucleic_end = end * 3
+    if frame == 1:
+        genomic_start = nucleic_start - 2
+        genomic_end = nucleic_end - 2
+    if frame == 2:
+        genomic_start = nucleic_start - 1
+        genomic_end = nucleic_end - 1
+    if frame == 3:
+        genomic_start = nucleic_start
+        genomic_end = nucleic_end
+    if frame == 4:
+        genomic_start = genome_size - (nucleic_start - 2)
+        genomic_end = genome_size -  (nucleic_end - 2)
+    if frame == 5:
+        genomic_start = genome_size - (nucleic_start - 1)
+        genomic_end = genome_size - (nucleic_end -1)
+    if frame == 6:
+        genomic_start = genome_size - nucleic_start
+        genomic_end = genome_size - nucleic_end
+        
+    if frame in [1,2,3]:
+        strand = '+'
+    elif frame in [4,5,6]:
+        strand = '-'
+    else:
+        raise ValueError("frame should be one of 1,2,3,4,5,6")
+        
+    return genomic_start, genomic_end, strand
+
+def collect_hmmsearch_info(hmmer_in, genome_sizes):
+    """ 
+    Parse the information in a dictionary of the form
+    { pvog_id: 
+        { genome_id : 
+            { frame : [(genomic_start, genomic_end, accuracy, pvog_coverage), ...],
+            ... , }
+        }
+    }
+    """
+    
+    hmm_results = {}
+    with open(hmmer_in, 'r') as fin:
+        for record in SearchIO.parse(fin, "hmmer3-text"):
+            # Initialize an empty dictionary for the pvog
+            hmm_results[record.id] = {}
+            for hit in record.hits:
+                if hit.is_included:
+                    # From the transeq output, accessions are suffixed
+                    # with _1, _2, _3, ..., depending on the strand
+                    genome_id = '_'.join(hit.id.split('_')[0:2])
+                    frame = int(hit.id.split('_')[2])
+                    hsps_included = [hsp.is_included for hsp in hit.hsps]
+                    # For multiple hsps that pass the threshold
+                    if any(hsps_included):
+                        # Get their score
+                        scores = [hsp.bitscore for hsp in hit.hsps]
+                        # and select the best one
+                        max_i = get_maximum_index(scores)
+                        best_hsp = hit.hsps[max_i]
+                        
+                        # Translate back to genomic coordinates
+                        genomic_coords = translate_to_genomic_coords(best_hsp.env_start, 
+                                                                     best_hsp.env_end, 
+                                                                     frame, 
+                                                                     genome_sizes.get(genome_id))
+                    
+                        span = (best_hsp.env_end - best_hsp.env_start) / record.seq_len                  
+                    
+                        if genome_id not in hmm_results[record.id]:
+                            hmm_results[record.id][genome_id] = (frame, genomic_coords[0],
+                                                                 genomic_coords[1], 
+                                                                 best_hsp.acc_avg,
+                                                                 span, 
+                                                                 genomic_coords[2])
+    return hmm_results
+
+
+def get_mean_from_df(subset_list, df):
+    df_subset = df.loc[subset_list, subset_list]
+    return df_subset.values.mean()
+
+
+def get_pvogs_ids(profiles_file):
+    all_pvogs = []
+    with open(profiles_file, 'r') as fin:
+        for line in fin:
+            if line.startswith('NAME'):
+                pvog = line.split()[1]
+                all_pvogs.append(pvog)
+    return all_pvogs
+
+def get_shortest_distance(startA, endA, startB, endB):
+    start_to_start = abs(startA - startB)
+    start_to_end = abs(startA - endB)
+    end_to_start = abs(endA - startB)
+    end_to_end = abs(endA - endB)
+    all_distances = [start_to_start, start_to_end, end_to_start, end_to_end]
+    return min(all_distances)
+
+def calculate_scores(interaction_tuple, hmm_results, ani_df, aai_df):
+    same_score = 0
+    inwards_score = 0
+    outwards_score = 0
+    avg_distance = 1000000 # 1000000 for pairs that are never on the same genome
+    js = 0
+    # include the number of genomes for each participating pvog
+    mean_ani = 0
+    mean_aai = 0
+    
+    pvogA = interaction_tuple[0]
+    pvogB = interaction_tuple[1]
+    
+    genomesA = set((hmm_results[pvogA].keys()))
+    genomesB = set(hmm_results[pvogB].keys())
+    
+    common_genomes = genomesA.intersection(genomesB)
+    if len(common_genomes) > 0:
+        all_genomes = genomesA.union(genomesB)
+        
+        # Jaccard score
+        js = len(common_genomes) / len(all_genomes)
+
+        # Mean ANI
+        mean_ani = get_mean_from_df(common_genomes, ani_df)
+        
+        # Mean AAI
+        mean_aai = get_mean_from_df(common_genomes, aai_df)
+
+        # Distances
+        sum_of_distances = 0
+    
+        for genome in common_genomes:
+            hitA = hmm_results[pvogA][genome]
+            hitB = hmm_results[pvogB][genome]           
+            
+            # Get the proper starts for distance calculation
+            if hitA[-1] == '+':
+                startA = hitA[1]
+                endA = hitA[2]
+            else: 
+                startA = hitA[2]
+                endA = hitA[1]
+                
+            if hitB[-1] == '+':
+                startB = hitB[1]
+                endB = hitB[2]
+            else:
+                startB = hitB[2]
+                endB = hitB[1]            
+            ## 1. Start to start
+            #sum_of_distances += abs(startA - startB)
+            
+            ## 2. Shortest distance
+            sum_of_distances += get_shortest_distance(startA, endA, startB, endB)
+            
+            # If they have the same orientation
+            # Regardless of '+' or '-'
+            if hitA[-1] == hitB[-1]:
+                same_score += 1 
+                
+            if hitA[-1] != hitB[-1]:
+                dstarts = abs(startA - startB)
+                dends = abs(endA - endB)
+                if dstarts >= dends:
+                    inwards_score += 1                    
+                else:
+                    outwards_score += 1
+    
+        same_score = same_score / len(common_genomes)
+        inwards_score = inwards_score / len(common_genomes)
+        outwards_score = outwards_score / len(common_genomes)
+        avg_distance = sum_of_distances / len(common_genomes)
+        
+    return len(genomesA), len(genomesB), len(common_genomes), js, same_score, inwards_score, outwards_score, avg_distance, mean_ani, mean_aai
+
+#def get_ani_for_genomes(genomes, ani_df):
+#    """
+#    Calculate mean ani for the genomes list from the 
+#    ani df
+#    """
+#    genomes_df = ani_df.loc[genomes, genomes]
+#    return genomes_df.values.mean()
+
+#def get_aai_for_genomes(genomes, aai_df):
+#    genomes_df = aai_df.loc[genomes, genomes]
+#    return genomes_df.values.mean()
+
+def main():
+    # Read in the arguments
+    args = parse_args()
+    print("Loading data...")
+
+    # Store the genome sizes
+    genome_sizes = get_seq_sizes(args.genomes_fasta)
+    print("Loaded sequence size info for {} input sequences".format(len(genome_sizes)))
+    
+    print("Reading hmmsearch information...")
+    # Get the hmmsearch results info for all pvogs in
+    hmm_results = collect_hmmsearch_info(args.hmmer_in, genome_sizes)
+    print("Done!")
+
+    print("Reading ANI matrix...")
+    # Read in the matrices
+    ani_df = pd.read_csv(args.ani_matrix, index_col=0, header=0, sep="\t")
+    print("Done!")
+
+    print("Reading AAI matrix...")
+    aai_df = pd.read_csv(args.aai_matrix, index_col=0, header=0, sep="\t")
+    print("Done!")
+
+    # Create a list that holds all pvog ids
+    all_pvogs = get_pvogs_ids(args.profiles_file)
+    
+    # Create all pairs of pvogs
+    all_combos = list(combinations(all_pvogs, 2))
+    print("Created {} possible combinations".format(len(all_combos)))
+
+    # Calculate scores
+    print("Caclulating and writing to file...")
+
+    counter = 0
+    with open(args.outfile, 'w') as fout:
+        fout.write("{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n".format('pvog1', 'pvog2',
+                                                                   'genomes1', 'genomes2',
+                                                                   'overlap_genomes', 'jaccard_score',
+                                                                    'same_score', 'inwards_score', 'outwards_score',
+                                                                   'avg_distance', 'mean_ani', 'mean_aai'))
+        for combo in all_combos:
+            int_string = '{}\t{}\t'.format(combo[0], combo[1])
+            scores = calculate_scores(combo, hmm_results, ani_df, aai_df)
+            scores_string = '\t'.join(map(str, scores))
+            fout.write(int_string + scores_string + '\n')
+            counter += 1
+            if counter % 1000000 == 0:
+                print("{}/{} processed".format(counter, len(all_combos)))
+
+    print("Scores are written to {}".format(str(args.outfile)))
+
+
+if __name__ == '__main__':
+    main()
+
diff --git a/workflow/scripts/download_genomes.py b/workflow/scripts/download_genomes.py
new file mode 100755
index 0000000000000000000000000000000000000000..5d393d5b8ae7a4b2fd7d531c508103269e951990
--- /dev/null
+++ b/workflow/scripts/download_genomes.py
@@ -0,0 +1,162 @@
+#!/usr/bin/env python
+
+import argparse
+from Bio import SeqIO, Entrez
+from Bio.Seq import Seq
+from Bio.SeqRecord import SeqRecord
+from Bio.Alphabet import IUPAC
+import time
+from math import ceil
+
+parser = argparse.ArgumentParser(description='Download a list of ncbi accessions to the output file')
+optionalArgs = parser._action_groups.pop()
+
+requiredArgs = parser.add_argument_group("required arguments")
+requiredArgs.add_argument('-i', '--input-list',
+                          dest='input_list',
+                          required=True,
+                          help="A txt file containing accessions to get from genbank")
+requiredArgs.add_argument('-o', '--output-file',
+                    dest='output_file',
+                    type=str,
+                    required=True,
+                    help="The output file to write sequences to")
+requiredArgs.add_argument('-e', '--e-mail',
+                    help='E-mail address to be used with Bio.Entrez.email. Required by NCBI to notify if something is '
+                         'off or if you overload their servers',
+                    dest='email',
+                    type=str,
+                    required=True
+                    )
+optionalArgs.add_argument('--output-fmt',
+                         required=False,
+                         dest='output_fmt',
+                         default='gb',
+                         type=str,
+                         help="Store the results in this file format [default='gb' (genbank)]")
+
+parser._action_groups.append(optionalArgs)
+
+def sequence_info_to_dic(sequence_info_file):
+    """
+    Collapse protein ids per genome
+    Args:
+        sequence_info_file (str): A tsv file containing 3 columns uniprot_id, genome_id, protein_id,
+                                    with a header.
+    Return:
+        sequence_info (dict): A dictionary of the form {genome_id: [(protein_id_1, uniprot_id_1),
+                                                                    ... ],
+                                                                    ... }
+    """
+    sequence_info = {}
+    with open(sequence_info_file, 'r') as f:
+        # Skip header
+        next(f)
+        for line in f:
+            fields = [field.strip() for field in line.split('\t')]
+            uniprot_id = fields[0]
+            genome_id = fields[1]
+            protein_id = fields[2]
+            if genome_id not in sequence_info:
+                sequence_info[genome_id] = [(uniprot_id, protein_id,)]
+            else:
+                sequence_info[genome_id].append((uniprot_id, protein_id))
+    return sequence_info
+
+def txt_file_to_list(genomes_txt):
+    with open(genomes_txt, 'r') as fin:
+        genomes_list = [line.strip() for line in fin]
+    return genomes_list
+
+
+def download_sequences(genomes_list, genomes_file, email_address, output_fmt="gb", batch_size = 100):
+
+    # Required by Bio.Entrez
+    Entrez.email = email_address
+
+    # Some progress tracking
+    total_batches = ceil(len(genomes_list) / batch_size)
+    batch_no = 0
+
+    with open(genomes_file, 'w') as fout:
+        for i in range(0, len(genomes_list), 100):
+            batch_no += 1
+            batch = genomes_list[i:i + 100]
+            print('Downloading batch {}/{}'.format(batch_no, total_batches))
+            handle = Entrez.efetch(db="nuccore", id=batch, rettype=output_fmt, retmode="text")
+            batch_data = handle.read()
+            fout.write(batch_data)
+            handle.close()
+            # Wait 2 seconds before next batch
+            time.sleep(2)
+
+
+if __name__ == '__main__':
+    args = parser.parse_args()
+
+    genomes_list = txt_file_to_list(args.input_list)
+    download_sequences(genomes_list, args.output_file, args.email, output_fmt=args.output_fmt)
+
+    # genomes_fna = args.prefix_out + '_genomes.fna'
+    # proteins_faa = args.prefix_out + '_proteins.faa'
+    # metadata = args.prefix_out + '_metadata.tsv'
+    #
+    # with open(genbank_file, 'r') as fin, \
+    #         open(genomes_fna, 'w') as fg, \
+    #         open(proteins_faa, 'w') as fp, \
+    #         open(metadata, 'w') as fout:
+    #     # Write the header to the metadata file
+    #     fout.write(metadata_header + '\n')
+    #
+    #     # Loop through the genbank records to extract the information
+    #     for record in SeqIO.parse(fin, format="genbank"):
+    #         unversioned = record.id.split('.')[0]  # some entries don't come with their version
+    #         # Check if the versioned accession is in the dictionary
+    #         if record.id in sequence_info:
+    #             record_id = record.id
+    #         else:  # otherwise use the unversioned string
+    #             record_id = unversioned
+    #
+    #         # Get a list of associated proteins with the record
+    #         proteins = [prot_info[1] for prot_info in sequence_info.get(record_id)]
+    #
+    #         # Create a mapping of protein INSDC accession to uniprot accession
+    #         prot_map = {prot_info[1]: prot_info[0] for prot_info in sequence_info.get(record_id)}
+    #
+    #         # Polyproteins come with one accession
+    #         # TO DO
+    #         # Get the specific location of each chain
+    #         # after it is post-translationally modified
+    #         # Maybe  use the FT attribute in the uniprot file?
+    #         if len(proteins) > 1 and len(set(proteins)) == 1:
+    #             # For now print these out and skip
+    #             print("Polyprotein? Nucleotide accession: {}, Protein list: {}".format(record_id, proteins))
+    #             pass
+    #
+    #         else:
+    #             # Get the genome sequence anyway
+    #             SeqIO.write(record, fg, "fasta")
+    #             # The list of protein accessions associated with the genome
+    #             # TO DO
+    #             # turning them in a set first otherwise there are some duplicates
+    #             lproteins = list(set(proteins))
+    #             for p in lproteins:
+    #                 # The dictionary of features of the genbank record
+    #                 for f in record.features:
+    #                     # If there is a protein_id field in the qualifiers
+    #                     # and is the same as the protein
+    #                     if ('protein_id' in f.qualifiers) and (p in f.qualifiers.get('protein_id')):
+    #                         # Extract genome length, start, end, strand of the feature
+    #                         str_out = '\t'.join(map(str, [p, prot_map[p], record.id, len(record.seq),
+    #                                                       f.location.start.position,
+    #                                                       f.location.end.position,
+    #                                                       f.strand]))
+    #                         fout.write(str_out + '\n')
+    #                         # Get the translation of the feature
+    #                         prot_seq = f.qualifiers.get('translation')[0]
+    #                         # Construct a SeqRecord object
+    #                         prot_rec = SeqRecord(Seq(prot_seq, IUPAC.protein),
+    #                                              id=p,
+    #                                              description='')
+    #                         # Write the sequence to the proteins file
+    #                         SeqIO.write(prot_rec, fp, "fasta")
diff --git a/workflow/scripts/extract_proteins_from_gb.py b/workflow/scripts/extract_proteins_from_gb.py
new file mode 100755
index 0000000000000000000000000000000000000000..d37232278cfde6045774beaa0906747c0d2c745e
--- /dev/null
+++ b/workflow/scripts/extract_proteins_from_gb.py
@@ -0,0 +1,49 @@
+#!/usr/bin/env python
+
+import argparse
+from Bio import SeqIO
+from Bio.Seq import Seq
+from Bio.SeqRecord import SeqRecord
+from Bio.Alphabet import IUPAC
+
+
+parser = argparse.ArgumentParser(description='Get all protein sequences from a genbank file with'
+                                             'annotated genomes.')
+optionalArgs = parser._action_groups.pop()
+
+requiredArgs = parser.add_argument_group("required arguments")
+requiredArgs.add_argument('-i', '--input-gb',
+                          dest='input_gb',
+                          required=True,
+                          help="A genbank file containing annotated genomes")
+requiredArgs.add_argument('-o', '--output-fasta',
+                    dest='output_fasta',
+                    type=str,
+                    required=True,
+                    help="The output file to write sequences to")
+
+parser._action_groups.append(optionalArgs)
+
+
+def extract_all_proteins_from_genbank_file(genbank_fin, fasta_out):
+    genomes, proteins = 0, 0
+    with open(fasta_out, 'w') as fout:
+        for record in SeqIO.parse(genbank_fin, "genbank"):
+            genomes += 1
+            for f in record.features:
+                if f.type == 'CDS':
+                    protein = f.qualifiers.get('protein_id')
+                    if protein is not None:
+                        prot_seq = f.qualifiers.get('translation')[0]
+                        prot_rec = SeqRecord(Seq(prot_seq, IUPAC.protein),
+                                             id=protein[0],
+                                             description='')
+                        proteins += SeqIO.write(prot_rec, fout, "fasta")
+
+    return genomes, proteins
+
+
+if __name__ == '__main__':
+    args = parser.parse_args()
+    g, p = extract_all_proteins_from_genbank_file(args.input_gb, args.output_fasta)
+    print("Extracted {} from {} input genomes".format(p, g))
diff --git a/workflow/scripts/fastani_mat_to_square.py b/workflow/scripts/fastani_mat_to_square.py
new file mode 100755
index 0000000000000000000000000000000000000000..f82cf89f1a6ce2634d7c86d85abea6c8f738cd64
--- /dev/null
+++ b/workflow/scripts/fastani_mat_to_square.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python
+
+import pandas as pd
+import numpy as np
+import argparse
+from pathlib import Path
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Convert the matrix output of fastANI (v. 1.3) to a square matrix')
+
+    optionalArgs = parser._action_groups.pop()
+
+    requiredArgs = parser.add_argument_group("required arguments")
+    requiredArgs.add_argument('-i', '--input-matrix',
+                              dest='mat_in',
+                              type=lambda p: Path(p).resolve(strict=True),
+                              required=True,
+                              help="A matrix file from fastANI, produced with its -matrix option")
+    requiredArgs.add_argument('-o', '--output-matrix',
+                              dest='mat_out',
+                              type=lambda p: Path(p).resolve(),
+                              required=True,
+                              help="File name of the square matrix")
+    optionalArgs.add_argument('--process-names',
+                              action='store_true',
+                              dest='process_names',
+                              help='Use the accession as the row/column names and not the full path'
+                                   'Assumes that the results look like /path/to/<accession>.fasta'
+                              )
+
+    parser._action_groups.append(optionalArgs)
+    return parser.parse_args()
+
+def convert_ani_matrix_to_square(ani_output_fp, process_names=True):
+    """
+    Helper function to convert the output fastANI matrix to a square matrix.
+    Makes some calculations easier.
+    :param ani_output_fp: Path to fastANI output matrix file
+    :return: A tuple of (the matrix itself, the names of the files)
+    """
+    with open(ani_output_fp, 'r') as fin:
+        # first line is the number of genomes compared
+        no_of_elements = int(fin.readline().strip())
+        # The rest of the lines contain the values
+        data_lines = [line.strip() for line in fin]
+        # Each line's values start from the second column
+        data_fields = [line.split('\t')[1:] for line in data_lines]
+        # First element is the genome_file
+        genomes_files = [line.split('\t')[0] for line in data_lines]
+        if process_names:
+            names = []
+            for gf in genomes_files:
+                # I am expecting something like 'NC_000866.4.fasta'
+                fname = Path(gf).name
+                name = fname.rstrip('.fasta')
+                names.append(name)
+        else:
+            names = genomes_files 
+        # Initialize a square matrix of shape no_of_elements x no_of_elements
+        # filled with zeros
+        mat = np.zeros((no_of_elements, no_of_elements))
+
+    # A loop to get values to fill the zero-filled matrix
+    for i in range(0, no_of_elements):
+        for j in range(0, no_of_elements):
+            # Diagonal
+            if i == j:
+                value = 100.
+            # Weird but works
+            # TO DO add better documentation of what is happening here
+            if i < j:
+                value = data_fields[j][i]
+                if value == 'NA':
+                    value = 0.
+            if i > j:
+                value = data_fields[i][j]
+                if value == 'NA':
+                    value = 0.
+            mat[i, j] = value
+    return mat, names
+
+def rewrite_ani_matrix_as_square(fastani_raw, output_fp, **kwargs):
+    """
+    Helper function for writing the result
+    """
+    # Get the matrix and names to use
+    mat, names = convert_ani_matrix_to_square(fastani_raw, process_names=kwargs['process_names'])
+    # Create a data frame
+    df = pd.DataFrame(mat, index=names, columns=names)
+    # Write the matrix to the file
+    df.to_csv(output_fp, header=True, index=True, sep='\t')
+
+def main():
+    args = parse_args()
+    rewrite_ani_matrix_as_square(args.mat_in, args.mat_out, process_names=args.process_names)
+
+if __name__ == '__main__':
+    main()
diff --git a/workflow/scripts/filter_intact.py b/workflow/scripts/filter_intact.py
new file mode 100755
index 0000000000000000000000000000000000000000..cbe5146ebd6789efd4a0ea49f338e31fd3bcfcf2
--- /dev/null
+++ b/workflow/scripts/filter_intact.py
@@ -0,0 +1,164 @@
+#!/usr/bin/env python
+from ete3 import NCBITaxa, TextFace, TreeStyle, faces
+import argparse
+
+
+parser = argparse.ArgumentParser(description='Parse the IntAct DB text file to get viral related entries')
+
+# Store the default optional groups to the optionalArgs
+optionalArgs = parser._action_groups.pop()
+
+# Add requiredArgs as an argument group for displaying help better
+requiredArgs = parser.add_argument_group("required arguments")
+requiredArgs.add_argument('-i', dest="intact_in",
+                    help="The raw txt intact file",
+                    required=True)
+requiredArgs.add_argument('-o', dest="intact_out",
+                    required=True,
+                    help="File path to write results to")
+
+# Add the rest of the optional arguments
+optionalArgs.add_argument('-v', '--vhost',
+                    dest="vhost",
+                    required=False,
+                    help="A txt file containing a list of bacteria and archaea associated virus taxids")
+optionalArgs.add_argument('--phages-only',
+                          dest='phages_only',
+                          action='store_true',
+                         help='Filter taxids to only contain phages provided in the vhost file')
+optionalArgs.add_argument("--tax-db",
+                          dest="tax_db",
+                          required=False,
+                          help="Location of the NCBI taxonomy sqlite db, created"
+                          )
+# Add the optionalArgs to the action_groups
+parser._action_groups.append(optionalArgs)
+
+def file_to_list(filename):
+    """
+    Read in a one-column txt file to a list
+    :param filename:
+    :return: A list where each line is an element
+    """
+    with open(filename, 'r') as fin:
+        alist = [line.strip() for line in fin]
+    return alist
+
+
+def parse_field(field_string):
+    """
+    Get a tuple of the first entries found in an psimitab column
+    (xref, value, desc)
+    :param field_string:
+    :return: A tuple (xref, value, desc) - desc is optional
+    """
+    parts = field_string.split(":")
+    xref = parts[0]
+    rest = parts[1]
+    if "(" in rest:
+        subfields = rest.split("(")
+        value = subfields[0]
+        desc = subfields[1][:-1]
+    else:
+        value = rest
+        desc = "NA"
+    return xref, value, desc
+
+
+def parse_column_string(astring):
+    """
+    Parses the information stored in a column field of psi-mitab
+    :param astring: The string of psimitab
+                    e.g. `taxid:9606(human)|taxid:9606(Homo sapiens)`
+    :return: A dictionary of the values
+            {'xref': ['taxid'],
+            'value': ['9606'],
+            'desc': ['human', 'Homo sapiens']}
+    """
+
+    # Initialize an empty list for each xref, value, desc key
+    column_data = {k: [] for k in ['xref', 'value', 'desc']}
+    if "|" in astring:
+        fields = astring.split("|")
+    else:
+        fields=[astring]
+
+    for f in fields:
+        result = parse_field(f)
+        xref, value, desc = result[0], result[1], result[2]
+        if xref not in column_data['xref']:
+            column_data['xref'].append(xref)
+        if value not in column_data['value']:
+            column_data['value'].append(value)
+        if desc not in column_data['desc']:
+            column_data['desc'].append(desc)
+    return column_data
+
+
+## TREE STUFF ##
+def layout(node):
+    """
+    Helper function to annotate nodes with scientific name
+    :param node: A node instance of a ete3.Tree
+    :return: Annotated node with sci_name
+    """
+    faces.add_face_to_node(TextFace(node.sci_name), node, 0)
+
+
+if __name__ == '__main__':
+    # Parse the args
+    args = parser.parse_args()
+    ncbi = NCBITaxa(dbfile=args.tax_db)
+    ncbi_viruses = ncbi.get_descendant_taxa('10239', intermediate_nodes=True)
+
+    intact_in = args.intact_in
+    intact_out = args.intact_out
+
+    # If a file with host taxids is provided
+    if args.vhost:
+        phage_taxids_file = args.vhost
+        # Get the list of taxids
+        phage_taxids = list(map(int, file_to_list(phage_taxids_file)))
+        print("Phages list : {} entries".format(len(phage_taxids)))
+    else:
+        phage_taxids = set()
+    # If we want to only focus on phages
+    if args.phages_only:
+        ncbi_phages = ncbi.get_descendant_taxa('28883', intermediate_nodes=True) # 28883 -> Caudovirales
+        compare_list = set.union(set(phage_taxids), set(ncbi_phages))
+    else:
+        ncbi_virus_taxids = ncbi.get_descendant_taxa('10239', intermediate_nodes=True)
+        print("NCBI Viruses descendants: {}".format(len(ncbi_virus_taxids)))
+        compare_list = set.union(set(ncbi_virus_taxids), set(phage_taxids))
+
+    print("Contents of input file will be scanned against {} viral associated taxids.".
+          format(len(compare_list)))
+    # Initialize an interaction counter and a virus-virus counter
+    int_counter, vv_counter = 0,0
+    # Store the parsed taxids
+    taxids = []
+    with open(intact_in, 'r') as fin, open(intact_out, 'w') as fout:
+        header_line = fin.readline()
+        fout.write(header_line)
+        for line in fin:
+            int_counter += 1
+            # Print some progress
+            if int_counter % 10000 == 0:
+                print("Parsed {} entries".format(int_counter))
+            columns = [field.strip() for field in line.split("\t")]
+            # On PSIMI-TAB v25 columns 8,9 have the taxonomy info
+            if columns[9] != '-' and columns[10] != '-':
+                taxidA_data = parse_column_string(columns[9])
+                taxidB_data = parse_column_string(columns[10])
+                taxidA, taxidB = taxidA_data['value'], taxidB_data['value']
+                # If the taxids are both viral
+                if (int(taxidA[0]) in compare_list) and (int(taxidB[0]) in compare_list):
+                    vv_counter += 1
+                    fout.write(line)
+                    taxids.append(taxidA[0])
+                    taxids.append(taxidB[0])
+
+    print("{} entries were parsed.".format(int_counter))
+    print("{} entries were viral".format(vv_counter))
+    taxids_set = set(taxids)
+    print("Filtered taxids list contains {} taxa".format(len(taxids_set)))
diff --git a/workflow/scripts/get_uniprot_ids.py b/workflow/scripts/get_uniprot_ids.py
new file mode 100755
index 0000000000000000000000000000000000000000..b26e81973bf06c1d46982ba1143eded7d2543f76
--- /dev/null
+++ b/workflow/scripts/get_uniprot_ids.py
@@ -0,0 +1,60 @@
+#!/usr/bin/env python
+
+import argparse
+import pandas as pd
+
+parser = argparse.ArgumentParser(description="Produce a list of uniprot ids from the metadata.tsv (output of\n"
+                                             "summarize_intact.py",
+                                formatter_class=argparse.RawTextHelpFormatter)
+optionalArgs = parser._action_groups.pop()
+
+requiredArgs = parser.add_argument_group("required arguments")
+requiredArgs.add_argument('-i', '--input_tsv',
+                    dest='input_tsv',
+                    type=str,
+                    required=True,
+                    help="The metadata.tsv file as produced from summarize_intact.py")
+requiredArgs.add_argument('-l', '--output-list',
+                    dest='output_list',
+                    help="A txt file containing the raw uniprot ids from intact.\n"
+                         "This will be used to post a query to the uniprot mapping tool",
+                    required=True)
+requiredArgs.add_argument('-o', '--output-tsv',
+                          dest='output_tsv',
+                          required=True,
+                          help='A tsv file that contains the result of filtering')
+
+parser._action_groups.append(optionalArgs)
+
+if __name__ == '__main__':
+    args = parser.parse_args()
+    df = pd.read_csv(args.input_tsv, sep='\t')
+    # Select non-self interactions where both interactors come from uniprot
+    interactions = df.loc[(df['same_protein'] == 0) &
+                          (df['source_A'] == 'uniprotkb') &
+                          (df['source_B'] == 'uniprotkb'), ]
+
+    interactions.to_csv(args.output_tsv, sep='\t', index=False)
+
+    print(f'Number of non-self interactions with uniprotkb identifiers: {interactions.shape[0]}')
+
+    # Create a protA_info data-frame with two columns prot_id, source_db
+    protA_info = interactions[['prot_A', 'source_A']].reset_index(drop=True)
+    protA_info = protA_info.rename(columns={'prot_A': 'prot_id',
+                                            'source_A': 'source_db'})
+    # Create a protB_info data-frame with two columns prot_id, source_db
+    protB_info = interactions[['prot_B', 'source_B']].reset_index(drop=True)
+    protB_info = protB_info.rename(columns={'prot_B': 'prot_id',
+                                            'source_B': 'source_db'})
+    # Concatenate the two data-frames
+    prot_info = pd.concat([protA_info, protB_info], ignore_index=True)
+    # Drop the duplicates based on uniprot_id
+    prot_info.drop_duplicates(subset='prot_id', inplace=True)
+    print(f'{prot_info.shape[0]} unique proteins come from uniprot')
+
+    uniprot_ids = prot_info['prot_id'].to_list()
+
+    # Write the identifiers to the output file
+    with open(args.output_list, 'w') as fout:
+        for uniprot_id in uniprot_ids:
+            fout.write(f'{uniprot_id}\n')
diff --git a/workflow/scripts/make_protein_combos.py b/workflow/scripts/make_protein_combos.py
new file mode 100755
index 0000000000000000000000000000000000000000..457f75ad346fca22c3ecd9d69c2c53d22989569b
--- /dev/null
+++ b/workflow/scripts/make_protein_combos.py
@@ -0,0 +1,167 @@
+#!/usr/bin/env python
+
+import argparse
+from itertools import combinations
+import random
+from Bio import SeqIO
+import pathlib
+from Bio.Seq import Seq
+from Bio.SeqRecord import SeqRecord
+from Bio.Alphabet import IUPAC
+
+
+parser = argparse.ArgumentParser(description='Create a 2-column tsv file with --number-of combinations of proteins'
+                                             'from the same genome, for all genomes that have more than 1 '
+                                             'protein. '
+                                             'If a 2-column tsv files is provided with --exclude, the set of these '
+                                             'interactions will be excluded.')
+optionalArgs = parser._action_groups.pop()
+
+requiredArgs = parser.add_argument_group("required arguments")
+requiredArgs.add_argument('-i', '--input-gb',
+                          dest='input_gb',
+                          required=True,
+                          help="A genbank file containing annotated genomes")
+requiredArgs.add_argument('-o', '--fasta-out',
+                          dest='fasta_out',
+                          type=str,
+                          required=True,
+                          help="The output faa file to write to")
+requiredArgs.add_argument('-x', '--interactions-out',
+                          dest='interactions_out',
+                          type=str,
+                          required=True,
+                          help="The output tsv file to write the interactions in (x for A *x* B)")
+
+requiredArgs.add_argument('-a', '--all-proteins-faa',
+                          dest='all_proteins_fasta',
+                          help='A fasta file that contains all proteins as extracted from the genbank file',
+                          required=True)
+optionalArgs.add_argument('--number-of-pairs',
+                          required=False,
+                          help="select this amount of pairs to output",
+                          dest='sample_no',
+                          type=int)
+optionalArgs.add_argument('--exclude',
+                          required=False,
+                          type=str,
+                          help="A 2-column tsv file that contains interactions to be excluded")
+optionalArgs.add_argument('--sample-size',
+                          required=False,
+                          dest='sample_size',
+                          type=int,
+                          help="If an --exclude file is present, set the number of interactions"
+                               "to sample to N * sample-size, where N=number of interactions in the exclude file",
+                          default=1
+                          )
+optionalArgs.add_argument('--random-seed',
+                          required=False,
+                          dest='random_seed',
+                          type=int,
+                          default=46,
+                          help="Set the random seed for reproducibility [default = 46]"
+                         )
+
+parser._action_groups.append(optionalArgs)
+
+def create_all_possible_combos(genbank_fin):
+    """
+    Create a set of all possible protein interactions, per genome,
+    from the input genbank_file
+    """
+    # all_pairs_dict = {}
+    all_pairs = []
+    for record in SeqIO.parse(genbank_fin, format='genbank'):
+        proteins = []
+        for f in record.features:
+            if f.type == 'CDS':
+                protein = f.qualifiers.get('protein_id')
+
+                if protein:
+                    protein_id = protein[0]
+                    proteins.append(protein_id)
+
+        if len(proteins) >= 2:
+            protein_combos = list(combinations(sorted(proteins), 2))
+
+            # all_pairs_dict[genome_acc] = protein_combos
+            for pair in protein_combos:
+                all_pairs.append(pair)
+        elif len(proteins) == 1:
+            pass
+
+    # Sort the tuples by protein id, to make comparisons easier
+    all_pairs_sorted = sorted([tuple(sorted(i)) for i in all_pairs])
+
+    return all_pairs_sorted
+
+
+def create_exclude_set(exclude_file):
+    """
+    Create a list of interactions to be excluded from analysis,
+    e.g. if they are part of the positive set.
+    """
+    excludes = []
+    with open(exclude_file, 'r') as fin:
+        for line in fin:
+            fields = line.split('\t')
+            p1, p2 = fields[0].strip(), fields[1].strip()
+            interaction = (p1,p2)
+            excludes.append(sorted(interaction))
+    exclude_pairs = sorted(set([tuple(sorted(i)) for i in excludes]))
+
+    return exclude_pairs
+
+
+if __name__ == '__main__':
+    args = parser.parse_args()
+
+    interactions_pool = create_all_possible_combos(args.input_gb)
+    #print("Total raw pool: {}".format(len(interactions_pool)))
+
+    if args.exclude:
+        exclude_set = create_exclude_set(args.exclude)
+        print('{} were input for exclusion'.format(len(exclude_set)))
+        interactions_pool = sorted(set(interactions_pool) - set(exclude_set))
+        print("Total pool after exclusion: {}".format(len(interactions_pool)))
+        if not args.sample_size:
+            print("The number of interactions to be selected will be set to {}".format(len(exclude_set)))
+            print("You can use the --sample-size argument to sample more")
+            sample_size = args.sample_size
+        else:
+            sample_size = args.sample_size * len(exclude_set)
+    elif not args.exclude and not args.sample_no:
+        # Just testing this
+        parser.error("Either provide an input file with interactions to exclude and set the sample-size option"
+                     "or set the number of interactions you want to sample with the --number-of-pairs option")
+    elif args.sample_no:
+        sample_size = args.sample_no
+    else:
+        parser.error("This is unexpected")
+
+    with open(args.all_proteins_fasta, 'r') as fin:
+        seq_data = SeqIO.to_dict(SeqIO.parse(fin, format="fasta"))
+    
+    random.seed(args.random_seed)
+
+    final_interactions = random.sample(interactions_pool, sample_size)
+    final_interactions = sorted([tuple(sorted(i)) for i in final_interactions])
+    #print(final_interactions[:5])
+    #print("Produced {} interactions".format(len(final_interactions)))
+
+    with open(args.interactions_out, 'w') as fout:
+        for interaction in final_interactions:
+            fout.write('{}\t{}\n'.format(interaction[0], interaction[1]))
+    #print("Interaction data are stored in {}".format(args.interactions_out))
+
+    accessions_list = []
+    for interaction in final_interactions:
+        accessions_list.append(interaction[0])
+        accessions_list.append(interaction[1])
+    accessions_list = sorted(set(accessions_list))
+
+    writer = 0
+    with open(args.fasta_out, 'w') as fout:
+        for accession in accessions_list:
+            writer += SeqIO.write(seq_data.get(accession), fout, format="fasta")
+    #print("Wrote {} unique sequences in file {}".format(writer, args.fasta_out))
diff --git a/workflow/scripts/predict.py b/workflow/scripts/predict.py
new file mode 100755
index 0000000000000000000000000000000000000000..3b5764fb0b40b8e7c597553b73116962cf3910e4
--- /dev/null
+++ b/workflow/scripts/predict.py
@@ -0,0 +1,245 @@
+#!/usr/bin/env python
+
+import argparse
+from pathlib import Path
+
+import pandas as pd
+
+import pickle
+
+from sklearn.ensemble import RandomForestClassifier
+from sklearn.model_selection import train_test_split
+
+
+FEATURES = ['jaccard_score',
+            'same_score',
+            'inwards_score',
+            'outwards_score',
+            'avg_distance',
+            'mean_ani',
+            'mean_aai']
+
+# Random State - for reproducibility
+RS = 1
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Predict interactions on target')
+
+    optionalArgs = parser._action_groups.pop()
+
+    requiredArgs = parser.add_argument_group("required arguments")
+    requiredArgs.add_argument('-m', '--model',
+                              dest='model_fp',
+                              type=lambda p: Path(p).resolve(strict=True),
+                              required=True,
+                              help="A pickle file containing the RF parameters"
+                              )
+    requiredArgs.add_argument('-t', '--target_tsv',
+                              dest='target_tsv',
+                              type=lambda p: Path(p).resolve(strict=True),
+                              required=True,
+                              help="A tsv file containing feature calculations for the interactions")
+    requiredArgs.add_argument('-p', '--positive-set',
+                              dest='positives_tsv',
+                              type=lambda p: Path(p).resolve(strict=True),
+                              required=True,
+                              help="Positive set")
+    requiredArgs.add_argument('-n', '--negative-set',
+                              dest='negatives_tsv',
+                              type=lambda p: Path(p).resolve(strict=True),
+                              required=True,
+                              help="Negative set with labels")
+    requiredArgs.add_argument('-o', '--output-file',
+                              dest='outfile',
+                              required=True,
+                              type=lambda p: Path(p).resolve(),
+                              help="File path to write the results in")
+    optionalArgs.add_argument('-j', '--jobs',
+                              dest='n_jobs',
+                              type=int,
+                              required=False,
+                              default=1,
+                              help="Number of parallel jobs to run for classification (default : 1)"
+                              )
+    parser._action_groups.append(optionalArgs)
+
+    return parser.parse_args()
+
+
+def read_scores_table(scores_fp, label=None):
+    """
+    Read a table from a tsv file provided as a path.
+    Return a dataframe.
+
+    If label is set, append a column named `label` with
+    the provided value
+    """
+    scores_df = pd.read_csv(scores_fp, sep='\t')
+    scores_df['interaction'] = scores_df['pvog1'] + '-' + scores_df['pvog2']
+    # Unpack features list because I want interaction prepended.
+    scores_df = scores_df[['interaction', *FEATURES]]
+    if label is not None:
+        scores_df['label'] = label
+    return scores_df
+
+
+def concat_data_frames(pos_df,
+                       neg_df,
+                       subsample=False,
+                       clean=True,
+                       is_scaled=True,
+                      ):
+    """
+    Concatenate two dataframes
+
+    subsample:bool 
+        Subsample the `neg_df` to a number of
+        observations equal to the number of pos_df 
+        (balance datasets)
+
+    clean:bool
+        Remove observations with a value of `avg_distance` == 100000 or 1.
+
+    is_scaled:bool
+        The features have been scaled to a range 0-1
+
+    Return:
+    concat_df: pd.DataFrame
+        The concatenated data frame
+    """
+
+    if clean is True:
+        pos_df = remove_ambiguous(pos_df)
+        neg_df = remove_ambiguous(neg_df)
+
+    n_positives = pos_df.shape[0]
+    n_negatives = neg_df.shape[0]
+
+    # Remove possible duplicate interactions from the negatives
+    # This might happen because of the random selection when creating the set
+    # Why I also select more negatives to begin with
+    neg_df = neg_df.loc[~neg_df['interaction'].isin(pos_df['interaction'])]
+
+    if (n_positives != n_negatives) and (subsample is True):
+        neg_df = neg_df.sample(n=n_positives, random_state=1)
+    concat_df = pd.concat([pos_df, neg_df]).reset_index(drop=True)
+
+    assert concat_df[concat_df.duplicated(subset=['interaction'])].empty == True, concat_df.loc[concat_df.duplicated(subset=['interaction'], keep=False)]
+
+    return concat_df
+
+
+def scale_df(input_df):
+    """
+    Scale all feature values in the data frame to [0-1].
+    """
+    maxes = input_df[FEATURES].max(axis=0)
+    scaled_data = input_df[FEATURES].divide(maxes)
+    if 'label' in input_df.columns:
+        scaled_df = pd.concat([input_df['interaction'], scaled_data, input_df['label']], axis=1)
+    else:
+        scaled_df = pd.concat([input_df['interaction'], scaled_data], axis=1)
+    return scaled_df
+
+
+def remove_ambiguous(input_df):
+    """
+    Select observations in the `input_df` that have feature values
+    """
+    df_clean = input_df[(input_df.jaccard_score != 0)] # This is true if they don't co-occur
+    return df_clean
+
+
+
+if __name__ == '__main__':
+    args=parse_args()
+
+    # Create the training set dataframe
+    pos_df = read_scores_table(args.positives_tsv, label=1)
+    print("Raw positive set interactions : {}".format(pos_df.shape[0]))
+
+    neg_df = read_scores_table(args.negatives_tsv, label=0)
+    print("Raw negative set interactions : {}".format(neg_df.shape[0]))
+
+    scaled_pos_df = scale_df(pos_df)
+    scaled_neg_df = scale_df(neg_df)
+    training_df = concat_data_frames(scaled_pos_df,
+                                     scaled_neg_df,
+                                     subsample=True,
+                                     clean=True,
+                                     is_scaled=True)
+    print("Processed training set interactions: {}".format(training_df.shape[0]))
+
+    # Read in the target dataset
+    target_df = read_scores_table(args.target_tsv)
+    input_targets = target_df.shape[0]
+    print("Target set interactions: {} ".format(input_targets))
+
+    # Clean the target from distances
+    target_df = remove_ambiguous(target_df)
+    # Scale it
+    target_df = scale_df(target_df)
+
+    # Remove posnegs
+    target_df = target_df.loc[~target_df['interaction'].isin(training_df['interaction'])]
+    final_targets = target_df.shape[0]
+    print("Removed {} interactions from target ( {} remaining)".format((input_targets - final_targets), final_targets))
+
+    # Read in the model
+    with open(args.model_fp, 'rb') as fin:
+        RF = pickle.load(fin)
+
+    if args.n_jobs:
+        RF.n_jobs = args.n_jobs
+
+    print("Classifier: {}".format(RF))
+
+
+    Xt = training_df[FEATURES]
+    yt = training_df['label']
+    # Keep these to append 
+    interactions = training_df['interaction']
+    # Get the feature values only, for predictions on the target
+    X = target_df[FEATURES]
+
+    # Split the training set to train/holdout (0.7/0.3)
+    # Holdout is thrown away here...
+    # This is required for consistency with the whole model selection process
+    X_train, X_holdout, y_train, y_holdout = train_test_split(Xt,
+                                                              yt,
+                                                              test_size=0.3,
+                                                              random_state=RS)
+    # write the actual training set used to a file
+    # So this can be appended to the complete final table
+    # Make a copy
+    X_training_out = X_train.copy()
+    # Append the label
+    X_training_out['label'] = y_train
+    # Append the interaction information
+    X_training_out['interaction'] = training_df['interaction']
+    # Give the training set a proba of 1
+    X_training_out['proba'] = 1.
+    # Rearrange columns
+    X_training_out = X_training_out[['interaction', *FEATURES, 'label', 'proba']]
+    
+    # Construct the name of the output file
+    dir_name = args.outfile.parent
+    training_set_fp = dir_name / Path("final_training_set.tsv")
+    X_training_out.to_csv(training_set_fp, sep = '\t', index=False)
+
+
+
+
+    print("Fitting...")
+    RF.fit(X_train, y_train)
+
+    print("Predicting...")
+    X_pred = RF.predict(X)
+    X_pred_proba = RF.predict_proba(X)
+
+    target_df['label'] = X_pred
+    target_df['proba'] = X_pred_proba[:,1]
+
+    target_df.to_csv(args.outfile, sep="\t", index=False)
+    print("Finished! Results are written in {}".format(str(args.outfile)))
+
diff --git a/workflow/scripts/process_annotations.py b/workflow/scripts/process_annotations.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae82e20ddab49e56ad9a515fac3fd78264926303
--- /dev/null
+++ b/workflow/scripts/process_annotations.py
@@ -0,0 +1,156 @@
+#!/usr/bin/env python
+
+import argparse
+import pandas as pd
+from pathlib import Path
+from collections import Counter
+import operator
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Parse pvogs annotations from the VOGProteinTable.txt "
+                                        "to a table that contains most frequently occurring raw and "
+                                        "processed annotation")
+    optionalArgs = parser._action_groups.pop()
+
+    requiredArgs = parser.add_argument_group("Required arguments")
+    requiredArgs.add_argument('-i', '--input-file',
+                            dest='input_fp',
+                            type=str,
+                            required=True,
+                            help="The VOGProteinTable.txt from the database")
+    requiredArgs.add_argument('-o', '--output-file',
+                        dest='output_fp',
+                        type=str,
+                        required=True,
+                        help="The output table location")
+
+    parser._action_groups.append(optionalArgs)
+    return parser.parse_args()
+
+def parse_annotations(vog_table_fp):
+    """
+    Parse the pvogs annotations from the VOGProteinTable.txt
+    This tries to normalize the annotations into a standard
+    set of descriptions that is used for counting.
+    
+    Anything that is annotated as hypothetical is translated
+    to 'unknown'.
+    
+    'protein' and 'putative' are also stripped.
+    
+    Returns
+    annotations_dic:dict
+        A dictionary that holds the processed terms
+        of the form
+            {pvog_id : 
+                [annotation_1, annotation_2, ...],
+            ...}
+            
+    annotations_original:dict
+        A dictionary that holds the original terms
+        of the form
+            {pvog_id: 
+                [annotation_1, annotation_2, ...],
+            ...}
+    
+    """
+
+    annotations_dic = {}
+    annotations_original = {}
+    with open(vog_table_fp, 'r') as fin:
+        for line in fin:
+                fields = line.split('|')
+                pvog = fields[0].split(':')[0]
+                if len(fields) > 2:
+                    annotation = fields[2].split(':')[0]
+                    processed_annotation = annotation
+                    if 'hypothetical' in processed_annotation:#or annotation == '-':
+                        processed_annotation=processed_annotation.replace('hypothetical protein', '').strip()
+                        if processed_annotation == '':
+                            processed_annotation = 'unknown'
+                        else:
+                            processed_annotation = processed_annotation.strip()
+
+                    if 'protein' in annotation:
+                        processed_annotation = processed_annotation.replace('protein', '')
+                    if 'putative' in annotation:
+                        processed_annotation = processed_annotation.replace('putative', '')
+                else:
+                    annotation = 'unknown'
+                    processed_annotation = 'unknown'
+
+                if pvog not in annotations_dic:
+                    annotations_dic[pvog] = [processed_annotation.strip()]
+                else:
+                    annotations_dic[pvog].append(processed_annotation.strip())
+
+                if pvog not in annotations_original:
+                    annotations_original[pvog] = [annotation.strip()]
+                else:
+                    annotations_original[pvog].append(annotation.strip())
+
+    return annotations_dic, annotations_original
+
+def get_max_count_annotation(annotations_dic):
+    """
+    Get the annotation with the max count from a given `annoations_dic`
+    
+    annotations_dic:dict
+        A dictionary of the form 
+             {pvog_id : 
+                [annotation_1, annotation_2, ...],
+            ...}
+    Return:
+        unique_only:dict
+            A dictionary with a pvog as key and
+            the annotation with the highest occurring frequency
+    """
+    unique_only = {}
+    for pvog in annotations_dic:
+        c = Counter(annotations_dic.get(pvog))
+        most_common = max(c.items(), key=operator.itemgetter(1))[0]
+        unique_only[pvog] = most_common
+    return unique_only
+
+def create_annotations_df(vog_table_fp):
+    """
+    A wrapper function that parses the annotations file
+    into a table
+
+    vog_table_fp:Path-like object
+        The filepath of VOGProteinTable.txt
+
+    Return
+    df_processed:pd.DataFrame
+        A dataframe that holds processed and raw annotations
+        for each pvog        
+    """
+    parsed_annotations = parse_annotations(vog_table_fp)
+    processed_annotations = parsed_annotations[0]
+    raw_annotations = parsed_annotations[1]
+    max_data_processed = get_max_count_annotation(processed_annotations)
+    max_data_raw = get_max_count_annotation(raw_annotations)
+    df_processed = pd.DataFrame.from_dict(max_data_processed, 
+                                          orient='index', 
+                                          columns=['annotation_processed'])
+    df_raw = pd.DataFrame.from_dict(max_data_raw, 
+                                    orient='index', 
+                                    columns=['annotation_raw'])
+    df_processed = df_processed.reset_index().rename(columns = {'index': 'pvog'})
+    df_raw = df_raw.reset_index().rename(columns={'index': 'pvog'})
+    df_processed['annotation_raw'] = df_raw['annotation_raw']
+    df_processed = df_processed.set_index('pvog')
+
+    return df_processed
+
+def main():
+    args = parse_args()
+    annotations_df = create_annotations_df(Path(args.input_fp))
+    annotations_df.to_csv(Path(args.output_fp), sep='\t')
+
+
+if __name__ == '__main__':
+    main()
+
+
+
diff --git a/workflow/scripts/process_comparem.py b/workflow/scripts/process_comparem.py
new file mode 100755
index 0000000000000000000000000000000000000000..65586b3321a36638b020fdc33e1cb170abcb2abe
--- /dev/null
+++ b/workflow/scripts/process_comparem.py
@@ -0,0 +1,68 @@
+#!/usr/bin/env python
+
+import pandas as pd
+import numpy as np
+import argparse
+from pathlib import Path
+
+parser = argparse.ArgumentParser(description="Rewrite comparem output as square-matrix and calculate "
+                                             "mean aai and nonzero mean aai for the records.")
+parser.add_argument("--input-tsv", "-i",
+                    required=True,
+                    dest="input_tsv",
+                    help="The aai_summary.tsv file produced with comparem's aai_wf")
+parser.add_argument("--output-tsv", "-o",
+                    dest='output_tsv',
+                    required=True,
+                    help="Output tsv file to write square matrix to")
+
+def process_comparem_output(aai_out, matrix_out):
+    with open(aai_out, 'r') as fin:
+        # Skip header
+        next(fin)
+        comparisons = {}
+        for line in fin:
+            fields = [f.strip() for f in line.split()]
+
+            genomeA, genomeB, mean_aai = fields[0], fields[2], float(fields[5])
+
+            if genomeA not in comparisons:
+                comparisons[genomeA] = {genomeA: 100.}
+            if genomeB not in comparisons:
+                comparisons[genomeB] = {genomeB: 100.}
+
+            comparisons[genomeA][genomeB] = mean_aai
+            comparisons[genomeB][genomeA] = mean_aai
+
+    df = pd.DataFrame.from_dict(comparisons)
+    # Write the square matrix
+    df.to_csv(matrix_out, sep='\t', header=True)
+    return True
+
+#    # Convert df to numpy array
+#    mat = df.to_numpy()
+#    # Calculate mean aai for all comparisons
+#    aai = mat.mean()
+#    # Calculate non zero mean
+#    nonzeros = np.count_nonzero(mat)
+#    nonzero_mean_aai = mat.sum() / nonzeros
+#
+#    return aai, nonzero_mean_aai
+
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+    input_fp = Path(args.input_tsv)
+    square_matrix_fp = Path(args.output_tsv)
+    process_comparem_output(input_fp, square_matrix_fp)
+
+#    if args.write_result:
+#        # I assume the aai_summary.tsv is in /path/to/INTERACTION/aai_out/aai/aai_summary.tsv
+#        interaction = input_dir.parent.parent.name
+#        result_file = Path.joinpath(input_dir, 'aai_means.tsv')
+#        with open(result_file, 'w') as fout:
+#            fout.write('interaction\tmean_aai\tnonzero_mean_aai\n')
+#            fout.write(f'{interaction}\t{aai}\t{nonzero_mean_aai}\n')
+#    else:
+#        print(f'Mean aai : {aai}')
+#        print(f'Non-zero mean aai: {nonzero_mean_aai}')
diff --git a/workflow/scripts/process_uniprot.py b/workflow/scripts/process_uniprot.py
new file mode 100755
index 0000000000000000000000000000000000000000..2f633ad3c081a6727cbee02b20ab4083dfde2b61
--- /dev/null
+++ b/workflow/scripts/process_uniprot.py
@@ -0,0 +1,376 @@
+#!/usr/bin/env python
+
+import argparse
+from Bio import SwissProt, SeqIO
+from Bio.Seq import Seq
+from Bio.SeqRecord import SeqRecord
+from Bio.Alphabet import IUPAC
+import pathlib
+from ast import literal_eval as make_tuple
+import pandas as pd
+
+parser = argparse.ArgumentParser(description="This is to get several mapping files for the interactions.",
+                                 formatter_class=argparse.RawTextHelpFormatter)
+optionalArgs = parser._action_groups.pop()
+
+requiredArgs = parser.add_argument_group("required arguments")
+requiredArgs.add_argument('-i', '--interactions-tsv',
+                          dest='interactions_tsv',
+                          required=True,
+                          help="Filtered interactions_tsv from get_uniprot_ids.py")
+requiredArgs.add_argument('-l', '--input_list',
+                          dest='uniprots_list',
+                          type=str,
+                          required=True,
+                          help="The output of get_uniprot_ids.py")
+requiredArgs.add_argument('-s', '--swissprot-file',
+                          dest='swissprot_file',
+                          required=True,
+                          help="A SwissProt file, containing results based on the primary ids")
+requiredArgs.add_argument('-p', '--prefix-out',
+                          dest='prefix_out',
+                          required=True,
+                          help="The prefix of the file names to write to. This will produce files:\n"
+                               "<prefix>/proteins.faa,\n"
+                               "<prefix>/uniprot2ncbi.mapping.txt\n"
+                               "<prefix>/ncbi2uniprot.mapping.txt\n"
+                               "<prefix>/skipped.uniprot.txt\n"
+                               "<prefix>/ncbi_interactions.txt\n"
+                               "<prefix>/ncbi_genomes.txt")
+parser._action_groups.append(optionalArgs)
+
+
+def has_refseq(db_list):
+    """
+    Return the index of the list where the 'RefSeq' string is located.
+    Otherwise return None
+    :param db_list: A list of db names taken as the first element of the tuples in a
+    Swissprot.record.cross_references list
+    :return: int: index or None
+    """
+    if 'RefSeq' in db_list:
+        return db_list.index('RefSeq')
+    else:
+        return None
+
+
+def give_me_proper_embl(cross_refs):
+    """
+    Filter for references where the first element == 'EMBL',
+    then search for the first occurence where the genome accession is not '-'.
+    This is to get both a valid protein accession and genome accession.
+
+    :param cross_refs: The full list of SwissProt.record.cross_references
+    :return:
+    """
+    # Get embl records first
+    embl_refs = []
+    for ref in cross_refs:
+        if ref[0] == 'EMBL':
+            embl_refs.append(ref)
+
+    genome_acc, prot_acc = embl_refs[0][1], embl_refs[0][2]
+
+    for ref in embl_refs:
+        if ref[2] != '-':
+            genome_acc, prot_acc = ref[1], ref[2]
+            break
+    return genome_acc, prot_acc
+
+
+def has_embl(db_list):
+    """
+    Check if 'EMBL' is in a list of database names
+    :param db_list:
+    :return: True or None
+    """
+    if 'EMBL' in db_list:
+        return True
+    else:
+        return None
+
+
+def get_ncbi_accessions(cross_refs):
+    """
+    Given a SwissProt.record.cross_references list, extract a tuple
+    of genome,protein accessions
+    :param cross_refs: Bio.SwissProt.record.cross_references list
+    :return: (genome_accession, protein_accession) associated with the record
+    """
+    dbs = [ref[0] for ref in cross_refs]
+
+    if has_refseq(dbs) is not None:
+        loc = has_refseq(dbs)
+        genome_acc = cross_refs[loc][2]
+        if '[' in genome_acc:
+            genome_acc = genome_acc.split()[0].rstrip('.')
+        prot_acc = cross_refs[loc][1]
+    elif has_embl(dbs) is not None and has_refseq(dbs) is None:
+        genome_acc, prot_acc = give_me_proper_embl(cross_refs)
+    #         genome_acc, prot_acc = cross_refs[loc][1], cross_refs[loc][2]
+    else:
+        genome_acc, prot_acc = '-', '-'
+
+    return genome_acc, prot_acc
+
+
+def construct_description(protein_seq, uniprotkb_id, ncbi_genome, ncbi_protein, start=None, end=None, taxid=None):
+    if start is not None and end is not None:
+        return f'uniprotkb={uniprotkb_id};' \
+            f'ncbi_protein={ncbi_protein};' \
+            f'start={start};' \
+            f'end={end};' \
+            f'ncbi_genome={ncbi_genome};' \
+            f'taxid={taxid}'
+    else:
+        return f'uniprotkb={uniprotkb_id};' \
+            f'ncbi_protein={ncbi_protein};' \
+            f'start=1;' \
+            f'end={len(protein_seq)};' \
+            f'ncbi_genome={ncbi_genome};' \
+            f'taxid={taxid};'
+
+
+def extract_prochain_sequence_from_record(record,
+                                          primary_id,
+                                          prochain_id,
+                                          genome_acc='-',
+                                          protein_acc='-'):
+    r_taxid = record.taxonomy_id
+    for feature in record.features:
+        if feature[0] == 'CHAIN' and feature[-1] == prochain_id:
+            start, end = feature[1], feature[2]
+            rec_seq = Seq(record.sequence[start - 1:end], IUPAC.protein)
+            original = '-'.join([primary_id, prochain_id])
+            if protein_acc != '-':
+                new_id = f'{protein_acc}_at_{start}-{end}'
+                prot_rec = SeqRecord(rec_seq, id=new_id,
+                                     description=construct_description(str(rec_seq),
+                                                                       original,
+                                                                       genome_acc,
+                                                                       protein_acc,
+                                                                       start,
+                                                                       end,
+                                                                       taxid=r_taxid))
+            else:
+                prot_rec = SeqRecord(rec_seq, id=original,
+                                     description=construct_description(str(rec_seq),
+                                                                       original,
+                                                                       genome_acc,
+                                                                       protein_acc,
+                                                                       start,
+                                                                       end,
+                                                                       taxid=r_taxid))
+    return prot_rec
+
+def uniprot_id_type(uniprot_id):
+    """
+    Check the type of a given uniprot_id
+    P04591 -> 'primary'
+    P04591-PRO_0000038593 -> 'prochain'
+    P04591-3 -> 'isoform
+    :param uniprot_id: String of uniprot_id
+    :return:
+    """
+    if '-' in uniprot_id:
+        second_part = uniprot_id.split('-')[1]
+        if second_part.startswith('PRO'):
+            return 'prochain'
+        else:
+            return 'isoform'
+    else:
+        return 'primary'
+
+def give_me_the_record(primary_id, swissprot_file):
+    """
+    Return a single record given with the primary id
+    :param primary_id: A primary id
+    :param swissprot_file: A swissprot file
+    :return: A record with accession == primary id
+    """
+    with open(swissprot_file, 'r') as fh:
+        for record in SwissProt.parse(fh):
+            if primary_id in record.accessions:
+                return record
+
+def extract_protein_from_record(record):
+    """
+    Grab the protein sequence as a string from a SwissProt record
+    :param record: A Bio.SwissProt.SeqRecord instance
+    :return:
+    """
+    return str(record.sequence)
+
+
+if __name__ == '__main__':
+    args = parser.parse_args()
+
+    with open(args.uniprots_list, 'r') as fin:
+        uniprots_list = [line.strip() for line in fin]
+
+    # Create the names of the output files
+    output_dir = pathlib.Path(args.prefix_out)
+    output_dir.mkdir(exist_ok=True)
+    # This is the fasta file that contains the proteins
+    output_fasta = pathlib.Path.joinpath(output_dir, 'proteins.faa')
+
+
+    # Initialize several book-keeping lists and dicts
+    # Some records I skip
+    skipped = []
+    # A dict {uniprotkb : ncbi-accession , ...}
+    uniprot2ncbi_protein_mapping = {}
+    # Keep track of how many sequences are being written
+    written = 0
+    # A dict { ncbi_protein_accession : [ uniprot_accession1, ...] , ... }
+    duplicates = {}
+    genomes = []
+
+    with open(output_fasta, 'w') as fout:
+        # Loop over the initial list of interactors
+        for protein_id in uniprots_list:
+            # Is it primary/prochain/isoform?
+            entry_type= uniprot_id_type(protein_id)
+            if entry_type == 'primary':
+                # Grab the associated record
+                r = give_me_the_record(protein_id, args.swissprot_file)
+                # Get the genome an protein accessions
+                ncbi_genome, ncbi_protein = get_ncbi_accessions(r.cross_references)
+
+                # I want to map the uniprot ids to ncbi ones.
+                # If there is nothing to map skip the record altogether
+                if ncbi_protein == '-':
+                    skipped.append(protein_id)
+                    pass
+                else:
+                    # For primary accessions the protein is the same
+                    rec_id = ncbi_protein
+                    rec_seq = extract_protein_from_record(r)
+                    rec_taxid = r.taxonomy_id
+                    description = construct_description(rec_seq, protein_id, ncbi_genome, ncbi_protein, taxid=rec_taxid)
+                    # Construct a SeqRecord object to write to file
+                    prot_rec = SeqRecord(Seq(rec_seq, IUPAC.protein),
+                                         id=rec_id,
+                                         description=description)
+                    # Book-keeping
+                    # match uniprot id to its ncbi counterpart
+                    uniprot2ncbi_protein_mapping[protein_id] = prot_rec.id
+                    # get the genome
+                    genomes.append(ncbi_genome)
+
+                    # Avoid duplicates
+                    if prot_rec.id not in duplicates:
+                        # Put this unique identifier in the seen_proteins dict
+                        duplicates[prot_rec.id] = [protein_id]
+                        # And write to file
+                        written += SeqIO.write(prot_rec, fout, "fasta")
+                    else:
+                        print("Possible duplicate:{}".format(prot_rec.id))
+                        duplicates[prot_rec.id].append(protein_id)
+                        pass
+            elif entry_type == 'prochain':
+                # When a PRO is present, things need to be different
+                primary_id = protein_id.split('-')[0]
+                prochain = protein_id.split('-')[1]
+                r = give_me_the_record(primary_id, args.swissprot_file)
+                ncbi_genome, ncbi_protein = get_ncbi_accessions(r.cross_references)
+                if ncbi_protein == '-':
+                    skipped.append(protein_id)
+                    pass
+                else:
+                    # The actual sequence is in the features table
+                    prot_rec = extract_prochain_sequence_from_record(r, primary_id, prochain,
+                                                                     genome_acc=ncbi_genome,
+                                                                     protein_acc=ncbi_protein)
+                    uniprot2ncbi_protein_mapping[protein_id] = prot_rec.id
+                    genomes.append(ncbi_genome)
+
+                    # Avoid duplicates
+                    if prot_rec.id not in duplicates:
+                        written += SeqIO.write(prot_rec, fout, "fasta")
+                        duplicates[prot_rec.id] = [protein_id]
+                    else:
+                        print("Possible duplicate {}".format(prot_rec.id))
+                        duplicates[prot_rec.id].append(protein_id)
+                        pass
+            elif entry_type == 'isoform':
+                skipped.append(protein_id)
+                print("{} is an isoform and I do not handle these for now".format(protein_id))
+                pass
+    print(30*'=')
+    print("{} proteins were extracted".format(written))
+
+    # This is a file that contains for each uniprot id its mapping to an ncbi protein
+    uniprot2ncbi_mapping_txt = pathlib.Path.joinpath(output_dir, 'uniprot2ncbi.mapping.txt')
+    uniprot_counter = 0
+    with open(uniprot2ncbi_mapping_txt, 'w') as fout:
+        for uniprot_id, ncbi_id in uniprot2ncbi_protein_mapping.items():
+            uniprot_counter += 1
+            fout.write('{}\t{}\n'.format(uniprot_id, ncbi_id))
+    print('{} uniprot ids are mapped to NCBI accessions'.format(uniprot_counter))
+
+    # A list of genomes to download from NCBI
+    genome_accessions_txt = pathlib.Path.joinpath(output_dir, 'genome_accessions.txt')
+    with open(genome_accessions_txt, 'w') as fout:
+        for genome in set(genomes):
+            fout.write('{}\n'.format(genome))
+    print('{} genomes are written to file'.format(len(set(genomes))))
+
+    # Each ncbi protein, might map to more uniprot ids
+    # This files contains this information - an aggregation at ncbi level
+    ncbi2uniprot_mapping_txt = pathlib.Path.joinpath(output_dir, 'ncbi2uniprot.mapping.txt')
+    with open(ncbi2uniprot_mapping_txt, 'w') as fout:
+        for ncbi_p in duplicates:
+            fout.write('{}\t{}\n'.format(ncbi_p, ','.join(duplicates.get(ncbi_p))))
+
+    # This files contains only the duplicates
+    duplicates_tsv = pathlib.Path.joinpath(output_dir, 'duplicates.tsv')
+    with open(duplicates_tsv, 'w') as fout:
+        for ncbi_p in duplicates:
+            if len(duplicates.get(ncbi_p)) > 1:
+                fout.write('{}\t{}\n'.format(ncbi_p, ','.join(duplicates.get(ncbi_p))))
+
+    # This files contains the list of records that were originally input
+    # but were skipped. For now this happens
+    # (1) when no valid ncbi_protein accession is found
+    # (2) for isoforms
+    skipped_txt = pathlib.Path.joinpath(output_dir, 'skipped.txt')
+    with open(skipped_txt, 'w') as fout:
+        for uniprot in skipped:
+            fout.write('{}\n'.format(uniprot))
+
+    # Get the proteins that are actually covered through this filtering process
+    proteins_covered = []
+    for ncbi_p in duplicates:
+        for uniprot_p in duplicates.get(ncbi_p):
+            proteins_covered.append(uniprot_p)
+    print('{} uniprot ids are covered'.format(len(proteins_covered)))
+
+    # Get the interactions from the input file
+    interactions = pd.read_csv(args.interactions_tsv, sep='\t')
+    print('Input file contains {} interactions'.format(interactions.shape[0]))
+    # Filter the interactions dataframe to these interactions only
+    interactions_filtered = interactions.loc[interactions.prot_A.isin(proteins_covered)
+                                             & interactions.prot_B.isin(proteins_covered), ]
+    print('{} interactions are covered from the {} proteins above'.format(interactions_filtered.shape[0],
+                                                                          len(proteins_covered)))
+
+    # Write the final set of interactions to file
+    interactions_filtered_tsv = pathlib.Path.joinpath(output_dir, 'interactions_filtered.tsv')
+    interactions_filtered.to_csv(interactions_filtered_tsv, sep='\t', index=False)
+
+    # Also provide a tsv file that contains the uniprot-uniprot interaction
+    # as ncbi-ncbi interaction (for the negatives set construction)
+    ncbi_interactions_tsv = pathlib.Path.joinpath(output_dir, 'ncbi_interactions.tsv')
+    # This is a list of strings ["('intA', 'intB')" , ...]
+    interaction_pairs = interactions_filtered['interaction'].to_list()
+    with open(ncbi_interactions_tsv, 'w') as fout:
+        for pair in interaction_pairs:
+            # Convert the literal tuple string to actual tuple
+            # TO DO
+            # Just make it a tuple to begin with in the interactions file...
+            p = make_tuple(pair)
+            fout.write('{}\t{}\t{}\t{}\n'.format(p[0], p[1],
+                                                 uniprot2ncbi_protein_mapping.get(p[0]),
+                                                 uniprot2ncbi_protein_mapping.get(p[1])))
+    print('Done!')
diff --git a/workflow/scripts/query_uniprot.py b/workflow/scripts/query_uniprot.py
new file mode 100755
index 0000000000000000000000000000000000000000..f5a56aca3e0fcd78a3592c16050da2cfff492033
--- /dev/null
+++ b/workflow/scripts/query_uniprot.py
@@ -0,0 +1,153 @@
+#!/usr/bin/env python
+
+import argparse
+import requests
+from requests.adapters import HTTPAdapter
+from urllib3.util.retry import Retry
+from math import ceil
+
+parser = argparse.ArgumentParser(description="Given a txt file with uniprot ids, post a query to uniprot\n"
+                                             "and retrieve a SwissProt file with the result",
+                                formatter_class=argparse.RawTextHelpFormatter)
+optionalArgs = parser._action_groups.pop()
+
+requiredArgs = parser.add_argument_group("required arguments")
+requiredArgs.add_argument('-i', '--input_txt',
+                    dest='input_txt',
+                    type=str,
+                    required=True,
+                    help="The output of get_uniprot_ids.txt")
+requiredArgs.add_argument('-o', '--output_sp',
+                    dest='output_sp',
+                    help="A txt file containing the raw uniprot ids from intact.\n"
+                         "This will contain the response from uniprot",
+                    required=True)
+optionalArgs.add_argument("--filter-list",
+                          action="store_true",
+                          help="Specify if the input list needs to be filtered for PROs and isoforms",
+                          dest="filter",
+                          required=False)
+optionalArgs.add_argument("--output-fmt",
+                          default='txt',
+                          type=str,
+                          dest="output_fmt",
+                          help="SwissProt records are returned by default.\n"
+                               "You can use 'fasta' for retrieving only the fasta sequences\n"
+                               "or 'xml' to get an xml file.")
+
+parser._action_groups.append(optionalArgs)
+
+
+def get_primary_ids_from_list(uniprot_ids):
+    """
+    Filter the input list to return only querrable primary ids.
+    As primary I define the part of the identifier's string that
+    serves as the accession.
+    E.g. P04591 is a primary id
+        P04591-PRO_0000038593 is a PTM chain with primary id P04591
+        P04591-3 is an isoform with primary id P04591
+
+    This will only return P04591. The rest may be retrieved from parsing
+    the results of the query.
+
+    :param uniprot_ids: A list of uniprot specific identifiers
+    :return: primary_ids: A list of primary ids
+    """
+    primary_ids = []
+    for i in uniprot_ids:
+        if '-' in i:
+            fields = i.split('-')
+            primary_id = fields[0]
+            # Append the primary id to the query list
+            if primary_id not in primary_ids:
+                primary_ids.append(primary_id)
+
+            # Make a mapping of primary ids to their pro-chains and isoforms
+        elif i not in primary_ids:
+            primary_ids.append(i)
+
+    # Double check I didn't put any duplicates in the final list
+    if len(primary_ids) != len(set(primary_ids)):
+        primary_ids = list(set(primary_ids))
+
+    return primary_ids
+
+# SHAMELESSLY STOLEN FROM
+# https://www.ebi.ac.uk/training/online/sites/ebi.ac.uk.training.online/files/UniProt_programmatically_py3.pdf
+# and modified to default to text
+
+def map_retrieve(ids2map, source_fmt='ACC+ID', target_fmt='ACC', output_fmt='txt'):
+    """
+    Map database identifiers from/to UniProt accessions.
+    The mapping is achieved using the RESTful mapping service provided by UniProt.
+    While a great many identifiers can be mapped the documentation
+    has to be consulted to check which options there are and what the database codes are.
+    Mapping UniProt to UniProt effectlvely allows batch retrieval
+    of entries.
+    Args:
+        ids2map (list or string): identifiers to be mapped
+        source_fmt (str, optional): format of identifiers to be mapped. Defaults to ACC+ID, which are UniProt accessions or IDs.
+        target_fmt (str, optional): desired identifier format. Defaults to ACC, which is UniProt accessions.
+        output_fmt (str, optional): return format of data. Defaults to list.
+    Returns:
+        mapped information (str)
+    """
+
+    BASE = 'http://www.uniprot.org'
+    KB_ENDPOINT = '/uniprot/'
+    TOOL_ENDPOINT = '/uploadlists/'
+
+    if hasattr(ids2map, 'pop'):
+        ids2map = ' '.join(ids2map)
+
+    payload = {'from': source_fmt,
+               'to': target_fmt,
+               'format': output_fmt,
+               'query': ids2map,
+               'include': 'yes',  # This is to include isoforms
+               }
+
+    with requests.Session() as s:
+        retries = Retry(total=3,
+                        backoff_factor=0.5,
+                        status_forcelist=[500, 502, 503, 504])
+
+        s.mount('http://', HTTPAdapter(max_retries=retries))
+        #         s.mount('https://', HTTPAdapter(max_retries=retries))
+
+        response = s.get(BASE + TOOL_ENDPOINT, params=payload)
+
+        if response.ok:
+            return response.text
+        else:
+            print(response.url)
+            response.raise_for_status()
+
+
+if __name__ == '__main__':
+    args = parser.parse_args()
+
+    with open(args.input_txt, 'r') as fin:
+        uniprots_list = [line.strip() for line in fin]
+    print("Input file contains {} records".format(len(uniprots_list)))
+
+    if args.filter:
+        print("Filtering out PROs and isoforms")
+        uniprots_list = get_primary_ids_from_list(uniprots_list)
+        print("Final list contains {} primary ids".format(len(uniprots_list)))
+
+    print("Posting queries of 100 to uniprot")
+    final_result = []
+    total_batches = ceil(len(uniprots_list) / 100)
+    batch_no = 0
+    for i in range(0, len(uniprots_list), 100):
+        batch_no += 1
+        slise = uniprots_list[i:i + 100]
+        # Run the query for each batch
+        print("Running batch {}/{}".format(batch_no, total_batches))
+        result = map_retrieve(slise, output_fmt=args.output_fmt)
+        # And store the result in a list
+        final_result.append(result)
+
+    with open(args.output_sp, 'w') as fout:
+        fout.write(''.join(final_result))
diff --git a/workflow/scripts/refseqs_to_pvogs.py b/workflow/scripts/refseqs_to_pvogs.py
new file mode 100755
index 0000000000000000000000000000000000000000..d78a9f202200ede11c11322e30ec12043b07655e
--- /dev/null
+++ b/workflow/scripts/refseqs_to_pvogs.py
@@ -0,0 +1,158 @@
+#!/usr/bin/env python
+
+from Bio import SearchIO, SeqIO
+from pathlib import Path
+import argparse
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='From uniprot ids to pvogs')
+
+    optionalArgs = parser._action_groups.pop()
+
+    requiredArgs = parser.add_argument_group("required arguments")
+    requiredArgs.add_argument('-i', '--interactions-file',
+                              dest='int_file',
+                              type=lambda p: Path(p).resolve(strict=True),
+                              required=True,
+                              help="A 2-column tsv file containing interaction data, with refseq ids at"
+                              )
+    requiredArgs.add_argument('-f', '--input-fasta',
+                              dest='fasta_in',
+                              type=lambda p: Path(p).resolve(strict=True),
+                              required=True,
+                              help="A fasta file with protein sequences retrieved from uniprot")
+    requiredArgs.add_argument('-hmm', '--input-hmm',
+                              dest='hmmer_tblout',
+                              type=lambda p: Path(p).resolve(strict=True),
+                              required=True,
+                              help="The tblout file from hmmsearching the proteins in <input_fasta> "
+                                   "against all pvogs")
+    requiredArgs.add_argument('-o', '--output-file',
+                              dest='outfile',
+                              required=True,
+                              type=lambda p: Path(p).resolve(),
+                              help="File path to write the results in")
+
+    parser._action_groups.append(optionalArgs)
+
+    return parser.parse_args()
+
+
+# This is higly specific for my case...
+remapper = {'P17312-2': 'P17312',
+           'P17312-3': 'P17312',
+           'P17312-4': 'P17312',
+           'P03705': 'P03705-2',
+           'P03711-2': 'P03711',
+           'P03692-2': 'P03692'}
+
+
+def get_interactions_info(interactions_fp):
+    interactions = []
+    with open(interactions_fp, 'r') as fin:
+        for line in fin:
+            fields = line.split('\t')
+            # Specific for get_uniprot_ids.py
+            intA, intB = fields[0].strip(), fields[1].strip()
+            interaction = tuple(sorted([intA, intB]))
+            interactions.append(interaction)
+    return interactions
+
+
+def get_unique_interactor_ids(interactions_list):
+    interactor_ids = []
+    for i in interactions_list:
+        interactor_ids.extend([i[0], i[1]])
+    return set(interactor_ids)
+
+
+def get_ids_from_fasta_file(fasta_in):
+    record_ids = []
+    with open(fasta_in, 'r') as fin:
+        for record in SeqIO.parse(fin, "fasta"):
+            record_ids.append(record.id)
+    return record_ids
+
+
+def parse_hmm_table(hmm_table):
+    """
+    Create a dictionary that stores all results per protein
+    """
+    hmm_results = {}
+    with open(hmm_table, 'r') as fin:
+        for record in SearchIO.parse(fin, "hmmer3-tab"):
+            for hit in record.hits:
+                prot_acc = hit.id
+                target = remapper.get(prot_acc, prot_acc)
+                if target not in hmm_results:
+                    hmm_results[target] = [(record.id, hit.evalue, hit.bitscore)]
+                else:
+                    hmm_results[target].append((record.id, hit.evalue, hit.bitscore))
+    return hmm_results
+
+
+def translate_proteins_to_pvogs(hmm_results_dic):
+
+    protein_to_pvog = {}
+    for p in hmm_results_dic:
+        results = hmm_results_dic[p]
+        if len(results) == 1:
+            best_result = results[0]
+        else:
+            # Select the best scoring pvog
+            scores = [r[2] for r in results]
+            best_index = scores.index(max(scores))
+            best_result = results[best_index]
+
+        protein_to_pvog[p] = best_result[0]
+
+    return protein_to_pvog
+
+
+def main():
+    args = parse_args()
+
+    interactions = get_interactions_info(args.int_file)
+    print("Parsed {} interactions".format(len(interactions)))
+
+    unique_ids = get_unique_interactor_ids(interactions)
+    print("Unique ids: {}".format(len(unique_ids)))
+
+    fasta_ids = get_ids_from_fasta_file(args.fasta_in)
+    print("Entries in fasta: {}".format(len(fasta_ids)))
+
+    hmm_results = parse_hmm_table(args.hmmer_tblout)
+    print("Proteins with hmmer hits: {}".format(len(list(hmm_results.keys()))))
+
+    protein_pvog_map = translate_proteins_to_pvogs(hmm_results)
+
+    no_results = []
+    counter = 0
+    with open(args.outfile, 'w') as fout:
+        for inter in interactions:
+            intA, intB = inter[0], inter[1]
+            if intA in protein_pvog_map:
+                if intB in protein_pvog_map:
+                    fout.write("{}\t{}\t{}\t{}\n".format(intA, protein_pvog_map[intA],
+                                                         intB, protein_pvog_map[intB]))
+                    counter += 1
+                else:
+                    no_results.append(intB)
+            else:
+                no_results.append(intA)
+
+    print("Interactions translated to pvogs: {}".format(counter))
+    print("Proteins with no hits: {}".format(len(set(no_results))))
+
+    if len(no_results) > 0:
+        print(20 * '=')
+        for i in no_results:
+            print(i)
+        print(20 * '=')
+
+    print("Results written in {}".format(args.outfile.__str__()))
+
+
+if __name__ == '__main__':
+    main()
diff --git a/workflow/scripts/remove_empty_files.py b/workflow/scripts/remove_empty_files.py
new file mode 100755
index 0000000000000000000000000000000000000000..57eed45d72de25b67f0fab9e9c9744d9f2a4c91f
--- /dev/null
+++ b/workflow/scripts/remove_empty_files.py
@@ -0,0 +1,50 @@
+#!/usr/bin/env python
+
+import argparse
+import pathlib
+import os
+
+parser = argparse.ArgumentParser(description="Remove all files associated with a genome that are empty",
+                                 formatter_class=argparse.RawTextHelpFormatter)
+optionalArgs = parser._action_groups.pop()
+
+requiredArgs = parser.add_argument_group("required arguments")
+requiredArgs.add_argument('-i', '--inut-dir',
+                          dest='input_dir',
+                          required=True,
+                          help="Input genes dir that contains the results of `comparem call_genes`")
+requiredArgs.add_argument('-o', '--output-txt',
+                          dest='output_txt',
+                          type=str,
+                          required=True,
+                          help="Output txt to write result. Will contain the genomes removed, if any,\n"
+                               "else it will have a 'All genomes passed!' so that is not empty")
+
+def check_file_is_empty(fin):
+    """
+
+    :param fin: A pathlib.Path object
+    :return: the file if true, nothing if false
+    """
+    if fin.stat().st_size == 0:
+        return fin
+    else:
+        pass
+
+if __name__ == '__main__':
+    args = parser.parse_args()
+
+    input_path = pathlib.Path(args.input_dir)
+    remove_these = []
+    for f in list(input_path.glob('*.faa')):
+        if check_file_is_empty(f):
+            remove_these.append(f)
+
+    with open(args.output_txt, 'w') as fout:
+        if len(remove_these) != 0:
+            for f in remove_these:
+                print('Removing file {}'.format(f.as_posix()))
+                f.unlink()
+                fout.write('{}\n'.format(f.as_posix()))
+        else:
+            fout.write("All genomes passed!\n")
diff --git a/workflow/scripts/split_multifasta.py b/workflow/scripts/split_multifasta.py
new file mode 100755
index 0000000000000000000000000000000000000000..8f09d42fc12613d36745dcd8587d12a8dcf1cf5e
--- /dev/null
+++ b/workflow/scripts/split_multifasta.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python
+
+import argparse
+import pathlib
+import gzip
+from Bio import SeqIO
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+               description="Split a multi-fasta to single files per sequence"
+    )
+    optionalArgs = parser._action_groups.pop()
+
+    requiredArgs = parser.add_argument_group("required arguments")
+
+    requiredArgs.add_argument(
+        "-i",
+        "--input",
+        dest="input_fp",
+        required=True,
+        help="Input fasta. Can be gz",
+    )
+    requiredArgs.add_argument(
+        "-o",
+        "--outdir",
+        dest="out_dir",
+        required=True,
+        help="A directory to store the files. It is NOT created",
+    )
+    optionalArgs.add_argument(
+        "--write-reflist",
+        action="store_true",
+        required=False,
+        dest="write_reflist",
+        help="Write a file that contains all paths to the output fasta files, "
+        "one per line, in the parent directory "
+    )
+
+    parser._action_groups.append(optionalArgs)
+
+    return parser.parse_args()
+
+def is_gz(path_string):
+    """
+    Return true if gzipped file
+    :param path: path to file
+    :return: boolean
+    """
+    return path_string.endswith(".gz") or path_string.endswith(".z")
+
+
+def optionally_compressed_handle(path, mode):
+    """
+    Return a file handle that is optionally gzip compressed
+    :param path: path
+    :param mode: mode
+    :return: handle
+    """
+    if mode == "r" or mode == "rb":
+        mode = "rt"
+    if mode == "w" or mode == "wb":
+        mode = "wt"
+    if is_gz(path):
+        return gzip.open(path, mode=mode)
+    else:
+        return open(path, mode=mode)
+
+def split_multifasta(input_fp, output_dir, write_reflist=False):
+    record_no = 0
+    filenames = []
+    with optionally_compressed_handle(str(input_fp), 'r') as fin:
+        for record in SeqIO.parse(fin, "fasta"):
+            genome_acc = record.id
+            single_fasta = "{}.fasta".format(genome_acc)
+            single_fasta_fp = output_dir.joinpath(single_fasta)
+            with open(single_fasta_fp, 'w') as fout:
+                record_no += SeqIO.write(record, fout, "fasta")
+            filenames.append(single_fasta_fp)
+
+            if record_no % 10000 == 0:
+                print("processed {} records".format(record_no))
+    if write_reflist:
+        reflist_txt = output_dir.parent.joinpath("reflist.txt")
+        with open(reflist_txt, 'w') as refout:
+            for f in filenames:
+                refout.write('{}\n'.format(f))
+
+    return record_no
+
+def main():
+    args = parse_args()
+    a = split_multifasta(pathlib.Path(args.input_fp),
+                         pathlib.Path(args.out_dir),
+                         args.write_reflist)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/workflow/scripts/subset_scores.py b/workflow/scripts/subset_scores.py
new file mode 100755
index 0000000000000000000000000000000000000000..e6cf4a05bb813848fe1a8b1bb8f48ddaf05e05a7
--- /dev/null
+++ b/workflow/scripts/subset_scores.py
@@ -0,0 +1,74 @@
+#!/usr/bin/env python
+
+import argparse
+from pathlib import Path
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Subset the scores to a given list of interactions')
+
+    optionalArgs = parser._action_groups.pop()
+
+    requiredArgs = parser.add_argument_group("required arguments")
+    requiredArgs.add_argument('-s', '--scores-file',
+                              dest='scores_file',
+                              type=lambda p: Path(p).resolve(strict=True),
+                              required=True,
+                              help="The output of calculate_all_scores.py"
+                              )
+    requiredArgs.add_argument('-i', '--input-ints',
+                              dest='input_ints',
+                              type=lambda p: Path(p).resolve(strict=True),
+                              required=True,
+                              help="A tsv file containing interaction mappings between ncbi and pvogs")
+    requiredArgs.add_argument('-o', '--output-file',
+                              dest='outfile',
+                              required=True,
+                              type=lambda p: Path(p).resolve(),
+                              help="File path to write the results in")
+
+    parser._action_groups.append(optionalArgs)
+
+    return parser.parse_args()
+
+def parse_pvogs_interactions(interactions_fp):
+    interactions = []
+    uniques = 0
+    all_in = 0
+    with open(interactions_fp, 'r') as fp:
+        for line in fp:
+            all_in += 1
+            fields = line.split('\t')
+            interaction = tuple(sorted([fields[1].strip(), fields[3].strip()]))
+            if interaction[0] == interaction[1]:
+                print("Skipping: {} (self)".format(interaction))
+                
+            elif interaction not in interactions:
+                interactions.append(interaction)
+                uniques +=1
+            else:
+                print("Skipping: {} (duplicate)".format(interaction))
+    print("Parsed {} / {} total input interactions".format(uniques, all_in))
+    return interactions
+
+def subset_all_scores(all_scores_fp, interactions_list, interactions_scores_fp):
+    counter = 0 
+    with open(all_scores_fp, 'r') as fin, open(interactions_scores_fp, 'w') as fout:
+        header = fin.readline()
+        fout.write(header)
+        for line in fin:
+            fields = line.split('\t')
+            interaction = tuple(sorted([fields[0].strip(), fields[1].strip()]))
+            if interaction in interactions_list:
+                fout.write(line)
+                counter += 1
+    return counter
+
+def main():
+    args = parse_args()
+    ints_list = parse_pvogs_interactions(args.input_ints)
+    total = subset_all_scores(args.scores_file, ints_list, args.outfile)
+    print("{} interaction scores written to file {}".format(total, str(args.outfile)))
+
+if __name__ == '__main__':
+    main()
+
diff --git a/workflow/scripts/summarize_intact.py b/workflow/scripts/summarize_intact.py
new file mode 100755
index 0000000000000000000000000000000000000000..ff2406790d365a016dfe2c5feaedcaeaa84f30b8
--- /dev/null
+++ b/workflow/scripts/summarize_intact.py
@@ -0,0 +1,426 @@
+#!/usr/bin/env python
+
+import argparse
+from ete3 import NCBITaxa, TextFace, TreeStyle, faces
+from collections import namedtuple
+import pandas as pd
+from Bio import SwissProt
+from pathlib import Path
+
+
+parser = argparse.ArgumentParser(description="A script that parses IntAct.txt (filtered for viruses) "
+                                             "to produce a visual summary of the taxa represented in it\n",
+                                formatter_class=argparse.RawTextHelpFormatter)
+optionalArgs = parser._action_groups.pop()
+
+requiredArgs = parser.add_argument_group("required arguments")
+requiredArgs.add_argument('-o', '--output_dir',
+                    dest='output_dir',
+                    type=str,
+                    required=True,
+                    help="The output directory where all resulting files will be stored. "
+                         "It is created if it's not there")
+requiredArgs.add_argument('-i', '--intact_in',
+                    dest='intact_in',
+                    help="A txt file containing the filtered IntAct data for viruses",
+                    required=True)
+optionalArgs.add_argument("--tax-db",
+                          dest="tax_db",
+                          required=False,
+                          help="Path to taxa.sqlite for ete3"
+                          )
+
+parser._action_groups.append(optionalArgs)
+
+
+## Functions for parsing IntAct fields
+
+def parse_field(field_string):
+    """
+    Args:
+        field_string (str) : A string of the form 'xref:value(desc)'.
+                            Desc is optional
+    """
+    field_data = {'xref': None}
+
+    parts = field_string.split(":")
+    xref = parts[0]
+    field_data['xref'] = {xref: {'value': None, 'desc': None}}
+    try:
+        rest = parts[1]
+        if "(" in rest:
+            subfields = rest.split("(")
+            field_data['xref'][xref]['value'] = subfields[0]
+            field_data['xref'][xref]['desc'] = subfields[1][:-1]
+        else:
+            field_data['xref'][xref]['value'] = rest
+            field_data['xref'][xref]['desc'] = '-'
+    except:
+        print(field_data)
+    return field_data
+
+
+def parse_column_string(astring):
+    """
+    Args:
+        astring (str): A full string after splitting the PSI-MITAB on tabs
+    Returns:
+        column_data (dict): A dictionary of the form {'xref':
+                                                        {xref:
+                                                        {'value': value,
+                                                        'desc': desc}
+                                                        }
+                                                     }
+    """
+    column_data = {}
+    if "|" in astring:
+        fields = astring.split("|")
+    else:
+        fields = [astring]  # LAAAAZY
+
+    for f in fields:
+        result = parse_field(f)
+        k = list(result['xref'].keys())[0]
+        if 'xref' not in column_data:
+            column_data['xref'] = {k: result['xref'][k]}
+        else:
+            column_data['xref'][k] = result['xref'][k]
+    return column_data
+
+
+def get_single_value_from_xref_dic(xref_dic):
+    """
+    If the given dictionary contains only one entry, get its value,
+    independent of key.
+
+    Args:
+        xref_dic (dict): A nested dictionary of the form {'xref':
+                                                            {xref:
+                                                            {'value': value,
+                                                             'desc': desc}
+                                                            }
+                                                         }
+    Returns:
+        value (str): The value of 'value' the innermost dict
+    """
+    if xref_dic:
+        if len(xref_dic['xref'].keys()) > 1:
+            return None
+        else:
+            k = next(iter(xref_dic['xref']))
+            value = xref_dic['xref'][k]['value']
+    else:
+        value = '-'
+    return value
+
+
+def parse_intact_data(int_file, skip_header=True):
+    ProtInfo = namedtuple('ProtInfo', ['id', 'srcdb', 'taxid'])
+    interactions_data = {}
+
+    with open(int_file, 'r') as f:
+        if skip_header:
+            next(f)
+        for line in f:
+            # Make some parsable fields
+            columns = [column.strip() for column in line.split('\t')]
+
+            # Get the proteinA info
+            protA_data = parse_column_string(columns[0])
+
+            protA_id = get_single_value_from_xref_dic(protA_data)
+            sourceA_db = list(protA_data['xref'].keys())[0]
+            taxA_data = parse_column_string(columns[9])
+            taxidA = int(taxA_data['xref']['taxid']['value'])
+            protAraw = ProtInfo(id=protA_id, srcdb=sourceA_db, taxid=taxidA)
+
+            # Get the proteinB info
+            protB_data = parse_column_string(columns[1])
+            protB_id = get_single_value_from_xref_dic(protB_data)
+            sourceB_db = list(protB_data['xref'].keys())[0]
+            taxB_data = parse_column_string(columns[10])
+            taxidB = int(taxB_data['xref']['taxid']['value'])
+            protBraw = ProtInfo(id=protB_id, srcdb=sourceB_db, taxid=taxidB)
+
+
+            # Sorting on interactors to make unique keys for the same interaction
+            # otherwise (A,B) != (B,A)
+            interaction = tuple(sorted((protA_id, protB_id,)))
+
+            if interaction[0] == protAraw.id:
+                protAinfo, protBinfo = protAraw, protBraw
+            else:
+                protAinfo, protBinfo = protBraw, protAraw
+
+            if protAinfo.taxid == protBinfo.taxid:
+                same_taxid = 1
+            else:
+                same_taxid = 0
+
+            if protAinfo.id == protBinfo.id:
+                same_protein = 1
+            else:
+                same_protein = 0
+
+            if str(interaction) not in interactions_data:
+                interactions_data[str(interaction)] = {'no_of_evidence': 1,
+                                                   'prot_A': protAinfo.id,
+                                                   'prot_B': protBinfo.id,
+                                                   'source_A': protAinfo.srcdb,
+                                                   'source_B': protBinfo.srcdb,
+                                                   'taxid_A': protAinfo.taxid,
+                                                   'taxid_B': protBinfo.taxid,
+                                                   'same_taxid': same_taxid,
+                                                   'same_protein': same_protein}
+
+            elif str(interaction) in interactions_data:
+                interactions_data[str(interaction)]['no_of_evidence'] += 1
+
+    return interactions_data
+
+
+
+
+
+def write_response_text_to_file(response_fn, response_text):
+    with open(response_fn, 'w') as fh:
+        fh.write(response_text)
+
+
+def sync_query_list_with_response(response_fn, query_list):
+    db_data = {}
+    with open(response_fn, 'r') as fh:
+        for record in SwissProt.parse(fh):
+            acc = record.accessions[0]
+            # Select only EMBL and RefSeq crossrefs
+            refseq_refs, embl_refs = [], []
+            for db_ref in record.cross_references:
+                if db_ref[0] == 'RefSeq':
+                    refseq_refs.append(db_ref[1:])
+                elif db_ref[0] == 'EMBL':
+                    embl_refs.append(db_ref[1:])
+            db_data[acc] = {'RefSeq': refseq_refs,
+                            'EMBL': embl_refs}
+
+    # This is to handle isoforms
+    # E.g. P03692 and P03692-1 can both be included in the query list
+    # P03705-2 can be in the query list but not P03705
+    for prot in query_list:  # For each of the original queries
+        if (prot not in db_data) and ('-' in prot):  # If the query is not returned
+            base_name = prot.split('-')[0]  # Search for the fist part
+            if base_name in db_data:  # If it is present in the db_data
+                if base_name not in query_list:  # If it's not in the original query list
+                    db_data[prot] = db_data[base_name]  # Fill in the information for the corresponding protein
+                    db_data.pop(base_name)  # AND remove the original part
+                elif base_name in query_list:  # If the first part is in the original query
+                    db_data[prot] = db_data[base_name]  # Fill in the information WITHOUT removing the original part
+        elif prot not in db_data:
+            print("I don't know what to do with this id: {}".format(prot))
+            pass
+
+    return db_data
+
+
+def uniprots_list_to_query(uniprots_list):
+    query_data = {}
+    for i in uniprots_list:
+        primary_acc = i.split('-')[0]
+        query_data[i] = primary_acc
+    return query_data
+
+
+def iso_is_present(k, values):
+    ind = None
+    for i, v in enumerate(values):
+        for m in v:
+            if '[{}]'.format(k) in m:
+                ind = int(i)
+    return ind
+
+
+def select_embl_info(k, values):
+    ind = None
+    for i, v in enumerate(values):
+        if 'ALT' in v[2]:
+            pass
+        else:
+            ind = int(i)
+    return ind
+
+
+def get_refseq_info(prot_id, refs):
+    if len(refs) == 1:
+        refseq_genome = refs[0][1].split()[0].strip('.')
+        refseq_protein = refs[0][0]
+    elif ('-' in k) and (len(refs) > 1):
+        isoform_index = iso_is_present(prot_id, refs)
+        if isoform_index is not None:
+            refseq_genome = refs[isoform_index][1].split()[0].strip('.')
+            refseq_protein = refs[isoform_index][0]
+        else:
+            refseq_genome = refs[0][1].split()[0].strip('.')
+            refseq_protein = refs[0][0]
+    else:
+        refseq_genome = refs[0][1].split()[0].strip('.')
+        refseq_protein = refs[0][0]
+    return refseq_genome, refseq_protein
+
+
+def get_embl_info(prot_id, refs):
+    if len(refs) == 1:
+        embl_genome = refs[0][0]
+        embl_protein = refs[0][1]
+    elif len(refs) > 1:
+        info_index = select_embl_info(prot_id, refs)
+        if info_index is not None:
+            embl_genome = refs[info_index][0]
+            embl_protein = refs[info_index][1]
+        else:
+            embl_genome = refs[0][0]
+            embl_protein = refs[0][1]
+    else:
+        embl_genome = refs[0][0]
+        embl_protein = refs[0][1]
+    return embl_genome, embl_protein
+
+
+def map_proteins_to_genomic_accessions(db_data):
+    accessions_map = {}
+    for k in db_data:
+        if db_data[k]['RefSeq']:
+            refs = db_data[k]['RefSeq']
+            (genome_id, protein_id) = get_refseq_info(k, refs)
+            accessions_map[k] = {'genome_id': genome_id,
+                                 'protein_id': protein_id}
+        elif not db_data[k]['RefSeq'] and db_data[k]['EMBL']:
+            refs = db_data[k]['EMBL']
+            (genome_id, protein_id) = get_embl_info(k, refs)
+            accessions_map[k] = {'genome_id': genome_id,
+                                 'protein_id': protein_id}
+    return accessions_map
+
+# ETE plotting options and functions
+def layout(node):
+    faces.add_face_to_node(TextFace(node.sci_name), node, 0)
+
+
+def plot_taxids(taxids_list, tree_png, tree_nw, tax_db=None):
+    if tax_db is not None:
+        ncbi = NCBITaxa(dbfile=tax_db)
+    else:
+        ncbi=NCBITaxa()
+
+    tree = ncbi.get_topology(taxids_list)
+    ts = TreeStyle()
+    ncbi.annotate_tree(tree, taxid_attr="sci_name")
+    ts.show_leaf_name = False
+    ts.mode = "c"
+    ts.layout_fn = layout
+    tree.render(tree_png, tree_style=ts)
+    tree.write(format=1, outfile=tree_nw)
+
+
+def write_sequence_dic_to_file(accession_map, fileout):
+    header_string = '\t'.join(['uniprot_id', 'genome_id', 'protein_id'])
+    with open(fileout, 'w') as f:
+        f.write(header_string+'\n')
+        for acc in accession_map:
+            info_string = '\t'.join([acc, accession_map[acc]['genome_id'], accession_map[acc]['protein_id']])
+            f.write(info_string + '\n')
+
+
+if __name__ == '__main__':
+    args = parser.parse_args()
+
+    # Create the output directory if it doesn't exist
+    output_dir = Path(args.output_dir)
+    if not output_dir.exists():
+        output_dir.mkdir()
+    else:
+        print("Output directory {} already exists! Contents will be overwritten".
+              format(args.output_dir))
+
+    plot_dir = output_dir.joinpath('plots')
+    plot_dir.mkdir(exist_ok=True)
+    tree_png = plot_dir.joinpath('taxa_tree.png')
+    tree_nw = plot_dir.joinpath('taxa_tree.nw')
+
+    # Get the intact file
+    intact_raw = args.intact_in
+    intact_data = parse_intact_data(intact_raw)
+
+    # Calculate the number of initial entries
+    no_of_entries = 0
+    for k in intact_data.keys():
+        no_of_entries += intact_data[k]['no_of_evidence']
+
+    print("No. of entries in input: {}".format(no_of_entries))
+    print("No. of interactions: {}".format(len(intact_data.keys())))
+
+    # Put the data in a data frame
+    df = pd.DataFrame.from_dict(intact_data, orient='index')
+    df.index.name = 'interaction'
+    metadata_tsv = output_dir.joinpath('metadata.tsv')
+    df.to_csv(metadata_tsv, sep='\t')
+
+    # Interactors information
+    unique_protAs = set(df['prot_A'])
+    unique_protBs = set(df['prot_B'])
+    unique_prots = len(set.union(unique_protAs, unique_protBs))
+    print("No. of unique proteins: {}".format(unique_prots))
+
+    # Taxids information
+    unique_taxidsA = set(df['taxid_A'])
+    unique_taxidsB = set(df['taxid_B'])
+    unique_taxids = set.union(unique_taxidsA, unique_taxidsB)
+    print("No. of unique taxids: {}".format(len(unique_taxids)))
+
+    print("Plotting NCBI taxonomy information. Results are stored in: {}".format(plot_dir))
+    plot_taxids(unique_taxids, tree_png.as_posix(), tree_nw.as_posix(), args.tax_db)
+
+    ######################################
+    ### UP TO HERE EVERYTHING IS RAW
+    ### NOW WE START SUBSETTING
+    #######################################
+    # Getting the uniprot related entries to dispatch a query
+    # protAs = df.loc[(df['source_A'] == 'uniprotkb')]['prot_A'].unique().tolist()
+    # protBs = df.loc[(df['source_B'] == 'uniprotkb')]['prot_B'].unique().tolist()
+    # uniprots = set.union(set(protAs), set(protBs))
+    # uniprots_list = list(uniprots)
+    #
+    # print("{} identifiers will be searched against uniprot for genome and protein sequence retrieval".
+    #       format(len(uniprots_list)))
+    #
+    # uniprot_response = output_dir.joinpath('uniprot_info.txt')
+    # print("Dispatching queries...")
+    # final_result = []
+    # # Split the query into batches of 100
+    # for i in range(0, len(uniprots_list), 100):
+    #     slise = uniprots_list[i:i + 100]
+    #     # Run the query for each batch
+    #     result = map_retrieve(slise)
+    #     # And store the result in a list
+    #     final_result.append(result)
+    #
+    # # Store the result in a file
+    # response_text = ''.join(final_result)
+    # write_response_text_to_file(uniprot_response, response_text)
+    #
+    # db_data = sync_query_list_with_response(uniprot_response, uniprots_list)
+    # # For each original entry in the protein list,
+    # # store its primary accession
+    # # TO DO
+    # # the sync_query_list_with_response function needs revisiting
+    # query_data = uniprots_list_to_query(uniprots_list)
+    #
+    # accession_map = map_proteins_to_genomic_accessions(db_data)
+    # for protein in query_data:  # {alt_acc : primary_acc}
+    #     if protein not in accession_map:  # alt_acc not in results
+    #         try:
+    #             primary_acc = query_data[protein]  # get the primary accession
+    #             accession_map[protein] = accession_map[primary_acc]
+    #         except KeyError:
+    #             accession_map[protein] = {'genome_id': '-', 'protein_id': '-'}
+    #
+    # sequence_info = output_dir.joinpath('sequence_info.txt')
+    # write_sequence_dic_to_file(accession_map, sequence_info)