Source code for ieeg.calc.scaling

from functools import singledispatch

import numpy as np
from mne import Epochs
from mne.epochs import BaseEpochs
from mne.time_frequency import AverageTFR, EpochsTFR
from mne.utils import logger, verbose
from ieeg.calc.stats import dist


def _log_rescale(baseline, mode='mean'):
    """Log the rescaling method."""
    if baseline is not None:
        msg = 'Applying baseline correction (mode: %s)' % mode
    else:
        msg = 'No baseline correction applied'
    return msg


[docs] @singledispatch def rescale(data: np.ndarray, basedata: np.ndarray, mode: str = 'mean', copy: bool = False, axis: tuple[int] | int = -1) -> np.ndarray: """Rescale (baseline correct) data. Implement a variety of baseline correction methods. The data is modified in place by default. Parameters ---------- data : array | mne.Epochs | mne.EpochsTFR It can be of any shape. The only constraint is that the last dimension should be time. basedata : array It can be of any shape. The last dimension should be time, and the other dimensions should be the same as data. mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio',\ default 'mean', optional Perform baseline correction by - subtracting the mean of baseline values ('mean') - dividing by the mean of baseline values ('ratio') - dividing by the mean of baseline values and taking the log ('logratio') - subtracting the mean of baseline values followed by dividing by the mean of baseline values ('percent') - subtracting the mean of baseline values and dividing by the standard deviation of baseline values ('zscore') - dividing by the mean of baseline values, taking the log, and dividing by the standard deviation of log baseline values ('zlogratio') copy : bool, optional Whether to return a new instance or modify in place. axis : int or tuple[int], optional Returns ------- data_scaled: array Array of same shape as data after rescaling. """ if copy: data = data.copy() match mode: case 'mean': def fun(d, m, s): d -= m case 'ratio': def fun(d, m, s): d /= m case 'logratio': def fun(d, m, s): d /= m np.log10(d, out=d) case 'percent': def fun(d, m, s): d -= m d /= m case 'zscore': def fun(d, m, s): d -= m d /= s case 'zlogratio': def fun(d, m, s): d /= m np.log10(d, out=d) d /= s case _: raise NotImplementedError() mean, std = dist(basedata, axis=axis, mode='std', ddof=1, keepdims=True) fun(data, mean, std) return data
@rescale.register @verbose def _(line: BaseEpochs, baseline: BaseEpochs, mode: str = 'mean', copy: bool = False, picks: list = 'data', verbose=None) -> Epochs: """Rescale (baseline correct) Epochs""" if copy: line: Epochs = line.copy() if verbose is not False: msg = _log_rescale(baseline, mode) logger.info(msg) # Average the baseline across epochs basedata = baseline.pick(picks)._data axes = list(range(basedata.ndim)) # within channels axes.pop(1) # If time frequency then within frequency if isinstance(line, EpochsTFR): axes = (0, 3) elif isinstance(line, AverageTFR): axes = 2 else: axes = tuple(axes) line.pick(picks)._data = rescale(line.pick(picks)._data, basedata, mode, False, axes) return line @rescale.register @verbose def _(line: EpochsTFR, baseline: EpochsTFR, mode: str = 'mean', copy: bool = False, picks: list = 'data', verbose=None) -> Epochs: """Rescale (baseline correct) Epochs""" if copy: line: Epochs = line.copy() if verbose is not False: msg = _log_rescale(baseline, mode) logger.info(msg) # Average the baseline across epochs basedata = baseline.pick(picks)._data # If time frequency then within frequency axes = (0, 3) line.pick(picks)._data = rescale(line.pick(picks)._data, basedata, mode, False, axes) return line @rescale.register @verbose def _(line: AverageTFR, baseline: AverageTFR, mode: str = 'mean', copy: bool = False, picks: list = 'data', verbose=None) -> Epochs: """Rescale (baseline correct) Epochs""" if copy: line: Epochs = line.copy() if verbose is not False: msg = _log_rescale(baseline, mode) logger.info(msg) # Average the baseline across epochs basedata = baseline.pick(picks)._data # If time frequency then within frequency axes = 2 line.pick(picks)._data = rescale(line.pick(picks)._data, basedata, mode, False, axes) return line