Source code for pylbo.visualisation.matrices

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from pylbo.visualisation.figure_window import FigureWindow


[docs] class MatrixFigure(FigureWindow): """Figure showing both matrices from a dataset.""" def __init__(self, dataset, figsize, **kwargs): fig, ax = super().create_default_figure(figlabel="matrices", figsize=figsize) super().__init__(fig) self.dataset = dataset self.kwargs = kwargs self.ax = ax self.ax2 = super().add_subplot_axes(self.ax, loc="right") self.draw()
[docs] def draw(self): """Draws the matrices.""" # matrix A rows, cols, vals = self.dataset.get_matrix_A() # take modulus of values vals = np.absolute(vals) im = self.ax.scatter(cols, rows, c=vals, s=6, cmap="plasma", norm=LogNorm()) self.ax.set_title("Matrix A (modulus)") divider = make_axes_locatable(self.ax) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im, cax=cax) # matrix B rows, cols, vals = self.dataset.get_matrix_B() im = self.ax2.scatter(cols, rows, c=vals, s=6, cmap="plasma", norm=LogNorm()) self.ax2.set_title("Matrix B") divider = make_axes_locatable(self.ax2) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im, cax=cax) self.fig.canvas.draw() dim_subblock = self.dataset.header["dims"]["dim_subblock"] dim_quadblock = self.dataset.header["dims"]["dim_quadblock"] dim_matrix = self.dataset.header["dims"]["dim_matrix"] for ax in (self.ax, self.ax2): maingrid_ticks = np.arange( 0.5, dim_matrix + dim_subblock + 0.5, dim_subblock ) for i in maingrid_ticks: ax.vlines(x=i, color="grey", alpha=0.6, ymin=0.5, ymax=dim_matrix + 0.5) ax.hlines(y=i, color="grey", alpha=0.6, xmin=0.5, xmax=dim_matrix + 0.5) minorgrid_ticks = np.arange(0.5, dim_matrix + 0.5, 2) for i in minorgrid_ticks: ax.vlines(x=i, color="grey", alpha=0.1, ymin=0.5, ymax=dim_matrix + 0.5) ax.hlines(y=i, color="grey", alpha=0.1, xmin=0.5, xmax=dim_matrix + 0.5) visualticks = np.arange(0, dim_matrix + 0.1, dim_quadblock) ax.set_xticks(visualticks) ax.set_yticks(visualticks) ax.set_xlim(0, dim_matrix + 1) ax.set_ylim(0, dim_matrix + 1) ax.tick_params(which="both", labelsize=13) ax.set_aspect("equal") ax.invert_yaxis()