Source code for driftplots.driftplotter

from __future__ import annotations

from pathlib import Path
from typing import Literal

from matplotlib.axes import Axes
from matplotlib.figure import Figure
from pyqtgraph.Qt import QtWidgets

from driftplots import mpl_plotting
from driftplots.data_loader import DataLoader
from driftplots.interactive.driftmap_plot_widget import DriftmapPlotWidget


[docs] class DriftPlotter: """Load Kilosort or SpikeInterface output and plot drift maps. On construction, spike data is loaded from a Kilosort output directory or a SpikeInterface ``SortingAnalyzer`` and stored as read-only arrays. Parameters ---------- path_or_analyzer Path to a Kilosort sorter output directory, or a SpikeInterface ``SortingAnalyzer`` object. When a path is given it must contain exactly one ``kilosort*.log`` file, which is used to detect the Kilosort version. """ def __init__(self, path_or_analyzer: str | Path, verbose: bool = True) -> None: """Load spike data from a Kilosort output directory or SortingAnalyzer. Parameters ---------- path_or_analyzer Path to a Kilosort sorter output directory, or a SpikeInterface ``SortingAnalyzer`` object. verbose : If `True`, messages are printed. """ self._data_loader = DataLoader(path_or_analyzer, verbose)
[docs] def drift_map_plot_interactive( self, decimate: int | bool | None | Literal["estimate"] = "estimate", good_units_only: bool | str = False, amplitude_cmap_scaling: str | tuple[float, float] = "linear", n_color_bins: int = 20, point_size: float = 5.0, filter_amplitude_mode: str | None = None, filter_amplitude_values: tuple[float, ...] = (), title: bool | str | None = None, verbose: bool = True, ) -> DriftmapPlotWidget: """Create an interactive pyqtgraph-based drift map widget. Parameters ---------- decimate Thin the spike dataset before plotting. Pass ``"estimate"`` to automatically reduce spikes to a reasonable count (≈ 100 000). Pass ``False``, ``None``, or ``0`` to disable decimation. Pass an integer *n* to keep every *n*-th spike. good_units_only If ``True``, only spikes belonging to "good" units are displayed. For Kilosort, this is taken from the cluster_groups.csv / cluster_group.tsv file that reflects labels set in Phy. For a SortingAnalyzer, a string must be passed. The labels are taken from the sorting property with the passed name (e.g. "KSLabel"). amplitude_cmap_scaling Controls how spike amplitudes are mapped to the greyscale colormap. Pass ``"linear"`` or ``"log2"`` or ``"log10"`` for automatic scaling, or a ``(min, max)`` tuple to set explicit bounds. When explicit bounds are set, the scaling is linear. n_color_bins Number of discrete greyscale bins used to colour spikes by amplitude. point_size Diameter of each scatter point in pixels. filter_amplitude_mode Controls how spikes are filtered based on amplitude before plotting. ``"percentile"`` treats the bounds as percentile ranks; ``"absolute"`` treats them as raw amplitude values. ``None`` disables amplitude filtering. filter_amplitude_values ``(low, high)`` bounds for amplitude filtering, used as set by ``filter_amplitude_mode``. Ignored when ``filter_amplitude_mode`` is ``None``. title Plot title. Pass a string to set a custom title, ``True`` to use a default title, or ``None`` / ``False`` to suppress the title entirely. verbose : If `True`, messages are printed. Returns ------- DriftmapPlotWidget The pyqtgraph widget. The widget is already populated but not yet shown; call ``app.exec()`` to display it. """ app = QtWidgets.QApplication.instance() or QtWidgets.QApplication([]) processed_data = self._data_loader.get_processed_data( good_units_only, decimate, filter_amplitude_mode, filter_amplitude_values, verbose, ) self.plot = DriftmapPlotWidget( processed_data, app, amplitude_cmap_scaling=amplitude_cmap_scaling, n_color_bins=n_color_bins, point_size=point_size, title=title, ) return self.plot
[docs] def drift_map_plot_matplotlib( self, decimate: int | bool | None | Literal["estimate"] = "estimate", good_units_only: bool | str = False, amplitude_cmap_scaling: str | tuple[float, float] = "linear", n_color_bins: int = 20, point_size: float = 5.0, filter_amplitude_mode: str | None = None, filter_amplitude_values: tuple[float, ...] = (), add_histogram_plot: bool = False, weight_histogram_by_amplitude: bool = False, title: bool | str | None = None, ax: Axes | None = None, verbose: bool = True, ) -> Figure: """Create a static Matplotlib drift map figure. Parameters ---------- decimate Thin the spike dataset before plotting. Pass ``"estimate"`` to automatically reduce spikes to a reasonable count (≈ 100 000). Pass ``False``, ``None``, or ``0`` to disable decimation. Pass an integer *n* to keep every *n*-th spike. good_units_only If ``True``, only spikes belonging to "good" units are displayed. For Kilosort, this is taken from the cluster_groups.csv / cluster_group.tsv file that reflects labels set in Phy. For a SortingAnalyzer, a string must be passed. The labels are taken from the sorting property with the passed name (e.g. "KSLabel"). amplitude_cmap_scaling Controls how spike amplitudes are mapped to the greyscale colormap. Pass ``"linear"`` or ``"log2"`` or ``"log10"`` for automatic scaling, or a ``(min, max)`` tuple to set explicit bounds. When explicit bounds are set, the scaling is linear. n_color_bins Number of discrete greyscale bins used to colour spikes by amplitude. point_size Diameter of each scatter point in pixels. filter_amplitude_mode Controls how spikes are filtered based on amplitude before plotting. ``"percentile"`` treats the bounds as percentile ranks; ``"absolute"`` treats them as raw amplitude values. ``None`` disables amplitude filtering. filter_amplitude_values ``(low, high)`` bounds for amplitude filtering, used as set by ``filter_amplitude_mode``. Ignored when ``filter_amplitude_mode`` is ``None``. add_histogram_plot If ``True``, add a side panel showing a depth histogram of spike activity. weight_histogram_by_amplitude If ``True``, weight the depth histogram by spike amplitude rather than counting spikes uniformly. Only used when ``add_histogram_plot`` is ``True``. title Plot title. Pass a string to set a custom title, ``True`` to use a default title, or ``None`` / ``False`` to suppress the title entirely. ax Existing Matplotlib axis to draw the drift map on. When ``add_histogram_plot`` is ``True``, a histogram axis is added beside this axis. verbose : If `True`, messages are printed. Returns ------- Figure The populated Matplotlib figure. """ processed_data = self._data_loader.get_processed_data( good_units_only, decimate, filter_amplitude_mode, filter_amplitude_values, verbose, ) fig = mpl_plotting.plot_matplotlib( processed_data, amplitude_cmap_scaling, n_color_bins, point_size, add_histogram_plot, weight_histogram_by_amplitude, title=title, ax=ax, ) return fig