Source code for ieeg.decoding.decode

from sklearn import config_context
try:
    import cupy as cp
except ImportError:
    cp = None

from ieeg.decoding.models import PcaLdaClassification
from ieeg.arrays.label import LabeledArray
from ieeg.calc.oversample import MinimumNaNSplit
from ieeg.arrays.api import array_namespace, Array, is_torch, is_numpy
from ieeg.arrays.reshape import sliding_window_view
from ieeg.calc.fast import mixup
import numpy as np
import matplotlib.pyplot as plt
from ieeg.viz.ensemble import plot_dist
from joblib import Parallel, delayed
import itertools
from tqdm import tqdm


[docs] class Decoder(MinimumNaNSplit): def __init__(self, categories: dict, n_splits: int = 5, n_repeats: int = 1, min_samples: int = 1, which: str = 'test', **kwargs): """Initialize the Decoder. Parameters ---------- categories : dict Dictionary mapping category names to category indices. n_splits : int, optional Number of splits for cross-validation, by default 5. n_repeats : int, optional Number of repetitions for cross-validation, by default 1. min_samples : int, optional Minimum number of samples required for each category, by default 1. which : str, optional Which set to use for validation ('test' or 'train'), by default 'test'. **kwargs Additional keyword arguments passed to the PcaLdaClassification model. """ # self.model = PcaLdaClassification(**kwargs) self.kwargs = kwargs MinimumNaNSplit.__init__(self, n_splits, n_repeats, None, min_samples, which) self.categories = categories self.current_job = "Repetitions"
[docs] def cv_cm(self, x_data: Array, labels: Array, normalize: str = None, obs_axs: int = -2, n_jobs: int = 1, average_repetitions: bool = True, window: int = None, shuffle: bool = False, oversample: bool = True, step: int = 1 ) -> Array: """Cross-validated confusion matrix Parameters ---------- x_data : np.ndarray The data to be decoded labels : np.ndarray The labels for the data normalize : str, optional How to normalize the confusion matrix, by default None obs_axs : int, optional The axis containing the observations, by default -2 n_jobs : int, optional The number of jobs to run in parallel, by default 1 average_repetitions : bool, optional Whether to average the repetitions, by default True window : int, optional The window size for time sliding, by default None shuffle : bool, optional Whether to shuffle the labels, by default False oversample : bool, optional Whether to oversample the training data, by default True step : int, optional The step size for time sliding, by default 1 Returns ------- np.ndarray The confusion matrix Examples -------- >>> np.random.seed(42) >>> decoder = Decoder({'heat': 1, 'hoot': 2, 'hot': 3, 'hut': 4}, ... 5, 10, explained_variance=0.8, da_type='lda') >>> X = np.random.randn(100, 50, 100) >>> labels = np.random.randint(1, 5, 50) >>> decoder.cv_cm(X, labels, normalize='true') array([[0.11111111, 0. , 0.03333333, 0.85555556], [0.1 , 0. , 0.04 , 0.86 ], [0.10666667, 0. , 0.04 , 0.85333333], [0.10625 , 0. , 0.03125 , 0.8625 ]]) >>> decoder = Decoder({'heat': 1, 'hoot': 2, 'hot': 3, 'hut': 4}, ... 5, 10, explained_variance=0.8, da_type='lda') >>> decoder.cv_cm(X, labels, normalize='true', window=20, step=5)[0] array([[0.04444444, 0. , 0.36666667, 0.58888889], [0.02 , 0.01 , 0.41 , 0.56 ], [0.03333333, 0. , 0.50666667, 0.46 ], [0.03125 , 0. , 0.5 , 0.46875 ]]) >>> decoder.cv_cm(X, labels, normalize='true', window=20, step=5, ... shuffle=True, oversample=True)[0] array([[0. , 0.12222222, 0.52222222, 0.35555556], [0.01 , 0.12 , 0.5 , 0.37 ], [0.00666667, 0.10666667, 0.50666667, 0.38 ], [0. , 0.09375 , 0.525 , 0.38125 ]]) >>> import cupy as cp # doctest: +SKIP >>> X = cp.random.randn(100, 100, 50, 100) # doctest: +SKIP >>> X[0, 0, 0, :] = np.nan # doctest: +SKIP >>> labels = cp.random.randint(1, 5, 50) # doctest: +SKIP >>> with config_context(array_api_dispatch=True): # doctest: +SKIP ... decoder.cv_cm(X, labels, normalize='true') # doctest: +SKIP array([[0. , 0.36666667, 0.63333333, 0. ], [0. , 0.32777778, 0.67222222, 0. ], [0. , 0.33157895, 0.66842105, 0. ], [0. , 0.35714286, 0.64285714, 0. ]]) """ assert all(lab in self.categories.values() for lab in labels), \ "Labels must be in the categories" xp = array_namespace(x_data) n_cats = len(self.categories) out_shape = (self.n_repeats, self.n_splits, n_cats, n_cats) if window is not None: out_shape = ((x_data.shape[-1] - window) // step + 1,) + out_shape mats = xp.zeros(out_shape, dtype=xp.int16) data = x_data.swapaxes(0, obs_axs) if shuffle: isnan = xp.isnan(data) std = float(xp.nanstd(data, dtype='f8')) data[isnan] = xp.random.normal(0, 3 * std, int(xp.sum(isnan, dtype='i8'))) # shuffled label pool label_stack = [labels.copy() for _ in range(self.n_repeats)] for i in range(self.n_repeats): self.shuffle_labels(data, label_stack[i], 0) # build the test/train indices from the shuffled labels for each # repetition, then chain together the repetitions # splits = (train, test) idxs = ((self.split(data, lab), lab) for lab in label_stack) idxs = ((itertools.islice(s, self.n_splits), itertools.repeat(l, self.n_splits)) for s, l in idxs) splits, label = zip(*idxs) splits = itertools.chain.from_iterable(splits) label = itertools.chain.from_iterable(label) idxs = zip(splits, label) else: idxs = ((splits, labels) for splits in self.split(data, labels)) # loop over folds and repetitions if n_jobs == 1: results = (_proc(train_idx, test_idx, l, data, i, self.n_splits, self.categories, window, step, oversample, self.kwargs) for i, ((train_idx, test_idx), l) in enumerate(idxs)) else: results = Parallel(n_jobs=n_jobs, verbose=0, require='sharedmem', return_as="generator_unordered")( delayed(_proc)(train_idx, test_idx, l, data, i, self.n_splits, self.categories, window, step, oversample, self.kwargs) for i, ((train_idx, test_idx), l) in enumerate(idxs)) # Collect the results t = tqdm(desc=self.current_job, total=self.n_splits * self.n_repeats) for result, rep, fold in results: mats[..., rep, fold, :, :] = result t.update() t.close() # average the repetitions if average_repetitions: mats = xp.mean(mats, axis=1) # normalize, sum the folds mats = xp.sum(mats, axis=-3) if normalize == 'true': divisor = xp.sum(mats, axis=-1, keepdims=True) elif normalize == 'pred': divisor = xp.sum(mats, axis=-2, keepdims=True) elif normalize == 'all': divisor = self.n_repeats else: divisor = 1 return mats / divisor
def _proc(train_idx, test_idx, lab, orig_data, pid, n_splits, cats, window, step, oversample, model_kwargs): """Process a single fold of data for cross-validation. Parameters ---------- train_idx : Array Indices of the training data. test_idx : Array Indices of the test data. lab : Array Labels for the data. orig_data : Array The original data to be processed. pid : int Process ID, used to determine repetition and fold. n_splits : int Number of splits for cross-validation. cats : dict Dictionary mapping category names to category indices. window : int or None Window size for time sliding. If None, no windowing is applied. step : int Step size for time sliding. oversample : bool Whether to oversample the training data. model_kwargs : dict Keyword arguments for the PcaLdaClassification model. Returns ------- tuple Confusion matrix, repetition index, and fold index. """ xp = array_namespace(orig_data) label_cats = xp.asarray(list(cats.values())) x_stacked, y_train, y_test = sample_fold(train_idx, test_idx, orig_data, lab, label_cats, 0, oversample, xp) model = PcaLdaClassification(**model_kwargs) def _fit_predict(x_flat): """Fit model on training data and predict on test data. Parameters ---------- x_flat : Array Flattened input data containing both training and test data. Returns ------- Array Confusion matrix of predictions. """ x_train, x_test = (x_flat[:train_idx.shape[0]], x_flat[train_idx.shape[0]:]) # fit model and score results model.fit(x_train, y_train) pred = model.predict(x_test) return confusion_matrix(y_test, pred, label_cats, namespace=xp) rep, fold = divmod(pid, n_splits) if window is None: x_flattened = x_stacked.reshape(x_stacked.shape[0], -1) return _fit_predict(x_flattened), rep, fold windowed = sliding_window_view(x_stacked, window, axis=-1, subok=True)[ ..., ::step, :] swapped = xp.moveaxis(windowed.swapaxes(-1, -2).reshape( windowed.shape[0], -1, windowed.shape[-2]), -1, 0) if is_numpy(xp): func = np.vectorize(_fit_predict, signature='(a,b) -> (d,d)', otypes=[xp.uint8]) out = func(swapped) else: out = xp.zeros((windowed.shape[-2], label_cats.shape[0], label_cats.shape[0]), dtype=xp.uint8) for i in range(windowed.shape[-2]): x_window = windowed[..., i, :] out[i] = _fit_predict(x_window.reshape(x_window.shape[0], -1)) return out, rep, fold
[docs] def confusion_matrix( y_true, y_pred, labels=None, namespace=None ): """Compute confusion matrix to evaluate the accuracy of a classification. By definition a confusion matrix :math:`C` is such that :math:`C_{i, j}` is equal to the number of observations known to be in group :math:`i` and predicted to be in group :math:`j`. Thus in binary classification, the count of true negatives is :math:`C_{0,0}`, false negatives is :math:`C_{1,0}`, true positives is :math:`C_{1,1}` and false positives is :math:`C_{0,1}`. Read more in the :ref:`User Guide <confusion_matrix>`. Parameters ---------- y_true : array-like of shape (n_samples,) Ground truth (correct) target values. y_pred : array-like of shape (n_samples,) Estimated targets as returned by a classifier. labels : array-like of shape (n_classes), default=None List of labels to index the matrix. This may be used to reorder or select a subset of labels. If ``None`` is given, those that appear at least once in ``y_true`` or ``y_pred`` are used in sorted order. .. versionadded:: 0.18 Returns ------- C : ndarray of shape (n_classes, n_classes) Confusion matrix whose i-th row and j-th column entry indicates the number of samples with true label being i-th class and predicted label being j-th class. References ---------- .. [1] `Wikipedia entry for the Confusion matrix <https://en.wikipedia.org/wiki/Confusion_matrix>`_ (Wikipedia and other references may use a different convention for axes). Examples -------- >>> y_true = [2, 0, 2, 2, 0, 1] >>> y_pred = [0, 0, 2, 2, 0, 2] >>> confusion_matrix(y_true, y_pred) array([[2, 0, 0], [0, 0, 1], [1, 0, 2]], dtype=int32) >>> y_true = ["cat", "ant", "cat", "cat", "ant", "bird"] >>> y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"] >>> confusion_matrix(y_true, y_pred, labels=["ant", "bird", "cat"]) array([[2, 0, 0], [0, 0, 1], [1, 0, 2]], dtype=int32) In the binary case, we can extract true positives, etc. as follows: >>> tn, fp, fn, tp = confusion_matrix([0, 1, 0, 1], [1, 1, 1, 0]).ravel() >>> (tn, fp, fn, tp) (0, 2, 1, 1) >>> confusion_matrix(y_true, y_pred) array([[2, 0, 0], [0, 0, 1], [1, 0, 2]], dtype=int32) """ if namespace is not None: xp = namespace elif isinstance(y_true, list) or isinstance(y_pred, list): y_true = np.array(y_true) y_pred = np.array(y_pred) xp = np else: xp = array_namespace(y_true, y_pred) if labels is None: labels, y_true_indices = xp.unique(y_true, return_inverse=True) else: labels = xp.array(labels) y_true_indices = xp.searchsorted(labels, y_true) y_pred_indices = xp.searchsorted(labels, y_pred) n_labels = labels.shape[0] cm = xp.zeros((n_labels, n_labels), dtype=xp.int32) xp.add.at(cm, (y_true_indices, y_pred_indices), 1) return cm
[docs] def nan_common_denom(array: LabeledArray, sort: bool = True, trials_ax: int = 1, min_trials: int = 0, ch_ax: int = 0, crop_trials: bool = True, verbose: bool = False) -> LabeledArray: """Remove trials with NaNs from all channels. This function processes a LabeledArray to remove trials containing NaN values, with options for sorting, specifying axes, and setting minimum trial counts. Parameters ---------- array : LabeledArray The input array to process. sort : bool, optional Whether to sort trials by NaN presence, by default True. trials_ax : int, optional The axis containing trials, by default 1. min_trials : int, optional Minimum number of trials to keep, by default 0. ch_ax : int, optional The axis containing channels, by default 0. crop_trials : bool, optional Whether to crop trials to the minimum number, by default True. verbose : bool, optional Whether to print verbose output, by default False. Returns ------- LabeledArray The processed array with NaN trials removed. Examples -------- >>> import numpy as np >>> from ieeg.arrays.label import LabeledArray >>> data = np.array([[1, 2, np.nan], [4, 5, 6], [7, np.nan, 9]]) >>> labels = [['ch1', 'ch2', 'ch3'], ['trial1', 'trial2', 'trial3']] >>> array = LabeledArray(data, labels) >>> processed_array = nan_common_denom(array, sort=True, trials_ax=1, ... ch_ax=0, min_trials=3, crop_trials=True, verbose=True) Lowest trials 2 at ch1 Channels excluded (too few trials): ['ch1', 'ch3'] """ others = [i for i in range(array.ndim) if ch_ax != i != trials_ax] isn = np.isnan(array.__array__()) nan_trials = np.any(isn, axis=tuple(others)) # Sort the trials by whether they are nan or not if sort: order = np.argsort(nan_trials, axis=1) old_shape = list(order.shape) new_shape = [1 if ch_ax != i != trials_ax else old_shape.pop(0) for i in range(array.ndim)] order = np.reshape(order, new_shape) data = np.take_along_axis(array.__array__(), order, axis=trials_ax) data = LabeledArray(data, array.labels.copy()) else: data = array ch_tnum = array.shape[trials_ax] - np.sum(nan_trials, axis=1) ch_min = ch_tnum.min() if verbose: print(f"Lowest trials {ch_min} at " f"{array.labels[ch_ax][ch_tnum.argmin()]}") ntrials = max(ch_min, min_trials) if ch_min < min_trials: # data = data.take(np.where(ch_tnum >= ntrials)[0], ch_idx) ch = np.array(array.labels[ch_ax])[ch_tnum < ntrials].tolist() if verbose: print(f"Channels excluded (too few trials): {ch}") # data = data.take(np.arange(ntrials), trials_idx) idx = [np.arange(ntrials) if i == trials_ax and crop_trials else np.arange(s) for i, s in enumerate(array.shape)] idx[ch_ax] = np.where([ch_tnum >= ntrials])[1] return data[np.ix_(*idx)]
[docs] def sample_fold(train_idx: Array, test_idx: Array, x_data: Array, labels: Array, unique: Array, axis: int, oversample: bool, xp) -> tuple[Array, Array, Array]: """Sample a fold of data for cross-validation. Parameters ---------- train_idx : Array Indices of the training data. test_idx : Array Indices of the test data. x_data : Array The data to be sampled. labels : Array Labels corresponding to the data. unique : Array Unique labels to be used for oversampling. axis : int Axis along which to stack the data. oversample : bool Whether to oversample the training data. xp : module The array namespace (numpy or cupy). Returns ------- tuple[Array, Array, Array] Stacked data, training labels, and test labels. """ # make first and only copy of x_data idx_stacked = xp.concatenate((train_idx, test_idx)) idx = tuple(slice(None) if i != axis else idx_stacked for i in range(x_data.ndim)) x_stacked = x_data[idx] y_stacked = labels[idx_stacked] # define train and test as views of x_stacked sep = train_idx.shape[0] y_train, y_test = y_stacked[:sep], y_stacked[sep:] if not oversample: return x_stacked, y_train, y_test idx1 = tuple(slice(None) if i != axis else slice(None, sep) for i in range(x_data.ndim)) idx2 = tuple(slice(None) if i != axis else slice(sep, None) for i in range(x_data.ndim)) x_train, x_test = x_stacked[idx1], x_stacked[idx2] # mixup2(x_train, labels[:sep], axis) idx = [slice(None) for _ in range(x_data.ndim)] for i in unique: # fill in train data nans with random combinations of # existing train data trials (mixup) isin = y_train == i idx[axis] = isin out = x_train[tuple(idx)] if out.size != 0: mixup(out, axis) if is_torch(xp): idx3 = tuple(None if j != axis else slice(None) for j in range(x_data.ndim)) x_train.masked_scatter_(isin[idx3], out) else: x_train[tuple(idx)] = out # fill in test data nans with noise from distribution is_nan = xp.isnan(x_test) if is_torch(xp): normal = xp.distributions.normal.Normal(0, 1, dtype=x_data.dtype) x_test.masked_scatter_(is_nan, normal.sample((xp.sum(is_nan),))) else: x_test[is_nan] = xp.random.normal(0, 1, int(xp.sum(is_nan))) return x_stacked, y_train, y_test
[docs] def flatten_features(arr: np.ndarray, obs_axs: int = -2) -> np.ndarray: """Flatten features in an array. This function swaps the first axis with the observation axis and reshapes the array to flatten all dimensions except the first one. Parameters ---------- arr : np.ndarray The input array to flatten. obs_axs : int, optional The axis containing observations, by default -2. Returns ------- np.ndarray The flattened array with shape (n_observations, n_features). Examples -------- >>> import numpy as np >>> np.random.seed(0) >>> arr = np.random.rand(4, 3, 2) >>> flatten_features(arr, obs_axs=-2) array([[0.5488135 , 0.71518937, 0.43758721, 0.891773 , 0.56804456, 0.92559664, 0.77815675, 0.87001215], [0.60276338, 0.54488318, 0.96366276, 0.38344152, 0.07103606, 0.0871293 , 0.97861834, 0.79915856], [0.4236548 , 0.64589411, 0.79172504, 0.52889492, 0.0202184 , 0.83261985, 0.46147936, 0.78052918]]) """ out = arr.swapaxes(0, obs_axs) return out.reshape(out.shape[0], -1)
[docs] def classes_from_labels(labels: np.ndarray, delim: str = '-', which: int = 0, crop: slice = slice(None), cats: dict = None ) -> tuple[dict, np.ndarray]: """Extract class IDs from string labels. This function processes string labels to extract class IDs using a delimiter, and returns a dictionary mapping class names to indices and an array of class indices. Parameters ---------- labels : np.ndarray Array of string labels to process. delim : str, optional Delimiter to split the labels, by default '-'. which : int, optional Which part of the split label to use, by default 0. crop : slice, optional Slice to apply to each label part, by default slice(None). cats : dict, optional Existing category mapping to use. If None, a new mapping is created. Returns ------- tuple[dict, np.ndarray] A tuple containing: - Dictionary mapping class names to indices - Array of class indices corresponding to the input labels Examples -------- >>> labels = np.array(['cat-dog', 'dog-cat', 'cat-bird']) >>> classes_from_labels(labels, delim='-') ({'cat': 0, 'dog': 1}, array([0, 1, 0])) """ class_ids = np.array([k.split(delim, )[which][crop] for k in labels]) if cats is None: classes = {k: i for i, k in enumerate(np.unique(class_ids))} return classes, np.array([classes[k] for k in class_ids]) else: return cats, np.array([cats[k] for k in class_ids])
[docs] def flatten_list(nested_list: list[list[str] | str]) -> list[str]: """Flatten a nested list of strings. This function takes a list that may contain both strings and lists of strings, and returns a single flat list containing all the strings. Parameters ---------- nested_list : list[list[str] | str] A list containing strings and/or lists of strings. Returns ------- list[str] A flattened list containing all strings from the input. Examples -------- >>> flatten_list(['a', ['b', 'c'], 'd']) ['a', 'b', 'c', 'd'] """ flat_list = [] for element in nested_list: if isinstance(element, list): flat_list.extend(element) else: flat_list.append(element) return flat_list
[docs] def plot_all_scores(all_scores: dict[str, np.ndarray], conds: list[str], idxs: dict[str, list[int]], colors: list[list[float]], suptitle: str = None, fig: plt.Figure = None, axs: plt.Axes = None, ylims: tuple[float, float] = (0.1, 0.8), **plot_kwargs ) -> tuple[plt.Figure, plt.Axes]: """Plot scores for different conditions and categories. This function creates plots of scores for different experimental conditions and categories, setting up appropriate axes and labels. Parameters ---------- all_scores : dict[str, np.ndarray] Dictionary mapping score names to score arrays. conds : list[str] List of condition names to plot. idxs : dict[str, list[int]] Dictionary mapping category names to indices. colors : list[list[float]] List of colors for each category. suptitle : str, optional Super title for the figure, by default None. fig : plt.Figure, optional Existing figure to plot on, by default None. axs : plt.Axes, optional Existing axes to plot on, by default None. ylims : tuple[float, float], optional Y-axis limits, by default (0.1, 0.8). **plot_kwargs Additional keyword arguments passed to plot_dist. Returns ------- tuple[plt.Figure, plt.Axes] The figure and axes objects containing the plots. """ names = list(idxs.keys()) if fig is None and axs is None: fig, axs = plt.subplots(1, len(conds)) elif axs is None: axs = fig.get_axes() if len(conds) == 1: axs = [axs] for color, name, idx in zip(colors, names, idxs.values()): for cond, ax in zip(conds, axs): if isinstance(cond, list): cond = "-".join(cond) ax.set_title(cond) if cond == 'resp': times = (-0.9, 0.9) ax.set_xlabel("Time from response (s)") else: times = (-0.4, 1.4) if 'aud' in cond: ax.set_xlabel("Time from stim (s)") elif 'go' in cond: ax.set_xlabel("Time from go (s)") else: raise ValueError("Condition not recognized") pl_sc = np.reshape(all_scores["-".join([name, cond])], (all_scores["-".join([name, cond])].shape[0], -1)).T plot_dist(pl_sc, mode='std', times=times, color=color, label=name, ax=ax, **plot_kwargs) if name is names[-1]: ax.legend() ax.set_title(cond) ax.set_ylim(*ylims) axs[0].set_ylabel("Accuracy (%)") if suptitle is not None: fig.suptitle(suptitle) return fig, axs