Source code for driftplots.driftplotter

from __future__ import annotations

from pathlib import Path
from typing import Literal

from matplotlib.figure import Figure
from pyqtgraph.Qt import QtWidgets

from driftplots import mpl_plotting
from driftplots.data_loader import DataLoader
from driftplots.interactive.driftmap_plot_widget import DriftmapPlotWidget

# test ideas:
# check signatures match default args between interactive and matplotlib


[docs] class DriftPlotter: """Load Kilosort sorter output and provide interactive or static drift map plots. On construction, spike data is loaded from a Kilosort output directory and stored as read-only arrays. Plotting methods apply optional filtering (noise exclusion, amplitude filtering, decimation) before handing the data to a plot backend. Parameters ---------- sorter_path Path to a Kilosort sorter output directory. Must contain exactly one ``kilosort*.log`` file used to detect the KS version. Attributes ---------- spike_times (num_spikes,) spike times (seconds for KS 1-3, samples for KS4). spike_amplitudes (num_spikes,) spike amplitudes. spike_depths (num_spikes,) spike depths along the probe (µm). spike_templates (num_spikes,) template or unit id assigned to each spike. templates (num_templates, num_samples, num_channels) template waveforms. channel_locations (num_channels, 2) x/y positions of each channel on the probe. """ def __init__(self, sorter_path: str | Path) -> None: """Load spike data from a Kilosort output directory. Parameters ---------- sorter_path Path to the Kilosort sorter output. Raises ------ AssertionError If the directory does not contain exactly one ``kilosort*.log`` file, or if the loaded spike arrays have mismatched sizes. """ self.data_loader = DataLoader(sorter_path) # TODO: rename
[docs] def drift_map_plot_interactive( self, decimate: int | bool | None | Literal["estimate"] = "estimate", exclude_noise: bool | str = False, amplitude_cmap_scaling: str | tuple[float, float] = "linear", n_color_bins: int = 20, point_size: float = 5.0, filter_amplitude_mode: str | None = None, filter_amplitude_values: tuple[float, ...] = (), title: bool | str | None = None, ) -> DriftmapPlotWidget: """Create an interactive pyqtgraph-based drift map widget. Parameters ---------- decimate Keep every *n*-th spike. Too many spikes will slow down the plot. if ``"estimate"` the number of spikes will be decimated to a reasonable range (e.g. 50,000). ``False``, ``None`` or ``0`` disables decimation.` Otherwise pass an integer e.g. 2 to keep every 2nd spike. exclude_noise Remove spikes labelled as noise. amplitude_cmap_scaling Colour-scaling mode or explicit ``(min, max)`` range. n_color_bins Number of grey-scale colour bins for amplitude. point_size Scatter-point diameter in pixels. filter_amplitude_mode Amplitude filtering mode. filter_amplitude_values Bounds for amplitude filtering. title Title of the plot Returns ------- DriftmapPlotWidget The pyqtgraph widget. This is already populated but not yet shown, use app.exec() to display. """ app = QtWidgets.QApplication.instance() or QtWidgets.QApplication([]) processed_data = self.data_loader.get_processed_data( exclude_noise, decimate, filter_amplitude_mode, filter_amplitude_values ) self.plot = DriftmapPlotWidget( processed_data, app, amplitude_cmap_scaling=amplitude_cmap_scaling, n_color_bins=n_color_bins, point_size=point_size, title=title, ) return self.plot
[docs] def drift_map_plot_matplotlib( self, decimate: int | bool | None | Literal["estimate"] = "estimate", exclude_noise: bool | str = False, amplitude_cmap_scaling: str | tuple[float, float] = "linear", n_color_bins: int = 20, point_size: float = 5.0, filter_amplitude_mode: str | None = None, filter_amplitude_values: tuple[float, ...] = (), add_histogram_plot: bool = False, weight_histogram_by_amplitude: bool = False, title: bool | str | None = None, ) -> Figure: """""" processed_data = self.data_loader.get_processed_data( exclude_noise, decimate, filter_amplitude_mode, filter_amplitude_values ) fig = mpl_plotting.plot_matplotlib( processed_data, amplitude_cmap_scaling, n_color_bins, point_size, add_histogram_plot, weight_histogram_by_amplitude, title=title, ) return fig