Source code for driftplots.amplitudes
from __future__ import annotations
from pathlib import Path
import numpy as np
import spikeinterface as si
from driftplots.data_loader import DataLoader
[docs]
def get_amplitudes(
list_of_path_or_analyzer: list[Path | si.SortingAnalyzer],
good_units_only: bool | str = False,
concatenate: bool = False,
verbose: bool = True,
) -> np.ndarray | list[np.ndarray]:
"""Return spike amplitudes from one or more sorter outputs.
Parameters
----------
list_of_path_or_analyzer
Kilosort output directory paths or SpikeInterface ``SortingAnalyzer``
objects to load amplitudes from. Can be a mix of both.
good_units_only
If ``True``, only spikes belonging to "good" units are displayed.
For Kilosort, this is taken from the
cluster_groups.csv / cluster_group.tsv file that reflects labels
set in Phy. For a SortingAnalyzer, a string must be passed.
The labels are taken from the sorting property with the
passed name (e.g. "KSLabel").
concatenate
If ``True``, concatenate all per-session amplitude arrays into a
single 1-D array. If ``False`` (default), return a list with one
array per session.
verbose :
If `True`, messages are printed.
Returns
-------
np.ndarray or list of np.ndarray
When ``concatenate=True``, a single 1-D array of all spike
amplitudes. When ``concatenate=False``, a list of 1-D arrays,
one per entry in ``list_of_path_or_analyzer``.
"""
all_spike_amplitudes = []
for path_or_analyzer in list_of_path_or_analyzer:
loader = DataLoader(path_or_analyzer, verbose)
processed_data = loader.get_processed_data(
good_units_only,
decimate=False,
filter_amplitude_mode=None,
filter_amplitude_values=None,
verbose=verbose,
)
all_spike_amplitudes.append(processed_data.spike_amplitudes)
if concatenate:
all_spike_amplitudes = np.concatenate(all_spike_amplitudes)
return all_spike_amplitudes