From 39f266267a3611daa34135ac382524321afaa5d8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 12 Jun 2026 13:22:46 +0200 Subject: [PATCH] fix: propagate self.tmp_data_to_save when selecting units in BasEMetricExtension --- .../core/analyzer_extension_core.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 1c261e3cad..3eeede22b5 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -1361,8 +1361,22 @@ def _select_extension_data(self, unit_ids: list[int | str]): dict Dictionary containing the selected metrics DataFrame. """ + import pandas as pd + + new_data = dict() new_metrics = self.data["metrics"].loc[np.array(unit_ids)] - return dict(metrics=new_metrics) + new_data["metrics"] = new_metrics + if self.tmp_data_to_save is not None: + for k in self.tmp_data_to_save: + old_data = self.data[k] + if isinstance(old_data, pd.DataFrame): + new_df = old_data.loc[np.array(unit_ids)] + new_data[k] = new_df + elif isinstance(old_data, np.ndarray): + old_arr = self.data[k] + new_arr = old_arr[self.sorting_analyzer.sorting.ids_to_indices(unit_ids), ...] + new_data[k] = new_arr + return new_data def _merge_extension_data( self,