Source code for vivarium.plots.colonies

'''
=====================
Colony Plotting Tools
=====================

This module contains tools to help you plot colony data.
'''

from __future__ import division, absolute_import, print_function

from matplotlib import pyplot as plt
import numpy as np
from scipy import stats

from vivarium.library.units import remove_units
from vivarium.processes.derive_colony_shape import Variables


INCH_PER_COL = 4
INCH_PER_ROW = 2
SUBPLOT_W_SPACE = 0.4
SUBPLOT_H_SPACE = 1.5


#: Key for circumference in path timeseries
CIRCUMFERENCE_PATH = (Variables.CIRCUMFERENCE,)
#: Key for surface area in path timeseries
AREA_PATH = (Variables.AREA,)
#: Key for major axis in path timeseries
MAJOR_AXIS_PATH = (Variables.MAJOR_AXIS,)
#: Key for minor axis in path timeseries
MINOR_AXIS_PATH = (Variables.MINOR_AXIS,)
#: Key to which the cirumference-to-area ratio will be written in the
#: path timeseries
CIRCUMFERENCE_AREA_RATIO_PATH = 'circumference / surface_area'
#: Key to which the number of colonies is written in the path timeseries
NUM_COLONIES_PATH = 'Number of Colonies'
#: Key to which the ratio of major axis to minor axis is written in the
#: path timeseries
AXIS_RATIO_PATH = '(Major Axis) / (Minor Axis)'


# METRIC DERIVERS

def _derive_circumference_area_ratio(path_ts):
    if not (CIRCUMFERENCE_PATH in path_ts and AREA_PATH in path_ts):
        return
    circumference = path_ts[CIRCUMFERENCE_PATH]
    area = path_ts[AREA_PATH]
    ratio = [
        [
            c / a
            for c, a in zip(circumference_list, area_list)
        ]
        for circumference_list, area_list in zip(circumference, area)
    ]
    path_ts[CIRCUMFERENCE_AREA_RATIO_PATH] = ratio

def _derive_axis_ratio(path_ts):
    if not (MAJOR_AXIS_PATH in path_ts and MINOR_AXIS_PATH in path_ts):
        return
    major = path_ts[MAJOR_AXIS_PATH]
    minor = path_ts[MINOR_AXIS_PATH]
    ratio = [
        [
            major_val / minor_val
            for major_val, minor_val in zip(major_list, minor_list)
        ]
        for major_list, minor_list in zip(major, minor)
    ]
    path_ts[AXIS_RATIO_PATH] = ratio


#: List of metric derivers that will be applied to the path timeseries
_METRIC_DERIVERS = [
    _derive_circumference_area_ratio,
    _derive_axis_ratio,
]


[docs]def plot_colony_metrics( path_ts, title_size=16, tick_label_size=12, max_cols=5 ): '''Plot colony metrics over time. Metric mean is plotted with SEM error bands. Arguments: path_ts (dict): Path timeseries of the data to plot. Each item in the dictionary should have as its key the path and as its value a list of values for each timepoint. Each value should be a list of metric values, one entry per colony. The dictionary should have one additional key, ``time``, whose value is a list of times for each timepoint. title_size (int): Font size for the title of each plot tick_label_size (int): Font size for each plot's axis tick labels. max_cols (int): The maximum number of columns. We add columns until we hit this limit, and only then do we add rows. Returns: matplotlib.figure.Figure: The plot as a Figure object. ''' path_ts = remove_units(path_ts) for deriver in _METRIC_DERIVERS: deriver(path_ts) times = path_ts['time'] del path_ts['time'] # path_ts has tuples for keys. Here we turn those into strings so # that numpy doesn't iterate through the path elements path_ts = { str(key): val for key, val in path_ts.items() } arbitrary_metric = list(path_ts.keys())[0] path_ts[NUM_COLONIES_PATH] = [ len(timepoint) for timepoint in path_ts[arbitrary_metric] ] # Create Figure paths = sorted(path_ts.keys()) n_cols = min(len(paths), max_cols) n_rows = int(np.ceil(len(paths) / n_cols)) fig = plt.figure( figsize=(INCH_PER_COL * n_cols, INCH_PER_ROW * n_rows)) grid = plt.GridSpec( ncols=n_cols, nrows=n_rows, wspace=SUBPLOT_W_SPACE, hspace=SUBPLOT_H_SPACE ) # Assign paths to subplot coordinates padding = [None] * int(n_cols * n_rows - len(paths)) paths += padding paths_grid = np.array(paths) paths_grid = paths_grid.reshape((n_rows, n_cols)) # Create the subplots for i in range(n_rows): for j in range(n_cols): path = paths_grid[i, j] if path is None: continue ax = fig.add_subplot(grid[i, j]) # Configure axes and titles for tick_type in ('major', 'minor'): ax.tick_params( axis='both', which=tick_type, labelsize=tick_label_size ) ax.title.set_text(path) ax.title.set_fontsize(title_size) ax.set_xlim([times[0], times[-1]]) ax.xaxis.get_offset_text().set_fontsize(tick_label_size) ax.yaxis.get_offset_text().set_fontsize(tick_label_size) ax.set_xlabel('time (s)', fontsize=title_size) # Plot data data = path_ts[path] if path == NUM_COLONIES_PATH: ax.plot(times, data) else: means = [] sems = [] plot_times = [] for i_time, metrics_list in enumerate(data): if not metrics_list: continue array = np.array(metrics_list) means.append(np.mean(array)) sems.append( stats.sem(array) if len(array) > 1 else 0) plot_times.append(times[i_time]) x = np.array(plot_times) y = np.array(means) yerr = np.array(sems) yerr[np.isnan(yerr)] = 0 ax.plot(x, y) ax.fill_between(x, y - yerr, y + yerr, alpha=0.2) return fig
[docs]def plot_metric_across_experiments( path_ts_dict, path, title=None, xlabel='time (s)', ylabel=None, title_size=16, tick_label_size=12, ): '''Overlay plots of a single metric from different experiments. Parameters: path_ts_dict (dict): Map from the string to use as the label for the experiment in the legend to that experiment's path timeseries. path (tuple): Path to plot. Should be a key in each value of ``path_ts_dict``. title (str): Plot title. If None, no title is set. xlabel (str): X-axis label. If None, no label is set. ylabel (str): Y-axis label. If None, no label is set. title_size (float): Font size for plot and axis titles. tick_label_size (float): Font size for tick labels. Returns: The figure with the plot. ''' fig, ax = plt.subplots() # Set labels and font sizes if title is not None: ax.set_title(title, fontsize=title_size) if xlabel is not None: ax.set_xlabel(xlabel, fontsize=title_size) if ylabel is not None: ax.set_ylabel(ylabel, fontsize=title_size) ax.xaxis.get_offset_text().set_fontsize(tick_label_size) ax.yaxis.get_offset_text().set_fontsize(tick_label_size) # Plot data for label, path_ts in path_ts_dict.items(): data = path_ts[path] times = path_ts['time'] if path == NUM_COLONIES_PATH: ax.plot(times, data, label=label) else: plot_times = [] means = [] for i, metrics_list in enumerate(data): if not metrics_list: continue means.append(np.mean(metrics_list)) plot_times.append(times[i]) ax.plot(plot_times, means, label=label) ax.legend() fig.tight_layout() return fig