Source code for gneiss.plot._heatmap

# ----------------------------------------------------------------------------
# Copyright (c) 2016--, gneiss development team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file COPYING.txt, distributed with this software.
# ----------------------------------------------------------------------------
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import pandas as pd
from gneiss.plot._dendrogram import SquareDendrogram
from gneiss.util import match_tips, NUMERATOR, DENOMINATOR


[docs]def heatmap(table, tree, mdvar, highlights=None, cmap='viridis', linewidth=0.5, grid_col='w', grid_width=2, highlight_width=0.02, figsize=(5, 5)): """ Creates heatmap plotting object Parameters ---------- table : pd.DataFrame Contain sample/feature labels along with table of values. Rows correspond to samples, and columns correspond to features. tree: skbio.TreeNode Tree representing the feature hierarchy. highlights: pd.DataFrame or dict of tuple of str List of internal nodes in the tree to highlight. Each internal node must contain two colors, one for the left subtree and the other for the right subtree highlight. The first color will always correspond to the left subtree, and the second color will always correspond to the right subtree. cmap : str Specifies the matplotlib colormap for the heatmap (default='viridis') linewidth : int Width of dendrogram lines. mdvar: pd.Series Metadata values for samples. The index must correspond to the index of `table`. highlight_width : int Width of highlights. (default=0.02) grid_col: str Color of vertical lines for highlighting sample metadata. (default='w') grid_width: int Width of vertical lines for highlighting sample metadata. (default=2) figsize: tuple of int Species (width, height) for figure. (default=(5, 5)) Returns ------- matplotlib.pyplot.figure Matplotlib figure object Note ---- The highlights parameter assumes that the tree is bifurcating. """ # match the tips dendrogram_width = 20 table, tree = match_tips(table, tree) table = table.T # get edges from tree t = SquareDendrogram.from_tree(tree) t = _tree_coordinates(t) pts = t.coords(width=dendrogram_width, height=table.shape[0]) edges = pts[['child0', 'child1']] edges = edges.dropna(subset=['child0', 'child1']) edges = edges.unstack() edges = pd.DataFrame({'src_node': edges.index.get_level_values(1), 'dest_node': edges.values}) edge_list = [] for i in edges.index: src = edges.loc[i, 'src_node'] dest = edges.loc[i, 'dest_node'] sx, sy = pts.loc[src].x, pts.loc[src].y dx, dy = pts.loc[dest].x, pts.loc[dest].y edge_list.append( {'x0': sx, 'y0': sy, 'x1': sx, 'y1': dy} ) edge_list.append( {'x0': sx, 'y0': dy, 'x1': dx, 'y1': dy} ) edge_list = pd.DataFrame(edge_list) # now plot the stuff fig = plt.figure(figsize=figsize, facecolor='white') plt.rcParams['axes.facecolor'] = 'white' xwidth = 0.2 top_buffer = 0.1 height = 0.8 # heatmap axes [axm_x, axm_y, axm_w, axm_h] = [0, top_buffer, xwidth, height] # create a split for the highlights if highlights is not None: h = len(highlights) else: h = 0 hwidth = highlight_width [axs_x, axs_y, axs_w, axs_h] = [xwidth, top_buffer, hwidth * h, height] # dendrogram axes on the right side hstart = xwidth + (h * hwidth) # beginning of heatmap [ax1_x, ax1_y, ax1_w, ax1_h] = [hstart, top_buffer, 1-hstart, height] # plot heatmap ax_heatmap = fig.add_axes([ax1_x, ax1_y, ax1_w, ax1_h], frame_on=True) _plot_heatmap(ax_heatmap, table, mdvar, grid_col, grid_width, cmap) # plot dendrogram ax_dendrogram = fig.add_axes([axm_x, axm_y, axm_w, axm_h], frame_on=True, sharey=ax_heatmap) _plot_dendrogram(ax_dendrogram, table, edge_list, linewidth=linewidth) # plot highlights for dendrogram if highlights is not None: ax_highlights = fig.add_axes([axs_x, axs_y, axs_w, axs_h], frame_on=True, sharey=ax_heatmap) _plot_highlights_dendrogram(ax_highlights, table, t, highlights) return fig
# TODO: Refactor and place in utils. This can be also # be used for the balance_basis calculations def _tree_coordinates(t): """ Builds a matrix to link tree positions to matrix""" # first traverse the tree to count the children for n in t.postorder(include_self=True): if n.is_tip(): n._n_tips = 1 else: n._n_tips = sum(c._n_tips for c in n.children) for i, n in enumerate(t.levelorder(include_self=True)): if n.is_root(): n._k = 0 n._t = 0 else: if n is n.parent.children[NUMERATOR]: n._k = n.parent._k + n.parent._r n._t = n.parent._t else: n._k = n.parent._k n._t = n.parent._t + n.parent._l if n.is_tip(): continue n._l = n.children[NUMERATOR]._n_tips n._r = n.children[DENOMINATOR]._n_tips return t def _plot_highlights_dendrogram(ax_highlights, table, t, highlights): """ Plots highlights for subtrees in the dendrograms. Note that this assumes that the dendrograms are strictly bifurcating and the highlights only specify the children for a given subtree. """ offset = 0.5 num_h = len(highlights) hcoords = [] for i, n in enumerate(highlights.index): node = t.find(n) k, l, r = node._k, node._l, node._r ax_highlights.add_patch( patches.Rectangle( (i/num_h, k-offset), # x, y 1/num_h, # width r, # height facecolor=highlights.iloc[i, 0] )) ax_highlights.add_patch( patches.Rectangle( (i/num_h, k+r-offset), # x, y 1/num_h, # width l, # height facecolor=highlights.iloc[i, 1] )) hcoords.append((i+offset)/num_h) ax_highlights.set_ylim([-offset, table.shape[0]-offset]) ax_highlights.set_yticks([]) ax_highlights.set_xticks(hcoords) ax_highlights.set_xticklabels(highlights.index, rotation=90) def _plot_dendrogram(ax_dendrogram, table, edges, linewidth=1): """ Plots the actual dendrogram. Parameters ---------- ax_dendrogram : matplotlib axes object Contains the matplotlib axes in which the dendrogram will be plotted. table : pd.DataFrame Contain sample/feature labels along with table of values. Rows correspond to samples, and columns correspond to features. edges : pd.DataFrame (x,y) coordinates for edges in the heatmap. """ offset = 0.5 for i in range(len(edges.index)): row = edges.iloc[i] ax_dendrogram.plot([row.x0, row.x1], [row.y0-offset, row.y1-offset], '-k', lw=linewidth) ax_dendrogram.set_ylim([-offset, table.shape[0]-offset]) ax_dendrogram.set_yticks([]) ax_dendrogram.set_xticks([]) ax_dendrogram.axis('off') def _sort_table(table, mdvar): """ Sorts metadata category and aligns with table. Parameters ---------- table : pd.DataFrame Contain sample/feature labels along with table of values. Rows correspond to samples, and columns correspond to features. mdvar : pd.Series Metadata values for samples. The index must correspond to the index of `table`. Returns ------- pd.DataFrame Aligned feature table. pd.Series Aligned metadata. """ mdvar = mdvar.sort_values() table = table.reindex(columns=mdvar.index) return table, mdvar def _plot_heatmap(ax_heatmap, table, mdvar, grid_col, grid_width, cmap): """ Sorts metadata category and aligns with table. Parameters ---------- ax_heatmap : matplotlib axes object Contains the matplotlib axes in which the heatmap will be plotted. table : pd.DataFrame Contain sample/feature labels along with table of values. Rows correspond to samples, and columns correspond to features. mdvar : pd.Series Metadata values for samples. The index must correspond to the index of `table`. grid_col: str Color of vertical lines for highlighting sample metadata. (default='w') grid_width: int Width of vertical lines for highlighting sample metadata. (default=2) """ # TODO add explicit test for this, since matplotlib orientation # is from top to down (i.e. is backwards) table, mdvar = _sort_table(table, mdvar) table = table.iloc[::-1, :] ax_heatmap.imshow(table, aspect='auto', interpolation='nearest', cmap=cmap) ax_heatmap.set_ylim([0, table.shape[0]]) vcounts = mdvar.value_counts() # ensure that the ordering is the same vcounts = vcounts.sort_index() ticks = vcounts.sort_index().cumsum() midpoints = ticks - (ticks - np.array([0] + list(ticks.values[:-1]))) / 2.0 ax_heatmap.set_xticks(ticks.values-0.5, minor=False) ax_heatmap.set_xticklabels([], minor=False) ax_heatmap.xaxis.grid(True, which='major', color=grid_col, linestyle='-', linewidth=grid_width) ax_heatmap.set_xticks(midpoints-0.5, minor=True) ax_heatmap.set_xticklabels(vcounts.index, minor=True) ax_heatmap.set_xlabel(mdvar.name)