import os
import re
import time
from textwrap import wrap

import matplotlib.pyplot as plt
import numpy as np
import pathos.multiprocessing as mp
import plotly.graph_objs as plotly_graph
import seaborn as sns
from plotly.offline import plot, iplot, init_notebook_mode

import magine.html_templates.html_tools as ht
from import log2_normalize_df

fold_change = 'fold_change'
flag = 'significant'
exp_method = 'source'
p_val = 'p_value'
rna = 'rna_seq'
gene = 'gene'
protein = 'protein'
metabolites = 'metabolites'
species_type = 'species_type'
sample_id = 'sample_id'
identifier = 'identifier'
label_col = 'label'

cm = plt.get_cmap('jet')

[docs]def write_table_to_html(data, save_name='index', out_dir=None, run_parallel=False, exp_data=None, plot_type='matplotlib'): """ Creates a html table of plots of genes for each ontology term. Parameters ---------- data : magine.enrichment.enrichment_result.EnrichmentResult save_name : str name of html output file out_dir : str, optional output path for all plots run_parallel : bool Create plots in parallel exp_data : plot_type : str {'matplotlib', 'plotly'} """ list_of_terms = list(data['term_name'].unique()) fig_dict, to_remove = plot_genes_by_ont(data=data, list_of_terms=list_of_terms, save_name=save_name, out_dir=out_dir, exp_data=exp_data, run_parallel=run_parallel, plot_type=plot_type ) for i in fig_dict: data.loc[data['term_name'] == i, 'term_name'] = fig_dict[i] data = data[~data['term_name'].isin(to_remove)] html_out = save_name + '_filter' ht.write_filter_table(data, html_out)
[docs]def plot_genes_by_ont(data, list_of_terms, save_name, out_dir=None, exp_data=None, run_parallel=False, plot_type='plotly'): """ Creates a figure for each GO term in data BaseData should be a result of running calculate_enrichment. This function creates a plot of all proteins per term if a term is significant and the number of the reference set is larger than 5 and the total number of species measured is less than 100. Parameters ---------- data : pandas.DataFrame previously ran enrichment analysis list_of_terms : list_list save_name : str name to save file out_dir : str output path for file exp_data : magine.ExperimentalData data to plot run_parallel : bool To run in parallel using pathos.multiprocessing plot_type : str plotly or matplotlib Returns ------- out_array : dict dict where keys are pointers to figure locations """ if out_dir is not None: if not os.path.exists(out_dir): os.mkdir(out_dir) if not os.path.exists(os.path.join(out_dir, 'Figures')): os.mkdir(os.path.join(out_dir, 'Figures')) data = data.copy() figure_locations = {} plots_to_create = [] to_remove = set() if plot_type not in {'plotly', 'matplotlib'}: raise AssertionError("Please pass plotly or matplotlibn as plot_type") # filter data by significance and number of references if len(list_of_terms) == 0: print("No significant ontology terms!!!") return figure_locations, to_remove # here we are going to iterate through all sig GO terms and create # a list of plots to create. For the HTML side, we need to point to # a location _data = # create plot of genes over time for n, i in enumerate(list_of_terms): # want to plot all species over time index = data['term_name'] == i name = data[index]['term_name'].unique() if len(name) > 0: name = name[0] gene_set = set() genes = data[index]['genes'] for g in genes: if isinstance(g, list): each = g else: each = g.split(',') gene_set.update(set(each)) if plot_type == 'matplotlib': # too many genes isn't helpful on plots, so skip them if len(gene_set) > 100: figure_locations[i] = '<a>{0}</a>'.format(name) continue local_save_name = os.path.join('Figures', '{0}_{1}'.format(n, save_name)) local_save_name = local_save_name.replace(':', '') out_point = '<a href="{0}.html">{1}</a>'.format(local_save_name, name) figure_locations[i] = out_point title = "{0} : {1}".format(str(i), name) local_df = _data.loc[_data[identifier].isin(list(gene_set))].copy() p_input = [local_df, list(gene_set), local_save_name, out_dir, title, plot_type] plots_to_create.append(p_input) print("Starting to create plots for each term") _make_plots(plots_to_create, plot_species, run_parallel) return figure_locations, to_remove
[docs]def plot_dataframe(exp_data, html_filename, out_dir='proteins', plot_type='plotly', run_parallel=False): """ Creates Parameters ---------- exp_data : magine.BaseData. html_filename : str out_dir: str, path Directory that will contain all proteins plot_type : str plotly or matplotlib output run_parallel : bool create plots in parallel Returns ------- """ if not os.path.exists(out_dir): os.mkdir(out_dir) local_data = exp_data.copy() species_to_plot = local_data[identifier].unique() fig_loc = {} plots = [] suffix = 'html' if plot_type == 'plotly' else 'pdf' for i in species_to_plot: save_name = re.sub('[/_.]', '', i) plots.append([local_data, [i], save_name, out_dir, i, plot_type]) n = '<a href="{0}/{1}.{2}">{1}</a>'.format(out_dir, save_name, suffix) fig_loc[i] = n _make_plots(plots, plot_species, run_parallel) # Place a link to the species for each key for key, value in fig_loc.items(): local_data.loc[exp_data[identifier] == key, identifier] = value cols = [identifier, label_col, fold_change, p_val, sample_id, exp_method, flag] local_data = local_data[cols] ht.write_filter_table(local_data, html_filename)
def _make_plots(plots_to_make, plot_func, parallel=False): for i, _ in enumerate(plots_to_make): plots_to_make[i].append('pdf') plots_to_make[i].append(True) if parallel: st2 = time.time() pool = mp.Pool() # lambda a: function(a[0], **a[1]), arguments pool.map_async(lambda a: plot_func(*a), plots_to_make) pool.close() pool.join() end2 = time.time() print("parallel time = {}".format(end2 - st2)) print("Done creating plots for each GO term") else: st1 = time.time() list(map(lambda a: plot_func(*a), plots_to_make)) end1 = time.time() print("sequential time = {}".format(end1 - st1)) plt.close('all')
[docs]def plot_species(df, species_list=None, save_name='test', out_dir=None, title=None, plot_type='plotly', image_format='pdf', close_plots=False): """ Parameters ---------- df: pandas.DataFrame magine formatted dataframe species_list: list List of genes to be plotter save_name: str Filename to be saved as out_dir: str Path for output to be saved title: str Title of plot, useful when list of genes corresponds to a GO term plot_type : str Use plotly to generate html output or matplotlib to generate pdf image_format : str pdf or png, only used if plot_type="matplotlib" close_plots : bool Close plot after making, use when creating lots of plots in parallel. Returns ------- """ ldf = df.copy() if out_dir is not None: if not os.path.exists(out_dir): os.mkdir(out_dir) # gather x axis points x_points = sorted(ldf[sample_id].unique()) if len(x_points) == 0: return if isinstance(x_points[0], np.float): x_point_dict = {i: x_points[n] for n, i in enumerate(x_points)} else: x_point_dict = {i: n for n, i in enumerate(x_points)} if species_list is not None: ldf = ldf.loc[ldf[identifier].isin(species_list)].copy() ldf = log2_normalize_df(ldf, column=fold_change) n_plots = len(ldf[identifier].unique()) num_colors = len(ldf[label_col].unique()) color_list = sns.color_palette("tab20", num_colors) if plot_type == 'matplotlib': fig = plt.figure() ax = fig.add_subplot(111) ax.set_prop_cycle(plt.cycler('color', color_list)) colors = enumerate(color_list) plotly = [] names_list = [] total_counter = 0 for name, j in ldf.groupby(identifier): index_counter = 0 for n, m in j.groupby(label_col): x = np.array(m[sample_id]) if len(x) < 1: continue y = np.array(m['fold_change']) sig_flag = np.array(m[flag]) index = np.argsort(x) x = x[index] y = y[index] s_flag = sig_flag[index] # x values with scaled values (only changes things if non-float # values are used for sample_id x_index = np.array([x_point_dict[ind] for ind in x]) index_counter += 1 total_counter += 1 # create matplotlib plot if plot_type == 'matplotlib': label = "\n".join(wrap(n, 40)) p = ax.plot(x_index, y, '.-', label=label) if len(s_flag) != 0: color = p[0].get_color() ax.plot(x_index[s_flag], y[s_flag], '^', color=color) # create plotly plot elif plot_type == 'plotly': c = next(colors)[1] plotly.append(_ploty_graph(x_index, y, n, n, c)) if len(s_flag) != 0: index_counter += 1 total_counter += 1 plotly.append(_ploty_graph(x_index[s_flag], y[s_flag], n, n, c, marker='x-open-dot')) names_list.append([name, index_counter]) if plot_type == 'matplotlib': lgd = _format_mpl(ax, x_point_dict, x_points) if save_name is not None: tmp_savename = "{}.{}".format(save_name, image_format) if out_dir is not None: tmp_savename = os.path.join(out_dir, tmp_savename) plt.savefig(tmp_savename, bbox_extra_artists=(lgd,), bbox_inches='tight') if close_plots: plt.close(fig) else: return fig elif plot_type == 'plotly': fig = _create_plotly(total_counter, n_plots, names_list, x_point_dict, title, x_points, plotly) if save_name: _save_ploty_output(fig, out_dir, save_name) else: init_notebook_mode(connected=True) iplot(fig)
def _format_mpl(ax, x_point_dict, x_points): ax.set_xlim(min(x_point_dict.values()) - 2, max(x_point_dict.values()) + 2) ax.set_xticks(sorted(x_point_dict.values())) ax.set_xticklabels(x_points, rotation=90) plt.ylabel('log$_2$ Fold Change') plt.axhline(y=np.log2(1.5), linestyle='--') plt.axhline(y=-np.log2(1.5), linestyle='--') handles, labels = ax.get_legend_handles_labels() lgd = ax.legend(handles, labels, loc='best', ncol=3, bbox_to_anchor=(1.01, 1.0)) return lgd def _create_plotly(total_counter, n_plots, names_list, x_point_dict, title, x_points, plotly_list): true_list = [True] * total_counter scroll_list = [dict(args=['visible', true_list], label='All', method='restyle')] prev = 0 # making all false except group defined by protein name for i in range(n_plots): t_row = [False] * total_counter for j in range(prev, prev + names_list[i][1]): t_row[j] = True prev += names_list[i][1] scroll = dict(args=['visible', t_row], label=names_list[i][0], method='restyle') scroll_list.append(scroll) update_menu = list([dict(x=-0.05, y=1, yanchor='top', buttons=scroll_list, )]) ticks = np.sort(list(x_point_dict.values())) min_tick = np.min(ticks) max_tick = np.max(ticks) layout = plotly_graph.Layout( title=title, showlegend=True, xaxis=dict(title='Sample index', range=[min_tick, max_tick], showticklabels=True, ticktext=x_points, tickmode='array', tickvals=ticks, ), yaxis=dict(title='log2fc'), hovermode="closest", updatemenus=update_menu ) return plotly_graph.Figure(data=plotly_list, layout=layout) def _save_ploty_output(fig, out_dir, save_name): tmp_savename = "{}.html".format(save_name) if out_dir is not None: tmp_savename = os.path.join(out_dir, tmp_savename) x = plot(fig, filename=tmp_savename, auto_open=False, include_plotlyjs=False, output_type='div') ht.format_ploty(x, tmp_savename) def _ploty_graph(x, y, label, enum, color, marker='circle'): """ Creates a single scatter plot Parameters ---------- x : list_like y : list_like label : str enum : int color : str marker : str Returns ------- """ l_color = 'rgba({},{},{},1.)'.format(color[0], color[1], color[2]) if marker != 'circle': mode = 'markers' show = False size = 12 else: mode = 'lines+markers' show = True size = 8 legend = 'group_{}'.format(enum) g = plotly_graph.Scatter( x=x, y=y, hoveron='points', name=label, visible=True, mode=mode, legendgroup=legend, showlegend=show, line=dict(color=l_color), marker=dict(symbol=marker, size=size, color=l_color), ) return g