diff --git a/src/spikeinterface/curation/curation_tools.py b/src/spikeinterface/curation/curation_tools.py index ff2c32d07f..281a623161 100644 --- a/src/spikeinterface/curation/curation_tools.py +++ b/src/spikeinterface/curation/curation_tools.py @@ -104,35 +104,35 @@ def _find_duplicated_spikes_random(spike_train: np.ndarray, censored_period: int @numba.jit(nopython=True, nogil=True, cache=False) def _find_duplicated_spikes_keep_first_iterative(spike_train, censored_period): - indices_of_duplicates = numba.typed.List() N = len(spike_train) + is_duplicate = np.zeros(N, dtype=np.bool_) for i in range(N - 1): - if i in indices_of_duplicates: + if is_duplicate[i]: continue for j in range(i + 1, N): if spike_train[j] - spike_train[i] > censored_period: break - indices_of_duplicates.append(j) + is_duplicate[j] = True - return np.asarray(indices_of_duplicates) + return np.nonzero(is_duplicate)[0] - @numba.jit(nopython=True, nogil=True, cache=True) + @numba.jit(nopython=True, nogil=True, cache=False) def _find_duplicated_spikes_keep_last_iterative(spike_train, censored_period): - indices_of_duplicates = numba.typed.List() N = len(spike_train) + is_duplicate = np.zeros(N, dtype=np.bool_) for i in range(N - 1, 0, -1): - if i in indices_of_duplicates: + if is_duplicate[i]: continue for j in range(i - 1, -1, -1): if spike_train[i] - spike_train[j] > censored_period: break - indices_of_duplicates.append(j) + is_duplicate[j] = True - return np.asarray(indices_of_duplicates) + return np.nonzero(is_duplicate)[0] def find_duplicated_spikes(