Source code for ieeg.timefreq.utils

from typing import Union

from mne.epochs import BaseEpochs
from mne.io import base
from mne.time_frequency import EpochsTFRArray
from mne.utils import fill_doc, verbose
from mne import Evoked

from ieeg import Signal
from ieeg.process import ensure_int, validate_type
from ieeg.calc.oversample import resample
from joblib import delayed, Parallel
import numpy as np
# from ieeg.process import parallelize
# from scipy.fft import fft, ifft, fftfreq
# from functools import partial
# from mne.evoked import Evoked
# from mne.time_frequency import tfr_array_stockwell as tas


[docs] def to_samples(time_length: Union[str, int], sfreq: float) -> int: """Convert a time length to a number of samples. Parameters ---------- time_length : str | int The time length to convert. If a string, it must be a human-readable time, e.g. "10s". sfreq : float The sampling frequency. Returns ------- samples : int The number of samples. """ validate_type(time_length, (str, int)) if isinstance(time_length, str): time_length = time_length.lower() err_msg = ('filter_length, if a string, must be a ' 'human-readable time, e.g. "0.7s", or "700ms", not ' '"%s"' % time_length) low = time_length.lower() if low.endswith('us'): mult_fact = 1e-6 time_length = time_length[:-2] elif low.endswith('ms'): mult_fact = 1e-3 time_length = time_length[:-2] elif low[-1] == 's': mult_fact = 1 time_length = time_length[:-1] elif low.endswith('sec'): mult_fact = 1 time_length = time_length[:-3] elif low[-1] == 'm': mult_fact = 60 time_length = time_length[:-1] elif low.endswith('min'): mult_fact = 60 time_length = time_length[:-3] else: raise ValueError(err_msg) # now get the number try: time_length = float(time_length) except ValueError: raise ValueError(err_msg) time_length = max(int(np.ceil(time_length * mult_fact * sfreq)), 1) time_length = ensure_int(time_length, 'filter_length') return time_length
[docs] @fill_doc def crop_pad(inst: Signal, pad: str, copy: bool = False) -> Signal: """Crop and pad an instance. Parameters ---------- inst : instance of Raw, Epochs, or Evoked The instance to crop and pad. pad : str The amount of time to pad the instance. If a string, it must be a human-readable time, e.g. "10s". copy : bool, optional If True, a copy of x, filtered, is returned. Otherwise, it operates on x in place. Defaults to False. Returns ------- inst : instance of Raw, Epochs, or Evoked The cropped and de-padded instance. """ if copy: out = inst.copy() else: out = inst pad = to_samples(pad, inst.info['sfreq']) / inst.info['sfreq'] out.crop(tmin=inst.tmin + pad, tmax=inst.tmax - pad) return out
# @verbose # def cwt(inst: BaseEpochs, f_low: float, f_high: float, n_fft: int, # width: float = 1.0, n_jobs: int = 1, decim: int = 1, verbose=10 # ) -> EpochsTFRArray: # """Compute the wavelet scaleogram. # # # # Parameters # ---------- # inst : instance of Raw, Epochs, or Evoked # The instance to compute the wavelet scaleogram for. # f_low : float # The lowest frequency to compute the scaleogram for. # f_high : float # The highest frequency to compute the scaleogram for. # k0 : int # The wavelet parameter. # n_jobs : int # The number of jobs to run in parallel. # decim : int # The decimation factor. # verbose : int # The verbosity level. # # Returns # ------- # scaleogram : instance of EpochsTFR # The wavelet scaleogram. # # Notes # ----- # Similar to https://www.mathworks.com/help/wavelet/ref/cwt.html # # Examples # -------- # >>> import mne # >>> from ieeg.io import raw_from_layout # >>> from ieeg.navigate import trial_ieeg # >>> from bids import BIDSLayout # # >>> with mne.use_log_level(0): # >>> bids_root = mne.datasets.epilepsy_ecog.data_path() # >>> layout = BIDSLayout(bids_root) # >>> raw = raw_from_layout(layout, subject="pt1", preload=True, # ... extension=".vhdr", verbose=False) # >>> epochs = trial_ieeg(raw, ['AST1,3', 'G16'], (-1, 2), verbose=False) # >>> cwt(epochs, n_jobs=1, decim=10, n_fft=40) # doctest: +ELLIPSIS # Using data from preloaded Raw for 2 events and 3001 original time... # Getting epoch for 85800-88801 # Getting epoch for 90760-93761 # 0 bad epochs dropped # Data is self data: False # <TFR from Epochs, unknown method | 2 epochs × 98 channels × 46 freqs ... # # """ # # ins = ((ti, chi, a[None, None]) for ti, b in enumerate(inst) for chi, # a in enumerate(b)) # proc = Parallel(n_jobs=n_jobs, verbose=verbose, require="sharedmem", # return_as="generator_unordered")( # delayed(tas_wrap)(i, j, x, inst.info['sfreq'], f_low, f_high, n_fft, # width, decim, False, 1) for i, j, x in ins) # out0, _, freqs = tas(inst[0, 0], inst.info['sfreq'], f_low, f_high, # n_fft, # width, decim, False, 1) # # out = np.empty((len(inst), len(inst.ch_names), *out0.data.shape), # dtype=np.float32) # out[0, 0] = out0 # for ti, chi, (x, _, _) in proc: # out[ti, chi] = x # # return EpochsTFRArray(inst.info, out, inst.times[::decim], freqs) # # data = inst.get_data(copy=False) # # wave, _, freqs = tfr_array_stockwell(data, inst.info['sfreq'], f_low, # f_high, # # n_jobs=n_jobs, decim=decim, # return_itc=False, # # average=False, width=width, # verbose=verbose, # # n_fft=n_fft) # # # # return EpochsTFRArray(inst.info, wave, inst.times[::decim], freqs) # # # def tas_wrap(i, j, *args, **kwargs): # return i, j, tas(*args, **kwargs)
[docs] @verbose def wavelet_scaleogram(inst: BaseEpochs, f_low: float = 2, f_high: float = 1000, k0: int = 6, n_jobs: int = 1, decim: int = 1, verbose=10) -> EpochsTFRArray: """Compute the wavelet scaleogram. Parameters ---------- inst : instance of Raw, Epochs, or Evoked The instance to compute the wavelet scaleogram for. f_low : float The lowest frequency to compute the scaleogram for. f_high : float The highest frequency to compute the scaleogram for. k0 : int The wavelet parameter. n_jobs : int The number of jobs to run in parallel. decim : int The decimation factor. verbose : int The verbosity level. Returns ------- scaleogram : instance of EpochsTFR The wavelet scaleogram. Notes ----- Similar to https://www.mathworks.com/help/wavelet/ref/cwt.html Examples -------- >>> import mne >>> from ieeg.io import raw_from_layout >>> from ieeg.navigate import trial_ieeg >>> from bids import BIDSLayout >>> bids_root = mne.datasets.epilepsy_ecog.data_path() >>> layout = BIDSLayout(bids_root) >>> raw = raw_from_layout(layout, subject="pt1", preload=True, ... extension=".vhdr", verbose=False) Reading 0 ... 269079 = 0.000 ... 269.079 secs... >>> epochs = trial_ieeg(raw, ['AST1,3', 'G16'], (-1, 2), verbose=False) >>> wavelet_scaleogram(epochs, n_jobs=1, decim=10) # doctest: +ELLIPSIS Using data from preloaded Raw for 2 events and 3001 original time points... Getting epoch for 85800-88801 Getting epoch for 90760-93761 0 bad epochs dropped Data is self data: False <TFR from Epochs, unknown method | 2 epochs × 98 channels × 46 freqs × ... """ data = inst.get_data(copy=False) f = np.fft.fft(data - np.mean(data, axis=-1, keepdims=True)) daughter, period = calculate_wavelets(inst.info['sfreq'], f_high, f_low, data.shape[-1], k0) wave = np.empty((f.shape[0], f.shape[1], len(period), data[..., ::decim].shape[-1]), dtype=np.float64) # ch X trials X freq X time ins = ((f[:, None, i], i) for i in range(f.shape[1])) def _ifft_abs(x, i): np.abs(np.fft.ifft(x * np.tile(daughter, ( f.shape[0], 1, 1)))[..., ::decim], out=wave[:, i]) proc = Parallel(n_jobs=n_jobs, verbose=verbose, require="sharedmem", return_as="generator_unordered")(delayed(_ifft_abs)(x, i) for x, i in ins) for _ in proc: pass # parallelize(_ifft_abs, ins, require='sharedmem', n_jobs=n_jobs, # verbose=verbose) return EpochsTFRArray(inst.info, wave, inst.times[::decim], 1 / period, events=inst.events, event_id=inst.event_id)
[docs] def calculate_wavelets(sfreq: float, f_high: float, f_low: float, n_samples: int, k0: int = 6): """Calculate Morlet wavelets for a range of frequencies. Parameters ---------- sfreq : float The sampling frequency. f_high : float The highest frequency to compute the scaleogram for. f_low : float The lowest frequency to compute the scaleogram for. n_samples : int The number of samples. k0 : int The wavelet parameter. Indicates the center frequency of the wavelet. Returns ------- daughter : ndarray The wavelets. period : ndarray The periods. Examples -------- >>> daughter, period = calculate_wavelets(1000, 100, 2, 1000) >>> daughter.shape (29, 1001) >>> period.shape (29,) """ dt = 1 / sfreq s0 = 1 / (f_high + (0.1 * f_high)) # the smallest resolvable scale n = n_samples J1 = (np.log2(n * dt / s0)) / 0.2 # (J1 determines the largest scale) k = np.arange(np.fix(n / 2)) + 1 k = k * ((2 * np.pi) / (n * dt)) kr = (-k).tolist() kr.reverse() k = np.array([0] + k.tolist() + kr) scale = s0 * np.power(2., (np.arange(0, J1) * 0.2)) fourier_factor = (4 * np.pi) / (k0 + np.sqrt(2 + np.square(k0))) period = fourier_factor * scale xxx = np.min(np.where((1. / period) < f_low)) period = np.flip(period[:xxx]) scale = np.flip(scale[:xxx]) scale1 = scale period = fourier_factor * scale1 expnt = -np.square(scale1[:, None] * k[None, :] - k0) / 2. * (k > 0.) norm = np.sqrt(scale1 * k[2]) * (np.power(np.pi, (-0.25))) * np.sqrt(n) daughter = norm[:, None] * np.exp(expnt) daughter = daughter * (k > 0.) return daughter, period
[docs] def roundup(x: float) -> int: """Round up to the nearest integer.""" n, d = divmod(x, 1) return int(n) + (d > 0)
# def _check_input_st(x_in, n_fft): # """Aux function.""" # # flatten to 2 D and memorize original shape # n_times = x_in.shape[-1] # # def _is_power_of_two(n): # return not (n > 0 and (n & (n - 1))) # # if n_fft is None or (not _is_power_of_two(n_fft) and n_times > n_fft): # # Compute next power of 2 # n_fft = 2 ** int(np.ceil(np.log2(n_times))) # elif n_fft < n_times: # raise ValueError( # f"n_fft cannot be smaller than signal size. Got {n_fft} < # {n_times}." # ) # if n_times < n_fft: # # logger.info( # # f'The input signal is shorter ({x_in.shape[-1]}) than "n_fft" # ({n_fft}). ' # # "Applying zero padding." # # ) # zero_pad = n_fft - n_times # pad_array = np.zeros(x_in.shape[:-1] + (zero_pad,), x_in.dtype) # x_in = np.concatenate((x_in, pad_array), axis=-1) # else: # zero_pad = 0 # return x_in, n_fft, zero_pad # # # def _precompute_st_windows(n_samp, start_f, stop_f, sfreq, width): # """Precompute stockwell Gaussian windows (in the freq domain).""" # tw = fftfreq(n_samp, 1.0 / sfreq) / n_samp # tw = np.r_[tw[:1], tw[1:][::-1]] # # k = width # 1 for classical stowckwell transform # f_range = np.arange(start_f, stop_f, 1) # windows = np.empty((len(f_range), len(tw)), dtype=np.complex128) # for i_f, f in enumerate(f_range): # if f == 0.0: # window = np.ones(len(tw)) # else: # window = (f / (np.sqrt(2.0 * np.pi) * k)) * np.exp( # -0.5 * (1.0 / k**2.0) * (f**2.0) * tw**2.0 # ) # window /= window.sum() # normalisation # windows[i_f] = fft(window) # return windows # # # def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W, average): # """Aux function.""" # decim = slice(None, None, decim) # n_samp = x.shape[-1] # decim_indices = decim.indices(n_samp - zero_pad) # n_out = len(range(*decim_indices)) # out_shape = (len(W), n_out) if average else (x.shape[0], len(W), n_out) # psd = np.empty(out_shape) # itc = np.empty_like(psd) if compute_itc else None # X = fft(x) # XX = np.concatenate([X, X], axis=-1) # for i_f, window in enumerate(W): # f = start_f + i_f # ST = ifft(XX[:, f : f + n_samp] * window) # TFR = ST[:, slice(*decim_indices)] # TFR_abs = np.abs(TFR) # TFR_abs[TFR_abs == 0] = 1.0 # if compute_itc: # TFR /= TFR_abs # itc[i_f] = np.abs(np.mean(TFR, axis=0)) # TFR_abs *= TFR_abs # if average: # psd[i_f] = np.mean(TFR_abs, axis=0) # else: # psd[..., i_f, :] = TFR_abs # return psd, itc # # # def _compute_freqs_st(fmin, fmax, n_fft, sfreq): # """Compute the frequencies for the Stockwell transform. # # Parameters # ---------- # fmin # fmax # n_fft # sfreq # # Returns # ------- # start_f # stop_f # freqs # # Examples # -------- # >>> _compute_freqs_st(30, 500, 200, 2048) # # """ # from scipy.fft import fftfreq # # freqs = fftfreq(n_fft, 1.0 / sfreq) # if fmin is None: # fmin = freqs[freqs > 0][0] # if fmax is None: # fmax = freqs.max() # # start_f = np.abs(freqs - fmin).argmin() # stop_f = np.abs(freqs - fmax).argmin() # freqs = freqs[start_f:stop_f] # return start_f, stop_f, freqs # # # def compute_freqs_st(fmin, fmax, n_fft, sfreq): # """Compute the frequencies for the Stockwell transform. # # Parameters # ---------- # fmin # fmax # n_fft # sfreq # # Returns # ------- # start_f # stop_f # freqs # # Examples # -------- # >>> compute_freqs_st(30, 500, 40, 2048) # # """ # start_f, stop_f, freqs = _compute_freqs_st(fmin, fmax, n_fft, sfreq) # temp = n_fft # while stop_f - start_f < n_fft: # temp += 1 # start_f, stop_f, freqs = _compute_freqs_st(fmin, fmax, temp, sfreq) # return start_f, stop_f, freqs # # # @verbose # def tfr_array_stockwell( # data, # sfreq, # fmin=None, # fmax=None, # n_fft=None, # width=1.0, # decim=1, # return_itc=False, # n_jobs=None, # average=True, # *, # verbose=None, # ): # """Compute power and intertrial coherence using Stockwell (S) transform. # # Same computation as `~mne.time_frequency.tfr_stockwell`, but operates on # :class:`NumPy arrays <numpy.ndarray>` instead of `~mne.Epochs` objects. # # See :footcite:`Stockwell2007,MoukademEtAl2014,WheatEtAl2010,JonesEtAl2006 # for more information. # # Parameters # ---------- # data : ndarray, shape (n_epochs, n_channels, n_times) # The signal to transform. # sfreq : float # The sampling frequency. # fmin : None, float # The minimum frequency to include. If None defaults to the minimum fft # frequency greater than zero. # fmax : None, float # The maximum frequency to include. If None defaults to the maximum fft # n_fft : int | None # The length of the windows used for FFT. If None, it defaults to the # next power of 2 larger than the signal length. # width : float # The width of the Gaussian window. If < 1, increased temporal # resolution, if > 1, increased frequency resolution. Defaults to 1. # (classical S-Transform). # %(decim_tfr)s # return_itc : bool # Return intertrial coherence (ITC) as well as averaged power. # %(n_jobs)s # %(verbose)s # # Returns # ------- # st_power : ndarray # The multitaper power of the Stockwell transformed data. # The last two dimensions are frequency and time. # itc : ndarray # The intertrial coherence. Only returned if return_itc is True. # freqs : ndarray # The frequencies. # # See Also # -------- # mne.time_frequency.tfr_stockwell # mne.time_frequency.tfr_multitaper # mne.time_frequency.tfr_array_multitaper # mne.time_frequency.tfr_morlet # mne.time_frequency.tfr_array_morlet # # References # ---------- # .. footbibliography:: # """ # if data.ndim != 3: # raise ValueError( # "data must be 3D with shape (n_epochs, n_channels, n_times), " # f"got {data.shape}" # ) # # trials, n_channels, n_out = data[..., ::decim].shape # start_f, stop_f, freqs = compute_freqs_st(fmin, fmax, n_fft, sfreq) # # W = _precompute_st_windows(data.shape[-1], start_f, stop_f, sfreq, width) # n_freq = stop_f - start_f # out_shape = (n_channels, n_freq, n_out) if average else \ # (trials, n_channels, n_freq, n_out) # psd = np.empty(out_shape) # itc = np.empty(out_shape) if return_itc else None # # ins = (data[:, c, :] for c in range(n_channels)) # myfunc = partial(_st_power_itc, start_f=start_f, compute_itc=return_itc, # zero_pad=0, decim=decim, W=W, average=average) # tfrs = Parallel(n_jobs=n_jobs, verbose=verbose, return_as='generator')( # delayed(myfunc)(x) for x in ins) # for c, (this_psd, this_itc) in enumerate(iter(tfrs)): # psd[..., c, :, :] = this_psd # if this_itc is not None: # itc[c] = this_itc # # return psd, itc, freqs # def _check_filterable(x: Union[Signal, np.ndarray], kind: str = 'filtered', alternative: str = 'filter') -> np.ndarray: # Let's be fairly strict about this -- users can easily coerce to ndarray # at their end, and we already should do it internally any time we are # using these low-level functions. At the same time, let's # help people who might accidentally use low-level functions that they # shouldn't use by pushing them in the right direction if isinstance(x, (base.BaseRaw, BaseEpochs, Evoked)): try: name = x.__class__.__name__ except Exception: pass else: raise TypeError( 'This low-level function only operates on np.ndarray ' f'instances. To get a {kind} {name} instance, use a method ' f'like `inst_new = inst.copy().{alternative}(...)` ' 'instead.') validate_type(x, (np.ndarray, list, tuple)) x = np.asanyarray(x) if x.dtype != np.float64: raise ValueError('Data to be %s must be real floating, got %s' % (kind, x.dtype,)) return x
[docs] def resample_tfr(tfr, sfreq, o_sfreq=None, copy=False): """Resample a TFR object to a new sampling frequency""" if copy: tfr = tfr.copy() if o_sfreq is None: # o_sfreq = len(tfr.times) / (tfr.tmax - tfr.tmin) o_sfreq = tfr.info["sfreq"] tfr._data = resample(tfr._data, o_sfreq, sfreq, axis=-1) lowpass = tfr.info.get("lowpass") lowpass = np.inf if lowpass is None else lowpass with tfr.info._unlock(): tfr.info["lowpass"] = min(lowpass, sfreq / 2) tfr.info["sfreq"] = sfreq new_times = resample(tfr.times, o_sfreq, sfreq, axis=-1) # adjust indirectly affected variables tfr._set_times(new_times) tfr._raw_times = tfr.times tfr._update_first_last() return tfr
if __name__ == "__main__": # Description: Produce spectrograms for each subject from ieeg.io import get_data, raw_from_layout from ieeg.calc.scaling import rescale from ieeg.viz.ensemble import chan_grid from ieeg.viz.parula import parula_map from ieeg.navigate import trial_ieeg, crop_empty_data, \ outliers_to_nan import os import numpy as np # check if currently running a slurm job HOME = os.path.expanduser("~") if 'SLURM_ARRAY_TASK_ID' in os.environ.keys(): LAB_root = os.path.join(HOME, "workspace", "CoganLab") layout = get_data("SentenceRep", root=LAB_root) subjects = list(int(os.environ['SLURM_ARRAY_TASK_ID'])) else: # if not then set box directory LAB_root = os.path.join(HOME, "Box", "CoganLab") layout = get_data("SentenceRep", root=LAB_root) subjects = layout.get(return_type="id", target="subject") data = dict() for sub in subjects: if sub != "D0005": continue # Load the data filt = raw_from_layout(layout.derivatives['notch'], subject=sub, extension='.edf', desc='notch', preload=False) # Crop raw data to minimize processing time good = crop_empty_data(filt, ).copy() # good.info['bads'] = channel_outlier_marker(good, 3, 2) good.drop_channels(good.info['bads']) good.load_data() ch_type = filt.get_channel_types(only_data_chs=True)[0] good.set_eeg_reference(ref_channels="average", ch_type=ch_type) # Remove intermediates from mem # good.plot() # epoching and trial outlier removal save_dir = os.path.join(layout.root, 'derivatives', 'spec', 'wavelet_test', sub) if not os.path.exists(save_dir): os.makedirs(save_dir) for epoch, t, name in zip( ("Start", "Word/Response/LS"), ((-0.5, 0), (-1, 1)), ( "base", "resp")): trials = trial_ieeg(good, epoch, t, preload=False) n_samps = trials.times.shape[-1] n_fft = 2 ** int(np.ceil(np.log2(n_samps))) times = [None, None] offset = n_fft - n_samps / 2 / good.info['sfreq'] times[0] = t[0] - offset times[1] = t[1] + offset trials = trial_ieeg(good, epoch, times, preload=True) outliers_to_nan(trials, outliers=10) spec = wavelet_scaleogram(trials, n_jobs=-2, decim=int( good.info['sfreq'] / 100)) # spec = cwt(trials, 30, 500, n_fft, 1, -1, 4) crop_pad(spec, "0.5s") # spec = spec.decimate(2, 1) del trials # if spec.sfreq % 100 == 0: # factor = spec.sfreq // 100 # offset = len(spec.times) % factor # spec = spec.decimate(factor, offset) resample_tfr(spec, 100, spec.times.shape[0] / (spec.tmax - spec.tmin)) if name == "base": base = spec.copy().crop(-0.5, 0) else: data[name] = spec # Plot the spectrogram result = rescale(data['resp'], base, mode='ratio', copy=True) avg = result.average(lambda x: np.nanmean(x, axis=0)) chan_grid(avg, size=(20, 10), vlim=(-.5, 3.), cmap=parula_map)