Source code for magine.plotting.heatmaps

import warnings
from itertools import chain

import matplotlib.pyplot as plt
import numpy as np
import scipy.cluster.hierarchy as sch
import seaborn as sns


[docs]def heatmap_from_array(data, convert_to_log=False, y_tick_labels='auto', cluster_row=False, cluster_col=False, columns='sample_id', index='term_name', values='combined_score', div_colors=False, num_colors=7, figsize=(6, 4), sort_row=None, annotate_sig=False, rank_index=None, linewidths=0.0, cluster_by_set=False, min_sig=0): """ Parameters ---------- data : magine.data.base.BaseData convert_to_log : bool Convert fold_change column to log2 scale y_tick_labels : list_like columns : str Name of columns of df for pivot index : str Name of index of df for pivot values : str Name of values of df for pivot cluster_col : bool Cluster the data using searborn.clustermap cluster_row : bool Cluster the data using searborn.clustermap div_colors : bool Use divergent colors for plotting figsize : tuple Size of figure, passed to matplotlib/seaborn sort_row : str Sort rows by ('index', 'mean', max') num_colors : int Number of colors for color bar annotate_sig : bool Add '*' annotation to plot for significant changed terms linewidths : float or None Add white line between plots cluster_by_set: bool Cluster by gene set column. Only works for enrichment_array min_sig : int Minimum number of significant 'index' across samples. Can be used to remove rows that are not significant across any sample. rank_index: bool Rank rows by index. Deprecated , plus use sort_row arg instead. Returns ------- plt.Figure """ if min_sig: d_copy = data.require_n_sig(columns=columns, index=index, n_sig=min_sig) else: d_copy = data.copy() array = d_copy.pivoter(convert_to_log, columns=columns, index=index, fill_value=0.0, values=values, min_sig=min_sig) if not len(array): warnings.warn("Empty array after filtering.") return # default values to be overwritten below col_colors = None annotations = None col_color_map = None col_labels = None fmt = None linkage = None add_col_group = False if rank_index is not None: warnings.warn("rank_index is deprecated; use sort_row='index'", DeprecationWarning) if rank_index: # assuming provided true, will sort by rank. array.sort_index(ascending=True, inplace=True) if sort_row is not None: if isinstance(sort_row, (list, np.ndarray)): array = array.reindex(sort_row) elif sort_row not in ('index', 'max', 'mean', 'min', 'sum'): raise ValueError("Can sort rows by 'index' name or 'max', 'min'," "'mean' of values") # rank by index or cluster by term column if sort_row == 'index': array.sort_index(ascending=True, inplace=True) elif sort_row == 'mean': new_index = array.mean(axis=1).sort_values(ascending=False).index array = array.reindex(new_index) elif sort_row == 'max': new_index = array.max(axis=1).sort_values(ascending=False).index array = array.reindex(new_index) elif sort_row == 'min': new_index = array.max(axis=1).sort_values(ascending=False).index array = array.reindex(new_index) elif sort_row == 'sum': new_index = array.sum(axis=1).sort_values(ascending=True).index array = array.reindex(new_index) elif isinstance(sort_row, list): array = array.reindex(sort_row) if cluster_by_set and "genes" in d_copy.columns: # clustering will be based on jaccard index of terms dist_mat, names = d_copy.calc_dist(level='sample') linkage = sch.linkage(dist_mat, method='average') # Add row cluster flag in case user didn't set cluster_row = True array = array.reindex(names) # set coloring scheme for heatmap if div_colors: pal = sns.color_palette("coolwarm", num_colors) center = 0 else: pal = sns.light_palette("purple", as_cmap=True) center = None # Group together by columns if provided if isinstance(columns, (list, tuple)) and len(columns) == 2: add_col_group = True col_labels, col_colors, col_color_map = _set_col_colors(array) # check annotations exist if annotate_sig: annotate_sig, annotations, fmt = _get_sig_annotations(array, d_copy, columns, index, min_sig) cluster_args = dict(method='complete', metric='correlation') if cluster_row or cluster_col or add_col_group: fig = sns.clustermap(array, cmap=pal, center=center, yticklabels=y_tick_labels, col_colors=col_colors, col_cluster=cluster_col, row_cluster=cluster_row, row_linkage=linkage, figsize=figsize, linewidths=linewidths, annot=annotations, fmt=fmt, **cluster_args ) # We need to reorder the annotations if we cluster if annotate_sig: if cluster_row: annotations = annotations[fig.dendrogram_row.reordered_ind] if cluster_col: annotations = annotations[:, fig.dendrogram_col.reordered_ind] # Only need figure for dendrogram ordering, not actual plot. # Can probably do this without the plotting interface, but this # seems to do the job for now. plt.close() # make final figure fig = sns.clustermap(array, cmap=pal, center=center, yticklabels=y_tick_labels, col_colors=col_colors, col_cluster=cluster_col, row_cluster=cluster_row, row_linkage=linkage, figsize=figsize, linewidths=linewidths, annot=annotations, fmt=fmt, **cluster_args ) # add labels to column colors if add_col_group: fig = _add_column_color_groups(d_copy, fig, col_color_map, col_labels, columns) # add clustered columns and rows, basically allows us to extract out # the clusters from the figure, if we wanted to do something with them # ie run enrichment analysis. if cluster_col: col_cltrs = sch.fcluster(fig.dendrogram_col.linkage, t=2, criterion='maxclust') col_cltrs = col_cltrs[fig.dendrogram_col.reordered_ind] col_clusters = dict() for i in sorted(set(col_cltrs)): cols = fig.data2d.columns.values[col_cltrs == i] col_clusters[i] = fig.data2d[cols] fig.col_clusters = col_clusters if cluster_row: row_cltrs = sch.fcluster(fig.dendrogram_row.linkage, t=2, criterion='maxclust') row_cltrs = row_cltrs[fig.dendrogram_row.reordered_ind] row_clusters = dict() for i in sorted(set(row_cltrs)): row_clusters[i] = fig.data2d.loc[row_cltrs == i].index.values fig.row_clusters = row_clusters fig.ax_heatmap.set_ylabel('') fig.ax_heatmap.set_xlabel('') # plt.subplots_adjust(right=0.7, top=1.5) else: fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111) sns.heatmap(array, ax=ax, yticklabels=y_tick_labels, cmap=pal, center=center, annot=annotations, fmt=fmt, linewidths=linewidths) ax.set_ylabel('') ax.set_xlabel('') return fig
[docs]def heatmap_by_terms(data, term_labels, term_sets, colors=None, min_sig=None, convert_to_log=False, y_tick_labels='auto', columns='sample_id', index='identifier', values='fold_change', linewidths=0, cluster_row=False, cluster_col=False, div_colors=False, num_colors=21, figsize=None, annotate_sig=False, **kwargs): """ Parameters ---------- data : pd.DataFrame term_labels : list_like List of labels for grouping term_sets : list_like List of list like that create the terms colors : list_like Colors for plotting, if not provided it will be created min_sig : int Number of sign convert_to_log : bool y_tick_labels : list_like columns : str Name of columns of df for pivotn index : str Name of index of df for pivot values : str Name of values of df for pivot cluster_col : bool Cluster the data using searborn.clustermap cluster_row : bool Cluster rows div_colors : bool Use divergent colors for plotting figsize : tuple Size of figure, passed to matplotlib/seaborn num_colors : int Number of colors for color bar annotate_sig : bool Add '*' annotation to plot for significant changed terms linewidths : float or None Add white line between plots min_sig : int Minimum number of significant 'index' across samples. Can be used to remove rows that are not significant across any sample. Returns ------- plt.Figure """ if len(term_labels) != len(term_sets): raise AssertionError("Number of term_labels must " "equal number of term_sets") # default values to be overwritten below annotations = None fmt = None add_col_group = False tmp_d = data.copy() if index == 'label': id_to_label = dict() for i, j in tmp_d[['identifier', 'label']].values: if i not in id_to_label: id_to_label[i] = set() id_to_label[i].add(j) new_term_sets = [ set(chain.from_iterable([id_to_label[j] for j in i if j in id_to_label])) for i in term_sets ] else: new_term_sets = term_sets all_items = set(chain.from_iterable(new_term_sets)) tmp_d = tmp_d.loc[tmp_d[index].isin(all_items)] # pivot datatable array = tmp_d.pivoter(convert_to_log, columns=columns, index=index, fill_value=0.0, values=values, min_sig=min_sig) if not len(array): warnings.warn("Empty array after filtering.") return if colors is None: colors = sns.color_palette("Paired", n_colors=len(term_labels)) else: if len(colors) != len(new_term_sets): raise AssertionError("Number of colors must " "equal number of term_labels") vals = set(array.index.values) final_sorted = sorted(new_term_sets[0].intersection(vals)) added = set(final_sorted) # create colors for each row_colors = [colors[0] for _ in added] to_remove = set() for term, color, cname in zip(new_term_sets[1:], colors[1:], term_labels[1:]): added_any = False for i in sorted(term.intersection(vals)): if i not in added: added_any = True row_colors.append(color) final_sorted.append(i) added.add(i) if not added_any: to_remove.add(cname) # only keep indexes that are in the provided sets array = array[array.index.isin(final_sorted)] # resort according to color array = array.reindex(final_sorted) if isinstance(columns, list) and len(columns) == 2: add_col_group = True col_labels, col_colors, col_color_map = _set_col_colors(array) else: col_labels, col_colors, col_color_map = None, None, None # set colors map for heatmap if div_colors: pal = sns.color_palette("coolwarm", num_colors) center = 0 else: pal = sns.light_palette("red", n_colors=len(new_term_sets), as_cmap=True) center = None if annotate_sig: annotate_sig, annotations, fmt = _get_sig_annotations(array, data, columns, index, min_sig) cluster_args = dict(method='single', metric='correlation') fig = sns.clustermap(array, yticklabels=y_tick_labels, figsize=figsize, linewidths=linewidths, row_colors=row_colors, col_colors=col_colors, col_cluster=cluster_col, row_cluster=cluster_row, cmap=pal, center=center, annot=annotations, fmt=fmt, **cluster_args ) if annotate_sig: if cluster_col: annotations = annotations[:, fig.dendrogram_col.reordered_ind] if cluster_row: annotations = annotations[fig.dendrogram_row.reordered_ind] plt.close() fig = sns.clustermap(array, yticklabels=y_tick_labels, figsize=figsize, linewidths=linewidths, row_colors=row_colors, col_colors=col_colors, col_cluster=cluster_col, row_cluster=cluster_row, cmap=pal, center=center, annot=annotations, fmt=fmt, **cluster_args ) for color, label in zip(colors, term_labels): if label in to_remove: continue fig.ax_row_dendrogram.bar(0, 0, color=color, label=label, linewidth=0) fig.ax_row_dendrogram.legend(loc=0, ncol=1) if add_col_group: fig = _add_column_color_groups(tmp_d, fig, col_color_map, col_labels, columns) fig.ax_heatmap.set_ylabel('') fig.ax_heatmap.set_xlabel('') plt.subplots_adjust(right=0.7) return fig
[docs]def cluster_distance_mat(dist_mat, names, figsize=(8, 8)): """ Creates heatmap from distance matrix. Parameters ---------- dist_mat : np.array Distance matrix array. names : list_like Names of ticks for distance matrix figsize : tuple Size of figure, passed to matplotlib Returns ------- """ # Compute and plot first dendrogram. fig = plt.figure(figsize=figsize) # Compute and plot second dendrogram. ax2 = fig.add_axes([0.3, 0.71, 0.6, 0.2]) Y = sch.linkage(dist_mat, method='average') Z2 = sch.dendrogram(Y) ax2.set_xticks([]) ax2.set_yticks([]) # Plot distance matrix. axmatrix = fig.add_axes([0.3, 0.1, 0.6, 0.6]) # reorder matrix idx1 = Z2['leaves'] dist_mat = dist_mat[idx1, :] dist_mat = dist_mat[:, idx1] names = names[idx1] # create figure im = axmatrix.matshow(dist_mat, aspect='auto', origin='lower', cmap=plt.cm.Reds, vmin=0, vmax=1) # add xtick labels axmatrix.set_xticks(range(len(names))) axmatrix.set_xticklabels(names, minor=False) axmatrix.xaxis.set_label_position('bottom') axmatrix.xaxis.tick_bottom() plt.xticks(rotation=90, fontsize=8) # add ytick labels axmatrix.set_yticks(range(len(names))) axmatrix.set_yticklabels(names, minor=False) axmatrix.yaxis.set_label_position('left') axmatrix.yaxis.tick_left() plt.yticks(rotation=0, fontsize=8) # add colorbar axcolor = fig.add_axes([0.94, 0.1, 0.02, 0.6]) plt.colorbar(im, cax=axcolor) return fig
def _set_col_colors(array): col_labels = array.columns.levels[0] labels = list(array.columns.levels[1]) col_color_map = sns.color_palette("Dark2", len(col_labels)) col_colors = [col_color_map[i] for i in array.columns.codes[0]] array.columns = [labels[i] for i in array.columns.codes[1]] return col_labels, col_colors, col_color_map def _add_column_color_groups(data, fig, colors, color_labels, columns): for color, label in zip(colors, color_labels): fig.ax_col_dendrogram.bar(0, 0, color=color, label=label, linewidth=0) plt.setp(fig.ax_col_dendrogram.yaxis.get_majorticklabels(), rotation=0, fontsize=16) fig.ax_col_dendrogram.legend(loc="center", fontsize=12, ncol=2, bbox_to_anchor=(0.5, 1., 0.5, 0.5)) v_line_list = [] prev = 0 for i in color_labels: n_samples = len(data[data[columns[0]] == i][columns[1]].unique()) prev += n_samples v_line_list.append(prev) fig.fig.axes[2].vlines(v_line_list, *fig.fig.axes[2].get_ylim()) fig.fig.axes[3].vlines(v_line_list, *fig.fig.axes[3].get_ylim()) return fig def _get_sig_annotations(arr, dat, columns, index, min_sig): # Have to rank by column for this to work if 'significant' in dat.columns: tmp2 = dat.pivoter(False, columns=columns, index=index, values='significant', fill_value=0, min_sig=min_sig) tmp2 = tmp2.reindex(arr.index) tmp2[tmp2 > 0] = True tmp2 = tmp2.replace(0, '') tmp2 = tmp2.replace(False, '') tmp2 = tmp2.replace(True, '+') return True, tmp2.values, '' else: print("To annotate please add a significant column to data") return False, None, None