Source code for ieeg.process

import inspect
import operator
from itertools import chain
from os import environ
from typing import Generator, Iterable, TypeVar, Union

import numpy as np
import pandas as pd
from joblib import Parallel, cpu_count, delayed
from mne.utils import config, logger
from scipy.signal import get_window


[docs] def iterate_axes(arr: np.ndarray, axes: tuple[int, ...]): """Iterate over all possible indices for a set of axes Parameters ---------- arr : np.ndarray The array to iterate over axes : tuple[int] The axes to iterate over index : tuple[int] The current index axis : int The current axis Yields ------ tuple[slice] The indices for the current iteration Examples -------- >>> arr = np.arange(24).reshape(2, 3, 4) >>> for sl in iterate_axes(arr, (0, 1)): ... print(arr[sl], sl) [0 1 2 3] (0, 0, slice(None, None, None)) [4 5 6 7] (0, 1, slice(None, None, None)) [ 8 9 10 11] (0, 2, slice(None, None, None)) [12 13 14 15] (1, 0, slice(None, None, None)) [16 17 18 19] (1, 1, slice(None, None, None)) [20 21 22 23] (1, 2, slice(None, None, None)) >>> for sl in iterate_axes(arr, (0, 2)): ... print(arr[sl], sl) [0 4 8] (0, slice(None, None, None), 0) [1 5 9] (0, slice(None, None, None), 1) [ 2 6 10] (0, slice(None, None, None), 2) [ 3 7 11] (0, slice(None, None, None), 3) [12 16 20] (1, slice(None, None, None), 0) [13 17 21] (1, slice(None, None, None), 1) [14 18 22] (1, slice(None, None, None), 2) [15 19 23] (1, slice(None, None, None), 3) """ shape = [arr.shape[a] for a in axes] for idx in np.ndindex(*shape): slices = [slice(None)] * arr.ndim for axis, i in zip(axes, idx): slices[axis] = i yield tuple(slices)
[docs] def ensure_int(x, name: str = 'unknown', must_be: str = 'an int', *, extra=''): """Ensure a variable is an integer. Parameters ---------- x : object The object to check. name : str The name of the variable to check. must_be : str The type of the variable to check. extra : str Extra text to add to the error message. Notes ----- This is preferred over numbers.Integral, see: https://github.com/scipy/scipy/pull/7351#issuecomment-299713159 Examples -------- >>> ensure_int(1) 1 >>> ensure_int(1.0) Traceback (most recent call last): ... TypeError: unknown must be an int, got <class 'float'> >>> ensure_int('1') Traceback (most recent call last): ... TypeError: unknown must be an int, got <class 'str'> >>> ensure_int('1.0', extra='a string') Traceback (most recent call last): ... TypeError: unknown must be an int a string, got <class 'str'> """ extra = f' {extra}' if extra else extra try: # someone passing True/False is much more likely to be an error than # intentional usage if isinstance(x, bool): raise TypeError() x = int(operator.index(x)) except TypeError: raise TypeError(f'{name} must be {must_be}{extra}, got {type(x)}') return x
[docs] def validate_type(item, types): """Validate the type of an object. Parameters ---------- item : object The object to check. types : type The type to check against. """ try: if isinstance(types, TypeVar): check = isinstance(item, types.__constraints__) elif types is int: ensure_int(item) check = True elif types is float: check = is_number(item) else: check = isinstance(item, types) except TypeError: check = False if not check: raise TypeError( f"must be an instance of {types}, " f"got {type(item)} instead.")
[docs] def is_number(s) -> bool: """Check if an object is a number Parameters ---------- s : object The object to check Returns ------- bool True if the object is a number, False otherwise """ if isinstance(s, str): try: float(s) return True except ValueError: return False elif isinstance(s, (np.number, int, float)): return True elif isinstance(s, pd.DataFrame): try: s.astype(float) return True except Exception: return False elif isinstance(s, pd.Series): try: pd.to_numeric(s) return True except Exception: return False else: return False
[docs] def proc_array(func: callable, arr_in: np.ndarray, axes: int | tuple[int] = 0, n_jobs: int = None, desc: str = "Slices", inplace: bool = True, **kwargs) -> np.ndarray: """Execute a function in parallel over slices of an array Parameters ---------- func : callable The function to execute arr_in : np.ndarray The array to slice axes : int | tuple[int] The axes to slice over n_jobs : int The number of jobs to run in parallel desc : str The description to use for the progress bar inplace : bool Whether to modify the input array in place Returns ------- np.ndarray The output of the function, same shape as the input array Examples -------- >>> def square(x): ... return x ** 2 >>> proc_array(square, np.arange(10)) array([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81]) """ if isinstance(axes, int): axes = (axes,) if inplace: arr_out = arr_in else: arr_out = arr_in.copy() # Get the cross-section indices and array input generator cross_sect_ind = list(np.ndindex(*[arr_in.shape[axis] for axis in axes])) array_gen = list(arr_in[indices] for indices in cross_sect_ind) gen = Parallel(n_jobs, return_as='generator', verbose=40)( delayed(func)(x_, **kwargs) for x_ in array_gen) # Create process pool and apply the function in parallel for out, ind in zip(gen, cross_sect_ind): arr_out[ind] = out return arr_out
[docs] def parallelize(func: callable, ins: Iterable, verbose: int = 10, n_jobs: int = None, **kwargs) -> list | None: """Parallelize a function to run on multiple processors. This function is a wrapper for joblib.Parallel. It will automatically determine the number of jobs to run in parallel based on the number of cores available on the system. It will also automatically set the temp_folder and max_nbytes parameters for joblib.Parallel based on the MNE_CACHE_DIR and MNE_MEMMAP_MIN_SIZE parameters in mne-python's configuration file. Notes ----- If the elements of the par_var iterable are tuples, the function will be called with the tuple unpacked, setting each item in the tuple to be assigned to a separate argument. If the elements of the par_var iterable are not tuples, the function will be called with the element as the first argument. Parameters ---------- func : callable The function to parallelize ins : Iterable The iterable to parallelize over n_jobs : int The number of jobs to run in parallel. If None, will use all available cores. If -1, will use all available cores. **kwargs Additional keyword arguments to pass to the function Returns ------- list The output of the function for each element in par_var Examples -------- >>> def square(x): ... return x ** 2 >>> parallelize(square, [1, 2, 3]) [1, 4, 9] """ assert ins if n_jobs is None: if 'n_jobs' in inspect.getfullargspec(func).args: kwargs['n_jobs'] = -2 else: n_jobs = -2 settings = dict(verbose=verbose) settings['prefer'] = kwargs.pop('prefer', None) settings['backend'] = kwargs.pop('backend', None) settings['mmap_mode'] = kwargs.pop('mmap_mode', 'r') settings['require'] = kwargs.pop('require', None) env = dict(**environ) if config.get_config('MNE_CACHE_DIR') is not None: settings['temp_folder'] = config.get_config('MNE_CACHE_DIR') elif 'TEMP' in env.keys(): settings['temp_folder'] = env['TEMP'] else: settings['temp_folder'] = None if config.get_config('MNE_MEMMAP_MIN_SIZE') is not None: settings['max_nbytes'] = config.get_config('MNE_MEMMAP_MIN_SIZE') else: settings['max_nbytes'] = get_mem() for var in ins: if isinstance(var, tuple): x_is_tup = True elif isinstance(ins, Generator): x_is_tup = False ins = chain((var,), ins) else: x_is_tup = False break if x_is_tup: return Parallel(n_jobs, **settings)(delayed(func)( *x_, **kwargs) for x_ in ins) else: return Parallel(n_jobs, **settings)(delayed(func)( x_, **kwargs) for x_ in ins)
[docs] def get_mem(n_jobs: int = None) -> int: """Get the amount of memory to use for parallelization. Returns ------- float | int The amount of memory to use for parallelization """ from psutil import virtual_memory, swap_memory ram_per = virtual_memory().available + swap_memory().free if n_jobs is None or n_jobs == -1: return ram_per while n_jobs < 0: n_jobs += cpu_count() return ram_per // cpu_count() * n_jobs
[docs] def sliding_window(x_data: np.ndarray, labels: np.ndarray, scorer: callable, window_size: int = 20, axis: int = -1, n_jobs: int = -3, **kwargs) -> np.ndarray: """Compute a function over a sliding window. Parameters ---------- x_data : np.ndarray, shape (..., trials, time) The data to compute the function over labels : np.ndarray, shape (trials,) The labels for each trial scorer : callable The function to compute over the sliding window. Must take two arguments, the data and the labels. window_size : int The size of the sliding window axis : int The axis to compute the sliding window over n_jobs : int The number of jobs to run in parallel Returns ------- np.ndarray The output of the function, shape (..., time - window_size + 1) Examples -------- >>> def square(x, labels): ... return np.mean(x ** 2, where=labels == 1) >>> x_data = np.arange(40).reshape(4, 10) >>> labels = np.array([0, 1, 1]) >>> sliding_window(x_data, labels, square, window_size=3) array([397.5, 431.5, 467.5, 505.5, 545.5, 587.5, 631.5]) """ # make windowing generator axis = x_data.ndim + axis if axis < 0 else axis slices = (slice(start, start + window_size) for start in range(0, x_data.shape[axis] - window_size)) idxs = (tuple(slice(None) if i != axis else sl for i in range(x_data.ndim)) for sl in slices) # Use joblib to parallelize the computation gen = Parallel(n_jobs=n_jobs, return_as='generator', verbose=40)( delayed(scorer)(x_data[idx], labels, **kwargs) for idx in idxs) # initialize output array by running 1 job and get the shape mat = next(gen) out = np.zeros((x_data.shape[axis] - window_size, *mat.shape), dtype=mat.dtype) out[0] = mat # fill in the rest of the output array for i, mat in enumerate(gen): out[i + 1] = mat return out
############################################################################### # Constant overlap-add processing class def _check_store(store): if isinstance(store, np.ndarray): store = [store] if isinstance(store, (list, tuple)) and all(isinstance(s, np.ndarray) for s in store): store = _Storer(*store) if not callable(store): raise TypeError('store must be callable, got type %s' % (type(store),)) return store
[docs] class COLA: """Constant overlap-add processing helper. Parameters ---------- process : callable A function that takes a chunk of input data with shape `(n_channels, n_samples)` and processes it. store : callable | ndarray A function that takes a completed chunk of output data. Can also be an `ndarray`, in which case it is treated as the output data in which to store the results. n_total : int The total number of samples. n_samples : int The number of samples per window. n_overlap : int The overlap between windows. window : str The window to use. Default is "hann". tol : float The tolerance for COLA checking. n_jobs : int The number of jobs to run in parallel. verbose : bool If True, print a message when the COLA condition is not met. Notes ----- This will process data using overlapping windows to achieve a constant output value. For example, for ``n_total=27``, ``n_samples=10``, ``n_overlap=5`` and ``window='triang'`` :: 1 _____ _______ | \\ /\\ /\\ / | \\ / \\ / \\ / | x x x | / \\ / \\ / \\ | / \\/ \\/ \\ 0 +----|----|----|----|----|- 0 5 10 15 20 25 This produces four windows: the first three are the requested length (10 samples) and the last one is longer (12 samples). The first and last window are asymmetric. """ def __init__(self, process, store, n_total, n_samples, n_overlap, sfreq, window='hann', tol=1e-10, *, verbose=None): n_samples = ensure_int(n_samples, 'n_samples') n_overlap = ensure_int(n_overlap, 'n_overlap') n_total = ensure_int(n_total, 'n_total') if n_samples <= 0: raise ValueError('n_samples must be > 0, got %s' % (n_samples,)) if n_overlap < 0: raise ValueError('n_overlap must be >= 0, got %s' % (n_overlap,)) if n_total < 0: raise ValueError('n_total must be >= 0, got %s' % (n_total,)) self._n_samples = int(n_samples) self._n_overlap = int(n_overlap) del n_samples, n_overlap if n_total < self._n_samples: raise ValueError('Number of samples per window (%d) must be at ' 'most the total number of samples (%s)' % (self._n_samples, n_total)) if not callable(process): raise TypeError('process must be callable, got type %s' % (type(process),)) self._process = process self._step = self._n_samples - self._n_overlap self._store = _check_store(store) self._idx = 0 self._in_buffers = self._out_buffers = None # Create our window boundaries window_name = window if isinstance(window, str) else 'custom' self._window = get_window(window, self._n_samples, fftbins=(self._n_samples - 1) % 2) self._window /= _check_cola(self._window, self._n_samples, self._step, window_name, tol=tol) self.starts = np.arange(0, n_total - self._n_samples + 1, self._step) self.stops = self.starts + self._n_samples delta = n_total - self.stops[-1] self.stops[-1] = n_total sfreq = float(sfreq) pl = 's' if len(self.starts) != 1 else '' if verbose: logger.info(' Processing %4d data chunk%s of (at least) %0.1f ' 'sec with %0.1f sec overlap and %s windowing' % (len(self.starts), pl, self._n_samples / sfreq, self._n_overlap / sfreq, window_name)) del window, window_name if delta > 0 and verbose: logger.info(' The final %0.3f sec will be lumped into the ' 'final window' % (delta / sfreq,)) @property def _in_offset(self): """Compute from current processing window start and buffer len.""" return self.starts[self._idx] + self._in_buffers[0].shape[-1]
[docs] def feed(self, *datas, verbose=None, **kwargs): """Pass in a chunk of data.""" # Append to our input buffer if self._in_buffers is None: self._in_buffers = [None] * len(datas) if len(datas) != len(self._in_buffers): raise ValueError('Got %d array(s), needed %d' % (len(datas), len(self._in_buffers))) for di, data in enumerate(datas): if not isinstance(data, np.ndarray) or data.ndim < 1: raise TypeError('data entry %d must be an 2D ndarray, got %s' % (di, type(data),)) if self._in_buffers[di] is None: # In practice, users can give large chunks, so we use # dynamic allocation of the in buffer. We could save some # memory allocation by only ever processing max_len at once, # but this would increase code complexity. self._in_buffers[di] = np.empty( data.shape[:-1] + (0,), data.dtype) if data.shape[:-1] != self._in_buffers[di].shape[:-1] or \ self._in_buffers[di].dtype != data.dtype: raise TypeError('data must dtype %s and shape[:-1]==%s, ' 'got dtype %s shape[:-1]=%s' % (self._in_buffers[di].dtype, self._in_buffers[di].shape[:-1], data.dtype, data.shape[:-1])) # logger.debug(' + Appending %d->%d' # % (self._in_offset, self._in_offset + data.shape[ # -1])) self._in_buffers[di] = np.concatenate( [self._in_buffers[di], data], -1) if self._in_offset > self.stops[-1]: raise ValueError('data (shape %s) exceeded expected total ' 'buffer size (%s > %s)' % (data.shape, self._in_offset, self.stops[-1])) # preallocate data to chunks data_chunks = map(lambda x, y: data[x:y], self.starts, self.stops) out_chunks = map(lambda d: self._process(d, **kwargs), data_chunks) # overlap add to buffer while self._idx < len(self.starts) and \ self._in_offset >= self.stops[self._idx]: start, stop = self.starts[self._idx], self.stops[self._idx] this_len = stop - start this_window = self._window.copy() if self._idx == len(self.starts) - 1: this_window = np.pad( self._window, (0, this_len - len(this_window)), 'constant') for offset in range(self._step, len(this_window), self._step): n_use = len(this_window) - offset this_window[offset:] += self._window[:n_use] if self._idx == 0: for offset in range(self._n_samples - self._step, 0, -self._step): this_window[:offset] += self._window[-offset:] # logger.debug(' * Processing %d->%d' % (start, stop)) this_proc = [in_[..., :this_len].copy() for in_ in self._in_buffers] if not all(proc.shape[-1] == this_len == this_window.size for proc in this_proc): raise RuntimeError('internal indexing error') outs = next(out_chunks) if self._out_buffers is None: max_len = np.max(self.stops - self.starts) self._out_buffers = [np.zeros(o.shape[:-1] + (max_len,), o.dtype) for o in outs] for oi, out in enumerate(outs): out *= this_window self._out_buffers[oi][..., :stop - start] += out self._idx += 1 if self._idx < len(self.starts): next_start = self.starts[self._idx] else: next_start = self.stops[-1] delta = next_start - self.starts[self._idx - 1] for di in range(len(self._in_buffers)): self._in_buffers[di] = self._in_buffers[di][..., delta:] # logger.debug(' - Shifting input/output buffers by %d samples' # % (delta,)) self._store(*[o[..., :delta] for o in self._out_buffers]) for ob in self._out_buffers: ob[..., :-delta] = ob[..., delta:] ob[..., -delta:] = 0.
def _check_cola(win, nperseg, step, window_name, tol=1e-10): """Check whether the Constant OverLap Add (COLA) constraint is met.""" # adapted from SciPy binsums = np.sum([win[ii * step:(ii + 1) * step] for ii in range(nperseg // step)], axis=0) if nperseg % step != 0: binsums[:nperseg % step] += win[-(nperseg % step):] const = np.median(binsums) deviation = np.max(np.abs(binsums - const)) if deviation > tol: raise ValueError('segment length %d with step %d for %s window ' 'type does not provide a constant output ' '(%g%% deviation)' % (nperseg, step, window_name, 100 * deviation / const)) return const class _Storer(object): """Store data in chunks.""" def __init__(self, *outs, picks=None): for oi, out in enumerate(outs): if not isinstance(out, np.ndarray) or out.ndim < 1: raise TypeError('outs[oi] must be >= 1D ndarray, got %s' % (out,)) self.outs = outs self.idx = 0 self.picks = picks def __call__(self, *outs): if (len(outs) != len(self.outs) or not all(out.shape[-1] == outs[0].shape[-1] for out in outs)): raise ValueError('Bad outs') idx = (Ellipsis,) if self.picks is not None: idx += (self.picks,) stop = self.idx + outs[0].shape[-1] idx += (slice(self.idx, stop),) for o1, o2 in zip(self.outs, outs): o1[idx] = o2 self.idx = stop