Source code for schlichtanders.mymatplotlib

""" plotting helpers """
from __future__ import division
from numpy import ma
import matplotlib as mpl
from matplotlib.colors import Normalize
import matplotlib.pyplot as plt
from contextlib import contextmanager
import numpy as np
from collections import Counter

__author__ = 'Stephan Sahm <Stephan.Sahm@gmx.de>'


# interactive plotting
# ====================

[docs]def plt_sync(plt_object, margin_size=0.1): """ updates a plot dynamically given a certain axes object (e.g. a line) """ plt_object.axes.relim() plt_object.axes.autoscale_view() plt_object.axes.margins(margin_size, margin_size) plt_object.figure.tight_layout() plt_object.figure.canvas.draw()
[docs]def add_val(hl, val, iteration_nr=None, update_fig=True): """ adds a value to a given line handle with x=#iteartion in a dynamic way default iteration nr is extracted from ``hl`` """ if iteration_nr is None: iteration_nr = len(hl.get_xdata()) add_point(hl, x=iteration_nr, y=val, update_fig=update_fig)
[docs]def add_point(hl, x, y, update_fig=True): """ adds a value to a given line handle in a dynamic way """ hl.set_xdata(np.append(hl.get_xdata(), x)) hl.set_ydata(np.append(hl.get_ydata(), y)) if update_fig: plt_sync(hl)
# colormaps # =========
[docs]class Centre(Normalize): """ a Normalizer which centers a colorbar, etc., around a midpoint """ def __init__(self, midpoint=0, vmin=None, vmax=None, clip=False): Normalize.__init__(self, vmin, vmax, clip) self.midpoint = midpoint def __call__(self, value, clip=None): if clip is None: clip = self.clip result, is_scalar = self.process_value(value) self.autoscale_None(result) vmin, vmax, midpoint = self.vmin, self.vmax, self.midpoint # if not (vmin < midpoint < vmax): # raise ValueError("midpoint must be between maxvalue and minvalue.") if midpoint < vmin: vmin = midpoint elif vmax < midpoint: vmax = midpoint elif vmin == vmax: result.fill(0) # Or should it be all masked? Or 0.5? elif vmin > vmax: raise ValueError("maxvalue must be bigger than minvalue") else: vmin = float(vmin) vmax = float(vmax) if clip: mask = ma.getmask(result) result = ma.array(np.clip(result.filled(vmax), vmin, vmax), mask=mask) # ma division is very slow; we can take a shortcut resdat = result.data #First scale to -1 to 1 range, than to from 0 to 1. resdat -= midpoint resdat[resdat>0] /= abs(vmax - midpoint) resdat[resdat<0] /= abs(vmin - midpoint) resdat /= 2. resdat += 0.5 result = ma.array(resdat, mask=result.mask, copy=False) if is_scalar: result = result[0] return result
[docs] def inverse(self, value): if not self.scaled(): raise ValueError("Not invertible until scaled") vmin, vmax, midpoint = self.vmin, self.vmax, self.midpoint if mpl.cbook.iterable(value): val = ma.asarray(value) val = 2 * (val-0.5) val[val>0] *= abs(vmax - midpoint) val[val<0] *= abs(vmin - midpoint) val += midpoint return val else: val = 2 * (val - 0.5) if val < 0: return val*abs(vmin-midpoint) + midpoint else: return val*abs(vmax-midpoint) + midpoint
# general helpers # ===============
[docs]def gca(): """ this version of gca won't create a new axis if none exists, however return None instead""" fig = plt.gcf() ckey, cax = fig._axstack.current_key_axes() if cax is None: return None return plt.gca()
@contextmanager
[docs]def useax(ax): """ sets current axes within scope and restores old axes on exits (can be nested) """ if ax is not None: old_ax = gca() plt.sca(ax) yield if old_ax is not None: plt.sca(old_ax) else: yield
[docs]def to_colors(cluster, ignore_cnt=5): # group all small cluster together: cluster_count = Counter(cluster) ignore = set(k for k, v in cluster_count.items() if v <= ignore_cnt) unique_cluster = np.unique(cluster) unique_colors = { k: (0.0, 0.0, 0.0, 0.75) if k in ignore else plt.cm.Set1(l) # (0.5, 0.5, 0.5, 0.75) for k, l in zip(unique_cluster, np.linspace(0, 1, num=len(unique_cluster), endpoint=False)) } return np.array([unique_colors[k] for k in cluster])