From 92ea617a5a34764ffbc5c51a6e7b4082c4007a21 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Wed, 10 Jun 2026 19:47:50 -0500 Subject: [PATCH] Vectorize _find_duplicated_spikes_keep_first/last_iterative kernels. Let `N` be the number of spikes in the train, and `D` be the number of spikes flagged as duplicates (i.e. the fall within the censored period of an earlier kept spike). The old kernels were using `if i in indices_of_duplicates` for each of `N` spikes to check if the spike was a duplicate, which is `O(N*D)`. But for heavily contaminated units, `D` can be quite large. On long (48h) recordings, this was blowing up, and taking 56min to compute `sd_ratio` for 232 units. One unit took nearly 5min by itself. This approach is constant-time lookup, so `O(N)` total. New time on the 48h dataset is 10.6s total to compute `sd_ratio` for all units. Tested equivalence with the old method on both the real data and synthetic bursty trains, plus the existing test still passes (note for the future: the existing test should probably check the exact indices). --- src/spikeinterface/curation/curation_tools.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/curation/curation_tools.py b/src/spikeinterface/curation/curation_tools.py index ff2c32d07f..281a623161 100644 --- a/src/spikeinterface/curation/curation_tools.py +++ b/src/spikeinterface/curation/curation_tools.py @@ -104,35 +104,35 @@ def _find_duplicated_spikes_random(spike_train: np.ndarray, censored_period: int @numba.jit(nopython=True, nogil=True, cache=False) def _find_duplicated_spikes_keep_first_iterative(spike_train, censored_period): - indices_of_duplicates = numba.typed.List() N = len(spike_train) + is_duplicate = np.zeros(N, dtype=np.bool_) for i in range(N - 1): - if i in indices_of_duplicates: + if is_duplicate[i]: continue for j in range(i + 1, N): if spike_train[j] - spike_train[i] > censored_period: break - indices_of_duplicates.append(j) + is_duplicate[j] = True - return np.asarray(indices_of_duplicates) + return np.nonzero(is_duplicate)[0] - @numba.jit(nopython=True, nogil=True, cache=True) + @numba.jit(nopython=True, nogil=True, cache=False) def _find_duplicated_spikes_keep_last_iterative(spike_train, censored_period): - indices_of_duplicates = numba.typed.List() N = len(spike_train) + is_duplicate = np.zeros(N, dtype=np.bool_) for i in range(N - 1, 0, -1): - if i in indices_of_duplicates: + if is_duplicate[i]: continue for j in range(i - 1, -1, -1): if spike_train[i] - spike_train[j] > censored_period: break - indices_of_duplicates.append(j) + is_duplicate[j] = True - return np.asarray(indices_of_duplicates) + return np.nonzero(is_duplicate)[0] def find_duplicated_spikes(