from functools import singledispatch
import numpy as np
from mne import Epochs, Evoked
from mne.time_frequency import EpochsTFRArray
from mne.io import Raw, base
from tqdm import tqdm
from joblib import Parallel, delayed
from ieeg.process import COLA, cpu_count, get_mem, parallelize
from ieeg.timefreq.utils import BaseEpochs, Signal
from ieeg.timefreq.hilbert import (filterbank_hilbert_first_half_wrapper,
extract_channel_wrapper, get_centers)
[docs]
@singledispatch
def hilbert_spectrogram(data: np.ndarray, fs: int, Wn=(1, 150),
decim: int = 1, spacing: float = 1/7, n_jobs=-1):
"""
Compute the phase and amplitude (envelope) of a signal for a single
frequency band, as in [#edwards]_. This is done using a filter bank of
gaussian shaped filters with center frequencies linearly spaced until 4Hz
and then logarithmically spaced. The Hilbert Transform of each filter's
output is computed and the amplitude and phase are computed from the
complex values. See [#edwards]_ for details on the filter bank used.
See Also
--------
filter_hilbert
Parameters
----------
"""
centers = get_centers(Wn, spacing)
bands = [(c - 0.01, c + 0.01) for c in centers]
# pre-allocate
out_shape = data.shape[:-1] + (data.shape[-1]//decim + 1, len(bands))
hilb_amp = np.zeros(out_shape, dtype='float32')
# run in parallel
proc = Parallel(n_jobs, verbose=10, return_as="generator")(
delayed(extract)(data, fs, band, True, spacing, 1, False)
for band in bands)
for i, out in enumerate(proc):
hilb_amp[..., i] = out[..., ::decim]
freqs = [np.mean(band) for band in bands]
return hilb_amp.transpose(0, 1, 3, 2), freqs
@hilbert_spectrogram.register
def _(inst: Epochs, Wn=(1, 150), decim: int = 1, spacing: float = 1/7,
n_jobs=-1):
"""Extract gamma band envelope from Raw object."""
array, freqs = hilbert_spectrogram(inst.get_data(copy=False),
inst.info['sfreq'], Wn, decim,
spacing, n_jobs)
return EpochsTFRArray(inst.info, array, inst.times[::decim], freqs,
events=inst.events, event_id=inst.event_id)
def _extract_inst(inst: Signal, fs: int, copy: bool, **kwargs) -> Signal:
if fs is None:
fs = inst.info['sfreq']
if copy:
sig = inst.copy()
else:
sig = inst
sig._data = extract(sig._data, fs, copy=False, **kwargs)
return sig
@extract.register
def _(inst: base.BaseRaw, fs: int = None,
passband: tuple[int, int] = (70, 150),
copy: bool = True, n_jobs=-1, verbose: bool = True) -> Raw:
"""Extract gamma band envelope from Raw object."""
return _extract_inst(inst, fs, copy, passband=passband, n_jobs=n_jobs,
verbose=verbose)
@extract.register
def _(inst: BaseEpochs, fs: int = None,
passband: tuple[int, int] = (70, 150),
copy: bool = True, n_jobs=-1, verbose: bool = True) -> Epochs:
"""Extract gamma band envelope from Epochs object."""
return _extract_inst(inst, fs, copy, passband=passband, n_jobs=n_jobs,
verbose=verbose)
@extract.register
def _(inst: Evoked, fs: int = None,
passband: tuple[int, int] = (70, 150),
copy: bool = True, n_jobs=-1, verbose: bool = True) -> Evoked:
"""Extract gamma band envelope from Evoked object."""
return _extract_inst(inst, fs, copy, passband=passband, n_jobs=n_jobs,
verbose=verbose)
def _my_hilt(x: np.ndarray, fs, Wn=(1, 150), n_jobs=-1):
# Set default window function and threshold
cfs = get_centers(Wn)
n_times = x.shape[0]
bytes_per = 16 # numpy complex128 is 16 bytes per num
chunk_size = min([x.size * bytes_per * len(cfs), get_mem()])
n_samples = int(chunk_size / (cpu_count() * x.shape[1] * len(cfs)))
n_overlap = (n_samples + 1) // 2
x_out = np.zeros_like(x.T)
idx = [0]
# Define how to process a chunk of data
def process(x_):
out = filterbank_hilbert(x_, fs, Wn, 1)
return (out[1]) # must return a tuple
# Define how to store a chunk of fully processed data (it's trivial)
def store(x_):
stop = idx[0] + x_.T.shape[-1]
x_out[..., idx[0]:stop] += x_.T
idx[0] = stop
COLA(process, store, n_times, n_samples, n_overlap, fs,
n_jobs=n_jobs).feed(x)
assert idx[0] == n_times
return x_out, cfs
[docs]
def filterbank_hilbert(x, fs, Wn=[70, 150], spacing=1./7, n_jobs=1):
"""
Compute the phase and amplitude (envelope) of a signal for a single
frequency band, as in [#edwards]_. This is done using a filter bank of
gaussian shaped filters with center frequencies linearly spaced until 4Hz
and then logarithmically spaced. The Hilbert Transform of each filter's
output is computed and the amplitude and phase are computed from the
complex values. See [#edwards]_ for details on the filter bank used.
See Also
--------
filter_hilbert
Parameters
----------
x : np.ndarray, shape (time, channels)
Signal to filter. Filtering is performed on each channel independently.
fs : int
Sampling rate.
Wn : list or array-like, length 2, default=[70, 150]
Lower and upper boundaries for filterbank center frequencies. A range
of [1, 150] results in 42 filters.
n_jobs : int, default=1
Number of jobs to use to compute filterbank across channels in
parallel.
Returns
-------
x_phase : np.ndarray, shape (time, channels, frequency_bins)
Phase of each frequency bin in the filter bank for each channel.
x_envelope : np.ndarray, shape (time, channels, frequency_bins)
Envelope of each frequency bin in the filter bank for each channel.
center_freqs : np.ndarray, shape (frequency_bins,)
Center frequencies for each frequency bin used in the filter bank.
Examples
--------
>>> import numpy as np
>>> x = np.random.rand(1000,3) # 3 channels of signals
>>> fs = 500
>>> x_envelope = filterbank_hilbert(x, fs, Wn=[1, 150])
... # the outputs have the phase and envelope for each channel and each
... # filter in the filterbank
>>> x_envelope.shape # 3rd dimension is one for each filter in filterbank
(1000, 3, 42)
>>> filterbank_hilbert(x, fs, [1, 150], 1/10).shape
(1000, 3, 58)
"""
x = x.astype('float32')
minf, maxf = Wn
if minf >= maxf:
raise ValueError(
(f'Upper bound of frequency range must be greater than lower bound'
f', but got lower bound of {minf} and upper bound of {maxf}'))
Xf, freqs, cfs, N, sds, h = filterbank_hilbert_first_half_wrapper(
x, fs, minf, maxf, spacing)
def extract_channel(Xf):
return extract_channel_wrapper(Xf, freqs, cfs, N, sds, h, minf, maxf)
# pre-allocate
hilb_amp = np.zeros((*x.shape, len(cfs)), dtype='float32')
# process channels sequentially
if n_jobs == 1:
for chn in range(x.shape[1]):
hilb_amp[:, chn] = extract_channel(Xf[:, chn])
# process channels in parallel
else:
results = Parallel(n_jobs)(delayed(extract_channel)(
Xf[:, chn]) for chn in range(x.shape[1]))
for chn, amp in enumerate(results):
hilb_amp[:, chn] = amp
return hilb_amp