diff --git a/workflow/Snakefile b/workflow/Snakefile new file mode 100644 index 0000000000000000000000000000000000000000..3212bff6ce99abe8d19cc9e40e732b8e43eec4a6 --- /dev/null +++ b/workflow/Snakefile @@ -0,0 +1,102 @@ +configfile: "config/config.yaml" + +# This defines the number of negatives +NEGATIVES = config["negatives"] + +# This defines the names of the negatives +DATASETS = ['N{}'.format(i+1) for i in range(0,NEGATIVES)] +# Append the 'positives'. Now a rule can expand on all datasets +DATASETS.append('positives') + + +include: "rules/get_data.smk" +include: "rules/pre_process.smk" +include: "rules/construct_datasets.smk" + + +rule all: + input: + # Raw data required + # Produced by get_data.smk + "data/genomes/phages_refseq.fasta", + "data/interactions/intact.txt", + "data/pvogs/all.hmm", + "data/pvogs/VOGProteinTable.txt", + "data/taxonomy_db/taxa.sqlite", + "data/taxonomy_db/taxa.sqlite.traverse.pkl", + + # Pre-processing for calculating matrices + # Produced by pre_process.smk + "results/scores.tsv", + "results/filtered_scores.tsv", + "results/annotations.tsv", + + # Getting to pvogs from proteins and calcualating features + # Produced by pre_process.smk + expand(["results/interaction_datasets/{dataset}/{dataset}.interactions.tsv", + "results/interaction_datasets/{dataset}/{dataset}.proteins.faa", + "results/interaction_datasets/{dataset}/{dataset}.pvogs_interactions.tsv", + "results/interaction_datasets/{dataset}/{dataset}.features.tsv"], + dataset=DATASETS), + + # Major results + # Produced by this file. + "results/RF/best_model.pkl", + "results/RF/best_model_id.txt", + "results/predictions.tsv", + "results/final_training_set.tsv" + + + +checkpoint random_forest: + input: + expand("results/interaction_datasets/{dataset}/{dataset}.features.tsv", dataset=DATASETS), + filtered_master_tsv = rules.filter_scores_table.output.filtered_master_tsv + output: + best_model = "results/RF/best_model.pkl", + best_model_id = "results/RF/best_model_id.txt" # This only contains the name as a string... + log: + "results/logs/processed_notebook.py.ipynb" + conda: + "envs/pvogs_jupy.yml" + threads: 16 + notebook: + "notebooks/data_processing.py.ipynb" + + +def get_best_model_id(wildcards): + """ + Helper function to get the id of the dataset that gave the + best results from the notebook search. + """ + checkpoint_output = checkpoints.random_forest.get(**wildcards).output[1] + with open(checkpoint_output, 'r') as fin: + best_model_id = fin.read().strip() + return best_model_id + +rule predict: + input: + positives_features_tsv = "results/interaction_datasets/positives/positives.features.tsv", + model_fp = "results/RF/best_model.pkl", + filtered_scores_tsv = rules.filter_scores_table.output.filtered_master_tsv + output: + predictions_tsv = "results/predictions.tsv", + final_training_set_tsv = "results/final_training_set.tsv" + log: + "results/logs/predict.log" + conda: + "envs/pvogs_jupy.yml" + params: + # This is read from the notebook output + dataset_string = get_best_model_id, + threads: 16 + shell: + "python workflow/scripts/predict.py " + "-j {threads} " + "-m {input.model_fp} " + "-p {input.positives_features_tsv} " + "-n results/interaction_datasets/{params.dataset_string}/{params.dataset_string}.features.tsv " + "-t {input.filtered_scores_tsv} " + "-o {output.predictions_tsv} " + "&>{log}" + diff --git a/workflow/envs/pvogs.yml b/workflow/envs/pvogs.yml new file mode 100644 index 0000000000000000000000000000000000000000..04a6e5aef80c2aae205054dd08aa902ae23e5479 --- /dev/null +++ b/workflow/envs/pvogs.yml @@ -0,0 +1,12 @@ +channels: + - bioconda + - conda-forge +dependencies: + - emboss==6.6.0.0 + - biopython==1.77 + - fastani==1.3 + - comparem==0.1.1 + - pandas==1.0.3 + - ete3==3.1.1 + - requests==2.23.0 + - hmmer==3.2.1 diff --git a/workflow/envs/pvogs_jupy.yml b/workflow/envs/pvogs_jupy.yml new file mode 100644 index 0000000000000000000000000000000000000000..805b9c75f5c715c5e6a555beb5112b0adad9633f --- /dev/null +++ b/workflow/envs/pvogs_jupy.yml @@ -0,0 +1,11 @@ +channels: + - bioconda + - conda-forge +dependencies: + - python==3.7 + - jupyter + - scipy==1.4.1 + - scikit-learn==0.21.3 + - seaborn==0.10.1 + - pandas>=1.0* + diff --git a/workflow/notebooks/data_processing.py.ipynb b/workflow/notebooks/data_processing.py.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..6cc3cf7311fac6232197d0d63f634a0035f06d60 --- /dev/null +++ b/workflow/notebooks/data_processing.py.ipynb @@ -0,0 +1,1228 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "# ML\n", + "from sklearn.model_selection import train_test_split, RandomizedSearchCV\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn import metrics\n", + "\n", + "from scipy.cluster.hierarchy import dendrogram, linkage\n", + "from scipy.spatial.distance import squareform, pdist\n", + "\n", + "# Data handling\n", + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "from collections import Counter\n", + "import operator\n", + "\n", + "# Plotting\n", + "%matplotlib inline\n", + "# import matplotlib.patches as mpatches\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "# Read and write\n", + "from pathlib import Path # filesystem\n", + "\n", + "## sklearn is giving me\n", + "## sklearn/model_selection/_search.py:814: DeprecationWarning: \n", + "## The default of the `iid` parameter will change from True to False in version 0.22 \n", + "## and will be removed in 0.24. This will change numeric results when test-set sizes are unequal.\n", + "## DeprecationWarning)\n", + "\n", + "# Suppress that\n", + "import warnings\n", + "warnings.filterwarnings('ignore', category=DeprecationWarning)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "# Initial setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "## GLOBAL VARIABLES\n", + "\n", + "These are variables that are used throughout the notebook" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "# Set this to False if you don't want to scale the feature values\n", + "SCALE_DATA = True" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "# Only change these if you know what you are doing\n", + "\n", + "# A dictionary that maps original feature names to final ones\n", + "FEATURES_MAP = dict(zip(['jaccard_score', 'same_score', \n", + " 'inwards_score', 'outwards_score', \n", + " 'avg_distance', 'mean_ani', \n", + " 'mean_aai'],\n", + " ['Co-occurrence', 'Co-orientation',\n", + " 'Convergent', 'Divergent',\n", + " 'Average Distance', 'Mean ANI',\n", + " 'Mean AAI']\n", + " )\n", + " )\n", + "\n", + "# A list to only get the features\n", + "# Useful for slicing data frames\n", + "FEATURES = ['jaccard_score', 'same_score', \n", + " 'inwards_score', 'outwards_score', \n", + " 'avg_distance', 'mean_ani', \n", + " 'mean_aai']\n", + " \n", + "# Features with label appended\n", + "# useful for plotting based on label\n", + "FEAT_LAB = FEATURES.copy()\n", + "FEAT_LAB.append('label')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "## HELPER FUNCTIONS" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "* Reading data\n", + "\n", + "Some of these functions were written **before** filtering for max distance, so they try to address this issue." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "def read_scores_table(scores_fp, label=None):\n", + " \"\"\"\n", + " Read a table from a tsv file provided as a path.\n", + " Return a dataframe.\n", + " \n", + " If label is set, append a column named `label` with\n", + " the provided value\n", + " \"\"\"\n", + " scores_df = pd.read_csv(scores_fp, sep='\\t')\n", + " scores_df['interaction'] = scores_df['pvog1'] + '-' + scores_df['pvog2']\n", + " # Unpack features list because I want interaction prepended.\n", + " scores_df = scores_df[['interaction', *FEATURES]] \n", + " if label is not None:\n", + " scores_df['label'] = label\n", + " return scores_df\n", + "\n", + "\n", + "def concat_data_frames(pos_df, \n", + " neg_df, \n", + " subsample=False,\n", + " clean=True,\n", + " is_scaled=True,\n", + " ):\n", + " \"\"\"\n", + " Concatenate two dataframes.\n", + " \n", + " subsample:bool \n", + " Subsample the `neg_df` to a number of\n", + " observations equal to the number of pos_df \n", + " (balance datasets)\n", + " \n", + " clean:bool\n", + " Remove observations with a value of `avg_distance` == 100000 or 1.\n", + " \n", + " is_scaled:bool\n", + " The features have been scaled to a range 0-1\n", + " \n", + " Return:\n", + " concat_df: pd.DataFrame\n", + " The concatenated data frame\n", + " \"\"\"\n", + " \n", + " if clean is True:\n", + " pos_df = remove_ambiguous(pos_df)\n", + " neg_df = remove_ambiguous(neg_df)\n", + " \n", + " n_positives = pos_df.shape[0]\n", + " n_negatives = neg_df.shape[0]\n", + " \n", + " # Remove possible duplicate interactions from the negatives\n", + " # This might happen because of the random selection when creating the set\n", + " # Why I also select more negatives to begin with\n", + " neg_df = neg_df.loc[~neg_df['interaction'].isin(pos_df['interaction'])]\n", + " \n", + " if (n_positives != n_negatives) and (subsample is True):\n", + " neg_df = neg_df.sample(n=n_positives, random_state=1)\n", + " concat_df = pd.concat([pos_df, neg_df])\n", + " \n", + " assert concat_df[concat_df.duplicated(subset=['interaction'])].empty == True, concat_df.loc[concat_df.duplicated(subset=['interaction'], keep=False)]\n", + " \n", + " return concat_df\n", + "\n", + "def scale_df(input_df):\n", + " \"\"\"\n", + " Scale all feature values in the data frame to [0-1].\n", + " \"\"\"\n", + " maxes = input_df[FEATURES].max(axis=0)\n", + " scaled_data = input_df[FEATURES].divide(maxes)\n", + " if 'label' in input_df.columns:\n", + " scaled_df = pd.concat([input_df['interaction'], scaled_data, input_df['label']], axis=1)\n", + " else:\n", + " scaled_df = pd.concat([input_df['interaction'], scaled_data], axis=1)\n", + " return scaled_df\n", + "\n", + "def remove_ambiguous(input_df):\n", + " \"\"\"\n", + " Select observations in the `input_df` that have feature values\n", + " \"\"\"\n", + " df_clean = input_df[(input_df.jaccard_score != 0)] # This is true if they don't co-occur\n", + " return df_clean\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "# Analysis\n", + "\n", + "From this point on all the magic happens" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "# Get all the files in a list of Path objects\n", + "filepaths = list(map(Path, snakemake.input))\n", + "# Remove the last element\n", + "filtered_scores_tsv = filepaths.pop(-1)\n", + "negatives_fp_list = []\n", + "for fp in filepaths:\n", + " # Grab the positives\n", + " if 'positives' in fp.name:\n", + " positives_tsv = fp\n", + " else:\n", + " # Append to the list of negatives\n", + " negatives_fp_list.append(fp)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "# Positives are always the same\n", + "pos_df = read_scores_table(positives_tsv, label = 1)\n", + "\n", + "# Negatives are stored in the list above\n", + "# Select one for visualization\n", + "neg_df = read_scores_table(negatives_fp_list[3], label = 0)\n", + "if SCALE_DATA is True:\n", + " pos_df = scale_df(pos_df)\n", + " neg_df = scale_df(neg_df)\n", + " \n", + "posneg = concat_data_frames(pos_df, neg_df,\n", + " clean=True,\n", + " subsample=True,\n", + " is_scaled=SCALE_DATA)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_features_correlation(data_df):\n", + " # Subset to features only\n", + " cor_data = data_df[FEATURES]\n", + " cor_df = cor_data.corr()\n", + " \n", + " plt.figure(figsize=(10, 8))\n", + " sns.heatmap(cor_df, annot=True, \n", + " cmap=sns.color_palette(\"RdBu_r\"), \n", + " vmin=-1, \n", + " vmax=1, \n", + " cbar_kws = {\"shrink\" : 0.5, \n", + " \"ticks\" : [-1, -0.5, 0, 0.5, 1],\n", + " }\n", + " ) \n", + " plt.show()\n", + "\n", + "## THIS IS NO LONGER USED in favor of the pairplot\n", + "# you can call it with\n", + "# plot_features_dist(posneg, FEATURES, is_scaled = SCALE_DATA)\n", + "def plot_features_dist(data, features, is_scaled=False):\n", + " \"\"\"\n", + " Make a gridplot of all features distributions\n", + " \n", + " data: A pd.DataFrame with features and a label column\n", + " features: A list of features to use\n", + " \"\"\"\n", + " if is_scaled:\n", + " distance_threshold = 1\n", + " else:\n", + " distance_threshold = 1000000\n", + " \n", + " df_list=[data.loc[:,[f, 'label']] for f in features]\n", + " # Initialize a figure\n", + " fig, axes = plt.subplots(3, 3, figsize = (8,4), dpi=300)\n", + " # Flatten the axes for plotting in the correct ax object\n", + " ax_list = axes.flatten()\n", + "\n", + " # Workaround to add the legend\n", + " red_patch = mpatches.Patch(color='red', label='positives', alpha=0.5)\n", + " blue_patch = mpatches.Patch(color='b', label='negatives', alpha=0.5)\n", + " plt.figlegend(handles=[red_patch, blue_patch], loc='lower center')\n", + " # Make the plots\n", + " for i in range(len(df_list)):\n", + " df=df_list[i]\n", + " if features[i] == \"avg_distance\":\n", + " x0 = df.loc[((df.label==0) & (df.avg_distance != distance_threshold)), features[i]]\n", + " x1 = df.loc[((df.label==1) & (df.avg_distance != distance_threshold)), features[i]]\n", + " else:\n", + " x0 = df.loc[df.label==0, features[i]] \n", + " x1 = df.loc[df.label==1, features[i]]\n", + "# a1 = sns.distplot(x0, ax=axes[coords[0], coords[1]], axlabel=features[i],color='b', label='negatives')\n", + "# a2 =sns.distplot(x1, ax=axes[coords[0], coords[1]], axlabel=features[i], color='r', label='positives')\n", + " ax = ax_list[i]\n", + " ax.hist(x0, color = 'b', alpha=.5)\n", + " ax.hist(x1, color = 'r', alpha= .5)\n", + " ax.set_title(\"{} (N={}/{})\".format(features[i], \n", + " x0.shape[0], \n", + " x1.shape[0]), \n", + " fontsize=8\n", + " )\n", + "\n", + " plt.subplots_adjust(hspace=0.7)\n", + "\n", + " axes[2,2].remove()\n", + " axes[2,1].remove() " + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "pair = sns.pairplot(posneg[FEAT_LAB],\n", + " hue = 'label', \n", + " vars=FEATURES,\n", + " markers = [\"+\", \"x\"],\n", + " palette = {\n", + " 0 : 'b',\n", + " 1 : 'r',\n", + " },\n", + "# diag_kind=\"hist\",\n", + " plot_kws = {\n", + " \"alpha\": 0.65,\n", + " },\n", + " diag_kws={\n", + " \"clip\" : [0. , 1.],\n", + " },\n", + " )\n", + "pair.set(xlim=(-0.1, 1.1))" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "plot_features_correlation(posneg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "## Cluster interactions" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "# Make a copy of the original df for plotting\n", + "d_posneg = posneg\n", + "\n", + "D_posneg = pdist(d_posneg[FEATURES])\n", + "D_posneg = D_posneg / D_posneg.max() # Scale the distances to be [0-1]\n", + "Z_posneg = linkage(D_posneg, method=\"average\")" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "# Append a color column based on the label value\n", + "# https://stackoverflow.com/a/26887820\n", + "\n", + "def color_label_interaction(row):\n", + " if row['label'] == 1:\n", + " return 'red'\n", + " else:\n", + " return 'blue'\n", + "\n", + "d_posneg['color'] = d_posneg.apply(lambda row: color_label_interaction(row), axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(20, 30))\n", + "dn = dendrogram(Z_posneg,\n", + " labels=d_posneg['interaction'].values, \n", + " orientation='left',\n", + " leaf_font_size=12,\n", + " color_threshold=0.1,\n", + " above_threshold_color='black',\n", + " )\n", + "plt.axvline(x=0.1, linestyle='--')\n", + "plt.title(\"Average linkage based on Euclidean distances\", size=14)\n", + "# plt.tick_params(axis='y', labelsize=12, direction='out')\n", + "ax=plt.gca()\n", + "ylbls = ax.get_ymajorticklabels()\n", + "for lbl in ylbls:\n", + " lbl.set_color(d_posneg.loc[d_posneg.interaction == lbl.get_text(), 'color'].values[0])\n", + "plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "### Distribution of distances pos vs neg" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "D_pos = pdist(pos_df[FEATURES])\n", + "D_neg = pdist(neg_df[FEATURES])" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(10,8))\n", + "\n", + "ax = sns.distplot(D_pos ,\n", + " hist=False,\n", + " kde_kws = {\"shade\" : True,\n", + " \"alpha\" : 0.5,\n", + " \"clip\" : [0. , 1.],\n", + " \"label\": 'Positives'\n", + " \n", + " },\n", + " color='r')\n", + "ax = sns.distplot(D_neg,\n", + " color='b',\n", + " hist=False,\n", + " kde_kws = {\"shade\" : True,\n", + " \"alpha\": 0.5,\n", + " \"clip\" : [0. , 1.],\n", + " \"label\": \"Negatives\"}\n", + " )\n", + "ax.set_title(\"Distances distribution\", size=14)\n", + "ax.set_xlabel(\"Distance\", size=14)\n", + "ax.set_ylabel(\"Density\", size=14)\n", + "ax.tick_params(axis='both', which='major', labelsize=12)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "# Random Forest" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "## Model selection and hyperparameter tuning" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "Define a search space for the parameters to be sampled from. \n", + "Less exhaustive than a full grid search, but quicker." + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "# Number of trees\n", + "n_estimators = [int(x) for x in range(100,5000,200)]\n", + "\n", + "# Number of features to consider at every split\n", + "# 'auto' uses sqrt(n_features), a float uses int(float * n_features)\n", + "max_features = ['auto', .5, 1.]\n", + "\n", + "# Maximum number of levels in each tree\n", + "max_depth = [int(x) for x in range(10, 60, 10)]\n", + "max_depth.append(None)\n", + "\n", + "# Minimum number of samples required to split a node\n", + "min_samples_split = [2, 5, 10]\n", + "\n", + "# Minimum number of samples required at each leaf node\n", + "min_samples_leaf = [1, 2, 4]\n", + "\n", + "# Method for selecting samples for training\n", + "bootstrap = [True, False]\n", + "\n", + "# Put them all in a dictionary\n", + "params_space = {\n", + " \"max_features\": max_features,\n", + " \"n_estimators\" : n_estimators,\n", + " \"max_depth\" : max_depth,\n", + " \"min_samples_split\" : min_samples_split,\n", + " \"min_samples_leaf\" : min_samples_leaf,\n", + " \"bootstrap\": bootstrap,\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a color mapping for the negative sets\n", + "\n", + "negative_sets = []\n", + "for negative in negatives_fp_list:\n", + " set_name = negative.name.split('.')[0]\n", + " negative_sets.append(set_name)\n", + " \n", + "color_palette = sns.color_palette(\"muted\", len(negative_sets))\n", + "\n", + "color_dict = dict(zip(negative_sets, color_palette))" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "def get_metric_value(metric_string, \n", + " truth_array, \n", + " prediction_array, \n", + " prediction_proba_array):\n", + " \"\"\"\n", + " Helper function to get metric values based on a string representation.\n", + " \"\"\"\n", + " if metric_string == 'accuracy':\n", + " return metrics.accuracy_score(truth_array, prediction_array)\n", + " if metric_string == 'accuracy_abs':\n", + " return metrics.accuracy_score(truth_array, prediction_array, normalize = False)\n", + " if metric_string == 'precision':\n", + " return metrics.precision_score(truth_array, prediction_array)\n", + " if metric_string == 'recall':\n", + " return metrics.recall_score(truth_array, prediction_array)\n", + " if metric_string == 'f1_score':\n", + " return metrics.f1_score(truth_array, prediction_array)\n", + " if metric_string == 'roc_auc_score':\n", + " return metrics.roc_auc_score(truth_array, prediction_proba_array[:,1])\n", + " if metric_string == 'fpr':\n", + " return metrics.roc_curve(truth_array, prediction_proba_array[:,1])[0]\n", + " if metric_string == 'tpr':\n", + " return metrics.roc_curve(truth_array, prediction_proba_array[:,1])[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "# A list of all the metric we are storing\n", + "_metrics = [\n", + " 'accuracy', \n", + " 'accuracy_abs', \n", + " 'f1_score', \n", + " 'precision', \n", + " 'recall', \n", + " 'roc_auc_score', \n", + " 'fpr', \n", + " 'tpr'\n", + " ]\n", + "# A dictionary that will hold all the values\n", + "all_metrics = {}\n", + "# A dictionary that holds the parameters for the best model\n", + "# after each iteration\n", + "best_models = {}\n", + "\n", + "# A dictionary that holds feature importances\n", + "# for each iteration\n", + "feat_importances = {}\n", + "\n", + "for i, negative in enumerate(negatives_fp_list):\n", + " \n", + " # Make a copy of the list with the negative_files\n", + " copy_of_list = negatives_fp_list.copy() \n", + " # Get the name of the current negative\n", + " training_set = negative.name.split('.')[0]\n", + " # Read the current negative data in a data frame\n", + " neg_df = read_scores_table(negative, label=0)\n", + " # ... scale it\n", + " neg = scale_df(neg_df)\n", + " # Concatenate with the positive\n", + " # A ground truth set is constructed\n", + " posneg = concat_data_frames(pos_df, \n", + " neg, \n", + " subsample=True, \n", + " clean=True, \n", + " is_scaled=True)\n", + " \n", + " # Split the frame to features and the label\n", + " X = posneg[FEATURES]\n", + " y = posneg['label']\n", + " \n", + " # Train test split\n", + " X_train, X_holdout, y_train, y_holdout = train_test_split(X,\n", + " y,\n", + " test_size=0.3,\n", + " random_state=1)\n", + " \n", + " # Instantiate the Randomized Search\n", + " models = RandomizedSearchCV(estimator = RandomForestClassifier(random_state=1),\n", + " param_distributions = params_space,\n", + " n_iter = 500, # number of parameter settings to sample\n", + " cv = 5, # Use 5-folds for CV\n", + " random_state = 1, # Set random state for reproducibility\n", + " n_jobs = snakemake.threads,\n", + "# verbose=1 # Print some messages to keep track of progress\n", + " )\n", + " # Execute the search\n", + " search = models.fit(X_train, y_train)\n", + " \n", + " # Make a classifier out of the best model\n", + " best_model = RandomForestClassifier(**search.best_params_, \n", + " random_state=1)\n", + " \n", + " # Fit the best model on the training data\n", + " best_model.fit(X_train, y_train)\n", + " \n", + " # Predict labels and probabilities for the test/holdout set\n", + " holdout_pred = best_model.predict(X_holdout)\n", + " holdout_proba = best_model.predict_proba(X_holdout)\n", + " \n", + " ###############\n", + " # Store results\n", + " ###############\n", + " # Create a tuple of the strings of the training set\n", + " # as a key for the all_metrics dict\n", + " self_key = (training_set, training_set)\n", + " \n", + " # The get_metric_value is defined above\n", + " all_metrics[self_key] = {metric : get_metric_value(metric,\n", + " y_holdout,\n", + " holdout_pred,\n", + " holdout_proba)\n", + " for metric in _metrics}\n", + " \n", + " best_models[training_set] = best_model\n", + " \n", + " importances = best_model.feature_importances_ \n", + " feat_importances[self_key] = dict(zip(FEATURES, importances))\n", + " \n", + " # Remove the current negative set file from the copy of the list\n", + " # The rest will be used for validation\n", + " validation_sets = [i for i in copy_of_list if i != negative]\n", + "\n", + " for validation_set in validation_sets:\n", + " # Repeat the same process as above, but without optimizing\n", + " validation_set_name = validation_set.name.split('.')[0]\n", + " negg = read_scores_table(validation_set, label=0)\n", + " scaled_negg = scale_df(negg)\n", + " \n", + " pn = concat_data_frames(pos_df, \n", + " scaled_negg, \n", + " subsample=True, \n", + " clean=True, \n", + " is_scaled=True)\n", + " XX = pn[FEATURES]\n", + " yy = pn['label']\n", + " \n", + " XX_train, XX_holdout, yy_train, yy_holdout = train_test_split(XX, yy, test_size=0.3, random_state=1)\n", + " best_model.fit(XX_train, yy_train)\n", + " holdout_pred = best_model.predict(XX_holdout)\n", + " holdout_proba = best_model.predict_proba(XX_holdout)\n", + " \n", + " \n", + " combo_key = (training_set, validation_set_name)\n", + " all_metrics[combo_key] = {metric : get_metric_value(metric, yy_holdout, holdout_pred, holdout_proba) \n", + " for metric in _metrics}\n", + " \n", + " importances = best_model.feature_importances_ \n", + " feat_importances[combo_key] = dict(zip(FEATURES, importances))\n", + " \n", + " print(\"Finished set : {}\".format(i))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "## Save results" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "# Results are stored in the results/RF directory\n", + "# Create it, or do nothing\n", + "RF_dir = Path(\"results/RF\")\n", + "RF_dir.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "## metrics\n", + "metrics_df = pd.DataFrame.from_dict(all_metrics, orient='index', columns=_metrics)\n", + "\n", + "# Append columns TS (Training Set) and VS (Validation Set)\n", + "# For slicing\n", + "metrics_df['TS'] = [i[0] for i in metrics_df.index.values]\n", + "metrics_df['VS'] = [i[1] for i in metrics_df.index.values]\n", + "\n", + "# Pickle is used to store the fpr, tpr values as arrays\n", + "# otherwise it messes up the column formatting\n", + "metrics_fp = RF_dir / Path('metrics.pkl')\n", + "metrics_tsv = RF_dir / Path('metrics.tsv')\n", + "metrics_df.to_pickle(metrics_fp)\n", + "## Drop the fpr and tpr to write to tsv\n", + "metrics_df.drop(columns=['fpr', \n", + " 'tpr']).to_csv(metrics_tsv, \n", + " sep='\\t', \n", + " index=False)\n", + "\n", + "## features\n", + "features_df = pd.DataFrame.from_dict(feat_importances,\n", + " orient = 'index',\n", + " columns = FEATURES)\n", + "features_df['TS'] = [i[0] for i in features_df.index.values]\n", + "features_df['VS'] = [i[1] for i in features_df.index.values]\n", + "features_fp = RF_dir / Path(\"features.tsv\")\n", + "\n", + "features_df.to_csv(features_fp, sep='\\t', index=False)\n", + "\n", + "\n", + "## models\n", + "models_dir = RF_dir / Path(\"models\")\n", + "models_dir.mkdir(exist_ok=True)\n", + "for dataset in best_models:\n", + " pkl_fname = Path('{}.RF.pkl'.format(dataset))\n", + " pkl_fpath = models_dir.joinpath(pkl_fname)\n", + " with open(pkl_fpath, 'wb') as fout:\n", + " pickle.dump(best_models[dataset], fout)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "## Best model across the board" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "# Calculate min, max, mean, std for all metrics across datasets\n", + "## this probably is not the most pythonic way of doing this\n", + "\n", + "metrics_stats = {}\n", + "for i, negative in enumerate(negatives_fp_list):\n", + " dataset_name = negative.name.split('.')[0]\n", + " # Get the stats for a particular dataset and trnaspose it\n", + " set_df = metrics_df.loc[metrics_df[\"TS\"] == dataset_name, _metrics].describe().T\n", + " \n", + " metrics_stats[dataset_name] = {'accuracy_mean' : set_df.loc['accuracy', 'mean'],\n", + " 'accuracy_std' : set_df.loc['accuracy', 'std'],\n", + " 'accuracy_min' : set_df.loc['accuracy', 'min'],\n", + " 'accuracy_max' : set_df.loc['accuracy', 'max'],\n", + " 'precision_mean' : set_df.loc['precision', 'mean'],\n", + " 'precision_std' : set_df.loc['precision', 'std'],\n", + " 'precision_min' : set_df.loc['precision', 'min'],\n", + " 'precision_max' : set_df.loc['precision', 'max'],\n", + " 'recall_mean' : set_df.loc['recall', 'mean'],\n", + " 'recall_std' : set_df.loc['recall', 'std'],\n", + " 'recall_min' : set_df.loc['recall', 'min'],\n", + " 'recall_max' : set_df.loc['recall', 'max'],\n", + " 'f1_score_mean': set_df.loc['f1_score', 'mean'],\n", + " 'f1_score_std': set_df.loc['f1_score', 'std'],\n", + " 'f1_score_min' : set_df.loc['f1_score', 'min'],\n", + " 'f1_score_max' : set_df.loc['f1_score', 'max'],\n", + " 'roc_auc_score_mean': set_df.loc['roc_auc_score', 'mean'],\n", + " 'roc_auc_score_std': set_df.loc['roc_auc_score', 'std'],\n", + " 'roc_auc_score_min' : set_df.loc['roc_auc_score', 'min'],\n", + " 'roc_auc_score_max' : set_df.loc['roc_auc_score', 'max'],\n", + " 'accuracy_abs_mean' : set_df.loc['accuracy_abs', 'mean'],\n", + " 'accuracy_abs_std' : set_df.loc['accuracy_abs', 'std'],\n", + " 'accuracy_abs_min' : set_df.loc['accuracy_abs', 'min'],\n", + " 'accuracy_abs_max' : set_df.loc['accuracy_abs', 'max']\n", + " }\n", + "\n", + "metrics_stats_df = pd.DataFrame.from_dict(metrics_stats, orient='index')\n", + "\n", + "# Save them\n", + "metrics_stats_fp = RF_dir / Path(\"metrics.stats.tsv\")\n", + "metrics_stats_df.to_csv(metrics_stats_fp, sep='\\t',)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [], + "source": [ + "# Select the best model automatically\n", + "\n", + "# Initialize a dictionary {N1 : 0, N2 : 0, ...}\n", + "dataset_scores = dict(zip(metrics_stats_df.T.columns, \n", + " [0 for i in range(len(metrics_stats_df.T.columns))]))\n", + "# Iterate over the rows of the dataframe\n", + "for stat, row in metrics_stats_df.T.iterrows():\n", + " if stat.endswith('std'):\n", + " #row is a series. idxmax() returns the value of the index of the series\n", + " dataset_scores[row.idxmin()] += 1 \n", + " else:\n", + " dataset_scores[row.idxmax()] += 1\n", + " \n", + "# Select best scoring model\n", + "best_model = max(dataset_scores.items(), key=operator.itemgetter(1))[0]\n", + "print(\"AND THE BEST MODEL IS...: {}\".format(best_model))\n", + "\n", + "best_model_fp = RF_dir / Path(\"best_model.pkl\")\n", + "# Should be identical with the one in the models dir\n", + "with open(best_model_fp, 'wb') as fout:\n", + " pickle.dump(best_models[best_model], fout)\n", + "\n", + "# Write the best model id in a file.\n", + "## This is used internally in the workflow as a checkpoint\n", + "# Should be a file containing just the id , e.g. N2\n", + "best_model_id_fp = RF_dir / Path(\"best_model_id.txt\")\n", + "with open(best_model_id_fp, 'w') as fout:\n", + " fout.write(best_model)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "# PLOTS" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "## Metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_dataset_roc_curve(training_set,\n", + " data_df,\n", + " plot_all=True,\n", + " ax=None,\n", + " label_xy=True):\n", + " \"\"\"\n", + " Plot the roc curve for a dataset.\n", + " training_set:str\n", + " The name of the dataset\n", + " data_df: pandas.Dataframe\n", + " A data frame that holds all metrics\n", + " plot_all:bool\n", + " If roc curves for all other datasets used for validation\n", + " should be plotted\n", + " ax:pyplot.axes\n", + " An axes object (used for subplotting)\n", + " label_xy:bool\n", + " If axes x,y should be annotated\n", + " \n", + " \"\"\"\n", + " \n", + " if ax is None:\n", + " ax = plt.gca()\n", + " \n", + " rc = ax.plot([0,1], [0,1], linestyle='--', color='k', label='Baseline')\n", + " \n", + " # Select all rows for a particular dataset used for the optimization\n", + " # and create a new dataframe\n", + " train_df = data_df.loc[data_df['TS'] == training_set]\n", + " # Grab the fpr and tpr arrays for that set\n", + " \n", + " fpr = train_df.loc[train_df['VS'] == training_set, 'fpr'].values\n", + " tpr = train_df.loc[train_df['VS'] == training_set, 'tpr'].values\n", + " # Get the auc_score value\n", + " auc_score = data_df.loc[(training_set, training_set), 'roc_auc_score']\n", + " # Plot the roc curve\n", + " ## fpr and tpr are returned as an array of arrays\n", + " ## with length == 1. Hence, we grab the first ellement which\n", + " ## is the array itself\n", + " rc = ax.plot(fpr[0], tpr[0], \n", + " label=\"{}, auc={:.2f}\".format(training_set, auc_score),\n", + " color = color_dict.get(training_set), \n", + " linewidth=3, \n", + " alpha=1)\n", + " # For plotting all other datasets roc curves\n", + " if plot_all is True:\n", + " for validation_set in train_df['VS'].values:\n", + " if validation_set == training_set:\n", + " pass\n", + " else:\n", + " # This is done on the subset train_df\n", + " fpr = train_df.loc[train_df['VS'] == validation_set, 'fpr'].values\n", + " tpr = train_df.loc[train_df['VS'] == validation_set, 'tpr'].values\n", + " auc_score = data_df.loc[(training_set, validation_set), 'roc_auc_score']\n", + " rc = ax.plot(fpr[0], tpr[0], \n", + " label=\"{}, auc={:.2f}\".format(validation_set, auc_score),\n", + " color = color_dict.get(validation_set),\n", + " linewidth=1,\n", + " alpha=0.7)\n", + " \n", + " rc = ax.legend(loc='lower right', fontsize='small')\n", + " \n", + " # When plotting multiple datasets in a subplots\n", + " if label_xy is True:\n", + " rc = ax.set_ylabel('True Positive Rate')\n", + " rc = ax.set_xlabel('False Positive Rate')\n", + " # Set the master title \n", + " rc = ax.set_title('ROC curves (dataset={})'.format(training_set))\n", + " return rc" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [], + "source": [ + "figures_dir = RF_dir / Path(\"figures\")\n", + "figures_dir.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "Figure 1a." + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [], + "source": [ + "fig1_fp = figures_dir / Path(\"Figure_1a.svg\")\n", + "\n", + "fig1, ax = plt.subplots(figsize=(5.32, 5.11))\n", + "\n", + "plot_dataset_roc_curve(best_model, metrics_df, plot_all=True)\n", + "\n", + "fig1.savefig(fig1_fp, dpi=600)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "* Figure S1" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [], + "source": [ + "fig_s1_fp = figures_dir / Path(\"Figure_S1.svg\")\n", + "fig, axs = plt.subplots(4,3, figsize=(18, 20))\n", + "\n", + "for i, ts in enumerate(negative_sets):\n", + " plot_dataset_roc_curve(ts, metrics_df, plot_all=True, ax = axs.flat[i], label_xy=False)\n", + "# https://napsterinblue.github.io/notes/python/viz/subplots/\n", + "else:\n", + " [ax.set_visible(False) for ax in axs.flatten()[i+1:]]\n", + " \n", + "# Manually placing axes labels with trial and error\n", + "fig.text(0.5, 0.09, \n", + " 'False Positive Rate', \n", + " ha='center', \n", + " va='center', \n", + " fontsize=\"x-large\")\n", + "fig.text(0.09, 0.5, \n", + " 'True Positive Rate', \n", + " ha='center', \n", + " va='center', \n", + " rotation='vertical', \n", + " fontsize=\"x-large\")\n", + "\n", + "# Super title\n", + "fig.suptitle(\"ROC Curves for all datasets\", \n", + " x=0.5, \n", + " y=0.92, \n", + " fontsize = \"x-large\")\n", + "# Save it\n", + "fig.savefig(fig_s1_fp, dpi=600, bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "* Figure S2" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [], + "source": [ + "fig_s2_fp = figures_dir / Path(\"Figure_S2.svg\")\n", + "\n", + "boxplot_metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'roc_auc_score',]\n", + "fig, axs = plt.subplots(len(boxplot_metrics), 1,\n", + " figsize=(12, len(boxplot_metrics)*6),\n", + " )\n", + "for i, metric in enumerate(boxplot_metrics):\n", + " title_string = ' '.join(metric.split('_'))\n", + " \n", + " axs.flat[i].set_title(title_string.title(), fontsize=\"x-large\", )\n", + " \n", + " a = sns.boxplot(data=metrics_df, \n", + " x='TS', \n", + " y=metric, \n", + " palette=color_dict, \n", + " ax = axs.flat[i]\n", + " )\n", + " \n", + " a.set_xticklabels([lbl.get_text() for lbl in a.get_xticklabels()],\n", + "# rotation=60,\n", + "# horizontalalignment='right',\n", + " fontsize=\"large\"\n", + " )\n", + "\n", + " a.set_ylim([0.4, 1.])\n", + " \n", + " a.set_yticklabels(np.around(a.get_yticks(), 1),\n", + " fontsize=\"large\")\n", + " a.set_xlabel('')\n", + " a.set_ylabel('')\n", + " \n", + "fig.savefig(fig_s2_fp, dpi=600, bbox_inches='tight')\n", + "# Gives a \n", + "# .../python3.7/site-packages/ipykernel_launcher.py:28: UserWarning: \n", + "# FixedFormatter should only be used together with FixedLocator" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "outputs": [], + "source": [ + "## Features" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "# Calculate basic stats for all features across all datasets\n", + "features_stats = {}\n", + "for i, negative in enumerate(negatives_fp_list):\n", + " dataset_name = negative.name.split('.')[0]\n", + " \n", + " set_df = features_df.loc[features_df[\"TS\"] == dataset_name, FEATURES].describe().T\n", + " features_stats[dataset_name] = {'jaccard_score_mean' : set_df.loc['jaccard_score', 'mean'],\n", + " 'jaccard_score_std' : set_df.loc['jaccard_score', 'std'],\n", + " 'jaccard_score_min' : set_df.loc['jaccard_score', 'min'],\n", + " 'jaccard_score_max' : set_df.loc['jaccard_score', 'max'],\n", + " 'same_score_mean' : set_df.loc['same_score', 'mean'],\n", + " 'same_score_std' : set_df.loc['same_score', 'std'],\n", + " 'same_score_min' : set_df.loc['same_score', 'min'],\n", + " 'same_score_max' : set_df.loc['same_score', 'max'],\n", + " 'inwards_score_mean' : set_df.loc['inwards_score', 'mean'],\n", + " 'inwards_score_std' : set_df.loc['inwards_score', 'std'],\n", + " 'inwards_score_min' : set_df.loc['inwards_score', 'min'],\n", + " 'inwards_score_max' : set_df.loc['inwards_score', 'max'],\n", + " 'outwards_score_mean': set_df.loc['outwards_score', 'mean'],\n", + " 'outwards_score_std': set_df.loc['outwards_score', 'std'],\n", + " 'outwards_score_min' : set_df.loc['outwards_score', 'min'],\n", + " 'outwards_score_max' : set_df.loc['outwards_score', 'max'],\n", + " 'avg_distance_mean': set_df.loc['avg_distance', 'mean'],\n", + " 'avg_distance_std': set_df.loc['avg_distance', 'std'],\n", + " 'avg_distance_min' : set_df.loc['avg_distance', 'min'],\n", + " 'avg_distance_max' : set_df.loc['avg_distance', 'max'],\n", + " 'mean_ani_mean' : set_df.loc['mean_ani', 'mean'],\n", + " 'mean_ani_std' : set_df.loc['mean_ani', 'std'],\n", + " 'mean_ani_min' : set_df.loc['mean_ani', 'min'],\n", + " 'mean_ani_max' : set_df.loc['mean_ani', 'max'],\n", + " 'mean_aai_mean' : set_df.loc['mean_aai', 'mean'],\n", + " 'mean_aai_std' : set_df.loc['mean_aai', 'std'],\n", + " 'mean_aai_min' : set_df.loc['mean_aai', 'min'],\n", + " 'mean_aai_max' : set_df.loc['mean_aai', 'max'], \n", + " }\n", + " \n", + "features_stats_df = pd.DataFrame.from_dict(features_stats, orient='index')\n", + "\n", + "# Store them\n", + "features_stats_fp = RF_dir / Path(\"features_stats.tsv\")\n", + "features_stats_df.T.to_csv(features_stats_fp, sep='\\t')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "feat_imp = features_df.loc[features_df['TS'] == best_model, FEATURES]\n", + "\n", + "fig, ax = plt.subplots(figsize=(5.32,5.11))\n", + "\n", + "features_color_palette = sns.color_palette(\"YlOrBr\", len(FEATURES))\n", + "\n", + "features_color_dict = dict(zip(FEATURES, features_color_palette))\n", + "\n", + "# fimp : For feature importances\n", + "fimp = sns.barplot(data=feat_imp,\n", + " orient='h',\n", + " ci=\"sd\",\n", + " order=['avg_distance',\n", + " 'jaccard_score', \n", + " 'mean_aai', \n", + " 'mean_ani', \n", + " 'inwards_score', \n", + " 'outwards_score', \n", + " 'same_score'],\n", + " palette = features_color_dict,\n", + " capsize=0.1,\n", + " )\n", + "# Trial and error\n", + "fimp.set_xlim([0., 0.45])\n", + " \n", + "fimp.set_yticklabels([FEATURES_MAP[lbl.get_text()] for lbl in fimp.get_yticklabels()],\n", + " fontsize='medium')\n", + "fimp.set_xlabel('Gini Importance', fontsize='large')\n", + "fimp.set_ylabel('')\n", + "fimp.set_title('Feature Importances', fontsize=\"x-large\")\n", + "\n", + "fig.tight_layout()\n", + "\n", + "# Save the fig\n", + "feat_importances_fp = figures_dir / Path(\"Figure_1b.svg\")\n", + "fig.savefig(feat_importances_fp, dpi= 600)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/workflow/rules/construct_datasets.smk b/workflow/rules/construct_datasets.smk new file mode 100644 index 0000000000000000000000000000000000000000..a71612091568a5ac45810d3420927e6a25b52576 --- /dev/null +++ b/workflow/rules/construct_datasets.smk @@ -0,0 +1,231 @@ +rule filter_intact: + input: + intact_raw = ancient("data/interactions/intact.txt"), + tax_db = ancient("data/taxonomy_db/taxa.sqlite") + output: + intact_phages = "results/interaction_datasets/01_filter_intact/intact_phages.txt" + conda: + "../envs/pvogs.yml" + log: + "results/logs/construct_datasets/filter_intact.log" + shell: + "python workflow/scripts/filter_intact.py " + "--tax-db {input.tax_db} " + "--phages-only " + "-i {input.intact_raw} " + "-o {output.intact_phages} &>{log}" + +rule summarize_intact: + input: + intact_phages = rules.filter_intact.output.intact_phages, + tax_db = ancient("data/taxonomy_db/taxa.sqlite") + output: + metadata_tsv = "results/interaction_datasets/02_summarize_intact/metadata.tsv" + params: + summary_dir = "results/interaction_datasets/02_summarize_intact" + conda: + "../envs/pvogs.yml" + log: + "results/logs/construct_datasets/summarize_intact.log" + shell: + "python workflow/scripts/summarize_intact.py " + "--tax-db {input.tax_db} " + "-i {input.intact_phages} " + "-o {params.summary_dir} &>{log}" + +rule get_uniprot_ids: + input: + metadata_tsv = rules.summarize_intact.output.metadata_tsv + output: + uniprot_ids_txt = "results/interaction_datasets/03_uniprot/uniprot_ids.txt", + uniprot_only = "results/interaction_datasets/03_uniprot/uniprot.phages_only.tsv" + conda: + "../envs/pvogs.yml" + log: + "results/logs/construct_datasets/get_uniprot_ids.log" + shell: + "python workflow/scripts/get_uniprot_ids.py " + "-i {input.metadata_tsv} " + "-l {output.uniprot_ids_txt} " + "-o {output.uniprot_only} &>{log}" + +rule download_uniprots: + input: + uniprot_ids_txt = rules.get_uniprot_ids.output.uniprot_ids_txt + output: + swissprot_txt = "results/interaction_datasets/03_uniprot/uniprot.sprot.txt" + log: + "results/logs/construct_datasets/query_uniprot.log" + conda: + "../envs/pvogs.yml" + shell: + "python workflow/scripts/query_uniprot.py " + "--filter-list " + "-i {input.uniprot_ids_txt} " + "-o {output.swissprot_txt} &>{log}" + +rule process_uniprot: + input: + uniprot_ids_txt = rules.get_uniprot_ids.output.uniprot_ids_txt, + uniprot_only = rules.get_uniprot_ids.output.uniprot_only, + swissprot_txt = rules.download_uniprots.output.swissprot_txt + output: + proteins_fasta = "results/interaction_datasets/04_process_uniprot/proteins.faa", + skipped = "results/interaction_datasets/04_process_uniprot/skipped.txt", + duplicates = "results/interaction_datasets/04_process_uniprot/duplicates.tsv", + interactions_filtered = "results/interaction_datasets/04_process_uniprot/interactions_filtered.tsv", + ncbi_interactions = "results/interaction_datasets/04_process_uniprot/ncbi_interactions.tsv", + genomes_accessions = "results/interaction_datasets/04_process_uniprot/genome_accessions.txt", + ncbi_to_uniprot = "results/interaction_datasets/04_process_uniprot/ncbi2uniprot.mapping.txt", + uniprot_to_ncbi = "results/interaction_datasets/04_process_uniprot/uniprot2ncbi.mapping.txt" + conda: + "../envs/pvogs.yml" + log: + "results/logs/construct_datasets/process_uniprot.log" + params: + output_dir = "results/interaction_datasets/04_process_uniprot" + shell: + "python workflow/scripts/process_uniprot.py " + "-i {input.uniprot_only} " + "-l {input.uniprot_ids_txt} " + "-s {input.swissprot_txt} " + "-p {params.output_dir} &>{log}" + +rule download_genomes: + input: + genomes_accessions = rules.process_uniprot.output.genomes_accessions + output: + genomes_gb = "results/interaction_datasets/05_interaction_datasets/genomes.gb" + log: + "results/logs/construct_datasets/download_genomes.log" + params: + email = config.get('email') + conda: + "../envs/pvogs.yml" + shell: + "python workflow/scripts/download_genomes.py " + "-i {input.genomes_accessions} " + "-o {output.genomes_gb} " + "-e {params.email} &>{log}" + +rule extract_proteins_from_genomes: + input: + genomes_gb = rules.download_genomes.output.genomes_gb + output: + all_proteins_faa = "results/interaction_datasets/05_genomes/all_proteins.faa" + conda: + "../envs/pvogs.yml" + log: + "results/logs/interaction_datasets/05_genomes/.extract_proteins.log" + shell: + "python workflow/scripts/extract_proteins_from_gb.py " + "-i {input.genomes_gb} " + "-o {output.all_proteins_faa} &>{log}" + +## Copy the proteins file and create a 2-column tsv in a positives directory +## The same naming structure helps with expansion for the rules. +rule copy_positives_proteins: + input: + proteins_faa = rules.process_uniprot.output.proteins_fasta + output: + positives_faa = "results/interaction_datasets/positives/positives.proteins.faa" + log: + "results/logs/construct_datasets/copy_positives_scores.log" + shell: + "cp {input.proteins_faa} {output.positives_faa} 2>{log}" + + +rule ncbi_positives: + input: + ncbi_interactions = rules.process_uniprot.output.ncbi_interactions, + genomes_gb = rules.download_genomes.output.genomes_gb + output: + ncbi_positives_tsv = "results/interaction_datasets/positives/positives.interactions.tsv" + log: + "results/logs/construct_datasets/ncbi_positives.log" + shell: + "cut -f3,4 {input.ncbi_interactions} > {output.ncbi_positives_tsv} 2>{log}" + + +rule make_negatives: + input: + ncbi_positives_tsv = rules.ncbi_positives.output.ncbi_positives_tsv, + all_proteins_faa = rules.extract_proteins_from_genomes.output.all_proteins_faa, + genomes_gb = rules.download_genomes.output.genomes_gb + + output: + expand(["results/interaction_datasets/N{I}/N{I}.interactions.tsv", + "results/interaction_datasets/N{I}/N{I}.proteins.faa"], + I = [i+1 for i in range(0, NEGATIVES)]) + conda: + "../envs/pvogs.yml" + params: + no_negatives = NEGATIVES + log: + "results/logs/construct_datasets/make_negatives.log" + shell: + "for i in `seq 1 {params.no_negatives}`;do " + " python workflow/scripts/make_protein_combos.py " + " -i {input.genomes_gb} " + " -a {input.all_proteins_faa} " + " -o results/interaction_datasets/N${{i}}/N${{i}}.proteins.faa " + " -x results/interaction_datasets/N${{i}}/N${{i}}.interactions.tsv " + " --exclude {input.ncbi_positives_tsv} " + " --sample-size 2 " + " --random-seed ${{i}}; " + "done &>{log}" + +rule hmmsearch: + input: + proteins_fasta = "results/interaction_datasets/{dataset}/{dataset}.proteins.faa", + all_pvogs_profiles = "data/pvogs/all.hmm" + output: + hmm_out_txt = "results/interaction_datasets/06_map_proteins_to_pvogs/{dataset}/{dataset}.hmmout.txt", + hmm_tblout_tsv = "results/interaction_datasets/06_map_proteins_to_pvogs/{dataset}/{dataset}.hmmtblout.tsv" + log: + "results/logs/construct_datasets/{dataset}/hmmsearch.log" + threads: 8 + conda: + "../envs/pvogs.yml" + shell: + "hmmsearch --cpu {threads} " + "-o {output.hmm_out_txt} " + "--tblout {output.hmm_tblout_tsv} " + "{input.all_pvogs_profiles} " + "{input.proteins_fasta}" + +rule refseqs_to_pvogs: + input: + interactions_tsv = "results/interaction_datasets/{dataset}/{dataset}.interactions.tsv", + proteins_faa = "results/interaction_datasets/{dataset}/{dataset}.proteins.faa", + hmm_tblout = "results/interaction_datasets/06_map_proteins_to_pvogs/{dataset}/{dataset}.hmmtblout.tsv" + output: + pvogs_interactions = "results/interaction_datasets/{dataset}/{dataset}.pvogs_interactions.tsv" + conda: + "../envs/pvogs.yml" + log: + "results/logs/construct_datasets/{dataset}/refseqs_to_pvogs.log" + shell: + "python workflow/scripts/refseqs_to_pvogs.py " + "-i {input.interactions_tsv} " + "-f {input.proteins_faa} " + "-hmm {input.hmm_tblout} " + "-o {output.pvogs_interactions} " + "&> {log}" + +rule create_features_tables: + input: + filtered_master_tsv = rules.filter_scores_table.output.filtered_master_tsv, + interactions_tsv = "results/interaction_datasets/{dataset}/{dataset}.pvogs_interactions.tsv" + output: + features_tsv = "results/interaction_datasets/{dataset}/{dataset}.features.tsv" + log: + "results/logs/construct_datasets/{dataset}/create_features_tables.log" + conda: + "../envs/pvogs.yml" + shell: + "python workflow/scripts/subset_scores.py " + "-s {input.filtered_master_tsv} " + "-i {input.interactions_tsv} " + "-o {output.features_tsv}" + diff --git a/workflow/rules/envs/zenodo.yml b/workflow/rules/envs/zenodo.yml new file mode 100644 index 0000000000000000000000000000000000000000..57a1cad0cf9c614115bf996a6eb1cb3789e384b8 --- /dev/null +++ b/workflow/rules/envs/zenodo.yml @@ -0,0 +1,8 @@ +channels: + - conda-forge +dependencies: + - python>=3.6 + - pip + - pip: + - zenodo_get==1.3.2 + diff --git a/workflow/rules/get_data.smk b/workflow/rules/get_data.smk new file mode 100644 index 0000000000000000000000000000000000000000..bfc779dcf13956de80edc6804182fcbd6efcdb86 --- /dev/null +++ b/workflow/rules/get_data.smk @@ -0,0 +1,28 @@ +rule download_archive: + output: + tar_gz = "pvogs_function.data.tar.gz" + log: + "results/logs/get_data/download_archive.log" + conda: + "envs/zenodo.yml" + params: + sandbox = "--sandbox" if config["zenodo"]["use_sandbox"] is True else "", + zenodo_id = config["zenodo"]["doi"] + shell: + "zenodo_get {params.sandbox} --record={params.zenodo_id}" + +rule extract_data: + input: + tar_gz = rules.download_archive.output.tar_gz + output: + genomes_fasta = "data/genomes/phages_refseq.fasta", + intact_txt= "data/interactions/intact.txt", + pvogs_profiles = "data/pvogs/all.hmm", + pvogs_annotations = "data/pvogs/VOGProteinTable.txt", + taxonomy_db = "data/taxonomy_db/taxa.sqlite", + taxonomy_pkl = "data/taxonomy_db/taxa.sqlite.traverse.pkl" + log: + "results/logs/get_data/extract_data.log" + shell: + "tar -xzvf {input.tar_gz} &>{log}" + diff --git a/workflow/rules/pre_process.smk b/workflow/rules/pre_process.smk new file mode 100644 index 0000000000000000000000000000000000000000..53cc8dc49d8933e4ddda4a3777ebfd267da4f5a7 --- /dev/null +++ b/workflow/rules/pre_process.smk @@ -0,0 +1,234 @@ +## pre_process +## PREPARE NECESSARY FILES FOR FEATURE CALCULATIONS + +rule translate_genomes: + input: + refseq_genomes = ancient("data/genomes/phages_refseq.fasta") + output: + transeq_genomes = "results/pre_process/transeq/transeq.genomes.fasta" + log: "results/logs/pre_process/transeq/transeq_genomes.log" + conda: + "../envs/pvogs.yml" + shell: + "transeq -frame 6 " + "-table 11 -clean " + "-sequence {input.refseq_genomes} " + "-outseq {output.transeq_genomes} 2>{log}" + +rule hmmsearch_transeq: + input: + transeq_genomes = rules.translate_genomes.output.transeq_genomes, + pvogs_all_profiles = ancient("data/pvogs/all.hmm") + output: + hmmout_txt = "results/pre_process/hmmsearch/transeq.hmmout.txt", + hmmtblout_tsv = "results/pre_process/hmmsearch/transeq.hmmtblout.tsv" + log: + "results/logs/pre_process/hmmsearch/hmmsearch_transeq.log" + conda: + "../envs/pvogs.yml" + threads: 10 + shell: + "hmmsearch --cpu {threads} " + "-o {output.hmmout_txt} " + "--tblout {output.hmmtblout_tsv} " + "{input.pvogs_all_profiles} " + "{input.transeq_genomes} " + "&>{log}" + + +rule split_genomes: + input: + refseq_genomes = ancient("data/genomes/phages_refseq.fasta") + output: + reflist_txt = "results/pre_process/reflist.txt" + log: + "results/logs/pre_process/split_genomes.log" + conda: + "../envs/pvogs.yml" + params: + genomes_dir = "results/pre_process/all_genomes" + shell: + "mkdir -p {params.genomes_dir} && " + "python workflow/scripts/split_multifasta.py " + "--write-reflist " + "-i {input.refseq_genomes} " + "-o {params.genomes_dir} 2>{log}" + + +## FASTANI +rule fastani: + input: + reflist_txt = rules.split_genomes.output.reflist_txt + output: + fastani_raw = "results/pre_process/fastani/fastani.out", + fastani_matrix = "results/pre_process/fastani/fastani.out.matrix" + log: + "results/logs/pre_process/fastani/fastani.log" + conda: + "../envs/pvogs.yml" + threads: 8 + params: + fragLen = 300, + minFraction = 0.1 + shell: + "fastANI -t {threads} " + "--ql {input.reflist_txt} " + "--rl {input.reflist_txt} " + "--fragLen {params.fragLen} " + "--minFraction {params.minFraction} " + "-o {output.fastani_raw} " + "--matrix 2>{log}" # Output the matrix + +rule fastani_matrix_to_square: + input: + fastani_matrix = rules.fastani.output.fastani_matrix + output: + fastani_square_mat = "results/pre_process/fastani/fastani.square.matrix" + conda: + "../envs/pvogs.yml" + log: + "results/logs/pre_process/fastani_matrix_to_square.log" + shell: + "python workflow/scripts/fastani_mat_to_square.py " + "-i {input.fastani_matrix} " + "-o {output.fastani_square_mat} " + "--process-names &>{log}" # Chop full paths to last node, which is the accession number + + + +## COMPAREM +rule comparem_call_genes: + input: + rules.split_genomes.output.reflist_txt + output: + done_file = touch("results/pre_process/comparem/aai_wf/genes/.done.txt") + conda: + "../envs/pvogs.yml" + log: + # This is produced by comparem by default + "results/pre_process/comparem/aai_wf/genes/comparem.log" + params: + genes_dir = "results/pre_process/comparem/aai_wf/genes", + genomes_dir = "results/pre_process/all_genomes" + threads: 10 + shell: + "comparem call_genes -c {threads} --silent -x fasta " + "{params.genomes_dir} " + "{params.genes_dir}" + +rule remove_empty_files: + input: + rules.comparem_call_genes.output.done_file + output: + info_file = "results/pre_process/comparem/aai_wf/genes/genomes_skipped.txt" + params: + genes_dir = "results/pre_process/comparem/aai_wf/genes" + log: + "results/logs/pre_process/remove_empty_files.log" + shell: + "python workflow/scripts/remove_empty_files.py " + "-i {params.genes_dir} " + "-o {output.info_file}" + +rule comparem_similarity: + input: + info_file = rules.remove_empty_files.output.info_file + output: + hits_sorted_tsv = "results/pre_process/comparem/aai_wf/similarity/hits_sorted.tsv", + query_genes_dmnd = "results/pre_process/comparem/aai_wf/similarity/query_genes.dmnd", + query_genes_faa = "results/pre_process/comparem/aai_wf/similarity/query_genes.faa" + params: + similarity_dir = "results/pre_process/comparem/aai_wf/similarity", + genes_dir = "results/pre_process/comparem/aai_wf/genes" + log: + # Produced by comparem by default + "results/pre_process/comparem/aai_wf/similarity/comparem.log" + threads: 10 + conda: + "../envs/pvogs.yml" + shell: + "comparem similarity " + "-c {threads} -x faa " + "--silent " + "{params.genes_dir} {params.genes_dir} {params.similarity_dir}" + + +rule comparem_aai: + input: + query_genes_faa = rules.comparem_similarity.output.query_genes_faa, + hits_sorted_tsv = rules.comparem_similarity.output.hits_sorted_tsv + output: + aai_summary_tsv = "results/pre_process/comparem/aai_wf/aai/aai_summary.tsv" + conda: + "../envs/pvogs.yml" + log: + # Produced by comparem by default + "results/pre_process/comparem/aai_wf/aai/comparem.log" + params: + aai_dir = "results/pre_process/comparem/aai_wf/aai" + shell: + "comparem aai --silent " + "{input.query_genes_faa} {input.hits_sorted_tsv} " + "{params.aai_dir}" + +rule process_comparem_matrix: + input: + aai_summary_tsv = rules.comparem_aai.output.aai_summary_tsv + output: + aai_summary_square = "results/pre_process/comparem/aai_wf/aai/aai_summary.square.tsv" + conda: + "../envs/pvogs.yml" + log: + "results/logs/pre_process/proecess_comparem_matrix.log" + shell: + "python workflow/scripts/process_comparem.py " + "-i {input.aai_summary_tsv} " + "-o {output.aai_summary_square} &>{log}" + +rule calculate_all_scores: + input: + ani_square_matrix = rules.fastani_matrix_to_square.output.fastani_square_mat, + aai_square_matrix = rules.process_comparem_matrix.output.aai_summary_square, + phages_genomes_fasta = ancient("data/genomes/phages_refseq.fasta"), + pvogs_all_profiles = ancient("data/pvogs/all.hmm"), + hmmout_txt = rules.hmmsearch_transeq.output.hmmout_txt + output: + master_table = "results/scores.tsv" + conda: "../envs/pvogs.yml" + log: + "results/logs/pre_process/calculate_scores.log" + shell: + "python workflow/scripts/calculate_all_scores.py " + "--profiles-file {input.pvogs_all_profiles} " + "--genomes {input.phages_genomes_fasta} " + "--input-hmm {input.hmmout_txt} " + "--ani-matrix {input.ani_square_matrix} " + "--aai-matrix {input.aai_square_matrix} " + "-o {output.master_table} &>{log}" + +rule filter_scores_table: + input: + master_table = rules.calculate_all_scores.output.master_table + output: + filtered_master_tsv = "results/filtered_scores.tsv" + log: + "results/pre_process/filter_scores_table.log" + shell: + """awk '{{if ( $10 != "1000000") print $0}}' {input.master_table} > {output.filtered_master_tsv}""" + +rule parse_annotations: + input: + raw_annotations_fp = ancient("data/pvogs/VOGProteinTable.txt") + output: + processed_annotations_fp = "results/annotations.tsv" + conda: + "../envs/pvogs.yml" + log: + "results/pre_process/process_annotations.log" + shell: + "python workflow/scripts/process_annotations.py " + "-i {input.raw_annotations_fp} " + "-o {output.processed_annotations_fp} " + "&>{log}" + + diff --git a/workflow/scripts/calculate_all_scores.py b/workflow/scripts/calculate_all_scores.py new file mode 100755 index 0000000000000000000000000000000000000000..d40a85ba993162c3b41420f2a8dd3ac2051a34a0 --- /dev/null +++ b/workflow/scripts/calculate_all_scores.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python + +import argparse +from pathlib import Path +from Bio import SeqIO, SearchIO +import pandas as pd +import numpy as np +from itertools import combinations + +def parse_args(): + parser = argparse.ArgumentParser(description='Calculate scores for a set of pVOG interactions, ' + 'provided as a tsv file') + + optionalArgs = parser._action_groups.pop() + + requiredArgs = parser.add_argument_group("required arguments") + requiredArgs.add_argument('-p', '--profiles-file', + dest='profiles_file', + type=lambda p: Path(p).resolve(strict=True), + required=True, + help="The all.hmm file from pVOGs database" + ) + requiredArgs.add_argument('-g', '--genomes', + dest='genomes_fasta', + type=lambda p: Path(p).resolve(strict=True), + required=True, + help="A fasta file with the genomes") + requiredArgs.add_argument('-hmm', '--input-hmm', + dest='hmmer_in', + type=lambda p: Path(p).resolve(strict=True), + required=True, + help="The regular output file from hmmsearch all pvogs against" + "the translated genomes" + ) + requiredArgs.add_argument('-ani_f', '--ani-matrix', + dest='ani_matrix', + type=lambda p: Path(p).resolve(strict=True), + required=True, + help="The square matrix resulting from fastANI with all genomes" + ) + + requiredArgs.add_argument('-aai_f', '--aai-matrix', + dest='aai_matrix', + type=lambda p: Path(p).resolve(strict=True), + required=True, + help="The aai square matrix from compareM on all genomes" + ) + requiredArgs.add_argument('-o', '--output-file', + dest='outfile', + required=True, + type=lambda p: Path(p).resolve(), + help="File path to write the results in") + + parser._action_groups.append(optionalArgs) + + return parser.parse_args() + +def get_seq_sizes(seq_fasta): + """ + Create a dict that holds length sizes for all records + in a fasta file. + """ + seq_sizes = {} + with open(seq_fasta, 'r') as fin: + for record in SeqIO.parse(fin, "fasta"): + seq_sizes[record.id] = len(record.seq) + return seq_sizes + +def get_maximum_index(values_list): + """ + Get the index of the maximum value in a list + + param: list: values_list A list of values + return: int: The index of the maximum value + """ + if len(values_list) == 1: + max_index = 0 + else: + max_index = values_list.index(max(values_list)) + return max_index + +def translate_to_genomic_coords(start, end, frame, genome_size): + """ + Translate the coordinates of the protein from transeq to genomic + coordinates. + + Strand is used here as orientation and not really [non-]conding. + If the frame is 1,2 or 3 (-->) I call this (+) strand. + Else, it is the (-) strand + + param: int: start The starting coordinate on the protein + param: int: end The ending coordinate on the protein + param: int: frame The frame on which it is found (1-6) + param: int: genome_size The size of the genomes + + return: tuple: (genomic start, genomic end, strand) + """ + nucleic_start = start * 3 + nucleic_end = end * 3 + if frame == 1: + genomic_start = nucleic_start - 2 + genomic_end = nucleic_end - 2 + if frame == 2: + genomic_start = nucleic_start - 1 + genomic_end = nucleic_end - 1 + if frame == 3: + genomic_start = nucleic_start + genomic_end = nucleic_end + if frame == 4: + genomic_start = genome_size - (nucleic_start - 2) + genomic_end = genome_size - (nucleic_end - 2) + if frame == 5: + genomic_start = genome_size - (nucleic_start - 1) + genomic_end = genome_size - (nucleic_end -1) + if frame == 6: + genomic_start = genome_size - nucleic_start + genomic_end = genome_size - nucleic_end + + if frame in [1,2,3]: + strand = '+' + elif frame in [4,5,6]: + strand = '-' + else: + raise ValueError("frame should be one of 1,2,3,4,5,6") + + return genomic_start, genomic_end, strand + +def collect_hmmsearch_info(hmmer_in, genome_sizes): + """ + Parse the information in a dictionary of the form + { pvog_id: + { genome_id : + { frame : [(genomic_start, genomic_end, accuracy, pvog_coverage), ...], + ... , } + } + } + """ + + hmm_results = {} + with open(hmmer_in, 'r') as fin: + for record in SearchIO.parse(fin, "hmmer3-text"): + # Initialize an empty dictionary for the pvog + hmm_results[record.id] = {} + for hit in record.hits: + if hit.is_included: + # From the transeq output, accessions are suffixed + # with _1, _2, _3, ..., depending on the strand + genome_id = '_'.join(hit.id.split('_')[0:2]) + frame = int(hit.id.split('_')[2]) + hsps_included = [hsp.is_included for hsp in hit.hsps] + # For multiple hsps that pass the threshold + if any(hsps_included): + # Get their score + scores = [hsp.bitscore for hsp in hit.hsps] + # and select the best one + max_i = get_maximum_index(scores) + best_hsp = hit.hsps[max_i] + + # Translate back to genomic coordinates + genomic_coords = translate_to_genomic_coords(best_hsp.env_start, + best_hsp.env_end, + frame, + genome_sizes.get(genome_id)) + + span = (best_hsp.env_end - best_hsp.env_start) / record.seq_len + + if genome_id not in hmm_results[record.id]: + hmm_results[record.id][genome_id] = (frame, genomic_coords[0], + genomic_coords[1], + best_hsp.acc_avg, + span, + genomic_coords[2]) + return hmm_results + + +def get_mean_from_df(subset_list, df): + df_subset = df.loc[subset_list, subset_list] + return df_subset.values.mean() + + +def get_pvogs_ids(profiles_file): + all_pvogs = [] + with open(profiles_file, 'r') as fin: + for line in fin: + if line.startswith('NAME'): + pvog = line.split()[1] + all_pvogs.append(pvog) + return all_pvogs + +def get_shortest_distance(startA, endA, startB, endB): + start_to_start = abs(startA - startB) + start_to_end = abs(startA - endB) + end_to_start = abs(endA - startB) + end_to_end = abs(endA - endB) + all_distances = [start_to_start, start_to_end, end_to_start, end_to_end] + return min(all_distances) + +def calculate_scores(interaction_tuple, hmm_results, ani_df, aai_df): + same_score = 0 + inwards_score = 0 + outwards_score = 0 + avg_distance = 1000000 # 1000000 for pairs that are never on the same genome + js = 0 + # include the number of genomes for each participating pvog + mean_ani = 0 + mean_aai = 0 + + pvogA = interaction_tuple[0] + pvogB = interaction_tuple[1] + + genomesA = set((hmm_results[pvogA].keys())) + genomesB = set(hmm_results[pvogB].keys()) + + common_genomes = genomesA.intersection(genomesB) + if len(common_genomes) > 0: + all_genomes = genomesA.union(genomesB) + + # Jaccard score + js = len(common_genomes) / len(all_genomes) + + # Mean ANI + mean_ani = get_mean_from_df(common_genomes, ani_df) + + # Mean AAI + mean_aai = get_mean_from_df(common_genomes, aai_df) + + # Distances + sum_of_distances = 0 + + for genome in common_genomes: + hitA = hmm_results[pvogA][genome] + hitB = hmm_results[pvogB][genome] + + # Get the proper starts for distance calculation + if hitA[-1] == '+': + startA = hitA[1] + endA = hitA[2] + else: + startA = hitA[2] + endA = hitA[1] + + if hitB[-1] == '+': + startB = hitB[1] + endB = hitB[2] + else: + startB = hitB[2] + endB = hitB[1] + ## 1. Start to start + #sum_of_distances += abs(startA - startB) + + ## 2. Shortest distance + sum_of_distances += get_shortest_distance(startA, endA, startB, endB) + + # If they have the same orientation + # Regardless of '+' or '-' + if hitA[-1] == hitB[-1]: + same_score += 1 + + if hitA[-1] != hitB[-1]: + dstarts = abs(startA - startB) + dends = abs(endA - endB) + if dstarts >= dends: + inwards_score += 1 + else: + outwards_score += 1 + + same_score = same_score / len(common_genomes) + inwards_score = inwards_score / len(common_genomes) + outwards_score = outwards_score / len(common_genomes) + avg_distance = sum_of_distances / len(common_genomes) + + return len(genomesA), len(genomesB), len(common_genomes), js, same_score, inwards_score, outwards_score, avg_distance, mean_ani, mean_aai + +#def get_ani_for_genomes(genomes, ani_df): +# """ +# Calculate mean ani for the genomes list from the +# ani df +# """ +# genomes_df = ani_df.loc[genomes, genomes] +# return genomes_df.values.mean() + +#def get_aai_for_genomes(genomes, aai_df): +# genomes_df = aai_df.loc[genomes, genomes] +# return genomes_df.values.mean() + +def main(): + # Read in the arguments + args = parse_args() + print("Loading data...") + + # Store the genome sizes + genome_sizes = get_seq_sizes(args.genomes_fasta) + print("Loaded sequence size info for {} input sequences".format(len(genome_sizes))) + + print("Reading hmmsearch information...") + # Get the hmmsearch results info for all pvogs in + hmm_results = collect_hmmsearch_info(args.hmmer_in, genome_sizes) + print("Done!") + + print("Reading ANI matrix...") + # Read in the matrices + ani_df = pd.read_csv(args.ani_matrix, index_col=0, header=0, sep="\t") + print("Done!") + + print("Reading AAI matrix...") + aai_df = pd.read_csv(args.aai_matrix, index_col=0, header=0, sep="\t") + print("Done!") + + # Create a list that holds all pvog ids + all_pvogs = get_pvogs_ids(args.profiles_file) + + # Create all pairs of pvogs + all_combos = list(combinations(all_pvogs, 2)) + print("Created {} possible combinations".format(len(all_combos))) + + # Calculate scores + print("Caclulating and writing to file...") + + counter = 0 + with open(args.outfile, 'w') as fout: + fout.write("{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n".format('pvog1', 'pvog2', + 'genomes1', 'genomes2', + 'overlap_genomes', 'jaccard_score', + 'same_score', 'inwards_score', 'outwards_score', + 'avg_distance', 'mean_ani', 'mean_aai')) + for combo in all_combos: + int_string = '{}\t{}\t'.format(combo[0], combo[1]) + scores = calculate_scores(combo, hmm_results, ani_df, aai_df) + scores_string = '\t'.join(map(str, scores)) + fout.write(int_string + scores_string + '\n') + counter += 1 + if counter % 1000000 == 0: + print("{}/{} processed".format(counter, len(all_combos))) + + print("Scores are written to {}".format(str(args.outfile))) + + +if __name__ == '__main__': + main() + diff --git a/workflow/scripts/download_genomes.py b/workflow/scripts/download_genomes.py new file mode 100755 index 0000000000000000000000000000000000000000..5d393d5b8ae7a4b2fd7d531c508103269e951990 --- /dev/null +++ b/workflow/scripts/download_genomes.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python + +import argparse +from Bio import SeqIO, Entrez +from Bio.Seq import Seq +from Bio.SeqRecord import SeqRecord +from Bio.Alphabet import IUPAC +import time +from math import ceil + +parser = argparse.ArgumentParser(description='Download a list of ncbi accessions to the output file') +optionalArgs = parser._action_groups.pop() + +requiredArgs = parser.add_argument_group("required arguments") +requiredArgs.add_argument('-i', '--input-list', + dest='input_list', + required=True, + help="A txt file containing accessions to get from genbank") +requiredArgs.add_argument('-o', '--output-file', + dest='output_file', + type=str, + required=True, + help="The output file to write sequences to") +requiredArgs.add_argument('-e', '--e-mail', + help='E-mail address to be used with Bio.Entrez.email. Required by NCBI to notify if something is ' + 'off or if you overload their servers', + dest='email', + type=str, + required=True + ) +optionalArgs.add_argument('--output-fmt', + required=False, + dest='output_fmt', + default='gb', + type=str, + help="Store the results in this file format [default='gb' (genbank)]") + +parser._action_groups.append(optionalArgs) + +def sequence_info_to_dic(sequence_info_file): + """ + Collapse protein ids per genome + Args: + sequence_info_file (str): A tsv file containing 3 columns uniprot_id, genome_id, protein_id, + with a header. + Return: + sequence_info (dict): A dictionary of the form {genome_id: [(protein_id_1, uniprot_id_1), + ... ], + ... } + """ + sequence_info = {} + with open(sequence_info_file, 'r') as f: + # Skip header + next(f) + for line in f: + fields = [field.strip() for field in line.split('\t')] + uniprot_id = fields[0] + genome_id = fields[1] + protein_id = fields[2] + if genome_id not in sequence_info: + sequence_info[genome_id] = [(uniprot_id, protein_id,)] + else: + sequence_info[genome_id].append((uniprot_id, protein_id)) + return sequence_info + +def txt_file_to_list(genomes_txt): + with open(genomes_txt, 'r') as fin: + genomes_list = [line.strip() for line in fin] + return genomes_list + + +def download_sequences(genomes_list, genomes_file, email_address, output_fmt="gb", batch_size = 100): + + # Required by Bio.Entrez + Entrez.email = email_address + + # Some progress tracking + total_batches = ceil(len(genomes_list) / batch_size) + batch_no = 0 + + with open(genomes_file, 'w') as fout: + for i in range(0, len(genomes_list), 100): + batch_no += 1 + batch = genomes_list[i:i + 100] + print('Downloading batch {}/{}'.format(batch_no, total_batches)) + handle = Entrez.efetch(db="nuccore", id=batch, rettype=output_fmt, retmode="text") + batch_data = handle.read() + fout.write(batch_data) + handle.close() + # Wait 2 seconds before next batch + time.sleep(2) + + +if __name__ == '__main__': + args = parser.parse_args() + + genomes_list = txt_file_to_list(args.input_list) + download_sequences(genomes_list, args.output_file, args.email, output_fmt=args.output_fmt) + + # genomes_fna = args.prefix_out + '_genomes.fna' + # proteins_faa = args.prefix_out + '_proteins.faa' + # metadata = args.prefix_out + '_metadata.tsv' + # + # with open(genbank_file, 'r') as fin, \ + # open(genomes_fna, 'w') as fg, \ + # open(proteins_faa, 'w') as fp, \ + # open(metadata, 'w') as fout: + # # Write the header to the metadata file + # fout.write(metadata_header + '\n') + # + # # Loop through the genbank records to extract the information + # for record in SeqIO.parse(fin, format="genbank"): + # unversioned = record.id.split('.')[0] # some entries don't come with their version + # # Check if the versioned accession is in the dictionary + # if record.id in sequence_info: + # record_id = record.id + # else: # otherwise use the unversioned string + # record_id = unversioned + # + # # Get a list of associated proteins with the record + # proteins = [prot_info[1] for prot_info in sequence_info.get(record_id)] + # + # # Create a mapping of protein INSDC accession to uniprot accession + # prot_map = {prot_info[1]: prot_info[0] for prot_info in sequence_info.get(record_id)} + # + # # Polyproteins come with one accession + # # TO DO + # # Get the specific location of each chain + # # after it is post-translationally modified + # # Maybe use the FT attribute in the uniprot file? + # if len(proteins) > 1 and len(set(proteins)) == 1: + # # For now print these out and skip + # print("Polyprotein? Nucleotide accession: {}, Protein list: {}".format(record_id, proteins)) + # pass + # + # else: + # # Get the genome sequence anyway + # SeqIO.write(record, fg, "fasta") + # # The list of protein accessions associated with the genome + # # TO DO + # # turning them in a set first otherwise there are some duplicates + # lproteins = list(set(proteins)) + # for p in lproteins: + # # The dictionary of features of the genbank record + # for f in record.features: + # # If there is a protein_id field in the qualifiers + # # and is the same as the protein + # if ('protein_id' in f.qualifiers) and (p in f.qualifiers.get('protein_id')): + # # Extract genome length, start, end, strand of the feature + # str_out = '\t'.join(map(str, [p, prot_map[p], record.id, len(record.seq), + # f.location.start.position, + # f.location.end.position, + # f.strand])) + # fout.write(str_out + '\n') + # # Get the translation of the feature + # prot_seq = f.qualifiers.get('translation')[0] + # # Construct a SeqRecord object + # prot_rec = SeqRecord(Seq(prot_seq, IUPAC.protein), + # id=p, + # description='') + # # Write the sequence to the proteins file + # SeqIO.write(prot_rec, fp, "fasta") diff --git a/workflow/scripts/extract_proteins_from_gb.py b/workflow/scripts/extract_proteins_from_gb.py new file mode 100755 index 0000000000000000000000000000000000000000..d37232278cfde6045774beaa0906747c0d2c745e --- /dev/null +++ b/workflow/scripts/extract_proteins_from_gb.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python + +import argparse +from Bio import SeqIO +from Bio.Seq import Seq +from Bio.SeqRecord import SeqRecord +from Bio.Alphabet import IUPAC + + +parser = argparse.ArgumentParser(description='Get all protein sequences from a genbank file with' + 'annotated genomes.') +optionalArgs = parser._action_groups.pop() + +requiredArgs = parser.add_argument_group("required arguments") +requiredArgs.add_argument('-i', '--input-gb', + dest='input_gb', + required=True, + help="A genbank file containing annotated genomes") +requiredArgs.add_argument('-o', '--output-fasta', + dest='output_fasta', + type=str, + required=True, + help="The output file to write sequences to") + +parser._action_groups.append(optionalArgs) + + +def extract_all_proteins_from_genbank_file(genbank_fin, fasta_out): + genomes, proteins = 0, 0 + with open(fasta_out, 'w') as fout: + for record in SeqIO.parse(genbank_fin, "genbank"): + genomes += 1 + for f in record.features: + if f.type == 'CDS': + protein = f.qualifiers.get('protein_id') + if protein is not None: + prot_seq = f.qualifiers.get('translation')[0] + prot_rec = SeqRecord(Seq(prot_seq, IUPAC.protein), + id=protein[0], + description='') + proteins += SeqIO.write(prot_rec, fout, "fasta") + + return genomes, proteins + + +if __name__ == '__main__': + args = parser.parse_args() + g, p = extract_all_proteins_from_genbank_file(args.input_gb, args.output_fasta) + print("Extracted {} from {} input genomes".format(p, g)) diff --git a/workflow/scripts/fastani_mat_to_square.py b/workflow/scripts/fastani_mat_to_square.py new file mode 100755 index 0000000000000000000000000000000000000000..f82cf89f1a6ce2634d7c86d85abea6c8f738cd64 --- /dev/null +++ b/workflow/scripts/fastani_mat_to_square.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python + +import pandas as pd +import numpy as np +import argparse +from pathlib import Path + +def parse_args(): + parser = argparse.ArgumentParser(description='Convert the matrix output of fastANI (v. 1.3) to a square matrix') + + optionalArgs = parser._action_groups.pop() + + requiredArgs = parser.add_argument_group("required arguments") + requiredArgs.add_argument('-i', '--input-matrix', + dest='mat_in', + type=lambda p: Path(p).resolve(strict=True), + required=True, + help="A matrix file from fastANI, produced with its -matrix option") + requiredArgs.add_argument('-o', '--output-matrix', + dest='mat_out', + type=lambda p: Path(p).resolve(), + required=True, + help="File name of the square matrix") + optionalArgs.add_argument('--process-names', + action='store_true', + dest='process_names', + help='Use the accession as the row/column names and not the full path' + 'Assumes that the results look like /path/to/<accession>.fasta' + ) + + parser._action_groups.append(optionalArgs) + return parser.parse_args() + +def convert_ani_matrix_to_square(ani_output_fp, process_names=True): + """ + Helper function to convert the output fastANI matrix to a square matrix. + Makes some calculations easier. + :param ani_output_fp: Path to fastANI output matrix file + :return: A tuple of (the matrix itself, the names of the files) + """ + with open(ani_output_fp, 'r') as fin: + # first line is the number of genomes compared + no_of_elements = int(fin.readline().strip()) + # The rest of the lines contain the values + data_lines = [line.strip() for line in fin] + # Each line's values start from the second column + data_fields = [line.split('\t')[1:] for line in data_lines] + # First element is the genome_file + genomes_files = [line.split('\t')[0] for line in data_lines] + if process_names: + names = [] + for gf in genomes_files: + # I am expecting something like 'NC_000866.4.fasta' + fname = Path(gf).name + name = fname.rstrip('.fasta') + names.append(name) + else: + names = genomes_files + # Initialize a square matrix of shape no_of_elements x no_of_elements + # filled with zeros + mat = np.zeros((no_of_elements, no_of_elements)) + + # A loop to get values to fill the zero-filled matrix + for i in range(0, no_of_elements): + for j in range(0, no_of_elements): + # Diagonal + if i == j: + value = 100. + # Weird but works + # TO DO add better documentation of what is happening here + if i < j: + value = data_fields[j][i] + if value == 'NA': + value = 0. + if i > j: + value = data_fields[i][j] + if value == 'NA': + value = 0. + mat[i, j] = value + return mat, names + +def rewrite_ani_matrix_as_square(fastani_raw, output_fp, **kwargs): + """ + Helper function for writing the result + """ + # Get the matrix and names to use + mat, names = convert_ani_matrix_to_square(fastani_raw, process_names=kwargs['process_names']) + # Create a data frame + df = pd.DataFrame(mat, index=names, columns=names) + # Write the matrix to the file + df.to_csv(output_fp, header=True, index=True, sep='\t') + +def main(): + args = parse_args() + rewrite_ani_matrix_as_square(args.mat_in, args.mat_out, process_names=args.process_names) + +if __name__ == '__main__': + main() diff --git a/workflow/scripts/filter_intact.py b/workflow/scripts/filter_intact.py new file mode 100755 index 0000000000000000000000000000000000000000..cbe5146ebd6789efd4a0ea49f338e31fd3bcfcf2 --- /dev/null +++ b/workflow/scripts/filter_intact.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python +from ete3 import NCBITaxa, TextFace, TreeStyle, faces +import argparse + + +parser = argparse.ArgumentParser(description='Parse the IntAct DB text file to get viral related entries') + +# Store the default optional groups to the optionalArgs +optionalArgs = parser._action_groups.pop() + +# Add requiredArgs as an argument group for displaying help better +requiredArgs = parser.add_argument_group("required arguments") +requiredArgs.add_argument('-i', dest="intact_in", + help="The raw txt intact file", + required=True) +requiredArgs.add_argument('-o', dest="intact_out", + required=True, + help="File path to write results to") + +# Add the rest of the optional arguments +optionalArgs.add_argument('-v', '--vhost', + dest="vhost", + required=False, + help="A txt file containing a list of bacteria and archaea associated virus taxids") +optionalArgs.add_argument('--phages-only', + dest='phages_only', + action='store_true', + help='Filter taxids to only contain phages provided in the vhost file') +optionalArgs.add_argument("--tax-db", + dest="tax_db", + required=False, + help="Location of the NCBI taxonomy sqlite db, created" + ) +# Add the optionalArgs to the action_groups +parser._action_groups.append(optionalArgs) + +def file_to_list(filename): + """ + Read in a one-column txt file to a list + :param filename: + :return: A list where each line is an element + """ + with open(filename, 'r') as fin: + alist = [line.strip() for line in fin] + return alist + + +def parse_field(field_string): + """ + Get a tuple of the first entries found in an psimitab column + (xref, value, desc) + :param field_string: + :return: A tuple (xref, value, desc) - desc is optional + """ + parts = field_string.split(":") + xref = parts[0] + rest = parts[1] + if "(" in rest: + subfields = rest.split("(") + value = subfields[0] + desc = subfields[1][:-1] + else: + value = rest + desc = "NA" + return xref, value, desc + + +def parse_column_string(astring): + """ + Parses the information stored in a column field of psi-mitab + :param astring: The string of psimitab + e.g. `taxid:9606(human)|taxid:9606(Homo sapiens)` + :return: A dictionary of the values + {'xref': ['taxid'], + 'value': ['9606'], + 'desc': ['human', 'Homo sapiens']} + """ + + # Initialize an empty list for each xref, value, desc key + column_data = {k: [] for k in ['xref', 'value', 'desc']} + if "|" in astring: + fields = astring.split("|") + else: + fields=[astring] + + for f in fields: + result = parse_field(f) + xref, value, desc = result[0], result[1], result[2] + if xref not in column_data['xref']: + column_data['xref'].append(xref) + if value not in column_data['value']: + column_data['value'].append(value) + if desc not in column_data['desc']: + column_data['desc'].append(desc) + return column_data + + +## TREE STUFF ## +def layout(node): + """ + Helper function to annotate nodes with scientific name + :param node: A node instance of a ete3.Tree + :return: Annotated node with sci_name + """ + faces.add_face_to_node(TextFace(node.sci_name), node, 0) + + +if __name__ == '__main__': + # Parse the args + args = parser.parse_args() + ncbi = NCBITaxa(dbfile=args.tax_db) + ncbi_viruses = ncbi.get_descendant_taxa('10239', intermediate_nodes=True) + + intact_in = args.intact_in + intact_out = args.intact_out + + # If a file with host taxids is provided + if args.vhost: + phage_taxids_file = args.vhost + # Get the list of taxids + phage_taxids = list(map(int, file_to_list(phage_taxids_file))) + print("Phages list : {} entries".format(len(phage_taxids))) + else: + phage_taxids = set() + # If we want to only focus on phages + if args.phages_only: + ncbi_phages = ncbi.get_descendant_taxa('28883', intermediate_nodes=True) # 28883 -> Caudovirales + compare_list = set.union(set(phage_taxids), set(ncbi_phages)) + else: + ncbi_virus_taxids = ncbi.get_descendant_taxa('10239', intermediate_nodes=True) + print("NCBI Viruses descendants: {}".format(len(ncbi_virus_taxids))) + compare_list = set.union(set(ncbi_virus_taxids), set(phage_taxids)) + + print("Contents of input file will be scanned against {} viral associated taxids.". + format(len(compare_list))) + # Initialize an interaction counter and a virus-virus counter + int_counter, vv_counter = 0,0 + # Store the parsed taxids + taxids = [] + with open(intact_in, 'r') as fin, open(intact_out, 'w') as fout: + header_line = fin.readline() + fout.write(header_line) + for line in fin: + int_counter += 1 + # Print some progress + if int_counter % 10000 == 0: + print("Parsed {} entries".format(int_counter)) + columns = [field.strip() for field in line.split("\t")] + # On PSIMI-TAB v25 columns 8,9 have the taxonomy info + if columns[9] != '-' and columns[10] != '-': + taxidA_data = parse_column_string(columns[9]) + taxidB_data = parse_column_string(columns[10]) + taxidA, taxidB = taxidA_data['value'], taxidB_data['value'] + # If the taxids are both viral + if (int(taxidA[0]) in compare_list) and (int(taxidB[0]) in compare_list): + vv_counter += 1 + fout.write(line) + taxids.append(taxidA[0]) + taxids.append(taxidB[0]) + + print("{} entries were parsed.".format(int_counter)) + print("{} entries were viral".format(vv_counter)) + taxids_set = set(taxids) + print("Filtered taxids list contains {} taxa".format(len(taxids_set))) diff --git a/workflow/scripts/get_uniprot_ids.py b/workflow/scripts/get_uniprot_ids.py new file mode 100755 index 0000000000000000000000000000000000000000..b26e81973bf06c1d46982ba1143eded7d2543f76 --- /dev/null +++ b/workflow/scripts/get_uniprot_ids.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python + +import argparse +import pandas as pd + +parser = argparse.ArgumentParser(description="Produce a list of uniprot ids from the metadata.tsv (output of\n" + "summarize_intact.py", + formatter_class=argparse.RawTextHelpFormatter) +optionalArgs = parser._action_groups.pop() + +requiredArgs = parser.add_argument_group("required arguments") +requiredArgs.add_argument('-i', '--input_tsv', + dest='input_tsv', + type=str, + required=True, + help="The metadata.tsv file as produced from summarize_intact.py") +requiredArgs.add_argument('-l', '--output-list', + dest='output_list', + help="A txt file containing the raw uniprot ids from intact.\n" + "This will be used to post a query to the uniprot mapping tool", + required=True) +requiredArgs.add_argument('-o', '--output-tsv', + dest='output_tsv', + required=True, + help='A tsv file that contains the result of filtering') + +parser._action_groups.append(optionalArgs) + +if __name__ == '__main__': + args = parser.parse_args() + df = pd.read_csv(args.input_tsv, sep='\t') + # Select non-self interactions where both interactors come from uniprot + interactions = df.loc[(df['same_protein'] == 0) & + (df['source_A'] == 'uniprotkb') & + (df['source_B'] == 'uniprotkb'), ] + + interactions.to_csv(args.output_tsv, sep='\t', index=False) + + print(f'Number of non-self interactions with uniprotkb identifiers: {interactions.shape[0]}') + + # Create a protA_info data-frame with two columns prot_id, source_db + protA_info = interactions[['prot_A', 'source_A']].reset_index(drop=True) + protA_info = protA_info.rename(columns={'prot_A': 'prot_id', + 'source_A': 'source_db'}) + # Create a protB_info data-frame with two columns prot_id, source_db + protB_info = interactions[['prot_B', 'source_B']].reset_index(drop=True) + protB_info = protB_info.rename(columns={'prot_B': 'prot_id', + 'source_B': 'source_db'}) + # Concatenate the two data-frames + prot_info = pd.concat([protA_info, protB_info], ignore_index=True) + # Drop the duplicates based on uniprot_id + prot_info.drop_duplicates(subset='prot_id', inplace=True) + print(f'{prot_info.shape[0]} unique proteins come from uniprot') + + uniprot_ids = prot_info['prot_id'].to_list() + + # Write the identifiers to the output file + with open(args.output_list, 'w') as fout: + for uniprot_id in uniprot_ids: + fout.write(f'{uniprot_id}\n') diff --git a/workflow/scripts/make_protein_combos.py b/workflow/scripts/make_protein_combos.py new file mode 100755 index 0000000000000000000000000000000000000000..457f75ad346fca22c3ecd9d69c2c53d22989569b --- /dev/null +++ b/workflow/scripts/make_protein_combos.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python + +import argparse +from itertools import combinations +import random +from Bio import SeqIO +import pathlib +from Bio.Seq import Seq +from Bio.SeqRecord import SeqRecord +from Bio.Alphabet import IUPAC + + +parser = argparse.ArgumentParser(description='Create a 2-column tsv file with --number-of combinations of proteins' + 'from the same genome, for all genomes that have more than 1 ' + 'protein. ' + 'If a 2-column tsv files is provided with --exclude, the set of these ' + 'interactions will be excluded.') +optionalArgs = parser._action_groups.pop() + +requiredArgs = parser.add_argument_group("required arguments") +requiredArgs.add_argument('-i', '--input-gb', + dest='input_gb', + required=True, + help="A genbank file containing annotated genomes") +requiredArgs.add_argument('-o', '--fasta-out', + dest='fasta_out', + type=str, + required=True, + help="The output faa file to write to") +requiredArgs.add_argument('-x', '--interactions-out', + dest='interactions_out', + type=str, + required=True, + help="The output tsv file to write the interactions in (x for A *x* B)") + +requiredArgs.add_argument('-a', '--all-proteins-faa', + dest='all_proteins_fasta', + help='A fasta file that contains all proteins as extracted from the genbank file', + required=True) +optionalArgs.add_argument('--number-of-pairs', + required=False, + help="select this amount of pairs to output", + dest='sample_no', + type=int) +optionalArgs.add_argument('--exclude', + required=False, + type=str, + help="A 2-column tsv file that contains interactions to be excluded") +optionalArgs.add_argument('--sample-size', + required=False, + dest='sample_size', + type=int, + help="If an --exclude file is present, set the number of interactions" + "to sample to N * sample-size, where N=number of interactions in the exclude file", + default=1 + ) +optionalArgs.add_argument('--random-seed', + required=False, + dest='random_seed', + type=int, + default=46, + help="Set the random seed for reproducibility [default = 46]" + ) + +parser._action_groups.append(optionalArgs) + +def create_all_possible_combos(genbank_fin): + """ + Create a set of all possible protein interactions, per genome, + from the input genbank_file + """ + # all_pairs_dict = {} + all_pairs = [] + for record in SeqIO.parse(genbank_fin, format='genbank'): + proteins = [] + for f in record.features: + if f.type == 'CDS': + protein = f.qualifiers.get('protein_id') + + if protein: + protein_id = protein[0] + proteins.append(protein_id) + + if len(proteins) >= 2: + protein_combos = list(combinations(sorted(proteins), 2)) + + # all_pairs_dict[genome_acc] = protein_combos + for pair in protein_combos: + all_pairs.append(pair) + elif len(proteins) == 1: + pass + + # Sort the tuples by protein id, to make comparisons easier + all_pairs_sorted = sorted([tuple(sorted(i)) for i in all_pairs]) + + return all_pairs_sorted + + +def create_exclude_set(exclude_file): + """ + Create a list of interactions to be excluded from analysis, + e.g. if they are part of the positive set. + """ + excludes = [] + with open(exclude_file, 'r') as fin: + for line in fin: + fields = line.split('\t') + p1, p2 = fields[0].strip(), fields[1].strip() + interaction = (p1,p2) + excludes.append(sorted(interaction)) + exclude_pairs = sorted(set([tuple(sorted(i)) for i in excludes])) + + return exclude_pairs + + +if __name__ == '__main__': + args = parser.parse_args() + + interactions_pool = create_all_possible_combos(args.input_gb) + #print("Total raw pool: {}".format(len(interactions_pool))) + + if args.exclude: + exclude_set = create_exclude_set(args.exclude) + print('{} were input for exclusion'.format(len(exclude_set))) + interactions_pool = sorted(set(interactions_pool) - set(exclude_set)) + print("Total pool after exclusion: {}".format(len(interactions_pool))) + if not args.sample_size: + print("The number of interactions to be selected will be set to {}".format(len(exclude_set))) + print("You can use the --sample-size argument to sample more") + sample_size = args.sample_size + else: + sample_size = args.sample_size * len(exclude_set) + elif not args.exclude and not args.sample_no: + # Just testing this + parser.error("Either provide an input file with interactions to exclude and set the sample-size option" + "or set the number of interactions you want to sample with the --number-of-pairs option") + elif args.sample_no: + sample_size = args.sample_no + else: + parser.error("This is unexpected") + + with open(args.all_proteins_fasta, 'r') as fin: + seq_data = SeqIO.to_dict(SeqIO.parse(fin, format="fasta")) + + random.seed(args.random_seed) + + final_interactions = random.sample(interactions_pool, sample_size) + final_interactions = sorted([tuple(sorted(i)) for i in final_interactions]) + #print(final_interactions[:5]) + #print("Produced {} interactions".format(len(final_interactions))) + + with open(args.interactions_out, 'w') as fout: + for interaction in final_interactions: + fout.write('{}\t{}\n'.format(interaction[0], interaction[1])) + #print("Interaction data are stored in {}".format(args.interactions_out)) + + accessions_list = [] + for interaction in final_interactions: + accessions_list.append(interaction[0]) + accessions_list.append(interaction[1]) + accessions_list = sorted(set(accessions_list)) + + writer = 0 + with open(args.fasta_out, 'w') as fout: + for accession in accessions_list: + writer += SeqIO.write(seq_data.get(accession), fout, format="fasta") + #print("Wrote {} unique sequences in file {}".format(writer, args.fasta_out)) diff --git a/workflow/scripts/predict.py b/workflow/scripts/predict.py new file mode 100755 index 0000000000000000000000000000000000000000..3b5764fb0b40b8e7c597553b73116962cf3910e4 --- /dev/null +++ b/workflow/scripts/predict.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python + +import argparse +from pathlib import Path + +import pandas as pd + +import pickle + +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import train_test_split + + +FEATURES = ['jaccard_score', + 'same_score', + 'inwards_score', + 'outwards_score', + 'avg_distance', + 'mean_ani', + 'mean_aai'] + +# Random State - for reproducibility +RS = 1 + +def parse_args(): + parser = argparse.ArgumentParser(description='Predict interactions on target') + + optionalArgs = parser._action_groups.pop() + + requiredArgs = parser.add_argument_group("required arguments") + requiredArgs.add_argument('-m', '--model', + dest='model_fp', + type=lambda p: Path(p).resolve(strict=True), + required=True, + help="A pickle file containing the RF parameters" + ) + requiredArgs.add_argument('-t', '--target_tsv', + dest='target_tsv', + type=lambda p: Path(p).resolve(strict=True), + required=True, + help="A tsv file containing feature calculations for the interactions") + requiredArgs.add_argument('-p', '--positive-set', + dest='positives_tsv', + type=lambda p: Path(p).resolve(strict=True), + required=True, + help="Positive set") + requiredArgs.add_argument('-n', '--negative-set', + dest='negatives_tsv', + type=lambda p: Path(p).resolve(strict=True), + required=True, + help="Negative set with labels") + requiredArgs.add_argument('-o', '--output-file', + dest='outfile', + required=True, + type=lambda p: Path(p).resolve(), + help="File path to write the results in") + optionalArgs.add_argument('-j', '--jobs', + dest='n_jobs', + type=int, + required=False, + default=1, + help="Number of parallel jobs to run for classification (default : 1)" + ) + parser._action_groups.append(optionalArgs) + + return parser.parse_args() + + +def read_scores_table(scores_fp, label=None): + """ + Read a table from a tsv file provided as a path. + Return a dataframe. + + If label is set, append a column named `label` with + the provided value + """ + scores_df = pd.read_csv(scores_fp, sep='\t') + scores_df['interaction'] = scores_df['pvog1'] + '-' + scores_df['pvog2'] + # Unpack features list because I want interaction prepended. + scores_df = scores_df[['interaction', *FEATURES]] + if label is not None: + scores_df['label'] = label + return scores_df + + +def concat_data_frames(pos_df, + neg_df, + subsample=False, + clean=True, + is_scaled=True, + ): + """ + Concatenate two dataframes + + subsample:bool + Subsample the `neg_df` to a number of + observations equal to the number of pos_df + (balance datasets) + + clean:bool + Remove observations with a value of `avg_distance` == 100000 or 1. + + is_scaled:bool + The features have been scaled to a range 0-1 + + Return: + concat_df: pd.DataFrame + The concatenated data frame + """ + + if clean is True: + pos_df = remove_ambiguous(pos_df) + neg_df = remove_ambiguous(neg_df) + + n_positives = pos_df.shape[0] + n_negatives = neg_df.shape[0] + + # Remove possible duplicate interactions from the negatives + # This might happen because of the random selection when creating the set + # Why I also select more negatives to begin with + neg_df = neg_df.loc[~neg_df['interaction'].isin(pos_df['interaction'])] + + if (n_positives != n_negatives) and (subsample is True): + neg_df = neg_df.sample(n=n_positives, random_state=1) + concat_df = pd.concat([pos_df, neg_df]).reset_index(drop=True) + + assert concat_df[concat_df.duplicated(subset=['interaction'])].empty == True, concat_df.loc[concat_df.duplicated(subset=['interaction'], keep=False)] + + return concat_df + + +def scale_df(input_df): + """ + Scale all feature values in the data frame to [0-1]. + """ + maxes = input_df[FEATURES].max(axis=0) + scaled_data = input_df[FEATURES].divide(maxes) + if 'label' in input_df.columns: + scaled_df = pd.concat([input_df['interaction'], scaled_data, input_df['label']], axis=1) + else: + scaled_df = pd.concat([input_df['interaction'], scaled_data], axis=1) + return scaled_df + + +def remove_ambiguous(input_df): + """ + Select observations in the `input_df` that have feature values + """ + df_clean = input_df[(input_df.jaccard_score != 0)] # This is true if they don't co-occur + return df_clean + + + +if __name__ == '__main__': + args=parse_args() + + # Create the training set dataframe + pos_df = read_scores_table(args.positives_tsv, label=1) + print("Raw positive set interactions : {}".format(pos_df.shape[0])) + + neg_df = read_scores_table(args.negatives_tsv, label=0) + print("Raw negative set interactions : {}".format(neg_df.shape[0])) + + scaled_pos_df = scale_df(pos_df) + scaled_neg_df = scale_df(neg_df) + training_df = concat_data_frames(scaled_pos_df, + scaled_neg_df, + subsample=True, + clean=True, + is_scaled=True) + print("Processed training set interactions: {}".format(training_df.shape[0])) + + # Read in the target dataset + target_df = read_scores_table(args.target_tsv) + input_targets = target_df.shape[0] + print("Target set interactions: {} ".format(input_targets)) + + # Clean the target from distances + target_df = remove_ambiguous(target_df) + # Scale it + target_df = scale_df(target_df) + + # Remove posnegs + target_df = target_df.loc[~target_df['interaction'].isin(training_df['interaction'])] + final_targets = target_df.shape[0] + print("Removed {} interactions from target ( {} remaining)".format((input_targets - final_targets), final_targets)) + + # Read in the model + with open(args.model_fp, 'rb') as fin: + RF = pickle.load(fin) + + if args.n_jobs: + RF.n_jobs = args.n_jobs + + print("Classifier: {}".format(RF)) + + + Xt = training_df[FEATURES] + yt = training_df['label'] + # Keep these to append + interactions = training_df['interaction'] + # Get the feature values only, for predictions on the target + X = target_df[FEATURES] + + # Split the training set to train/holdout (0.7/0.3) + # Holdout is thrown away here... + # This is required for consistency with the whole model selection process + X_train, X_holdout, y_train, y_holdout = train_test_split(Xt, + yt, + test_size=0.3, + random_state=RS) + # write the actual training set used to a file + # So this can be appended to the complete final table + # Make a copy + X_training_out = X_train.copy() + # Append the label + X_training_out['label'] = y_train + # Append the interaction information + X_training_out['interaction'] = training_df['interaction'] + # Give the training set a proba of 1 + X_training_out['proba'] = 1. + # Rearrange columns + X_training_out = X_training_out[['interaction', *FEATURES, 'label', 'proba']] + + # Construct the name of the output file + dir_name = args.outfile.parent + training_set_fp = dir_name / Path("final_training_set.tsv") + X_training_out.to_csv(training_set_fp, sep = '\t', index=False) + + + + + print("Fitting...") + RF.fit(X_train, y_train) + + print("Predicting...") + X_pred = RF.predict(X) + X_pred_proba = RF.predict_proba(X) + + target_df['label'] = X_pred + target_df['proba'] = X_pred_proba[:,1] + + target_df.to_csv(args.outfile, sep="\t", index=False) + print("Finished! Results are written in {}".format(str(args.outfile))) + diff --git a/workflow/scripts/process_annotations.py b/workflow/scripts/process_annotations.py new file mode 100644 index 0000000000000000000000000000000000000000..ae82e20ddab49e56ad9a515fac3fd78264926303 --- /dev/null +++ b/workflow/scripts/process_annotations.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python + +import argparse +import pandas as pd +from pathlib import Path +from collections import Counter +import operator + +def parse_args(): + parser = argparse.ArgumentParser(description="Parse pvogs annotations from the VOGProteinTable.txt " + "to a table that contains most frequently occurring raw and " + "processed annotation") + optionalArgs = parser._action_groups.pop() + + requiredArgs = parser.add_argument_group("Required arguments") + requiredArgs.add_argument('-i', '--input-file', + dest='input_fp', + type=str, + required=True, + help="The VOGProteinTable.txt from the database") + requiredArgs.add_argument('-o', '--output-file', + dest='output_fp', + type=str, + required=True, + help="The output table location") + + parser._action_groups.append(optionalArgs) + return parser.parse_args() + +def parse_annotations(vog_table_fp): + """ + Parse the pvogs annotations from the VOGProteinTable.txt + This tries to normalize the annotations into a standard + set of descriptions that is used for counting. + + Anything that is annotated as hypothetical is translated + to 'unknown'. + + 'protein' and 'putative' are also stripped. + + Returns + annotations_dic:dict + A dictionary that holds the processed terms + of the form + {pvog_id : + [annotation_1, annotation_2, ...], + ...} + + annotations_original:dict + A dictionary that holds the original terms + of the form + {pvog_id: + [annotation_1, annotation_2, ...], + ...} + + """ + + annotations_dic = {} + annotations_original = {} + with open(vog_table_fp, 'r') as fin: + for line in fin: + fields = line.split('|') + pvog = fields[0].split(':')[0] + if len(fields) > 2: + annotation = fields[2].split(':')[0] + processed_annotation = annotation + if 'hypothetical' in processed_annotation:#or annotation == '-': + processed_annotation=processed_annotation.replace('hypothetical protein', '').strip() + if processed_annotation == '': + processed_annotation = 'unknown' + else: + processed_annotation = processed_annotation.strip() + + if 'protein' in annotation: + processed_annotation = processed_annotation.replace('protein', '') + if 'putative' in annotation: + processed_annotation = processed_annotation.replace('putative', '') + else: + annotation = 'unknown' + processed_annotation = 'unknown' + + if pvog not in annotations_dic: + annotations_dic[pvog] = [processed_annotation.strip()] + else: + annotations_dic[pvog].append(processed_annotation.strip()) + + if pvog not in annotations_original: + annotations_original[pvog] = [annotation.strip()] + else: + annotations_original[pvog].append(annotation.strip()) + + return annotations_dic, annotations_original + +def get_max_count_annotation(annotations_dic): + """ + Get the annotation with the max count from a given `annoations_dic` + + annotations_dic:dict + A dictionary of the form + {pvog_id : + [annotation_1, annotation_2, ...], + ...} + Return: + unique_only:dict + A dictionary with a pvog as key and + the annotation with the highest occurring frequency + """ + unique_only = {} + for pvog in annotations_dic: + c = Counter(annotations_dic.get(pvog)) + most_common = max(c.items(), key=operator.itemgetter(1))[0] + unique_only[pvog] = most_common + return unique_only + +def create_annotations_df(vog_table_fp): + """ + A wrapper function that parses the annotations file + into a table + + vog_table_fp:Path-like object + The filepath of VOGProteinTable.txt + + Return + df_processed:pd.DataFrame + A dataframe that holds processed and raw annotations + for each pvog + """ + parsed_annotations = parse_annotations(vog_table_fp) + processed_annotations = parsed_annotations[0] + raw_annotations = parsed_annotations[1] + max_data_processed = get_max_count_annotation(processed_annotations) + max_data_raw = get_max_count_annotation(raw_annotations) + df_processed = pd.DataFrame.from_dict(max_data_processed, + orient='index', + columns=['annotation_processed']) + df_raw = pd.DataFrame.from_dict(max_data_raw, + orient='index', + columns=['annotation_raw']) + df_processed = df_processed.reset_index().rename(columns = {'index': 'pvog'}) + df_raw = df_raw.reset_index().rename(columns={'index': 'pvog'}) + df_processed['annotation_raw'] = df_raw['annotation_raw'] + df_processed = df_processed.set_index('pvog') + + return df_processed + +def main(): + args = parse_args() + annotations_df = create_annotations_df(Path(args.input_fp)) + annotations_df.to_csv(Path(args.output_fp), sep='\t') + + +if __name__ == '__main__': + main() + + + diff --git a/workflow/scripts/process_comparem.py b/workflow/scripts/process_comparem.py new file mode 100755 index 0000000000000000000000000000000000000000..65586b3321a36638b020fdc33e1cb170abcb2abe --- /dev/null +++ b/workflow/scripts/process_comparem.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python + +import pandas as pd +import numpy as np +import argparse +from pathlib import Path + +parser = argparse.ArgumentParser(description="Rewrite comparem output as square-matrix and calculate " + "mean aai and nonzero mean aai for the records.") +parser.add_argument("--input-tsv", "-i", + required=True, + dest="input_tsv", + help="The aai_summary.tsv file produced with comparem's aai_wf") +parser.add_argument("--output-tsv", "-o", + dest='output_tsv', + required=True, + help="Output tsv file to write square matrix to") + +def process_comparem_output(aai_out, matrix_out): + with open(aai_out, 'r') as fin: + # Skip header + next(fin) + comparisons = {} + for line in fin: + fields = [f.strip() for f in line.split()] + + genomeA, genomeB, mean_aai = fields[0], fields[2], float(fields[5]) + + if genomeA not in comparisons: + comparisons[genomeA] = {genomeA: 100.} + if genomeB not in comparisons: + comparisons[genomeB] = {genomeB: 100.} + + comparisons[genomeA][genomeB] = mean_aai + comparisons[genomeB][genomeA] = mean_aai + + df = pd.DataFrame.from_dict(comparisons) + # Write the square matrix + df.to_csv(matrix_out, sep='\t', header=True) + return True + +# # Convert df to numpy array +# mat = df.to_numpy() +# # Calculate mean aai for all comparisons +# aai = mat.mean() +# # Calculate non zero mean +# nonzeros = np.count_nonzero(mat) +# nonzero_mean_aai = mat.sum() / nonzeros +# +# return aai, nonzero_mean_aai + + +if __name__ == "__main__": + args = parser.parse_args() + input_fp = Path(args.input_tsv) + square_matrix_fp = Path(args.output_tsv) + process_comparem_output(input_fp, square_matrix_fp) + +# if args.write_result: +# # I assume the aai_summary.tsv is in /path/to/INTERACTION/aai_out/aai/aai_summary.tsv +# interaction = input_dir.parent.parent.name +# result_file = Path.joinpath(input_dir, 'aai_means.tsv') +# with open(result_file, 'w') as fout: +# fout.write('interaction\tmean_aai\tnonzero_mean_aai\n') +# fout.write(f'{interaction}\t{aai}\t{nonzero_mean_aai}\n') +# else: +# print(f'Mean aai : {aai}') +# print(f'Non-zero mean aai: {nonzero_mean_aai}') diff --git a/workflow/scripts/process_uniprot.py b/workflow/scripts/process_uniprot.py new file mode 100755 index 0000000000000000000000000000000000000000..2f633ad3c081a6727cbee02b20ab4083dfde2b61 --- /dev/null +++ b/workflow/scripts/process_uniprot.py @@ -0,0 +1,376 @@ +#!/usr/bin/env python + +import argparse +from Bio import SwissProt, SeqIO +from Bio.Seq import Seq +from Bio.SeqRecord import SeqRecord +from Bio.Alphabet import IUPAC +import pathlib +from ast import literal_eval as make_tuple +import pandas as pd + +parser = argparse.ArgumentParser(description="This is to get several mapping files for the interactions.", + formatter_class=argparse.RawTextHelpFormatter) +optionalArgs = parser._action_groups.pop() + +requiredArgs = parser.add_argument_group("required arguments") +requiredArgs.add_argument('-i', '--interactions-tsv', + dest='interactions_tsv', + required=True, + help="Filtered interactions_tsv from get_uniprot_ids.py") +requiredArgs.add_argument('-l', '--input_list', + dest='uniprots_list', + type=str, + required=True, + help="The output of get_uniprot_ids.py") +requiredArgs.add_argument('-s', '--swissprot-file', + dest='swissprot_file', + required=True, + help="A SwissProt file, containing results based on the primary ids") +requiredArgs.add_argument('-p', '--prefix-out', + dest='prefix_out', + required=True, + help="The prefix of the file names to write to. This will produce files:\n" + "<prefix>/proteins.faa,\n" + "<prefix>/uniprot2ncbi.mapping.txt\n" + "<prefix>/ncbi2uniprot.mapping.txt\n" + "<prefix>/skipped.uniprot.txt\n" + "<prefix>/ncbi_interactions.txt\n" + "<prefix>/ncbi_genomes.txt") +parser._action_groups.append(optionalArgs) + + +def has_refseq(db_list): + """ + Return the index of the list where the 'RefSeq' string is located. + Otherwise return None + :param db_list: A list of db names taken as the first element of the tuples in a + Swissprot.record.cross_references list + :return: int: index or None + """ + if 'RefSeq' in db_list: + return db_list.index('RefSeq') + else: + return None + + +def give_me_proper_embl(cross_refs): + """ + Filter for references where the first element == 'EMBL', + then search for the first occurence where the genome accession is not '-'. + This is to get both a valid protein accession and genome accession. + + :param cross_refs: The full list of SwissProt.record.cross_references + :return: + """ + # Get embl records first + embl_refs = [] + for ref in cross_refs: + if ref[0] == 'EMBL': + embl_refs.append(ref) + + genome_acc, prot_acc = embl_refs[0][1], embl_refs[0][2] + + for ref in embl_refs: + if ref[2] != '-': + genome_acc, prot_acc = ref[1], ref[2] + break + return genome_acc, prot_acc + + +def has_embl(db_list): + """ + Check if 'EMBL' is in a list of database names + :param db_list: + :return: True or None + """ + if 'EMBL' in db_list: + return True + else: + return None + + +def get_ncbi_accessions(cross_refs): + """ + Given a SwissProt.record.cross_references list, extract a tuple + of genome,protein accessions + :param cross_refs: Bio.SwissProt.record.cross_references list + :return: (genome_accession, protein_accession) associated with the record + """ + dbs = [ref[0] for ref in cross_refs] + + if has_refseq(dbs) is not None: + loc = has_refseq(dbs) + genome_acc = cross_refs[loc][2] + if '[' in genome_acc: + genome_acc = genome_acc.split()[0].rstrip('.') + prot_acc = cross_refs[loc][1] + elif has_embl(dbs) is not None and has_refseq(dbs) is None: + genome_acc, prot_acc = give_me_proper_embl(cross_refs) + # genome_acc, prot_acc = cross_refs[loc][1], cross_refs[loc][2] + else: + genome_acc, prot_acc = '-', '-' + + return genome_acc, prot_acc + + +def construct_description(protein_seq, uniprotkb_id, ncbi_genome, ncbi_protein, start=None, end=None, taxid=None): + if start is not None and end is not None: + return f'uniprotkb={uniprotkb_id};' \ + f'ncbi_protein={ncbi_protein};' \ + f'start={start};' \ + f'end={end};' \ + f'ncbi_genome={ncbi_genome};' \ + f'taxid={taxid}' + else: + return f'uniprotkb={uniprotkb_id};' \ + f'ncbi_protein={ncbi_protein};' \ + f'start=1;' \ + f'end={len(protein_seq)};' \ + f'ncbi_genome={ncbi_genome};' \ + f'taxid={taxid};' + + +def extract_prochain_sequence_from_record(record, + primary_id, + prochain_id, + genome_acc='-', + protein_acc='-'): + r_taxid = record.taxonomy_id + for feature in record.features: + if feature[0] == 'CHAIN' and feature[-1] == prochain_id: + start, end = feature[1], feature[2] + rec_seq = Seq(record.sequence[start - 1:end], IUPAC.protein) + original = '-'.join([primary_id, prochain_id]) + if protein_acc != '-': + new_id = f'{protein_acc}_at_{start}-{end}' + prot_rec = SeqRecord(rec_seq, id=new_id, + description=construct_description(str(rec_seq), + original, + genome_acc, + protein_acc, + start, + end, + taxid=r_taxid)) + else: + prot_rec = SeqRecord(rec_seq, id=original, + description=construct_description(str(rec_seq), + original, + genome_acc, + protein_acc, + start, + end, + taxid=r_taxid)) + return prot_rec + +def uniprot_id_type(uniprot_id): + """ + Check the type of a given uniprot_id + P04591 -> 'primary' + P04591-PRO_0000038593 -> 'prochain' + P04591-3 -> 'isoform + :param uniprot_id: String of uniprot_id + :return: + """ + if '-' in uniprot_id: + second_part = uniprot_id.split('-')[1] + if second_part.startswith('PRO'): + return 'prochain' + else: + return 'isoform' + else: + return 'primary' + +def give_me_the_record(primary_id, swissprot_file): + """ + Return a single record given with the primary id + :param primary_id: A primary id + :param swissprot_file: A swissprot file + :return: A record with accession == primary id + """ + with open(swissprot_file, 'r') as fh: + for record in SwissProt.parse(fh): + if primary_id in record.accessions: + return record + +def extract_protein_from_record(record): + """ + Grab the protein sequence as a string from a SwissProt record + :param record: A Bio.SwissProt.SeqRecord instance + :return: + """ + return str(record.sequence) + + +if __name__ == '__main__': + args = parser.parse_args() + + with open(args.uniprots_list, 'r') as fin: + uniprots_list = [line.strip() for line in fin] + + # Create the names of the output files + output_dir = pathlib.Path(args.prefix_out) + output_dir.mkdir(exist_ok=True) + # This is the fasta file that contains the proteins + output_fasta = pathlib.Path.joinpath(output_dir, 'proteins.faa') + + + # Initialize several book-keeping lists and dicts + # Some records I skip + skipped = [] + # A dict {uniprotkb : ncbi-accession , ...} + uniprot2ncbi_protein_mapping = {} + # Keep track of how many sequences are being written + written = 0 + # A dict { ncbi_protein_accession : [ uniprot_accession1, ...] , ... } + duplicates = {} + genomes = [] + + with open(output_fasta, 'w') as fout: + # Loop over the initial list of interactors + for protein_id in uniprots_list: + # Is it primary/prochain/isoform? + entry_type= uniprot_id_type(protein_id) + if entry_type == 'primary': + # Grab the associated record + r = give_me_the_record(protein_id, args.swissprot_file) + # Get the genome an protein accessions + ncbi_genome, ncbi_protein = get_ncbi_accessions(r.cross_references) + + # I want to map the uniprot ids to ncbi ones. + # If there is nothing to map skip the record altogether + if ncbi_protein == '-': + skipped.append(protein_id) + pass + else: + # For primary accessions the protein is the same + rec_id = ncbi_protein + rec_seq = extract_protein_from_record(r) + rec_taxid = r.taxonomy_id + description = construct_description(rec_seq, protein_id, ncbi_genome, ncbi_protein, taxid=rec_taxid) + # Construct a SeqRecord object to write to file + prot_rec = SeqRecord(Seq(rec_seq, IUPAC.protein), + id=rec_id, + description=description) + # Book-keeping + # match uniprot id to its ncbi counterpart + uniprot2ncbi_protein_mapping[protein_id] = prot_rec.id + # get the genome + genomes.append(ncbi_genome) + + # Avoid duplicates + if prot_rec.id not in duplicates: + # Put this unique identifier in the seen_proteins dict + duplicates[prot_rec.id] = [protein_id] + # And write to file + written += SeqIO.write(prot_rec, fout, "fasta") + else: + print("Possible duplicate:{}".format(prot_rec.id)) + duplicates[prot_rec.id].append(protein_id) + pass + elif entry_type == 'prochain': + # When a PRO is present, things need to be different + primary_id = protein_id.split('-')[0] + prochain = protein_id.split('-')[1] + r = give_me_the_record(primary_id, args.swissprot_file) + ncbi_genome, ncbi_protein = get_ncbi_accessions(r.cross_references) + if ncbi_protein == '-': + skipped.append(protein_id) + pass + else: + # The actual sequence is in the features table + prot_rec = extract_prochain_sequence_from_record(r, primary_id, prochain, + genome_acc=ncbi_genome, + protein_acc=ncbi_protein) + uniprot2ncbi_protein_mapping[protein_id] = prot_rec.id + genomes.append(ncbi_genome) + + # Avoid duplicates + if prot_rec.id not in duplicates: + written += SeqIO.write(prot_rec, fout, "fasta") + duplicates[prot_rec.id] = [protein_id] + else: + print("Possible duplicate {}".format(prot_rec.id)) + duplicates[prot_rec.id].append(protein_id) + pass + elif entry_type == 'isoform': + skipped.append(protein_id) + print("{} is an isoform and I do not handle these for now".format(protein_id)) + pass + print(30*'=') + print("{} proteins were extracted".format(written)) + + # This is a file that contains for each uniprot id its mapping to an ncbi protein + uniprot2ncbi_mapping_txt = pathlib.Path.joinpath(output_dir, 'uniprot2ncbi.mapping.txt') + uniprot_counter = 0 + with open(uniprot2ncbi_mapping_txt, 'w') as fout: + for uniprot_id, ncbi_id in uniprot2ncbi_protein_mapping.items(): + uniprot_counter += 1 + fout.write('{}\t{}\n'.format(uniprot_id, ncbi_id)) + print('{} uniprot ids are mapped to NCBI accessions'.format(uniprot_counter)) + + # A list of genomes to download from NCBI + genome_accessions_txt = pathlib.Path.joinpath(output_dir, 'genome_accessions.txt') + with open(genome_accessions_txt, 'w') as fout: + for genome in set(genomes): + fout.write('{}\n'.format(genome)) + print('{} genomes are written to file'.format(len(set(genomes)))) + + # Each ncbi protein, might map to more uniprot ids + # This files contains this information - an aggregation at ncbi level + ncbi2uniprot_mapping_txt = pathlib.Path.joinpath(output_dir, 'ncbi2uniprot.mapping.txt') + with open(ncbi2uniprot_mapping_txt, 'w') as fout: + for ncbi_p in duplicates: + fout.write('{}\t{}\n'.format(ncbi_p, ','.join(duplicates.get(ncbi_p)))) + + # This files contains only the duplicates + duplicates_tsv = pathlib.Path.joinpath(output_dir, 'duplicates.tsv') + with open(duplicates_tsv, 'w') as fout: + for ncbi_p in duplicates: + if len(duplicates.get(ncbi_p)) > 1: + fout.write('{}\t{}\n'.format(ncbi_p, ','.join(duplicates.get(ncbi_p)))) + + # This files contains the list of records that were originally input + # but were skipped. For now this happens + # (1) when no valid ncbi_protein accession is found + # (2) for isoforms + skipped_txt = pathlib.Path.joinpath(output_dir, 'skipped.txt') + with open(skipped_txt, 'w') as fout: + for uniprot in skipped: + fout.write('{}\n'.format(uniprot)) + + # Get the proteins that are actually covered through this filtering process + proteins_covered = [] + for ncbi_p in duplicates: + for uniprot_p in duplicates.get(ncbi_p): + proteins_covered.append(uniprot_p) + print('{} uniprot ids are covered'.format(len(proteins_covered))) + + # Get the interactions from the input file + interactions = pd.read_csv(args.interactions_tsv, sep='\t') + print('Input file contains {} interactions'.format(interactions.shape[0])) + # Filter the interactions dataframe to these interactions only + interactions_filtered = interactions.loc[interactions.prot_A.isin(proteins_covered) + & interactions.prot_B.isin(proteins_covered), ] + print('{} interactions are covered from the {} proteins above'.format(interactions_filtered.shape[0], + len(proteins_covered))) + + # Write the final set of interactions to file + interactions_filtered_tsv = pathlib.Path.joinpath(output_dir, 'interactions_filtered.tsv') + interactions_filtered.to_csv(interactions_filtered_tsv, sep='\t', index=False) + + # Also provide a tsv file that contains the uniprot-uniprot interaction + # as ncbi-ncbi interaction (for the negatives set construction) + ncbi_interactions_tsv = pathlib.Path.joinpath(output_dir, 'ncbi_interactions.tsv') + # This is a list of strings ["('intA', 'intB')" , ...] + interaction_pairs = interactions_filtered['interaction'].to_list() + with open(ncbi_interactions_tsv, 'w') as fout: + for pair in interaction_pairs: + # Convert the literal tuple string to actual tuple + # TO DO + # Just make it a tuple to begin with in the interactions file... + p = make_tuple(pair) + fout.write('{}\t{}\t{}\t{}\n'.format(p[0], p[1], + uniprot2ncbi_protein_mapping.get(p[0]), + uniprot2ncbi_protein_mapping.get(p[1]))) + print('Done!') diff --git a/workflow/scripts/query_uniprot.py b/workflow/scripts/query_uniprot.py new file mode 100755 index 0000000000000000000000000000000000000000..f5a56aca3e0fcd78a3592c16050da2cfff492033 --- /dev/null +++ b/workflow/scripts/query_uniprot.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python + +import argparse +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry +from math import ceil + +parser = argparse.ArgumentParser(description="Given a txt file with uniprot ids, post a query to uniprot\n" + "and retrieve a SwissProt file with the result", + formatter_class=argparse.RawTextHelpFormatter) +optionalArgs = parser._action_groups.pop() + +requiredArgs = parser.add_argument_group("required arguments") +requiredArgs.add_argument('-i', '--input_txt', + dest='input_txt', + type=str, + required=True, + help="The output of get_uniprot_ids.txt") +requiredArgs.add_argument('-o', '--output_sp', + dest='output_sp', + help="A txt file containing the raw uniprot ids from intact.\n" + "This will contain the response from uniprot", + required=True) +optionalArgs.add_argument("--filter-list", + action="store_true", + help="Specify if the input list needs to be filtered for PROs and isoforms", + dest="filter", + required=False) +optionalArgs.add_argument("--output-fmt", + default='txt', + type=str, + dest="output_fmt", + help="SwissProt records are returned by default.\n" + "You can use 'fasta' for retrieving only the fasta sequences\n" + "or 'xml' to get an xml file.") + +parser._action_groups.append(optionalArgs) + + +def get_primary_ids_from_list(uniprot_ids): + """ + Filter the input list to return only querrable primary ids. + As primary I define the part of the identifier's string that + serves as the accession. + E.g. P04591 is a primary id + P04591-PRO_0000038593 is a PTM chain with primary id P04591 + P04591-3 is an isoform with primary id P04591 + + This will only return P04591. The rest may be retrieved from parsing + the results of the query. + + :param uniprot_ids: A list of uniprot specific identifiers + :return: primary_ids: A list of primary ids + """ + primary_ids = [] + for i in uniprot_ids: + if '-' in i: + fields = i.split('-') + primary_id = fields[0] + # Append the primary id to the query list + if primary_id not in primary_ids: + primary_ids.append(primary_id) + + # Make a mapping of primary ids to their pro-chains and isoforms + elif i not in primary_ids: + primary_ids.append(i) + + # Double check I didn't put any duplicates in the final list + if len(primary_ids) != len(set(primary_ids)): + primary_ids = list(set(primary_ids)) + + return primary_ids + +# SHAMELESSLY STOLEN FROM +# https://www.ebi.ac.uk/training/online/sites/ebi.ac.uk.training.online/files/UniProt_programmatically_py3.pdf +# and modified to default to text + +def map_retrieve(ids2map, source_fmt='ACC+ID', target_fmt='ACC', output_fmt='txt'): + """ + Map database identifiers from/to UniProt accessions. + The mapping is achieved using the RESTful mapping service provided by UniProt. + While a great many identifiers can be mapped the documentation + has to be consulted to check which options there are and what the database codes are. + Mapping UniProt to UniProt effectlvely allows batch retrieval + of entries. + Args: + ids2map (list or string): identifiers to be mapped + source_fmt (str, optional): format of identifiers to be mapped. Defaults to ACC+ID, which are UniProt accessions or IDs. + target_fmt (str, optional): desired identifier format. Defaults to ACC, which is UniProt accessions. + output_fmt (str, optional): return format of data. Defaults to list. + Returns: + mapped information (str) + """ + + BASE = 'http://www.uniprot.org' + KB_ENDPOINT = '/uniprot/' + TOOL_ENDPOINT = '/uploadlists/' + + if hasattr(ids2map, 'pop'): + ids2map = ' '.join(ids2map) + + payload = {'from': source_fmt, + 'to': target_fmt, + 'format': output_fmt, + 'query': ids2map, + 'include': 'yes', # This is to include isoforms + } + + with requests.Session() as s: + retries = Retry(total=3, + backoff_factor=0.5, + status_forcelist=[500, 502, 503, 504]) + + s.mount('http://', HTTPAdapter(max_retries=retries)) + # s.mount('https://', HTTPAdapter(max_retries=retries)) + + response = s.get(BASE + TOOL_ENDPOINT, params=payload) + + if response.ok: + return response.text + else: + print(response.url) + response.raise_for_status() + + +if __name__ == '__main__': + args = parser.parse_args() + + with open(args.input_txt, 'r') as fin: + uniprots_list = [line.strip() for line in fin] + print("Input file contains {} records".format(len(uniprots_list))) + + if args.filter: + print("Filtering out PROs and isoforms") + uniprots_list = get_primary_ids_from_list(uniprots_list) + print("Final list contains {} primary ids".format(len(uniprots_list))) + + print("Posting queries of 100 to uniprot") + final_result = [] + total_batches = ceil(len(uniprots_list) / 100) + batch_no = 0 + for i in range(0, len(uniprots_list), 100): + batch_no += 1 + slise = uniprots_list[i:i + 100] + # Run the query for each batch + print("Running batch {}/{}".format(batch_no, total_batches)) + result = map_retrieve(slise, output_fmt=args.output_fmt) + # And store the result in a list + final_result.append(result) + + with open(args.output_sp, 'w') as fout: + fout.write(''.join(final_result)) diff --git a/workflow/scripts/refseqs_to_pvogs.py b/workflow/scripts/refseqs_to_pvogs.py new file mode 100755 index 0000000000000000000000000000000000000000..d78a9f202200ede11c11322e30ec12043b07655e --- /dev/null +++ b/workflow/scripts/refseqs_to_pvogs.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python + +from Bio import SearchIO, SeqIO +from pathlib import Path +import argparse + + +def parse_args(): + parser = argparse.ArgumentParser(description='From uniprot ids to pvogs') + + optionalArgs = parser._action_groups.pop() + + requiredArgs = parser.add_argument_group("required arguments") + requiredArgs.add_argument('-i', '--interactions-file', + dest='int_file', + type=lambda p: Path(p).resolve(strict=True), + required=True, + help="A 2-column tsv file containing interaction data, with refseq ids at" + ) + requiredArgs.add_argument('-f', '--input-fasta', + dest='fasta_in', + type=lambda p: Path(p).resolve(strict=True), + required=True, + help="A fasta file with protein sequences retrieved from uniprot") + requiredArgs.add_argument('-hmm', '--input-hmm', + dest='hmmer_tblout', + type=lambda p: Path(p).resolve(strict=True), + required=True, + help="The tblout file from hmmsearching the proteins in <input_fasta> " + "against all pvogs") + requiredArgs.add_argument('-o', '--output-file', + dest='outfile', + required=True, + type=lambda p: Path(p).resolve(), + help="File path to write the results in") + + parser._action_groups.append(optionalArgs) + + return parser.parse_args() + + +# This is higly specific for my case... +remapper = {'P17312-2': 'P17312', + 'P17312-3': 'P17312', + 'P17312-4': 'P17312', + 'P03705': 'P03705-2', + 'P03711-2': 'P03711', + 'P03692-2': 'P03692'} + + +def get_interactions_info(interactions_fp): + interactions = [] + with open(interactions_fp, 'r') as fin: + for line in fin: + fields = line.split('\t') + # Specific for get_uniprot_ids.py + intA, intB = fields[0].strip(), fields[1].strip() + interaction = tuple(sorted([intA, intB])) + interactions.append(interaction) + return interactions + + +def get_unique_interactor_ids(interactions_list): + interactor_ids = [] + for i in interactions_list: + interactor_ids.extend([i[0], i[1]]) + return set(interactor_ids) + + +def get_ids_from_fasta_file(fasta_in): + record_ids = [] + with open(fasta_in, 'r') as fin: + for record in SeqIO.parse(fin, "fasta"): + record_ids.append(record.id) + return record_ids + + +def parse_hmm_table(hmm_table): + """ + Create a dictionary that stores all results per protein + """ + hmm_results = {} + with open(hmm_table, 'r') as fin: + for record in SearchIO.parse(fin, "hmmer3-tab"): + for hit in record.hits: + prot_acc = hit.id + target = remapper.get(prot_acc, prot_acc) + if target not in hmm_results: + hmm_results[target] = [(record.id, hit.evalue, hit.bitscore)] + else: + hmm_results[target].append((record.id, hit.evalue, hit.bitscore)) + return hmm_results + + +def translate_proteins_to_pvogs(hmm_results_dic): + + protein_to_pvog = {} + for p in hmm_results_dic: + results = hmm_results_dic[p] + if len(results) == 1: + best_result = results[0] + else: + # Select the best scoring pvog + scores = [r[2] for r in results] + best_index = scores.index(max(scores)) + best_result = results[best_index] + + protein_to_pvog[p] = best_result[0] + + return protein_to_pvog + + +def main(): + args = parse_args() + + interactions = get_interactions_info(args.int_file) + print("Parsed {} interactions".format(len(interactions))) + + unique_ids = get_unique_interactor_ids(interactions) + print("Unique ids: {}".format(len(unique_ids))) + + fasta_ids = get_ids_from_fasta_file(args.fasta_in) + print("Entries in fasta: {}".format(len(fasta_ids))) + + hmm_results = parse_hmm_table(args.hmmer_tblout) + print("Proteins with hmmer hits: {}".format(len(list(hmm_results.keys())))) + + protein_pvog_map = translate_proteins_to_pvogs(hmm_results) + + no_results = [] + counter = 0 + with open(args.outfile, 'w') as fout: + for inter in interactions: + intA, intB = inter[0], inter[1] + if intA in protein_pvog_map: + if intB in protein_pvog_map: + fout.write("{}\t{}\t{}\t{}\n".format(intA, protein_pvog_map[intA], + intB, protein_pvog_map[intB])) + counter += 1 + else: + no_results.append(intB) + else: + no_results.append(intA) + + print("Interactions translated to pvogs: {}".format(counter)) + print("Proteins with no hits: {}".format(len(set(no_results)))) + + if len(no_results) > 0: + print(20 * '=') + for i in no_results: + print(i) + print(20 * '=') + + print("Results written in {}".format(args.outfile.__str__())) + + +if __name__ == '__main__': + main() diff --git a/workflow/scripts/remove_empty_files.py b/workflow/scripts/remove_empty_files.py new file mode 100755 index 0000000000000000000000000000000000000000..57eed45d72de25b67f0fab9e9c9744d9f2a4c91f --- /dev/null +++ b/workflow/scripts/remove_empty_files.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python + +import argparse +import pathlib +import os + +parser = argparse.ArgumentParser(description="Remove all files associated with a genome that are empty", + formatter_class=argparse.RawTextHelpFormatter) +optionalArgs = parser._action_groups.pop() + +requiredArgs = parser.add_argument_group("required arguments") +requiredArgs.add_argument('-i', '--inut-dir', + dest='input_dir', + required=True, + help="Input genes dir that contains the results of `comparem call_genes`") +requiredArgs.add_argument('-o', '--output-txt', + dest='output_txt', + type=str, + required=True, + help="Output txt to write result. Will contain the genomes removed, if any,\n" + "else it will have a 'All genomes passed!' so that is not empty") + +def check_file_is_empty(fin): + """ + + :param fin: A pathlib.Path object + :return: the file if true, nothing if false + """ + if fin.stat().st_size == 0: + return fin + else: + pass + +if __name__ == '__main__': + args = parser.parse_args() + + input_path = pathlib.Path(args.input_dir) + remove_these = [] + for f in list(input_path.glob('*.faa')): + if check_file_is_empty(f): + remove_these.append(f) + + with open(args.output_txt, 'w') as fout: + if len(remove_these) != 0: + for f in remove_these: + print('Removing file {}'.format(f.as_posix())) + f.unlink() + fout.write('{}\n'.format(f.as_posix())) + else: + fout.write("All genomes passed!\n") diff --git a/workflow/scripts/split_multifasta.py b/workflow/scripts/split_multifasta.py new file mode 100755 index 0000000000000000000000000000000000000000..8f09d42fc12613d36745dcd8587d12a8dcf1cf5e --- /dev/null +++ b/workflow/scripts/split_multifasta.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python + +import argparse +import pathlib +import gzip +from Bio import SeqIO + +def parse_args(): + parser = argparse.ArgumentParser( + description="Split a multi-fasta to single files per sequence" + ) + optionalArgs = parser._action_groups.pop() + + requiredArgs = parser.add_argument_group("required arguments") + + requiredArgs.add_argument( + "-i", + "--input", + dest="input_fp", + required=True, + help="Input fasta. Can be gz", + ) + requiredArgs.add_argument( + "-o", + "--outdir", + dest="out_dir", + required=True, + help="A directory to store the files. It is NOT created", + ) + optionalArgs.add_argument( + "--write-reflist", + action="store_true", + required=False, + dest="write_reflist", + help="Write a file that contains all paths to the output fasta files, " + "one per line, in the parent directory " + ) + + parser._action_groups.append(optionalArgs) + + return parser.parse_args() + +def is_gz(path_string): + """ + Return true if gzipped file + :param path: path to file + :return: boolean + """ + return path_string.endswith(".gz") or path_string.endswith(".z") + + +def optionally_compressed_handle(path, mode): + """ + Return a file handle that is optionally gzip compressed + :param path: path + :param mode: mode + :return: handle + """ + if mode == "r" or mode == "rb": + mode = "rt" + if mode == "w" or mode == "wb": + mode = "wt" + if is_gz(path): + return gzip.open(path, mode=mode) + else: + return open(path, mode=mode) + +def split_multifasta(input_fp, output_dir, write_reflist=False): + record_no = 0 + filenames = [] + with optionally_compressed_handle(str(input_fp), 'r') as fin: + for record in SeqIO.parse(fin, "fasta"): + genome_acc = record.id + single_fasta = "{}.fasta".format(genome_acc) + single_fasta_fp = output_dir.joinpath(single_fasta) + with open(single_fasta_fp, 'w') as fout: + record_no += SeqIO.write(record, fout, "fasta") + filenames.append(single_fasta_fp) + + if record_no % 10000 == 0: + print("processed {} records".format(record_no)) + if write_reflist: + reflist_txt = output_dir.parent.joinpath("reflist.txt") + with open(reflist_txt, 'w') as refout: + for f in filenames: + refout.write('{}\n'.format(f)) + + return record_no + +def main(): + args = parse_args() + a = split_multifasta(pathlib.Path(args.input_fp), + pathlib.Path(args.out_dir), + args.write_reflist) + + +if __name__ == '__main__': + main() diff --git a/workflow/scripts/subset_scores.py b/workflow/scripts/subset_scores.py new file mode 100755 index 0000000000000000000000000000000000000000..e6cf4a05bb813848fe1a8b1bb8f48ddaf05e05a7 --- /dev/null +++ b/workflow/scripts/subset_scores.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python + +import argparse +from pathlib import Path + +def parse_args(): + parser = argparse.ArgumentParser(description='Subset the scores to a given list of interactions') + + optionalArgs = parser._action_groups.pop() + + requiredArgs = parser.add_argument_group("required arguments") + requiredArgs.add_argument('-s', '--scores-file', + dest='scores_file', + type=lambda p: Path(p).resolve(strict=True), + required=True, + help="The output of calculate_all_scores.py" + ) + requiredArgs.add_argument('-i', '--input-ints', + dest='input_ints', + type=lambda p: Path(p).resolve(strict=True), + required=True, + help="A tsv file containing interaction mappings between ncbi and pvogs") + requiredArgs.add_argument('-o', '--output-file', + dest='outfile', + required=True, + type=lambda p: Path(p).resolve(), + help="File path to write the results in") + + parser._action_groups.append(optionalArgs) + + return parser.parse_args() + +def parse_pvogs_interactions(interactions_fp): + interactions = [] + uniques = 0 + all_in = 0 + with open(interactions_fp, 'r') as fp: + for line in fp: + all_in += 1 + fields = line.split('\t') + interaction = tuple(sorted([fields[1].strip(), fields[3].strip()])) + if interaction[0] == interaction[1]: + print("Skipping: {} (self)".format(interaction)) + + elif interaction not in interactions: + interactions.append(interaction) + uniques +=1 + else: + print("Skipping: {} (duplicate)".format(interaction)) + print("Parsed {} / {} total input interactions".format(uniques, all_in)) + return interactions + +def subset_all_scores(all_scores_fp, interactions_list, interactions_scores_fp): + counter = 0 + with open(all_scores_fp, 'r') as fin, open(interactions_scores_fp, 'w') as fout: + header = fin.readline() + fout.write(header) + for line in fin: + fields = line.split('\t') + interaction = tuple(sorted([fields[0].strip(), fields[1].strip()])) + if interaction in interactions_list: + fout.write(line) + counter += 1 + return counter + +def main(): + args = parse_args() + ints_list = parse_pvogs_interactions(args.input_ints) + total = subset_all_scores(args.scores_file, ints_list, args.outfile) + print("{} interaction scores written to file {}".format(total, str(args.outfile))) + +if __name__ == '__main__': + main() + diff --git a/workflow/scripts/summarize_intact.py b/workflow/scripts/summarize_intact.py new file mode 100755 index 0000000000000000000000000000000000000000..ff2406790d365a016dfe2c5feaedcaeaa84f30b8 --- /dev/null +++ b/workflow/scripts/summarize_intact.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python + +import argparse +from ete3 import NCBITaxa, TextFace, TreeStyle, faces +from collections import namedtuple +import pandas as pd +from Bio import SwissProt +from pathlib import Path + + +parser = argparse.ArgumentParser(description="A script that parses IntAct.txt (filtered for viruses) " + "to produce a visual summary of the taxa represented in it\n", + formatter_class=argparse.RawTextHelpFormatter) +optionalArgs = parser._action_groups.pop() + +requiredArgs = parser.add_argument_group("required arguments") +requiredArgs.add_argument('-o', '--output_dir', + dest='output_dir', + type=str, + required=True, + help="The output directory where all resulting files will be stored. " + "It is created if it's not there") +requiredArgs.add_argument('-i', '--intact_in', + dest='intact_in', + help="A txt file containing the filtered IntAct data for viruses", + required=True) +optionalArgs.add_argument("--tax-db", + dest="tax_db", + required=False, + help="Path to taxa.sqlite for ete3" + ) + +parser._action_groups.append(optionalArgs) + + +## Functions for parsing IntAct fields + +def parse_field(field_string): + """ + Args: + field_string (str) : A string of the form 'xref:value(desc)'. + Desc is optional + """ + field_data = {'xref': None} + + parts = field_string.split(":") + xref = parts[0] + field_data['xref'] = {xref: {'value': None, 'desc': None}} + try: + rest = parts[1] + if "(" in rest: + subfields = rest.split("(") + field_data['xref'][xref]['value'] = subfields[0] + field_data['xref'][xref]['desc'] = subfields[1][:-1] + else: + field_data['xref'][xref]['value'] = rest + field_data['xref'][xref]['desc'] = '-' + except: + print(field_data) + return field_data + + +def parse_column_string(astring): + """ + Args: + astring (str): A full string after splitting the PSI-MITAB on tabs + Returns: + column_data (dict): A dictionary of the form {'xref': + {xref: + {'value': value, + 'desc': desc} + } + } + """ + column_data = {} + if "|" in astring: + fields = astring.split("|") + else: + fields = [astring] # LAAAAZY + + for f in fields: + result = parse_field(f) + k = list(result['xref'].keys())[0] + if 'xref' not in column_data: + column_data['xref'] = {k: result['xref'][k]} + else: + column_data['xref'][k] = result['xref'][k] + return column_data + + +def get_single_value_from_xref_dic(xref_dic): + """ + If the given dictionary contains only one entry, get its value, + independent of key. + + Args: + xref_dic (dict): A nested dictionary of the form {'xref': + {xref: + {'value': value, + 'desc': desc} + } + } + Returns: + value (str): The value of 'value' the innermost dict + """ + if xref_dic: + if len(xref_dic['xref'].keys()) > 1: + return None + else: + k = next(iter(xref_dic['xref'])) + value = xref_dic['xref'][k]['value'] + else: + value = '-' + return value + + +def parse_intact_data(int_file, skip_header=True): + ProtInfo = namedtuple('ProtInfo', ['id', 'srcdb', 'taxid']) + interactions_data = {} + + with open(int_file, 'r') as f: + if skip_header: + next(f) + for line in f: + # Make some parsable fields + columns = [column.strip() for column in line.split('\t')] + + # Get the proteinA info + protA_data = parse_column_string(columns[0]) + + protA_id = get_single_value_from_xref_dic(protA_data) + sourceA_db = list(protA_data['xref'].keys())[0] + taxA_data = parse_column_string(columns[9]) + taxidA = int(taxA_data['xref']['taxid']['value']) + protAraw = ProtInfo(id=protA_id, srcdb=sourceA_db, taxid=taxidA) + + # Get the proteinB info + protB_data = parse_column_string(columns[1]) + protB_id = get_single_value_from_xref_dic(protB_data) + sourceB_db = list(protB_data['xref'].keys())[0] + taxB_data = parse_column_string(columns[10]) + taxidB = int(taxB_data['xref']['taxid']['value']) + protBraw = ProtInfo(id=protB_id, srcdb=sourceB_db, taxid=taxidB) + + + # Sorting on interactors to make unique keys for the same interaction + # otherwise (A,B) != (B,A) + interaction = tuple(sorted((protA_id, protB_id,))) + + if interaction[0] == protAraw.id: + protAinfo, protBinfo = protAraw, protBraw + else: + protAinfo, protBinfo = protBraw, protAraw + + if protAinfo.taxid == protBinfo.taxid: + same_taxid = 1 + else: + same_taxid = 0 + + if protAinfo.id == protBinfo.id: + same_protein = 1 + else: + same_protein = 0 + + if str(interaction) not in interactions_data: + interactions_data[str(interaction)] = {'no_of_evidence': 1, + 'prot_A': protAinfo.id, + 'prot_B': protBinfo.id, + 'source_A': protAinfo.srcdb, + 'source_B': protBinfo.srcdb, + 'taxid_A': protAinfo.taxid, + 'taxid_B': protBinfo.taxid, + 'same_taxid': same_taxid, + 'same_protein': same_protein} + + elif str(interaction) in interactions_data: + interactions_data[str(interaction)]['no_of_evidence'] += 1 + + return interactions_data + + + + + +def write_response_text_to_file(response_fn, response_text): + with open(response_fn, 'w') as fh: + fh.write(response_text) + + +def sync_query_list_with_response(response_fn, query_list): + db_data = {} + with open(response_fn, 'r') as fh: + for record in SwissProt.parse(fh): + acc = record.accessions[0] + # Select only EMBL and RefSeq crossrefs + refseq_refs, embl_refs = [], [] + for db_ref in record.cross_references: + if db_ref[0] == 'RefSeq': + refseq_refs.append(db_ref[1:]) + elif db_ref[0] == 'EMBL': + embl_refs.append(db_ref[1:]) + db_data[acc] = {'RefSeq': refseq_refs, + 'EMBL': embl_refs} + + # This is to handle isoforms + # E.g. P03692 and P03692-1 can both be included in the query list + # P03705-2 can be in the query list but not P03705 + for prot in query_list: # For each of the original queries + if (prot not in db_data) and ('-' in prot): # If the query is not returned + base_name = prot.split('-')[0] # Search for the fist part + if base_name in db_data: # If it is present in the db_data + if base_name not in query_list: # If it's not in the original query list + db_data[prot] = db_data[base_name] # Fill in the information for the corresponding protein + db_data.pop(base_name) # AND remove the original part + elif base_name in query_list: # If the first part is in the original query + db_data[prot] = db_data[base_name] # Fill in the information WITHOUT removing the original part + elif prot not in db_data: + print("I don't know what to do with this id: {}".format(prot)) + pass + + return db_data + + +def uniprots_list_to_query(uniprots_list): + query_data = {} + for i in uniprots_list: + primary_acc = i.split('-')[0] + query_data[i] = primary_acc + return query_data + + +def iso_is_present(k, values): + ind = None + for i, v in enumerate(values): + for m in v: + if '[{}]'.format(k) in m: + ind = int(i) + return ind + + +def select_embl_info(k, values): + ind = None + for i, v in enumerate(values): + if 'ALT' in v[2]: + pass + else: + ind = int(i) + return ind + + +def get_refseq_info(prot_id, refs): + if len(refs) == 1: + refseq_genome = refs[0][1].split()[0].strip('.') + refseq_protein = refs[0][0] + elif ('-' in k) and (len(refs) > 1): + isoform_index = iso_is_present(prot_id, refs) + if isoform_index is not None: + refseq_genome = refs[isoform_index][1].split()[0].strip('.') + refseq_protein = refs[isoform_index][0] + else: + refseq_genome = refs[0][1].split()[0].strip('.') + refseq_protein = refs[0][0] + else: + refseq_genome = refs[0][1].split()[0].strip('.') + refseq_protein = refs[0][0] + return refseq_genome, refseq_protein + + +def get_embl_info(prot_id, refs): + if len(refs) == 1: + embl_genome = refs[0][0] + embl_protein = refs[0][1] + elif len(refs) > 1: + info_index = select_embl_info(prot_id, refs) + if info_index is not None: + embl_genome = refs[info_index][0] + embl_protein = refs[info_index][1] + else: + embl_genome = refs[0][0] + embl_protein = refs[0][1] + else: + embl_genome = refs[0][0] + embl_protein = refs[0][1] + return embl_genome, embl_protein + + +def map_proteins_to_genomic_accessions(db_data): + accessions_map = {} + for k in db_data: + if db_data[k]['RefSeq']: + refs = db_data[k]['RefSeq'] + (genome_id, protein_id) = get_refseq_info(k, refs) + accessions_map[k] = {'genome_id': genome_id, + 'protein_id': protein_id} + elif not db_data[k]['RefSeq'] and db_data[k]['EMBL']: + refs = db_data[k]['EMBL'] + (genome_id, protein_id) = get_embl_info(k, refs) + accessions_map[k] = {'genome_id': genome_id, + 'protein_id': protein_id} + return accessions_map + +# ETE plotting options and functions +def layout(node): + faces.add_face_to_node(TextFace(node.sci_name), node, 0) + + +def plot_taxids(taxids_list, tree_png, tree_nw, tax_db=None): + if tax_db is not None: + ncbi = NCBITaxa(dbfile=tax_db) + else: + ncbi=NCBITaxa() + + tree = ncbi.get_topology(taxids_list) + ts = TreeStyle() + ncbi.annotate_tree(tree, taxid_attr="sci_name") + ts.show_leaf_name = False + ts.mode = "c" + ts.layout_fn = layout + tree.render(tree_png, tree_style=ts) + tree.write(format=1, outfile=tree_nw) + + +def write_sequence_dic_to_file(accession_map, fileout): + header_string = '\t'.join(['uniprot_id', 'genome_id', 'protein_id']) + with open(fileout, 'w') as f: + f.write(header_string+'\n') + for acc in accession_map: + info_string = '\t'.join([acc, accession_map[acc]['genome_id'], accession_map[acc]['protein_id']]) + f.write(info_string + '\n') + + +if __name__ == '__main__': + args = parser.parse_args() + + # Create the output directory if it doesn't exist + output_dir = Path(args.output_dir) + if not output_dir.exists(): + output_dir.mkdir() + else: + print("Output directory {} already exists! Contents will be overwritten". + format(args.output_dir)) + + plot_dir = output_dir.joinpath('plots') + plot_dir.mkdir(exist_ok=True) + tree_png = plot_dir.joinpath('taxa_tree.png') + tree_nw = plot_dir.joinpath('taxa_tree.nw') + + # Get the intact file + intact_raw = args.intact_in + intact_data = parse_intact_data(intact_raw) + + # Calculate the number of initial entries + no_of_entries = 0 + for k in intact_data.keys(): + no_of_entries += intact_data[k]['no_of_evidence'] + + print("No. of entries in input: {}".format(no_of_entries)) + print("No. of interactions: {}".format(len(intact_data.keys()))) + + # Put the data in a data frame + df = pd.DataFrame.from_dict(intact_data, orient='index') + df.index.name = 'interaction' + metadata_tsv = output_dir.joinpath('metadata.tsv') + df.to_csv(metadata_tsv, sep='\t') + + # Interactors information + unique_protAs = set(df['prot_A']) + unique_protBs = set(df['prot_B']) + unique_prots = len(set.union(unique_protAs, unique_protBs)) + print("No. of unique proteins: {}".format(unique_prots)) + + # Taxids information + unique_taxidsA = set(df['taxid_A']) + unique_taxidsB = set(df['taxid_B']) + unique_taxids = set.union(unique_taxidsA, unique_taxidsB) + print("No. of unique taxids: {}".format(len(unique_taxids))) + + print("Plotting NCBI taxonomy information. Results are stored in: {}".format(plot_dir)) + plot_taxids(unique_taxids, tree_png.as_posix(), tree_nw.as_posix(), args.tax_db) + + ###################################### + ### UP TO HERE EVERYTHING IS RAW + ### NOW WE START SUBSETTING + ####################################### + # Getting the uniprot related entries to dispatch a query + # protAs = df.loc[(df['source_A'] == 'uniprotkb')]['prot_A'].unique().tolist() + # protBs = df.loc[(df['source_B'] == 'uniprotkb')]['prot_B'].unique().tolist() + # uniprots = set.union(set(protAs), set(protBs)) + # uniprots_list = list(uniprots) + # + # print("{} identifiers will be searched against uniprot for genome and protein sequence retrieval". + # format(len(uniprots_list))) + # + # uniprot_response = output_dir.joinpath('uniprot_info.txt') + # print("Dispatching queries...") + # final_result = [] + # # Split the query into batches of 100 + # for i in range(0, len(uniprots_list), 100): + # slise = uniprots_list[i:i + 100] + # # Run the query for each batch + # result = map_retrieve(slise) + # # And store the result in a list + # final_result.append(result) + # + # # Store the result in a file + # response_text = ''.join(final_result) + # write_response_text_to_file(uniprot_response, response_text) + # + # db_data = sync_query_list_with_response(uniprot_response, uniprots_list) + # # For each original entry in the protein list, + # # store its primary accession + # # TO DO + # # the sync_query_list_with_response function needs revisiting + # query_data = uniprots_list_to_query(uniprots_list) + # + # accession_map = map_proteins_to_genomic_accessions(db_data) + # for protein in query_data: # {alt_acc : primary_acc} + # if protein not in accession_map: # alt_acc not in results + # try: + # primary_acc = query_data[protein] # get the primary accession + # accession_map[protein] = accession_map[primary_acc] + # except KeyError: + # accession_map[protein] = {'genome_id': '-', 'protein_id': '-'} + # + # sequence_info = output_dir.joinpath('sequence_info.txt') + # write_sequence_dic_to_file(accession_map, sequence_info)