Source code for ieeg.timefreq.superlets
# Time-frequency analysis with superlets
# Based on 'Time-frequency super-resolution with superlets'
# by Moca et al., 2021 Nature Communications
#
# Implementation by Harald Bârzan and Richard Eugen Ardelean
#
# Note: for runs on multiple batches of data, the class SuperletTransform can
# be instantiated just once
# this saves time and memory allocation for the wavelets and buffers
#
import numpy as np
from scipy.signal import fftconvolve
from ieeg import Signal
import mne
from joblib import Parallel, delayed
# spread, in units of standard deviation, of the Gaussian window of the
# Morlet wavelet
MORLET_SD_SPREAD = 6
# the length, in units of standard deviation, of the actual support window of
# the Morlet
MORLET_SD_FACTOR = 2.5
[docs]
def computeWaveletSize(fc, nc, fs):
"""
Compute the size in samples of a morlet wavelet.
Parameters
----------
fc : float
Center frequency in Hz.
nc : float
Number of cycles.
fs : float
Sampling rate in Hz.
Returns
-------
int
Size of the wavelet in samples.
"""
sd = (nc / 2) * (1 / np.abs(fc)) / MORLET_SD_FACTOR
return int(2 * np.floor(np.round(sd * fs * MORLET_SD_SPREAD) / 2) + 1)
[docs]
def computeLongestWaveletSize(fs, foi, c1, ord):
"""
Estimates the size of the longest wavelet.
Parameters
----------
fs : float
Sampling rate in Hz.
foi : array_like
Frequencies of interest in Hz.
c1 : float
Base number of cycles parameter.
ord : tuple or list
The order or order range for superlets.
Returns
-------
int
Size of the longest wavelet in samples.
"""
# make order parameter
if len(ord) == 1:
ord = (ord, ord)
# orders = np.linspace(start=ord[0], stop=ord[1], num=len(foi))
orders = np.interp(foi, [min(foi), max(foi)], ord)
# create wavelets
max = 0
for iFreq in range(len(foi)):
centerFreq = foi[iFreq]
nWavelets = int(np.ceil(orders[iFreq]))
for iWave in range(nWavelets):
# create morlet wavelet
wlen = computeWaveletSize(centerFreq, fs, (iWave + 1) * c1)
if wlen > max:
max = wlen
return max
[docs]
def gausswin(size, alpha):
"""
Create a Gaussian window.
Parameters
----------
size : int
Size of the window in samples.
alpha : float
Parameter controlling the width of the window.
Returns
-------
ndarray
Gaussian window of specified size.
"""
halfSize = int(np.floor(size / 2))
idiv = alpha / halfSize
t = (np.arange(size, dtype=np.float64) - halfSize) * idiv
window = np.exp(-(t * t) * 0.5)
return window
[docs]
def morlet(fc, nc, fs):
"""
Create an analytic Morlet wavelet.
Parameters
----------
fc : float
Center frequency in Hz.
nc : float
Number of cycles.
fs : float
Sampling rate in Hz.
Returns
-------
ndarray
Complex Morlet wavelet.
"""
size = computeWaveletSize(fc, nc, fs)
half = int(np.floor(size / 2))
gauss = gausswin(size, MORLET_SD_SPREAD / 2)
igsum = 1 / gauss.sum()
ifs = 1 / fs
t = (np.arange(size, dtype=np.float64) - half) * ifs
wavelet = gauss * np.exp(2 * np.pi * fc * t * 1j) * igsum
return wavelet
[docs]
def fractional(x):
"""
Get the fractional part of the scalar value x.
Parameters
----------
x : float
Input scalar value.
Returns
-------
float
Fractional part of x.
"""
return x - int(x)
[docs]
class SuperletTransform:
"""
Class used to compute the Superlet Transform of input data.
This class implements the superlet transform algorithm for time-frequency
analysis as described in Moca et al., 2021.
"""
def __init__(self,
inputSize,
samplingRate,
baseCycles,
superletOrders,
frequencyRange=None,
frequencyBins=None,
frequencies=None):
"""
Initialize the superlet transform.
Parameters
----------
inputSize : int
Size of the input in samples.
samplingRate : float
The sampling rate of the input signal in Hz.
baseCycles : float
Number of cycles of the smallest wavelet (c1 in the paper).
superletOrders : tuple
A tuple containing the range of superlet orders, linearly
distributed along frequencyRange.
frequencyRange : tuple
Tuple of ascending frequency points, in Hz.
frequencyBins : int
Number of frequency bins to sample in the interval frequencyRange.
frequencies : array_like, optional
Specific list of frequencies - can be provided instead of
frequencyRange (it is ignored in this case).
"""
# clear to reinit
self.clear()
# initialize containers
if frequencies is not None:
frequencyBins = len(frequencies)
frequencyRange = [np.min(frequencies), np.max(frequencies)]
self.frequencies = frequencies
else:
self.frequencies = np.linspace(start=frequencyRange[0],
stop=frequencyRange[1],
num=frequencyBins)
self.inputSize = inputSize
self.orders = np.interp(self.frequencies, frequencyRange,
superletOrders)
# self.orders = np.linspace(start=superletOrders[0],
# stop=superletOrders[1], num=frequencyBins)
self.convBuffer = np.zeros(inputSize, dtype=np.complex128)
self.poolBuffer = np.zeros(inputSize, dtype=np.float64)
self.superlets = []
# create wavelets
for iFreq in range(frequencyBins):
centerFreq = self.frequencies[iFreq]
nWavelets = int(np.ceil(self.orders[iFreq]))
self.superlets.append([])
for iWave in range(nWavelets):
# create morlet wavelet
self.superlets[iFreq].append(
morlet(centerFreq, (iWave + 1) * baseCycles, samplingRate))
[docs]
def __del__(self):
"""
Destructor.
Cleans up resources when the object is deleted.
"""
self.clear()
[docs]
def clear(self):
"""
Clear the transform.
Resets all internal variables to None, freeing memory.
"""
# fields
self.inputSize = None
self.superlets = None
self.poolBuffer = None
self.convBuffer = None
self.frequencies = None
self.orders = None
[docs]
def longestWaveletSize(self):
"""
Return the size of the longest wavelet.
Returns
-------
int
Size of the longest wavelet in samples.
"""
max = 0
for s in self.superlets:
for w in s:
if w.shape[0] > max:
max = w.shape[0]
return max
[docs]
def validTimeRegion(self):
"""
Compute the start and end of the valid spectrum region.
Returns
-------
tuple
A tuple containing:
- start : int
The start of the valid time region.
- end : int
The end of the valid time region.
"""
pad = self.longestWaveletSize() // 2
start = self.inputSize + pad
end = self.inputSize - pad
return start, end
[docs]
def transform(self, inputData):
"""
Apply the transform to a buffer or list of buffers.
Parameters
----------
inputData : ndarray
An NDarray of input data. Can be a single buffer or a list of
buffers.
Returns
-------
ndarray
The transformed data as a time-frequency representation.
Raises
------
Exception
If input data size doesn't match the defined input size for this
transform.
"""
# compute number of arrays to transform
if len(inputData.shape) == 1:
if inputData.shape[0] != self.inputSize:
raise ValueError("Input data must meet the defined input size"
" for this transform.")
result = np.zeros((self.inputSize, len(self.frequencies)),
dtype=np.float64)
self.transformOne(inputData, result)
return result
else:
n = int(np.sum(inputData.shape[0:len(inputData.shape) - 1]))
insize = int(inputData.shape[len(inputData.shape) - 1])
if insize != self.inputSize:
raise ValueError("Input data must meet the defined input size"
" for this transform.")
# reshape to data list
datalist = np.reshape(inputData, (n, insize), 'C')
result = np.zeros((len(self.frequencies), self.inputSize),
dtype=np.float64)
for i in range(0, n):
self.transformOne(datalist[i, :], result)
return result / n
[docs]
def transformOne(self, inputData, accumulator):
"""
Apply the superlet transform on a single data buffer.
Parameters
----------
inputData : ndarray
A 1xInputSize array containing the signal to be transformed.
accumulator : ndarray
A spectrum to accumulate the resulting superlet transform.
Notes
-----
This method modifies the accumulator array in-place.
"""
accumulator.resize((len(self.frequencies), self.inputSize))
for iFreq in range(len(self.frequencies)):
# init pooling buffer
self.poolBuffer.fill(1)
if len(self.superlets[iFreq]) > 1:
# superlet
nWavelets = int(np.floor(self.orders[iFreq]))
rfactor = 1.0 / nWavelets
for iWave in range(nWavelets):
self.convBuffer = fftconvolve(inputData,
self.superlets[iFreq][iWave],
"same")
self.poolBuffer *= 2 * np.abs(self.convBuffer) ** 2
if fractional(self.orders[iFreq]) != 0 and len(
self.superlets[iFreq]) == nWavelets + 1:
# apply the fractional wavelet
exponent = self.orders[iFreq] - nWavelets
rfactor = 1 / (nWavelets + exponent)
self.convBuffer = fftconvolve(inputData,
self.superlets[iFreq][
nWavelets], "same")
self.poolBuffer *= (2 * np.abs(
self.convBuffer) ** 2) ** exponent
# perform geometric mean
accumulator[iFreq, :] += self.poolBuffer ** rfactor
else:
# wavelet transform
accumulator[iFreq, :] += (2 * np.abs(
fftconvolve(inputData, self.superlets[iFreq][0],
"same")) ** 2).astype(np.float64)
[docs]
def cropSpectrum(spectrum, paddingSize):
"""
Remove paddingSize samples at both ends of the spectrum.
Parameters
----------
spectrum : ndarray
A 2D numpy array representing the time-frequency spectrum.
paddingSize : int
Number of samples to remove - equals to longestWaveletSize() / 2
of the computing SuperletTransform object.
Returns
-------
ndarray
The spectrum with the padding removed.
"""
return spectrum[:, paddingSize:(spectrum.shape[1] - paddingSize)]
# main superlet function
[docs]
def superlets(data,
fs,
foi,
c1,
ord):
"""
Perform fractional adaptive superlet transform (FASLT) on a list of trials.
Parameters
----------
data : ndarray
A numpy array of data. The rightmost dimension of the data is the trial
size. The result will be the average over all the spectra.
fs : float
The sampling rate in Hz.
foi : array_like
List of frequencies of interest.
c1 : float
Base number of cycles parameter.
ord : tuple or list
The order (for SLT) or order range (for FASLT), spanned across the
frequencies of interest.
Returns
-------
ndarray
A matrix containing the average superlet spectrum.
Notes
-----
This is the main function for computing the superlet transform.
"""
# determine buffer size
bufferSize = data.shape[-1]
# make order parameter
if len(ord) == 1:
ord = (ord, ord)
# build the superlet analyzer
faslt = SuperletTransform(inputSize=bufferSize,
frequencyRange=None,
frequencyBins=None,
samplingRate=fs,
frequencies=foi,
baseCycles=c1,
superletOrders=ord)
# apply transform
result = faslt.transform(data)
faslt.clear()
return result
[docs]
def superlet_tfr(inst: Signal,
foi: list[float],
c1: float,
ord: tuple[int, int] = (1, 1),
decim: int = 1,
n_jobs: int = 1) -> Signal:
"""
Compute the superlet time-frequency representation of the input signal.
Parameters
----------
inst : Signal
The input signal (e.g., Raw, Epochs, Evoked).
foi : list[float]
List of frequencies of interest.
c1 : float
Base number of cycles parameter.
ord : tuple[int, int], optional
The order (for SLT) or order range (for FASLT), spanned across the
frequencies of interest. Default is (1, 1).
decim : int, optional
Decimation factor for the output. Default is 1.
Returns
-------
Signal
The time-frequency representation of the input signal.
"""
# check if the input is a Raw or Epochs object
times = inst.times[::decim]
sfreq = inst.info['sfreq']
if isinstance(inst, (mne.io.BaseRaw | mne.Epochs)):
data = inst.get_data()
# check if the input is an Evoked object
elif isinstance(inst, mne.Evoked):
data = inst.data[np.newaxis, :]
else:
raise ValueError("Input must be a Raw, Epochs or Evoked object.")
# compute superlet transform
# determine buffer size
bufferSize = data.shape[-1]
# make order parameter
if len(ord) == 1:
ord = (ord, ord)
# build the superlet analyzer
faslt = SuperletTransform(inputSize=bufferSize,
frequencies=foi,
samplingRate=sfreq,
baseCycles=c1,
superletOrders=ord)
freqs = faslt.frequencies
out = np.zeros(data.shape[:-1] + (len(freqs), len(times)),
dtype=data.dtype)
def _apply_transform(idx):
tout = faslt.transform(data[idx])[..., ::decim]
return tout, idx
# apply transform in parallel
par = Parallel(n_jobs=n_jobs, return_as='generator_unordered', verbose=10)(
delayed(_apply_transform)(i) for i in np.ndindex(data.shape[:-1]))
for o, i in par:
# apply transform
out[i] = o
faslt.clear()
# create TFR object and return it
tfr = mne.time_frequency.EpochsTFRArray(inst.info, out, times, freqs,
events=inst.events,
event_id=inst.event_id)
return tfr