Source code for ieeg.calc.oversample

from typing import Literal, Tuple

import numpy as np
from numpy.typing import NDArray
from sklearn.model_selection import RepeatedStratifiedKFold

import itertools
from functools import partial
from ieeg.calc.fast import mixup, norm
from ieeg.arrays.api import array_namespace, is_numpy, intersect1d, setdiff1d
from decimal import Decimal

Array2D = NDArray[Tuple[Literal[2], ...]]
Vector = NDArray[Literal[1]]


[docs] class MinimumNaNSplit(RepeatedStratifiedKFold): """A Repeated Stratified KFold iterator that splits the data into sections This class splits the data into sections, checking that the training set never has fewer than the specified number of non-NaN values. Parameters ---------- n_splits : int The number of splits. n_repeats : int, optional The number of times to repeat the splits, by default 10. random_state : int, optional The random state to use, by default None. Examples -------- >>> import numpy as np >>> np.random.seed(0) >>> X = np.vstack((np.arange(1, 9).reshape(4, 2), np.full((4, 2), np.nan))) >>> y = np.array([0, 0, 1, 1, 0, 0, 1, 1]) >>> msn = MinimumNaNSplit(2, 3) >>> for train, test in msn.split(X, y): ... print("train:", train, "test:", test) train: [2 3 4 5] test: [0 1 6 7] train: [0 1 6 7] test: [2 3 4 5] train: [2 3 4 5] test: [0 1 6 7] train: [0 1 6 7] test: [2 3 4 5] train: [2 3 4 5] test: [0 1 6 7] train: [0 1 6 7] test: [2 3 4 5] >>> msn = MinimumNaNSplit(2, 3, which='test', min_non_nan=1) >>> for train, test in msn.split(X, y): ... print("train:", train, "test:", test) train: [1 3 4 7] test: [0 2 5 6] train: [0 2 5 6] test: [1 3 4 7] train: [0 3 5 7] test: [1 2 4 6] train: [1 2 4 6] test: [0 3 5 7] train: [1 2 5 6] test: [0 3 4 7] train: [0 3 4 7] test: [1 2 5 6] """ def __init__(self, n_splits: int, n_repeats: int = 10, random_state: int = None, min_non_nan: int = 2, which: str = 'train'): super(MinimumNaNSplit, self).__init__( n_splits=n_splits, n_repeats=n_repeats, random_state=random_state) self.n_splits = n_splits self.min_non_nan = min_non_nan if which not in ('train', 'test'): raise ValueError("which must be either 'train' or 'test'") self.which = which
[docs] def split(self, X, y=None, groups=None): xp = array_namespace(*[a for a in (X, y, groups) if a is not None]) # find where the nans are where = xp.isnan(X).any(axis=tuple(range(X.ndim))[1:]) not_where = xp.nonzero(~where)[0] where = xp.nonzero(where)[0] splits = self._splits(X, y, groups, xp) # if there are no nans, then just split the data if len(where) == 0: yield from splits return elif (n_non_nan := not_where.shape[0]) < (n_min := self.min_non_nan + 1): raise ValueError(f"Need at least {n_min} non-nan values, but only" f" have {n_non_nan}") check = {'train': lambda t: setdiff1d(not_where, t, xp=xp, assume_unique=True), 'test': lambda t: intersect1d(not_where, t, xp=xp, assume_unique=True)} # check that all training sets for each kfold within each repetition # have at least min_non_nan non-nan values idxs = [xp.nonzero(i == y)[0] for i in xp.unique(y)] kfold_set = [None] * self.n_splits while element := next(splits, False): for i in range(self.n_splits): if i == 0: train, test = element else: train, test = next(splits) # if any test set has more non-nan values than the total number # of non-nan values minus the minimum number of non-nan values, # then throw out the split and append an extra repetition if all(intersect1d(check[self.which](test), i, xp=xp, assume_unique=True).shape[0] < self.min_non_nan for i in idxs): for _ in range(i + 1, self.n_splits): next(splits) extra = self._splits(X, y, groups, xp) one_rep = itertools.islice(extra, self.n_splits) splits = itertools.chain(one_rep, splits) break kfold_set[i] = (train, test) else: yield from kfold_set
def _splits(self, X, y, groups, xp): splits = super(MinimumNaNSplit, self).split(X, y, groups) if not is_numpy(xp): splits = ((xp.asarray(train), xp.asarray(test)) for train, test in splits) return splits
[docs] @staticmethod def oversample(arr: np.ndarray, func: callable = mixup, axis: int = 1, copy: bool = True, seed=None) -> np.ndarray: """Oversample nan rows using func Parameters ---------- arr : array The data to oversample. func : callable The function to use to oversample the data. axis : int The axis along which to apply func. copy : bool Whether to copy the data before oversampling. Examples -------- >>> np.random.seed(0) >>> arr = np.array([[1, 2], [4, 5], [7, 8], ... [float("nan"), float("nan")]]) >>> MinimumNaNSplit.oversample(arr, norm, 0) array([[1. , 2. ], [4. , 5. ], [7. , 8. ], [8.32102813, 5.98018098]]) >>> MinimumNaNSplit.oversample(arr, mixup, 0, seed=42) # doctest: +SKIP array([[1. , 2. ], [4. , 5. ], [7. , 8. ], [5.24946679, 6.24946679]]) """ if copy: arr = arr.copy() axis = arr.ndim + axis if axis < 0 else axis if arr.ndim <= 0: raise ValueError("Cannot apply func to a 0-dimensional array") elif seed is not None: func(arr, axis, seed=seed) else: func(arr, axis) return arr
[docs] def shuffle_labels(self, arr: np.ndarray, labels: np.ndarray, trials_ax: int = 0, min_trials: int = 1): """Shuffle the labels while making sure that the minimum non nan trials are kept Parameters ---------- arr : array The data to shuffle. labels : array The labels to shuffle. trials_ax : int The axis along which to apply func. min_trials : int The minimum number of non-nan trials to keep. By default, self.n_splits Examples -------- >>> np.random.seed(0) >>> arr = np.array([[[1, 2], [4, 5], [7, 8], ... [float("nan"), float("nan")]]]) >>> labels = np.array([0, 0, 1, 1]) >>> MinimumNaNSplit(1).shuffle_labels(arr, labels, 1, 1) >>> labels array([1, 1, 0, 0]) """ xp = array_namespace(arr) cats = xp.unique(labels) gt_labels = [0] * cats.shape[0] min_trials *= self.n_splits i = 0 while not all(g >= min_trials for g in gt_labels): xp.random.shuffle(labels) for j, l in enumerate(cats): eval_arr = xp.take(arr, xp.flatnonzero(labels == l), trials_ax) gt_labels[j] = xp.min(xp.sum( xp.all(~xp.isnan(eval_arr), axis=2), axis=trials_ax)) if sum(gt_labels) < min_trials * cats.shape[0]: raise ValueError("Not enough non-nan trials to shuffle") i += 1 if i > 100000: raise ValueError("Could not find a shuffle that satisfies the" " minimum number of non-nan trials " f"{gt_labels}")
[docs] def oversample_nan(arr: np.ndarray, func: callable, axis: int = 1, copy: bool = True, seed: int = None) -> np.ndarray: """Oversample nan rows using func Parameters ---------- arr : array The data to oversample. func : callable The function to use to oversample the data. axis : int The axis along which to apply func. copy : bool Whether to copy the data before oversampling. Examples -------- >>> np.random.seed(0) >>> arr = np.array([[1, 2], [4, 5], [7, 8], ... [float("nan"), float("nan")]]) >>> oversample_nan(arr, norm, 0) array([[1. , 2. ], [4. , 5. ], [7. , 8. ], [8.32102813, 5.98018098]]) >>> oversample_nan(arr, mixup, 0, seed=42) # doctest: +SKIP array([[1. , 2. ], [4. , 5. ], [7. , 8. ], [5.24946679, 6.24946679]]) >>> arr3 = np.arange(24, dtype=float).reshape(2, 3, 4) >>> arr3[0, 2, :] = [float("nan")] * 4 >>> oversample_nan(arr3, mixup, 1, seed=42) # doctest: +SKIP array([[[ 0. , 1. , 2. , 3. ], [ 4. , 5. , 6. , 7. ], [ 2.33404428, 3.33404428, 4.33404428, 5.33404428]], <BLANKLINE> [[12. , 13. , 14. , 15. ], [16. , 17. , 18. , 19. ], [20. , 21. , 22. , 23. ]]]) >>> oversample_nan(arr3, norm, 1) array([[[ 0. , 1. , 2. , 3. ], [ 4. , 5. , 6. , 7. ], [ 3.95747597, 7.4817864 , 7.73511598, 3.04544424]], <BLANKLINE> [[12. , 13. , 14. , 15. ], [16. , 17. , 18. , 19. ], [20. , 21. , 22. , 23. ]]]) """ if copy: arr = arr.copy() axis = arr.ndim + axis if axis < 0 else axis if arr.ndim <= 0: raise ValueError("Cannot apply func to a 0-dimensional array") elif func is mixup and seed is not None: func(arr, axis, seed=seed) else: func(arr, axis) return arr
[docs] def find_nan_indices(arr: np.ndarray, obs_axis: int) -> tuple: """Find the indices of rows with and without NaN values Parameters ---------- arr : array The data to find indices. obs_axis : int The axis along which to apply func. Returns ------- tuple A tuple of two arrays containing the indices of rows with and without NaN values. Examples -------- >>> arr = np.array([[1, 2], [4, 5], [7, 8], ... [float("nan"), float("nan")]]) >>> find_nan_indices(arr, 0) # doctest: +ELLIPSIS (array([3]... array([0, 1, 2]...) """ # Initialize boolean mask of rows with NaN values not_obs = tuple(i for i in range(arr.ndim) if i != obs_axis) # Check each row individually nan = np.any(np.isnan(arr), axis=not_obs) # Get indices of rows with and without NaN values using boolean indexing nan_rows = np.flatnonzero(nan) non_nan_rows = np.flatnonzero(~nan) return nan_rows, non_nan_rows
[docs] def sortbased_rand(n_range: int, iterations: int, n_picks: int = -1): """Generate random numbers using sort-based sampling, resulting in a random choice generation without replacement along the first axis. Parameters ---------- n_range : int The range of numbers to sample from. iterations : int The number of iterations to run. n_picks : int The number of numbers to pick from the range. If -1, then the number of picks is equal to the range. Returns ------- array An array of shape (iterations, n_picks) containing the random numbers. References ---------- [1] `stackoverflow link <https://stackoverflow.com/questions/31955660/effic iently-generating-multiple-instances-of-numpy-random-choice-without-replace /31958263#31958263>`_ Examples -------- >>> np.random.seed(0) >>> sortbased_rand(10, 5, 3) array([[9, 4, 6], [6, 4, 5], [4, 6, 9], [4, 0, 2], [3, 7, 6]]) """ return np.argsort(np.random.rand(iterations, n_range), axis=1 )[:, :n_picks]
[docs] def mixup2(arr: np.ndarray, labels: np.ndarray, obs_axis: int, alpha: float = 1., seed: int = None, _isnan=None) -> None: """Mixup the data using the labels Parameters ---------- arr : array The data to mixup. labels : array The labels to use for mixing. obs_axis : int The axis along which to apply func. alpha : float The alpha value for the beta distribution. seed : int The seed for the random number generator. Examples -------- >>> np.random.seed(0) >>> arr = np.array([[1, 2], [4, 5], [7, 8], ... [float("nan"), float("nan")]]) >>> labels = np.array([0, 0, 1, 1]) >>> mixup2(arr, labels, 0) >>> arr array([[1. , 2. ], [4. , 5. ], [7. , 8. ], [6.03943491, 7.03943491]]) """ if _isnan is None: isnan = np.isnan(arr).any(-1) else: isnan = _isnan if arr.ndim > 2: arr = arr.swapaxes(obs_axis, -2) isnan = isnan.swapaxes(obs_axis, -1) for i in range(arr.shape[0]): if isnan[i].any(): mixup2(arr[i], labels, obs_axis, alpha, seed, isnan[i]) else: if seed is not None: np.random.seed(seed) if obs_axis == 1: arr = arr.T n_nan = np.where(isnan)[0] n_non_nan = np.where(~isnan)[0] for i in n_nan: l_class = labels[i] possible_choices = np.nonzero(np.logical_and( ~isnan, labels == l_class))[0] choice1 = np.random.choice(possible_choices) choice2 = np.random.choice(n_non_nan) lam = np.random.beta(alpha, alpha) if lam < .5: lam = 1 - lam arr[i] = lam * arr[choice1] + (1 - lam) * arr[choice2]
[docs] def resample(arr: np.ndarray, sfreq: int | float, new_sfreq: int | float, axis: int = -1) -> np.ndarray: """Resample an array through linear interpolation. Parameters ---------- arr : array The array to resample. sfreq : int The original sampling frequency. new_sfreq : int The new sampling frequency. Returns ------- resampled : array The resampled array. Examples -------- >>> import numpy as np >>> arr = np.arange(10) >>> resample(arr, 10, 5) array([0. , 2.25, 4.5 , 6.75, 9. ]) >>> resample(arr, 10.2, 5.1) array([0. , 2.25, 4.5 , 6.75, 9. ]) >>> resample(arr, 10, 20) array([0. , 0.47368421, 0.94736842, 1.42105263, 1.89473684, 2.36842105, 2.84210526, 3.31578947, 3.78947368, 4.26315789, 4.73684211, 5.21052632, 5.68421053, 6.15789474, 6.63157895, 7.10526316, 7.57894737, 8.05263158, 8.52631579, 9. ]) >>> resample(arr, 10, 7) array([0. , 1.5, 3. , 4.5, 6. , 7.5, 9. ]) >>> arr = np.arange(30).reshape(5, 6) >>> resample(arr, 6, 10) array([[ 0. , 0.55555556, 1.11111111, 1.66666667, 2.22222222, 2.77777778, 3.33333333, 3.88888889, 4.44444444, 5. ], [ 6. , 6.55555556, 7.11111111, 7.66666667, 8.22222222, 8.77777778, 9.33333333, 9.88888889, 10.44444444, 11. ], [12. , 12.55555556, 13.11111111, 13.66666667, 14.22222222, 14.77777778, 15.33333333, 15.88888889, 16.44444444, 17. ], [18. , 18.55555556, 19.11111111, 19.66666667, 20.22222222, 20.77777778, 21.33333333, 21.88888889, 22.44444444, 23. ], [24. , 24.55555556, 25.11111111, 25.66666667, 26.22222222, 26.77777778, 27.33333333, 27.88888889, 28.44444444, 29. ]]) """ while axis < 0: axis += arr.ndim if sfreq == new_sfreq: return arr elif not sfreq % 1 == 0: num, denom = Decimal(str(sfreq)).as_integer_ratio() return resample(arr, num, new_sfreq * denom, axis) elif not new_sfreq % 1 == 0: num, denom = Decimal(str(new_sfreq)).as_integer_ratio() return resample(arr, sfreq * denom, num, axis) else: # Directly calculate the new sample points seconds = arr.shape[axis] / sfreq o_indices = np.arange(arr.shape[axis]) new_samps = int(round(new_sfreq * seconds)) indices = np.linspace(0, arr.shape[axis] - 1, new_samps) if arr.ndim == 1: return np.interp(indices, o_indices, arr) # for multi-dimensional arrays, we flatten non-axis dimensions, then # apply the 1d interpolation, then reshape func = partial(np.interp, indices, o_indices) arr_in = np.swapaxes(arr, axis, -1).reshape(-1, arr.shape[axis]) out_flat = np.apply_along_axis(func, 1, arr_in) out_shape = arr.shape[:axis] + (new_samps,) + arr.shape[axis + 1:] return out_flat.reshape(out_shape)