Source code for binoculars.plot

from typing import List, Optional, Union

import numpy
import matplotlib.colors
import matplotlib.cm
import mpl_toolkits.mplot3d

from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.lines import Line2D

from binoculars.space import Space

# Adapted from http://www.ster.kuleuven.be/~pieterd/python/html/plotting/interactive_colorbar.html
# which in turn is based on an example from http://matplotlib.org/users/event_handling.html


[docs]class DraggableColorbar(object): def __init__(self, cbar, mappable): self.cbar = cbar self.mappable = mappable self.press = None self.cycle = sorted( [i for i in dir(matplotlib.cm) if hasattr(getattr(matplotlib.cm, i), "N")] ) try: # matploltib 2.x cmap_name = cbar.get_cmap().name except AttributeError: # matplotlib 3.x cmap_name = mappable.get_cmap().name self.index = self.cycle.index(cmap_name) self.canvas = self.cbar.patch.figure.canvas
[docs] def connect(self): self.cidpress = self.canvas.mpl_connect("button_press_event", self.on_press) self.cidrelease = self.canvas.mpl_connect( "button_release_event", self.on_release ) self.cidmotion = self.canvas.mpl_connect("motion_notify_event", self.on_motion) self.cidkeypress = self.canvas.mpl_connect("key_press_event", self.key_press)
[docs] def disconnect(self): self.canvas.mpl_disconnect(self.cidpress) self.canvas.mpl_disconnect(self.cidrelease) self.canvas.mpl_disconnect(self.cidmotion) self.canvas.mpl_disconnect(self.cidkeypress)
[docs] def on_press(self, event): if event.inaxes == self.cbar.ax: self.press = event.x, event.y
[docs] def key_press(self, event): if event.key == "down": self.index += 1 elif event.key == "up": self.index -= 1 if self.index < 0: self.index = len(self.cycle) elif self.index >= len(self.cycle): self.index = 0 cmap = self.cycle[self.index] self.mappable.set_cmap(cmap) self.cbar.patch.figure.canvas.draw()
[docs] def on_motion(self, event): if self.press is None or event.inaxes != self.cbar.ax: return xprev, yprev = self.press # dx = event.x - xprev # unused for now dy = event.y - yprev self.press = event.x, event.y if isinstance(self.cbar.norm, matplotlib.colors.LogNorm): scale = 0.999 * numpy.log10(self.cbar.norm.vmax / self.cbar.norm.vmin) if event.button == 1: self.cbar.norm.vmin *= scale ** numpy.sign(dy) self.cbar.norm.vmax *= scale ** numpy.sign(dy) elif event.button == 3: self.cbar.norm.vmin *= scale ** numpy.sign(dy) self.cbar.norm.vmax /= scale ** numpy.sign(dy) else: scale = 0.03 * (self.cbar.norm.vmax - self.cbar.norm.vmin) if event.button == 1: self.cbar.norm.vmin -= scale * numpy.sign(dy) self.cbar.norm.vmax -= scale * numpy.sign(dy) elif event.button == 3: self.cbar.norm.vmin -= scale * numpy.sign(dy) self.cbar.norm.vmax += scale * numpy.sign(dy) self.mappable.set_norm(self.cbar.norm) self.canvas.draw()
[docs] def on_release(self, event): # force redraw on mouse release self.press = None self.mappable.set_norm(self.cbar.norm) self.canvas.draw()
[docs]def get_clipped_norm(data, clipping=0.0, log=True): if hasattr(data, "compressed"): data = data.compressed() else: data = data.flatten() if log: data = data[data > 0] if len(data) == 0: return matplotlib.colors.LogNorm(1, 10) if clipping: chop = int(round(data.size * clipping)) clip = sorted(data)[chop : -(1 + chop)] vmin, vmax = clip[0], clip[-1] else: vmin, vmax = data.min(), data.max() if log: return matplotlib.colors.LogNorm(vmin, vmax) else: return matplotlib.colors.Normalize(vmin, vmax)
[docs]def plot( space : Space, fig: Figure, ax : Axes, log: bool=True, loglog: bool=False, clipping: float=0.0, fit: Optional[bool]=None, norm: Optional[float]=None, colorbar: bool=True, labels: bool=True, interpolation: str="nearest", **plotopts ) -> Union[List[Line2D]]: if space.dimension == 1: data = space.get_masked() xrange = numpy.ma.array(space.axes[0][:], mask=data.mask) if fit is not None: if log: p1 = ax.semilogy(xrange, data, "wo", **plotopts) p2 = ax.semilogy(xrange, fit, "r", linewidth=2, **plotopts) elif loglog: p1 = ax.loglog(xrange, data, "wo", **plotopts) p2 = ax.loglog(xrange, fit, "r", linewidth=2, **plotopts) else: p1 = ax.plot(xrange, data, "wo", **plotopts) p2 = ax.plot(xrange, fit, "r", linewidth=2, **plotopts) else: if log: p1 = ax.semilogy(xrange, data, **plotopts) elif loglog: p1 = ax.loglog(xrange, data, **plotopts) else: p1 = ax.plot(xrange, data, **plotopts) p2 = [] if labels: ax.set_xlabel(space.axes[0].label) ax.set_ylabel("Intensity (a.u.)") return p1 + p2 elif space.dimension == 2: data = space.get_masked() # 2D IMSHOW PLOT xmin = space.axes[0].min xmax = space.axes[0].max ymin = space.axes[1].min ymax = space.axes[1].max if not norm: norm = get_clipped_norm(data, clipping, log) if fit is not None: im = ax.imshow( fit.transpose(), origin="lower", extent=(xmin, xmax, ymin, ymax), aspect="auto", norm=norm, interpolation=interpolation, **plotopts ) else: im = ax.imshow( data.transpose(), origin="lower", extent=(xmin, xmax, ymin, ymax), aspect="auto", norm=norm, interpolation=interpolation, **plotopts ) if labels: ax.set_xlabel(space.axes[0].label) ax.set_ylabel(space.axes[1].label) if colorbar: cbarwidget = fig.colorbar(im) fig._draggablecbar = DraggableColorbar( cbarwidget, im ) # we need to store this instance somewhere fig._draggablecbar.connect() return im elif space.dimension == 3: if not isinstance(ax, mpl_toolkits.mplot3d.Axes3D): raise ValueError( "For 3D plots, the 'ax' parameter must be an Axes3D instance (use for example gca(projection='3d') to get one)" ) cmap = getattr(matplotlib.cm, plotopts.pop("cmap", "jet")) if norm is None: norm = get_clipped_norm(space.get_masked(), clipping, log) data = space.get() mask = numpy.bitwise_or(~numpy.isfinite(data), data == 0) gridx, gridy, gridz = tuple(grid[~mask] for grid in space.get_grid()) im = ax.scatter( gridx, gridy, gridz, c=cmap(norm(data[~mask])), marker=",", alpha=0.7, linewidths=0, ) # p1 = ax.plot_surface(gridx[0,:,:], gridy[0,:,:], gridz[0,:,:], facecolors=cmap(norm(space.project(0).get_masked())), shade=False, cstride=1, rstride=1) # p2 = ax.plot_surface(gridx[:,-1,:], gridy[:,-1,:], gridz[:,-1,:], facecolors=cmap(norm(space.project(1).get_masked())), shade=False, cstride=1, rstride=1) # p3 = ax.plot_surface(gridx[:,:,0], gridy[:,:,0], gridz[:,:,0], facecolors=cmap(norm(space.project(2).get_masked())), shade=False, cstride=1, rstride=1) if labels: ax.set_xlabel(space.axes[0].label) ax.set_ylabel(space.axes[1].label) ax.set_zlabel(space.axes[2].label) if fig._draggablecbar: fig._draggablecbar.disconnect() return im elif space.dimension > 3: raise ValueError( "Cannot plot 4 or higher dimensional spaces, use projections or slices to decrease dimensionality." )