import numpy as np
import pandas as pd
from magine.plotting.heatmaps import heatmap_from_array
flag = 'significant'
[docs]class BaseData(pd.DataFrame):
"""
This class derived from pd.DataFrame
"""
_index = None
def __init__(self, *args, **kwargs):
super(BaseData, self).__init__(*args, **kwargs)
@property
def _constructor(self):
return BaseData
@property
def sig(self):
""" terms with significant flag """
return self.loc[self[flag]].copy()
[docs] def pivoter(self, convert_to_log=False, columns='sample_id',
values='fold_change', index=None, fill_value=None, min_sig=0):
""" Pivot data on provided axis.
Parameters
----------
convert_to_log : bool
Convert values column to log2
index : str
Index for pivot table
columns : str
Columns to pivot
values : str
Values of pivot table
fill_value : float, optional
Fill pivot table nans with
min_sig : int
Required number of significant terms to keep in a row, default 0
Returns
-------
"""
d_copy = self.copy()
if index is None:
index = self._index
if convert_to_log:
d_copy.log2_normalize_df(values, inplace=True)
if min_sig:
if not isinstance(min_sig, int):
raise AssertionError()
if 'significant' not in d_copy.columns:
print('In order to filter based on minimum sig figs, '
'please add a "significant" column')
d_copy.require_n_sig(index=index, columns=columns,
n_sig=min_sig,
inplace=True)
if not d_copy.shape[0]:
return pd.DataFrame()
array = pd.pivot_table(d_copy, index=index, fill_value=fill_value,
columns=columns, values=values)
if isinstance(values, list):
return array
if isinstance(columns, list):
array.sort_values(
by=sorted(tuple(map(tuple, d_copy[columns].values))),
ascending=False, inplace=True
)
elif isinstance(columns, str):
array.sort_values(by=sorted(d_copy[columns].unique()),
ascending=False, inplace=True)
return array
[docs] def require_n_sig(self, columns='sample_id', index=None,
n_sig=3, inplace=False,
verbose=False):
""" Filter index to have at least "min_terms" significant species.
Parameters
----------
columns : str
Columns to consider
index : str, list
The column with which to filter by counts
n_sig : int
Number of terms required to not be filtered
inplace : bool
Filter in place or return a copy of the filtered data
verbose : bool
Returns
-------
new_data : BaseData
"""
if index is None:
index = self._index
# create safe copy of array
new_data = self.copy()
# get list of columns
cols_to_check = list(new_data[columns].unique())
if flag not in new_data.columns:
raise AssertionError('Requires significant column')
# pivot
sig = pd.pivot_table(new_data,
index=index,
fill_value=0,
values=flag,
columns=columns
)[cols_to_check]
# convert everything that's not 0 to 1
sig[sig > 0] = 1
sig = sig[sig.T.sum() >= n_sig]
if isinstance(index, list):
keepers = {i[0] for i in sig.index.values}
new_data = new_data[new_data[index[0]].isin(keepers)]
elif isinstance(index, str):
n_before = len(new_data[index].unique())
keepers = {i for i in sig.index.values}
new_data = new_data.loc[new_data[index].isin(keepers)]
n_after = len(new_data[index].unique())
if verbose:
print("Number in index went from {} to {}"
"".format(n_before, n_after))
else:
print("Index is not a str or a list. What is it?")
if inplace:
self._update_inplace(new_data)
else:
return new_data
[docs] def present_in_all_columns(self, columns='sample_id',
index=None, inplace=False):
""" Require index to be present in all columns
Parameters
----------
columns : str
Columns to consider
index : str, list
The column with which to filter by counts
inplace : bool
Filter in place or return a copy of the filtered data
Returns
-------
new_data : BaseData
"""
if index is None:
index = self._index
# create safe copy of array
new_data = self.copy()
n_before = len(new_data[index].unique())
# get list of columns
cols_to_check = list(new_data[columns].unique())
if flag not in new_data.columns:
raise AssertionError("Missing {} column in data".format(flag))
# pivot
pivoted_df = pd.pivot_table(new_data, index=index, fill_value=np.nan,
values=flag, columns=columns
)[cols_to_check]
# sig = pivoted_df.loc[~np.any(np.isnan(pivoted_df.values), axis=1)]
sig = pivoted_df.loc[~pivoted_df.isnull().T.any()]
if isinstance(index, list):
keepers = {i[0] for i in sig.index.values}
new_data = new_data[new_data[index[0]].isin(keepers)]
elif isinstance(index, str):
keepers = {i for i in sig.index.values}
new_data = new_data.loc[new_data[index].isin(keepers)]
else:
print("Index is not a str or a list. What is it?")
n_after = len(new_data[index].unique())
print("Number in index went from {} to {}".format(n_before, n_after))
if inplace:
self._update_inplace(new_data)
else:
return new_data
[docs] def log2_normalize_df(self, column='fold_change', inplace=False):
""" Convert "fold_change" column to log2.
Does so by taking log2 of all positive values and -log2 of all negative
values.
Parameters
----------
column : str
Column to convert
inplace : bool
Where to apply log2 in place or return new dataframe
Returns
-------
"""
new_data = self.copy()
greater = new_data[column] > 0
less = new_data[column] < 0
new_data.loc[greater, column] = np.log2(new_data[greater][column])
new_data.loc[less, column] = -np.log2(-new_data[less][column])
if inplace:
self._update_inplace(new_data)
else:
return new_data
[docs] def heatmap(self, subset=None, subset_index=None,
convert_to_log=True, y_tick_labels='auto',
cluster_row=False, cluster_col=False, cluster_by_set=False,
index=None, values=None, columns=None,
annotate_sig=True, figsize=(8, 12), div_colors=True,
linewidths=0, num_colors=21, sort_row=None, min_sig=0,
rank_index=None):
""" Creates heatmap of data, providing pivot and formatting.
Parameters
----------
subset : list or str
Will filter to only contain a provided list.
If a str, will filter based on .contains(subset)
subset_index : str
Index to for subset list to match against
convert_to_log : bool
Convert values to log2 scale
y_tick_labels : str
Column of values, default = 'auto'
cluster_row : bool
cluster_col : bool
cluster_by_set : bool
Clusters by gene set, only used in EnrichmentResult derived class
index : str
Index of heatmap, will be 'row' variables
values : str
Values to display in heatmap
columns : str
Value that will be used as columns
annotate_sig : bool
Add '+' annotation to not 'significant=True' column
figsize : tuple
Figure size to pass to matplotlib
div_colors : bool
Use colors that are divergent
(red to blue, instead of shades of blue)
num_colors : int
How many colors to include on color bar
linewidths : float
line width between individual cols and rows
sort_row : str
Rank by 'mean', 'max', 'min' or 'index'
min_sig : int
Minimum number of significant 'index' across samples. Can be used to
remove rows that are not significant across any sample.
rank_index : bool
Deprecated, please use sort_row='index' to sort by alphabetically
Returns
-------
matplotlib.figure
"""
if rank_index is not None:
raise DeprecationWarning("Please use sort_row='index'")
if index is None:
index = self._identifier
if values is None:
values = self._value_name
if columns is None:
columns = self._sample_id_name
df = self.copy()
if subset is not None:
if subset_index is None:
subset_index = index
if isinstance(subset, str):
df = df.loc[df[subset_index].str.contains(subset)]
else:
df = df.loc[df[subset_index].isin(subset)]
if not df.shape[0]:
print("No terms match subset")
return
return heatmap_from_array(
df, convert_to_log=convert_to_log,
y_tick_labels=y_tick_labels,
cluster_row=cluster_row, cluster_col=cluster_col,
cluster_by_set=cluster_by_set, figsize=figsize,
columns=columns, index=index, values=values,
div_colors=div_colors, num_colors=num_colors,
sort_row=sort_row, annotate_sig=annotate_sig,
linewidths=linewidths, min_sig=min_sig, rank_index=rank_index
)