from collections import Counter
from functools import cache, singledispatch
from typing import Union
import numpy as np
from mne import Epochs, event, events_from_annotations
from mne.epochs import BaseEpochs
from mne.io import Raw, base
from mne.time_frequency import AverageTFRArray, tfr_multitaper
from mne.utils import _pl, fill_doc, logger, verbose
from scipy import fft, signal, stats
from ieeg import ListNum
from ieeg.calc.scaling import rescale
from ieeg.calc.stats import sine_f_test
from ieeg.process import COLA, is_number
from ieeg.timefreq.utils import crop_pad, to_samples
[docs]
class WindowingRemover(object):
"""Removes windowing artifacts from data.
Parameters
----------
sfreq : float
The sampling frequency of the data.
line_freqs : list of float
The frequencies of the line noise.
notch_width : list of float
The notch widths for each line frequency.
filter_length : int
The length of the filter to use.
low_bias : bool
Whether to use a low bias filter.
adaptive : bool
Whether to use an adaptive filter.
bandwidth : float
The bandwidth of the multitaper windowing function.
p_value : float
The p-value to use in the F-test.
verbose : bool
Whether to print information.
"""
@verbose
def __init__(self, sfreq: float, line_freqs: ListNum,
notch_width: ListNum, filter_length: int, low_bias: bool,
adaptive: bool, bandwidth: float, p_value: float,
verbose: bool = None):
self.sfreq = sfreq
self.line_freqs = line_freqs
self.notch_width = notch_width
self.p_value = p_value
self.filter_length = filter_length
self.verbose = verbose
self.low_bias = low_bias
self.adaptive = adaptive
self.bandwidth = bandwidth
self.logger = logger
self.rm_freqs = list()
[docs]
def dpss_windows(self, N: int, half_nbw: float, Kmax: int, *,
sym: bool = True, norm: Union[int, str] = None
) -> tuple[np.ndarray, np.ndarray]:
"""Compute Discrete Prolate Spheroidal Sequences.
Will give of orders [0,Kmax-1] for a given frequency-spacing multiple
NW and sequence length N.
.. note:: Copied from NiTime.
Parameters
----------
N : int
Sequence length.
half_nbw : float
Standardized half bandwidth corresponding to 2 * half_bw = BW*f0
= BW*N/dt but with dt taken as 1.
Kmax : int
Number of DPSS windows to return is Kmax (orders 0 through Kmax-1).
sym : bool
Whether to generate a symmetric window (`True`, for filter design)
or a periodic window (`False`, for spectral analysis). Default is
`True`.
norm : 2 | 'approximate' | 'subsample' | None
Window normalization method. If ``'approximate'`` or
``'subsample'``, windows are normalized by the maximum, and a
correction scale-factor for even-length windows is applied either
using ``N**2/(N**2+half_nbw)`` ("approximate") or a FFT-based
subsample shift ("subsample"). ``2`` uses the L2 norm. ``None``
(the default) uses ``"approximate"`` when ``Kmax=None`` and ``2``
otherwise.
Returns
-------
v, e : tuple,
The v array contains DPSS windows shaped (Kmax, N).
e are the eigenvalues.
Notes
-----
Tridiagonal form of DPSS calculation (Slepian, 1978)
References
----------
David S. Slepian. Prolate spheroidal wave functions, fourier analysis,
and uncertainty-V: the discrete case. Bell System Technical Journal,
57(5):1371–1430, 1978. doi:10.1002/j.1538-7305.1978.tb02104.x.
"""
dpss, eigvals = signal.windows.dpss(
N, half_nbw, Kmax, sym=sym, norm=norm, return_ratios=True)
if self.low_bias:
idx = (eigvals > 0.9)
if not idx.any():
self.logger.warn('Could not properly use low_bias, keeping'
'lowest-bias taper')
idx = [np.argmax(eigvals)]
dpss, eigvals = dpss[idx], eigvals[idx]
assert len(dpss) > 0 # should never happen
assert dpss.shape[1] == N # old nitime bug
return dpss, eigvals
[docs]
@fill_doc
def params(self, n_times: int) -> tuple[np.ndarray, np.ndarray, bool]:
"""Triage windowing and multitaper parameters.
Parameters
----------
n_times : int
The number of time points.
Returns
-------
window_fun : array, shape=(n_tapers, n_times)
The window functions for each taper.
eigenvals : array, shape=(n_tapers,)
The eigenvalues for each taper.
adaptive : bool
Whether to use adaptive weights to combine the tapered spectra into
PSD
"""
# Compute standardized half-bandwidth
if isinstance(self.bandwidth, str):
self.logger.info(
' Using standard spectrum estimation with "%s" window'
% (self.bandwidth,))
window_fun = signal.get_window(self.bandwidth, n_times)[np.newaxis]
return window_fun, np.ones(1), False
if self.bandwidth is not None:
half_nbw = float(self.bandwidth) * n_times / self.sfreq
else:
half_nbw = 4.
if half_nbw < 0.5:
raise ValueError(
'bandwidth value %s yields a normalized bandwidth of %s < 0.5,'
' use a value of at least %s'
% (self.bandwidth, half_nbw, self.sfreq / n_times))
# Compute DPSS windows
n_tapers_max = int(np.floor(2 * half_nbw - 1))
window_fun, eigvals = self.dpss_windows(
n_times, half_nbw, n_tapers_max, sym=True)
self.logger.info(' Using multitaper spectrum estimation with %d DPSS '
'windows' % len(eigvals))
if self.adaptive and len(eigvals) < 3:
self.logger.warn('Not adaptively combining the spectral estimators'
' due to a low number of tapers (%s < 3).' % (
len(eigvals),))
self.adaptive = False
return window_fun, eigvals, self.adaptive
[docs]
@cache
def get_thresh(self, n_times: int = None) -> tuple[np.ndarray, float]:
"""Get the window function and threshold for given time points.
Parameters
----------
n_times : int | None
The number of time points. If None, the filter length will be used.
Returns
-------
window_fun : array, shape=(n_tapers, n_times)
The window functions for each taper
threshold : float
The threshold for the F-statistic.
"""
if n_times is None:
n_times = self.filter_length
# figure out what tapers to use
window_fun, _, _ = self.params(n_times)
# F-stat of 1-p point
threshold = stats.f.ppf(1 - self.p_value / n_times, 2,
2 * len(window_fun) - 2)
return window_fun, threshold
[docs]
def __call__(self, x: np.ndarray) -> np.ndarray:
"""Remove line frequencies from data using multitaper method."""
# Set default window function and threshold
window_fun, thresh = self.get_thresh()
n_times = x.shape[-1]
n_samples = window_fun.shape[1]
n_overlap = (n_samples + 1) // 2
x_out = np.zeros_like(x)
idx = [0]
# Define how to process a chunk of data
def process(x_):
window_fun, thresh = self.get_thresh()
out = _mt_remove(x_, self.sfreq, self.line_freqs, self.notch_width,
window_fun, thresh, self.get_thresh)
self.rm_freqs.append(out[1])
return (out[0],) # must return a tuple
# Define how to store a chunk of fully processed data (it's trivial)
def store(x_):
stop = idx[0] + x_.shape[-1]
x_out[..., idx[0]:stop] += x_
idx[0] = stop
COLA(process, store, n_times, n_samples, n_overlap, self.sfreq,
verbose=False).feed(x)
assert idx[0] == n_times
# report found frequencies, but do some sanitizing first by binning
# into 1 Hz bins
counts = Counter(sum((np.unique(np.round(ff)).tolist()
for f in self.rm_freqs for ff in f), list()))
kind = 'Detected' if self.line_freqs is None else 'Removed'
found_freqs = '\n'.join(f' {freq:6.2f} : '
f'{counts[freq]:4d} window{_pl(counts[freq])}'
for freq in sorted(counts)) or ' None'
self.logger.info(f'{kind} notch frequencies (Hz):\n{found_freqs}')
x = x_out
return x_out
def _mt_remove(x: np.ndarray, sfreq: float, line_freqs: ListNum,
notch_widths: ListNum, window_fun: np.ndarray,
threshold: float, get_thresh: callable,
) -> tuple[np.ndarray, list[float]]:
"""Use MT-spectrum to remove line frequencies.
Based on Chronux. If line_freqs is specified, all freqs within notch_width
of each line_freq is set to zero.
"""
assert x.ndim == 1
if x.shape[-1] != window_fun.shape[-1]:
window_fun, threshold = get_thresh(x.shape[-1])
# compute mt_spectrum (returning n_ch, n_tapers, n_freq)
x_p, freqs = spectra(x[np.newaxis, :], window_fun, sfreq)
f_stat, A = sine_f_test(window_fun, x_p)
# find frequencies to remove
indices = np.where(f_stat > threshold)[1]
# pdf = 1-stats.f.cdf(f_stat, 2, window_fun.shape[0]-2)
# indices = np.where(pdf < 1/x.shape[-1])[1]
# specify frequencies within indicated ranges
if line_freqs is not None and notch_widths is not None:
if not isinstance(notch_widths, (list, tuple)) and is_number(
notch_widths):
notch_widths = [notch_widths] * len(line_freqs)
ranges = [(freq - notch_width / 2, freq + notch_width / 2
) for freq, notch_width in zip(line_freqs, notch_widths)]
indices = [ind for ind in indices if any(
lower <= freqs[ind] <= upper for (lower, upper) in ranges)]
fits = list()
# make "time" vector
rads = 2 * np.pi * (np.arange(x.size) / float(sfreq))
for ind in indices:
c = 2 * A[0, ind]
fit = np.abs(c) * np.cos(freqs[ind] * rads + np.angle(c))
fits.append(fit)
if len(fits) == 0:
datafit = 0.0
else:
# fitted sinusoids are summed, and subtracted from data
datafit = np.sum(fits, axis=0)
return x - datafit, freqs[indices]
[docs]
def spectra(x: np.ndarray, dpss: np.ndarray, sfreq: float,
n_fft: int = None) -> tuple[np.ndarray, np.ndarray]:
"""Compute significant tapered spectra.
Parameters
----------
x : array, shape=(..., n_times)
Input signal
dpss : array, shape=(n_tapers, n_times)
The tapers
sfreq : float
The sampling frequency
n_fft : int | None
Length of the FFT. If None, the number of samples in the input signal
will be used.
Returns
-------
x_mt : array, shape=(..., n_tapers, n_times)
The tapered spectra
freqs : array
The frequency points in Hz of the spectra
"""
if n_fft is None:
n_fft = x.shape[-1] # round(sfreq * round(x.shape[-1]*pad_fact/sfreq))
# remove mean (do not use in-place subtraction as it may modify input x)
x = x - np.mean(x, axis=-1, keepdims=True)
# only keep positive frequencies
freqs = fft.rfftfreq(n_fft, 1. / sfreq)
# The following is equivalent to this, but uses less memory:
x_mt = fft.rfft(x[:, np.newaxis, :] * dpss, n=n_fft, workers=1)
# n_tapers = dpss.shape[0] if dpss.ndim > 1 else 1
# x_mt = np.zeros(x.shape[:-1] + (n_tapers, len(freqs)),
# dtype=np.complex128)
# for idx, sig in enumerate(x):
# x_mt[idx] = fft.rfft(sig[..., np.newaxis, :] * dpss, n=n_fft)
# Adjust DC and maybe Nyquist, depending on one-sided transform
x_mt[..., 0] /= np.sqrt(2.)
if n_fft % 2 == 0:
x_mt[..., -1] /= np.sqrt(2.)
return x_mt, freqs
[docs]
@fill_doc
@singledispatch
@verbose
def spectrogram(line: BaseEpochs, freqs: np.ndarray,
baseline: BaseEpochs = None, n_cycles: np.ndarray = None,
pad: str = "0s", correction: str = 'ratio',
verbose: int = None, **kwargs) -> AverageTFRArray:
"""Calculate the multitapered, baseline corrected spectrogram
Parameters
----------
line : Epochs
The data to be processed
%(freqs_tfr)s
baseline : Epochs
The baseline to be used for correction
%(n_cycles_tfr)s
pad : str
The amount of padding to be removed in the spectrogram
correction : str
The type of baseline correction to be used
%(verbose)s
Notes
-----
%(time_bandwidth_tfr_notes)s
Returns
-------
power : AverageTFR
The multitapered, baseline corrected spectrogram
"""
if n_cycles is None:
n_cycles = freqs / 2
power, itc = tfr_multitaper(line, freqs, n_cycles, verbose=verbose,
**kwargs)
# crop the padding off the spectral estimates
crop_pad(power, pad)
if baseline is None:
return power
# apply baseline correction
basepower, bitc = tfr_multitaper(baseline, freqs, n_cycles,
verbose=verbose, **kwargs)
crop_pad(basepower, pad)
# set output data
corrected_data = rescale(power._data, basepower._data, correction, axis=-1)
return AverageTFRArray(power.info, corrected_data, power.times, freqs,
nave=power.nave, comment=power.comment,
method=power.method)
@spectrogram.register
def _(line: base.BaseRaw, freqs: np.ndarray, line_event: str, tmin: float,
tmax: float, base_event: str = None, base_tmin: float = None,
base_tmax: float = None, n_cycles: np.ndarray = None, pad: str = "500ms",
correction: str = 'ratio', **kwargs) -> AverageTFRArray:
"""Calculate the multitapered, baseline corrected spectrogram
Parameters
----------
line : Raw
The data to be processed
freqs : array-like
The frequencies to be used in the spectrogram
line_event : str
The event to be used for the spectrogram
tmin : float
The start time of the spectrogram
tmax : float
The end time of the spectrogram
base_event : str
The event to be used for the baseline
base_tmin : float
The start time of the baseline
base_tmax : float
The end time of the baseline
n_cycles : array-like
The number of cycles to be used in the spectrogram
pad : str
The amount of padding to be used in the spectrogram
correction : str
The type of baseline correction to be used
Returns
-------
power : AverageTFR
The multitapered, baseline corrected spectrogram
"""
# determine the events
events, ids = events_from_annotations(line)
dat_ids = [ids[i] for i in event.match_event_names(ids, line_event)]
# pad the data
pad_secs = to_samples(pad, line.info['sfreq']) / line.info['sfreq']
# Epoch the data
data = Epochs(line, events, dat_ids, tmin - pad_secs,
tmax + pad_secs, baseline=None, preload=True)
# run baseline corrected version
if base_event is None:
return spectrogram(data, freqs, None, n_cycles, pad, correction,
**kwargs)
base_ids = [ids[i] for i in event.match_event_names(ids, base_event)]
baseline = Epochs(line, events, base_ids, base_tmin - pad_secs,
base_tmax + pad_secs, baseline=None, preload=True)
return spectrogram(data, freqs, baseline, n_cycles, pad, correction,
**kwargs)