Source code for ieeg.mt_filter

import argparse
from typing import Union

import numpy as np
from mne.io import pick
from mne.utils import fill_doc, logger, verbose

from ieeg import ListNum
from ieeg.process import proc_array
from ieeg.timefreq import utils as mt_utils
from ieeg.timefreq.multitaper import WindowingRemover


[docs] @fill_doc @verbose def line_filter(raw: mt_utils.Signal, fs: float = None, freqs: ListNum = 60., filter_length: str = '10s', notch_widths: ListNum = 10., mt_bandwidth: float = None, p_value: float = 0.05, picks: list[Union[int, str]] = None, n_jobs: int = None, adaptive: bool = True, low_bias: bool = True, copy: bool = True, *, verbose: Union[int, bool, str] = None ) -> mt_utils.Signal: """Apply a multitaper line noise notch filter for the signal instance. Applies a multitaper power line noise notch filter to the signal, operating on the last dimension. Uses the F-test to find significant sinusoidal components to remove. This is done by fitting a sinusoid to the power spectrum at each time point and frequency, and testing whether the resulting fit is significantly different from a flat spectrum. The significance test is done using an F-test, which requires fitting two models (one flat, one sinusoidal) at each time point and frequency. The F-test is corrected for multiple comparisons using the Benjamini-Hochberg procedure. Parameters ---------- raw : mt_utils.Signal Signal to filter. fs : float, optional Sampling rate in Hz. Default is taken from the raw object. freqs : float | array-like of float, optional Frequencies to notch filter in Hz, e.g. np.arange(60, 241, 60). None can only be used with the mode 'spectrum_fit', where an F test is used to find sinusoidal components. filter_length : str | int, optional Length of the filter to use. If str, assumed to be human-readable time in units of "s" or "ms" (e.g., "10s" or "5500ms"). If an int, it is assumed to be in samples and used directly. notch_widths : float | array of float, optional Width of the stop band (centred at each freq in freqs) in Hz. Default is 10. mt_bandwidth : float, optional The bandwidth of the multitaper windowing function in Hz. Default will set the half frequency bathwidth to 4 Hz. p_value : float, optional P-value to use in F-test thresholding to determine significant sinusoidal components to remove. Note that this will be Bonferroni corrected for the number of frequencies, so large p-values may be justified. %(picks_all)s %(n_jobs)s adaptive : bool, optional Use adaptive weights to combine the tapered spectra into PSD. Default is True. low_bias : bool, optional Only use tapers with more than 90 percent spectral concentration within bandwidth. Default is True. copy : bool, optional If True, a copy of x, filtered, is returned. Otherwise, it operates on x in place. %(verbose)s Returns ------- filt : mt_utils.Signal The signal instance with the filtered data. See Also -------- <https://mne.tools/stable/generated/mne.filter.notch_filter.html> Notes ----- The frequency response is (approximately) given by :: 1-|---------- ----------- | \\ / | \\ / | \\ / | \\ / 0-| - | | | | | 0 Fp1 freq Fp2 Nyq For each freq in freqs, where ``Fp1 = freq - trans_bandwidth / 2`` and ``Fs2 = freq + trans_bandwidth / 2``. References ---------- Multi-taper removal is inspired by code from the Chronux toolbox, see www.chronux.org and the book "Observed Brain Dynamics" by Partha Mitra & Hemant Bokil, Oxford University Press, New York, 2008. Please cite this in publications if this function is used. Examples -------- >>> import mne >>> from bids import BIDSLayout >>> from ieeg.io import raw_from_layout >>> bids_root = mne.datasets.epilepsy_ecog.data_path(verbose=False) >>> layout = BIDSLayout(bids_root) >>> raw = raw_from_layout(layout, subject="pt1", preload=True, ... extension=".vhdr", verbose=False) Reading 0 ... 269079 = 0.000 ... 269.079 secs... >>> mne.set_log_level("WARNING") >>> filt = line_filter(raw, freqs=[60, 120, 180]) >>> mne.set_log_level("INFO") """ if fs is None: fs = raw.info["sfreq"] if copy: filt = raw.copy() else: filt = raw x = filt.get_data("data") x = mt_utils._check_filterable(x, 'notch filtered', 'notch_filter') if freqs is not None: freqs = np.atleast_1d(freqs) # Only have to deal with notch_widths for non-autodetect if notch_widths is None: notch_widths = freqs / 200.0 elif np.any(notch_widths < 0): raise ValueError('notch_widths must be >= 0') else: notch_widths = np.atleast_1d(notch_widths) if len(notch_widths) == 1: notch_widths = notch_widths[0] * np.ones_like(freqs) elif len(notch_widths) != len(freqs): raise ValueError('notch_widths must be None, scalar, or the ' 'same length as freqs') data_idx = np.where([ch_t in set(raw.get_channel_types( only_data_chs=True)) for ch_t in raw.get_channel_types()])[0] # convert filter length to samples if filter_length is None: filter_length = x.shape[-1] filter_length: int = min(mt_utils.to_samples(filter_length, fs), x.shape[-1]) process = WindowingRemover(fs, freqs, notch_widths, filter_length, adaptive, low_bias, mt_bandwidth, p_value) filt._data[data_idx] = mt_spectrum_proc(x, process, picks, n_jobs) return filt
[docs] def mt_spectrum_proc(x: np.ndarray, process: callable, picks: list, n_jobs: int) -> np.ndarray: """Call _mt_spectrum_remove.""" # set up array for filtering, reshape to 2D, operate on last axis x, orig_shape, picks = _prep_for_filtering(x, picks) proc_array(process, x, n_jobs=n_jobs, desc="Channels") x.shape = orig_shape return x
def _prep_for_filtering(x: np.ndarray, picks: list = None ) -> tuple[np.ndarray, tuple, int]: """Set up array as 2D for filtering ease.""" x = mt_utils._check_filterable(x) orig_shape = x.shape x = np.atleast_2d(x) picks = pick._picks_to_idx(x.shape[-2], picks) x.shape = (np.prod(x.shape[:-1]), x.shape[-1]) if len(orig_shape) == 3: n_epochs, n_channels, n_times = orig_shape offset = np.repeat(np.arange(0, n_channels * n_epochs, n_channels), len(picks)) picks = np.tile(picks, n_epochs) + offset elif len(orig_shape) > 3: raise ValueError('picks argument is not supported for data with more' ' than three dimensions') assert all(0 <= pick < x.shape[0] for pick in picks) # guaranteed by above return x, orig_shape, picks def _get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, description=""" """, epilog=""" Made by Aaron Earle-Richardson (ae166@duke.edu) """) parser.add_argument("-s", "--subject", required=False, default=None, help="data subject to clean") return parser def _main(subject: str = None, save: bool = False): import mne from bids import BIDSLayout from ieeg.io import raw_from_layout, save_derivative import os HOME = os.path.expanduser("~") LAB_root = os.path.join(HOME, "Box", "CoganLab") # %% Set up logging mne.set_log_file("output.log", "%(levelname)s: %(message)s - %(asctime)s", overwrite=True) mne.set_log_level("INFO") bids_root = LAB_root + "/BIDS-1.3_SentenceRep/BIDS" layout = BIDSLayout(bids_root) if subject is not None: do_subj = [subject] else: do_subj = layout.get(return_type="id", target="subject") do_subj.sort() if 'SLURM_ARRAY_TASK_ID' in os.environ.keys(): taskIDs = [int(os.environ['SLURM_ARRAY_TASK_ID'])] else: taskIDs = list(range(len(do_subj))) for id in taskIDs: subj = do_subj[id] try: raw = raw_from_layout(layout, subject=subj, extension=".edf", preload=False) # %% filter data filt = line_filter(raw, mt_bandwidth=10., n_jobs=-1, filter_length='700ms', verbose=10, freqs=[60], notch_widths=20) filt2 = line_filter(filt, mt_bandwidth=10., n_jobs=-1, filter_length='20s', verbose=10, freqs=[60, 120, 180, 240], notch_widths=20) # %% Save the data if save: save_derivative(filt2, layout, "clean") except Exception as e: logger.error(e) if __name__ == "__main__": args = _get_parser().parse_args() _main(**vars(args))