- (1) provides detailed explanations for some of the latest work in the field,
- (2) incorporates some inspirations, although they may not be rigorously proven here.
📌 Research Themes: Recover speech/semantic information from invasive/non-invasive brain recordings.
Speech/Semantic reconstruction of continuous language from brain recordings is an emerging research field that aims to decode the speech, words, sentences, or even entire narratives from neural activity patterns recordings. It can potentially to transform our understanding of how the brain processes speech.
Invasive devices have made significant breakthroughs in decoding brain signals, with deep learning algorithms trained on intracranial recordings now able to decode basic language features such as phonemes, words, and spectrograms. However, compared to invasive recording devices, non-invasive brain recordings have clear advantages for users. Currently, extending these invasive recording modeling methods to natural language and non-invasive brain recordings remains a major challenge.
Tang et al. proposed a novel approach to reconstruct continuous language based on three participants fMRI recording while they listened to spoken stories. They utilized GPT and an encoding model that could score the likelihood of candidate sequences and used beam search to decode the fMRI signals. They successfully reconstructed the auditory and semantic content of the stories. Despite the low temporal resolution of fMRI, this strategy clearly demonstrates the ability to capture language-related neural mechanisms.
Unlike traditional regression coding models, Défossez et al. have adopted an end-to-end framework to model the mapping between EEG data and speech representations. They introduce a model trained with contrastive learning to decode self-supervised representations of perceived speech from non-invasive M/EEG recordings. They evaluate their methods on four public datasets, including 175 volunteers. The results show that their model can identify the corresponding speech segment from 3 seconds of MEG signals, with up to 41% accuracy out of more than 1,000 distinct possibilities on average across participants, and more than 80% in the very best participants. This is an extremely exciting result. Imagine, if there is a language model (LM) that can contextualize these representations (for example, GPT4, LLaMA2), it is very likely to be able to recover the original auditory stimuli.
A recent work published on a preprint combined these two works. First, a multi-subject decoding model was trained using contrastive learning to reconstruct continuous word embeddings from MEG data. Subsequently, a beam search algorithm was adopted to generate text sequences based on the reconstructed word embeddings. Given a candidate sentence in the beam, a language model was used to predict the subsequent words.
Decoding brain signals is a significant challenge. I think, from the perspective of computer science or artificial intelligence, these efforts provide excellent entry points as a starting point to delve into this field.
In this page, I will introduce some representative works in detail to restore the decoding process of an EEG recording. Because large artifacts can be a problem in M/EEG data, many details are of great significance, and I try to show them all with many many code. Then some exploration, such as further decoding sentences from the speech representations predicted from EEG data are presented. In summary, this page:
Formal research results are still in progress
Part One: Decoding speech perception from non-invasive brain recordings¶
In the former section, we mentioned an end-to-end decoding framework, which we will start with. Many details of data processing can provide us with a good understanding of the characteristics of M/EEG data. It primarily focuses on: Current methods are limited to
- training a model on a single patient
- aiming to decode a limited set of interpretable features (MEL spectrogram, letters, phonemes, small set of words).
The authors propose two approaches to address these issues:
- a single architecture trained across a large cohort of participants and
- deep representations of speech learnt with self-supervised learning on a large quantity of speech data.
Now, we will explore this work through Inference.
1.Dataset¶
We start directly with the most appealing result. In Table 2 of the paper, we observe that the model (+wav2vec2.0) achieves a TOP-10 accuracy of 70.7% on Gwilliams (MEG). This result is obtained using the MEG_MASC dataset. Please note the disk space requirement, as it is 100GB in size.
"MEG-MASC" dataset provides a curated set of raw magnetoencephalography (MEG) recordings of 27 English speakers who listened to two hours of naturalistic stories. Each participant performed two identical sessions, involving listening to four fictional stories from the Manually Annotated Sub-Corpus (MASC) intermixed with random word lists and comprehension questions.
The authors time-stamp the onset and offset of each word and phoneme in the metadata of the recording, and organize the dataset according to the 'Brain Imaging Data Structure' (BIDS).
import logging
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from collections import defaultdict
from pathlib import Path
import flashy.logging
import flashy.utils
import mne
import numpy as np
import pandas as pd
import torch
import bm
from bm import play
from bm.losses import ClipLoss
from bm.train import main
from omegaconf import OmegaConf
from torch.utils.data import ConcatDataset, DataLoader, TensorDataset
from scripts.run_eval_probs import _get_extra_info, _load_test_data, EvalJob, Evaluator, run_eval
logger = logging.getLogger(__name__)
selection = {'study': 'gwilliams2022'}
recording_lists = list(bm.studies.from_selection(selection))
sample_recording = recording_lists[0]
print("# of recording:",len(recording_lists))
recording_lists[:10]
# of recording: 196
[Gwilliams2022Recording('01_session0_story0'), Gwilliams2022Recording('01_session0_story1'), Gwilliams2022Recording('01_session0_story2'), Gwilliams2022Recording('01_session0_story3'), Gwilliams2022Recording('01_session1_story0'), Gwilliams2022Recording('01_session1_story1'), Gwilliams2022Recording('01_session1_story2'), Gwilliams2022Recording('01_session1_story3'), Gwilliams2022Recording('02_session0_story0'), Gwilliams2022Recording('02_session0_story1')]
bm.studies.from_selection
has customized loading methods for four different datasets. We will focus on the first sample for observation. The first 10 objects are printed, including speaker, session, and story information for listening.
sample_recording.mne_info
bids_path: /ssd3/other/penglinkai01/brainmagick/data/gwilliams2022/download/sub-01/ses-0/meg/sub-01_ses-0_task-0_meg.con Extracting SQD Parameters from /ssd3/other/penglinkai01/brainmagick/data/gwilliams2022/download/sub-01/ses-0/meg/sub-01_ses-0_task-0_meg.con... Creating Raw.info structure... Setting channel info structure... Creating Info structure... Ready. Reading events from /ssd3/other/penglinkai01/brainmagick/data/gwilliams2022/download/sub-01/ses-0/meg/sub-01_ses-0_task-0_events.tsv. Reading channel info from /ssd3/other/penglinkai01/brainmagick/data/gwilliams2022/download/sub-01/ses-0/meg/sub-01_ses-0_task-0_channels.tsv. The stimulus channel "STI 014" is present in the raw data, but not included in channels.tsv. Removing the channel. NOTE: pick_types() is a legacy function. New code should use inst.pick(...). NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
/ssd3/other/penglinkai01/brainmagick/bm/studies/gwilliams2022.py:106: RuntimeWarning: The unit for channel(s) MISC 001, MISC 002, MISC 003, MISC 004, MISC 005, MISC 006, MISC 007, MISC 008, MISC 009, MISC 010, MISC 011, MISC 012, MISC 013, MISC 014, MISC 015, MISC 016, MISC 017, MISC 018, MISC 019, MISC 020, MISC 021, MISC 022, MISC 023, MISC 024, MISC 025, MISC 026, MISC 027, MISC 028, MISC 029, MISC 030, MISC 031, MISC 032 has changed from V to NA. raw = read_raw_bids(bids_path) # FIXME this is NOT a lazy read
Measurement date | January 01, 2000 00:00:00 GMT |
---|---|
Experimenter | mne_anonymize | Participant | sub-01 |
Digitized points | Not available |
Good channels | 208 Magnetometers |
Bad channels | None |
EOG channels | Not available |
ECG channels | Not available |
Sampling frequency | 1000.00 Hz |
Highpass | 0.03 Hz |
Lowpass | 200.00 Hz |
sample_recording.raw().get_data().shape
(208, 396000)
The information of sample 01_session0_story0
is printed (see mne_info
in mne-python). It has 208 channels and a sampling rate of 1000Hz. The recorded signal shape indicates a duration of 396 seconds (396000/1000).
sample_recording.events()[:5] # load when first access. see bm.studies.api.Recording method events()
story | story_uid | sound_id | kind | start | sound | duration | filepath | phoneme | sequence_id | ... | speech_rate | voice | pronounced | word | language | modality | word_sequence | phoneme_id | offset | uid | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | NaN | NaN | NaN | block | 23.506 | NaN | 6.250000 | NaN | NaN | NaN | ... | NaN | NaN | NaN | NaN | english | audio | NaN | NaN | NaN | Tara stood stock still waiting for the first t... |
1 | lw1 | 0.0 | 0.0 | word | 23.506 | stimuli/audio/lw1_0.wav | 0.300000 | NaN | NaN | 0.0 | ... | 205.0 | Allison | 1.0 | Tara | english | audio | Tara stood stock still waiting for the first t... | NaN | NaN | NaN |
2 | lw1 | 0.0 | 0.0 | sound | 23.506 | stimuli/audio/lw1_0.0.wav | 95.881678 | /ssd3/other/penglinkai01/brainmagick/data/gwil... | NaN | NaN | ... | NaN | NaN | NaN | NaN | english | audio | NaN | NaN | 0.0 | NaN |
3 | lw1 | 0.0 | 0.0 | phoneme | 23.506 | stimuli/audio/lw1_0.wav | 0.080000 | NaN | t_B | 0.0 | ... | 205.0 | Allison | 1.0 | NaN | english | audio | NaN | 0.0 | NaN | NaN |
4 | lw1 | 0.0 | 0.0 | phoneme | 23.586 | stimuli/audio/lw1_0.wav | 0.090000 | NaN | eh_I | 0.0 | ... | 205.0 | Allison | 1.0 | NaN | english | audio | NaN | 1.0 | NaN | NaN |
5 rows × 22 columns
uid = (sample_recording.__class__.__name__, sample_recording.subject_uid)
uid
('Gwilliams2022Recording', '01')
The recording's events()
include the start and end times of stimulus, phonemes, words, and other information. Among them, kind
indicates the type of stimulus, whether it is a word, phoneme, or sentence. uid
can be considered as the speaker id, which will be a crucial attribute later.
sample_recording.events()[sample_recording.events().kind == 'word']
story | story_uid | sound_id | kind | start | sound | duration | filepath | phoneme | sequence_id | ... | speech_rate | voice | pronounced | word | language | modality | word_sequence | phoneme_id | offset | uid | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | lw1 | 0.0 | 0.0 | word | 23.506 | stimuli/audio/lw1_0.wav | 0.30 | NaN | NaN | 0.0 | ... | 205.0 | Allison | 1.0 | Tara | english | audio | Tara stood stock still waiting for the first t... | NaN | NaN | NaN |
7 | lw1 | 0.0 | 0.0 | word | 23.816 | stimuli/audio/lw1_0.wav | 0.24 | NaN | NaN | 0.0 | ... | 205.0 | Allison | 1.0 | stood | english | audio | Tara stood stock still waiting for the first t... | NaN | NaN | NaN |
13 | lw1 | 0.0 | 0.0 | word | 24.056 | stimuli/audio/lw1_0.wav | 0.37 | NaN | NaN | 0.0 | ... | 205.0 | Allison | 1.0 | stock | english | audio | Tara stood stock still waiting for the first t... | NaN | NaN | NaN |
17 | lw1 | 0.0 | 0.0 | word | 24.586 | stimuli/audio/lw1_0.wav | 0.40 | NaN | NaN | 0.0 | ... | 205.0 | Allison | 1.0 | still | english | audio | Tara stood stock still waiting for the first t... | NaN | NaN | NaN |
23 | lw1 | 0.0 | 0.0 | word | 25.136 | stimuli/audio/lw1_0.wav | 0.41 | NaN | NaN | 0.0 | ... | 205.0 | Allison | 1.0 | waiting | english | audio | Tara stood stock still waiting for the first t... | NaN | NaN | NaN |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
3147 | lw1 | 0.0 | 3.0 | word | 361.097 | stimuli/audio/lw1_3.wav | 0.17 | NaN | NaN | 52.0 | ... | 205.0 | Allison | 1.0 | end | english | audio | Tara would not let giddy hopes drag them onto ... | NaN | NaN | NaN |
3150 | lw1 | 0.0 | 3.0 | word | 361.277 | stimuli/audio/lw1_3.wav | 0.14 | NaN | NaN | 52.0 | ... | 205.0 | Allison | 1.0 | for | english | audio | Tara would not let giddy hopes drag them onto ... | NaN | NaN | NaN |
3154 | lw1 | 0.0 | 3.0 | word | 361.487 | stimuli/audio/lw1_3.wav | 0.58 | NaN | NaN | 52.0 | ... | 205.0 | Allison | 1.0 | project | english | audio | Tara would not let giddy hopes drag them onto ... | NaN | NaN | NaN |
3162 | lw1 | 0.0 | 3.0 | word | 362.207 | stimuli/audio/lw1_3.wav | 0.15 | NaN | NaN | 52.0 | ... | 205.0 | Allison | 1.0 | and | english | audio | Tara would not let giddy hopes drag them onto ... | NaN | NaN | NaN |
3165 | lw1 | 0.0 | 3.0 | word | 362.817 | stimuli/audio/lw1_3.wav | 0.34 | NaN | NaN | 52.0 | ... | 205.0 | Allison | 1.0 | species | english | audio | Tara would not let giddy hopes drag them onto ... | NaN | NaN | NaN |
668 rows × 22 columns
The above code filters out all the words, and for the sample_recording
, it contains 3165 words. The last word starts at 362.817 seconds and lasts for 0.34 seconds.
During the preprocessing stage, the EEG recording is resampled to 120Hz using julius.ResampleFrac()
. The preprocessing steps of the data are cached to the disk for efficiency, where _cache_folder
specified as the storage location for the data. Finally, its shape is (208, 47520).
low_mne = bm.studies.api.preprocess_mne(sample_recording.raw(),sample_rate=120,highpass=0)
sample_recording._cache_folder = Path("cache/studies/gwilliams2022/01_session0_story0")
bm.dataset._preload(sample_recording, sample_rate=120, highpass=0)
low_mne
Creating RawArray with float64 data, n_channels=208, n_times=47520 Range : 0 ... 47519 = 0.000 ... 395.992 secs Ready. Opening raw data file cache/studies/gwilliams2022/01_session0_story0/meg-sr120-hp0-raw.fif... Isotrak not found Range : 0 ... 47519 = 0.000 ... 395.992 secs Ready.
Measurement date | January 01, 2000 00:00:00 GMT |
---|---|
Experimenter | mne_anonymize | Participant | sub-01 |
Digitized points | Not available |
Good channels | 208 Magnetometers |
Bad channels | None |
EOG channels | Not available |
ECG channels | Not available |
Sampling frequency | 120.00 Hz |
Highpass | 0.03 Hz |
Lowpass | 200.00 Hz |
Duration | 00:06:36 (HH:MM:SS) |
low_mne.get_data().shape
(208, 47520)
Now we use this recording to construct the input for the model. Obviously, for a recording of over 30 seconds, multiple samples can be extracted. The event type block
is used to segment the fragments. Then, the merge_blocks
function concatenates blocks with a duration shorter than min_block_duration
. Afterwards, the data is randomly divided into train, develop, and test subsets in a ratio of 7:2:1. The preprocessed
function implements the aforementioned resampling step. Here, the processing steps for the test set are additionally applied: a 3-second segment (from -500 milliseconds to 2.5 seconds) is selected for each word in the test set. Note that tmin
is -500ms and tmax
is 2.5s.
test_ratio, valid_ratio = 0.2, 0.1
min_block_duration = 6
min_n_blocks_per_split = 1
sample_rate = 120
highpass=0
tmin, tmax = -0.5, 2.5
meg_dimension = max(recording.meg_dimension for recording in [sample_recording])
factory_kwargs = {}
factory_kwargs.update(sample_rate=sample_rate, highpass=highpass, meg_dimension=meg_dimension,baseline=None)
fact = bm.dataset.SegmentDataset.Factory(**factory_kwargs)
# we only use one example for probe
for i, recording in enumerate([sample_recording]):
events = recording.events()
blocks = events[events.kind == 'block']
blocks = blocks.event.merge_blocks(min_block_duration_s=min_block_duration)
blocks = bm.dataset.assign_blocks(blocks, [test_ratio, valid_ratio], seed=12, min_n_blocks_per_split=min_n_blocks_per_split)
# start-stops
# [(0.0, 29.756), (29.756, 40.096), (40.096, 68.416), ...]
start_stops = [(b.start, b.start + b.duration) for b in blocks.itertuples()]
# following code are same as "fact.apply(recording, blocks=start_stops)"
data = recording.preprocessed(sample_rate, highpass=highpass)
sample_rate = bm.utils.Frequency(data.info["sfreq"])
times = np.arange(0, data.times[-1], 3.0)
events = recording.events().copy()
events = events.sort_values('start')
delta = 0.5 / sample_rate
mask = np.logical_and(times + tmin >= 0, times + tmax < data.times[-1] + delta)
print("Mask(raw):",mask.sum())
in_any_split = False
counter = 0
for start, stop in start_stops:
in_split = times + tmin >= start
margin = tmax - delta
in_split &= times + margin < stop
in_any_split |= in_split
mask &= in_any_split
samples = sample_rate.to_ind(times[mask])
unique_samples = np.unique(samples)
print("Mask(processed):",mask.sum())
print("# of unique_samples:",len(unique_samples))
Mask(raw): 131 Mask(processed): 104 # of unique_samples: 104
我们选取的样本拥有396秒,按照3s的间隔,得到131个可能的起始位置(见下方的times
)。不过这些点需要属于某个block区间来保证该部分内容拥有完整的单词,也就是说,起始时间被包括在某个block里面,同时结束为止也需要被包括在同一个个block里面。最后104个样本被得到。下面用一个例子展示了这个过程:
The selected sample has a duration of 396 seconds. With a 3-second interval, theoretically, we can obtain 131 possible starting positions (see "times" below). However, these points need to belong to a block interval to ensure that this portion contains complete words, i.e. the starting time should be included in a block, and the ending time should also be included in the same block. Finally, 104 samples are obtained. The following results demonstrates this process:
For the first sample, it belongs to the start_stops interval (0.0, 29.756):
times=
[ 0. 3. 6. 9. 12. 15. 18. 21. 24. 27. 30. 33. 36. 39.
42. 45. 48. 51. 54. 57. 60. 63. 66. 69. 72. 75. 78. 81.
84. 87. 90. 93. 96. 99. 102. 105. 108. 111. 114. 117. 120. 123.
126. 129. 132. 135. 138. 141. 144. 147. 150. 153. 156. 159. 162. 165.
168. 171. 174. 177. 180. 183. 186. 189. 192. 195. 198. 201. 204. 207.
210. 213. 216. 219. 222. 225. 228. 231. 234. 237. 240. 243. 246. 249.
252. 255. 258. 261. 264. 267. 270. 273. 276. 279. 282. 285. 288. 291.
294. 297. 300. 303. 306. 309. 312. 315. 318. 321. 324. 327. 330. 333.
336. 339. 342. 345. 348. 351. 354. 357. 360. 363. 366. 369. 372. 375.
378. 381. 384. 387. 390. 393.]
times + tmin >= start
[False True True True True True True True True True True True
True True True True True True True True True True True True
True True True True True True True True True True True True
True True True True True True True True True True True True
True True True True True True True True True True True True
True True True True True True True True True True True True
True True True True True True True True True True True True
True True True True True True True True True True True True
True True True True True True True True True True True True
True True True True True True True True True True True True
True True True True True True True True True True True True]
in_split &= times + margin < stop
[False True True True True True True True True True False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False]
mask
[False True True True True True True True True True False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False]
Finally, for the interval (0.0, 29.756), the time starting positions at 3, 6, 9, 12, 15, 18, 21, and 24 can be considered as valid test segments. From the uid
, we can get an idea of the approximate content of this sentence. The corresponding audio clips are also shown below.
blocks[:3]
start | duration | modality | language | uid | kind | split | |
---|---|---|---|---|---|---|---|
0 | 0.000 | 29.756 | audio | english | Tara stood stock still waiting for the first t... | block | 1 |
1 | 29.756 | 10.340 | audio | english | The gentle constant breeze of recycled air fro... | block | 2 |
2 | 40.096 | 28.320 | audio | english | Results Harmon she suppressed the surge of ann... | block | 0 |
from IPython.display import Audio, display
display(Audio(filename="lw1_0_clip.wav", rate=16000))
2.Method¶
The above image shows the model used, with the problem formalization and its contributions indicated on the right side.
To get a detailed understanding of these methods and corresponding intentions, I performed inference on a sample from the dataset, specifically $(X, Y)$ denotes a speech audio and its corresponding EEG data pair. Extracting useful information from M/EEG signals can be a challenging task due to the potential noise. The following code carefully traces the matrices and logical operations of the data, which may contribute to the success of the methods.
Wav2vec 2.0¶
Wav2vec 2.0 is a self-supervised speech model, learning deep representations from unlabeled speech waveforms. The working process of wav2vec 2.0 can be described in the following steps:
Feature Extraction: First, the model takes in an original speech waveform $S \in \mathbb{R}^{T'}$, where $T'$ is the number of time steps. The CNN feature extractor $f_{\text{feat}}$ transforms this waveform into a series of latent representations $C \in \mathbb{R}^{F\times T}$, where $F$ is hidden dimension.
Context Network: Then, the context network $f_{\text{context}}$ takes these latent representations and uses a self-attention mechanism to capture the relationships between them. The output of this network is a series of context representations $Y \in \mathbb{R}^{F\times T}$.
Contrastive Loss: During training, the model uses a contrastive loss function $L_{\text{contrast}}$ to encourages the model to generate representations that can distinguish different time steps from the same speech segment, while suppressing representations from different speech segments. This process can be represented as:
$$
\min_{f_{\text{feat}}, f_{\text{context}}} L_{\text{contrast}}(Y, C)
$$
After training, wav2vec 2.0 can generate robust representations $Y$, which can be used for various downstream tasks, such as speech recognition or speech translation. Here, authors use the wav2vec2-large-xlsr-53
, which has been pre-trained on 56k hours of speech from 53 different languages. During the training process, the parameters of the Wav2vec 2.0 model are not updated.
Brain module¶
For the brain module, the paper introduces a deep neural network ${f}_\mathrm{clip}$, input with raw M/EEG times series $X$ and a one-hot-encoding of the corresponding subject $s$, and outputs the latent brain representation $Z$, with the same sample rate as $X$.
Due to the complexity of data processing, we still have a long way to go before reaching familiar code such as torch.module
and its forward()
function for brain module. For this part, we are going to understand its constructure through experiments. You can obtain the pre-trained model checkpoint from this github issue and place it in the relative path: outputs/grids/base/97d170e1
and outputs/xps/97d170e1
. This is the default location.
sig = "97d170e1"
logger.info(f"Loading solver {sig}")
mne.set_log_level(False)
flashy.logging.setup_logging(with_file_log=False)
solver = play.get_solver_from_sig(sig, override_cfg={})
solver.model.eval()
solver.loss.eval()
xp._argv_cache: /ssd3/other/penglinkai01/brainmagick/outputs/xps/97d170e1/.argv.json
[11-14 22:38:21][bm.play][INFO] - Loading solver from XP 97d170e1. Overrides used: ['dset.selections=[gwilliams2022]'] [11-14 22:38:21][bm._env][WARNING] - Hostname gpu226.corp.yodao.com not defined in /conf/study_paths/study_paths.yaml. Using default paths. [11-14 22:38:21][bm.dataset][WARNING] - Requested 1000 recordings but only found 196 [11-14 22:38:33][bm.dataset][INFO] - Loading Subjects | 39/196 | 5.48 it/sec [11-14 22:38:39][bm.dataset][INFO] - Loading Subjects | 78/196 | 5.71 it/sec [11-14 22:38:46][bm.dataset][INFO] - Loading Subjects | 117/196 | 5.68 it/sec [11-14 22:38:53][bm.dataset][INFO] - Loading Subjects | 156/196 | 5.75 it/sec [11-14 22:38:59][bm.dataset][INFO] - Loading Subjects | 195/196 | 5.74 it/sec [11-14 22:38:59][bm.dataset][INFO] - # Examples (train | valid | test): 203152 | 35156 | 69972 [11-14 22:39:01][bm.train][INFO] - Model hash: ed97b0fbdffe06faf696d2eecaacb57b143fd68d /ssd3/other/penglinkai01/miniconda3/envs/bm/lib/python3.8/site-packages/flashy/loggers/tensorboard.py:47: UserWarning: tensorboard package was not found: use pip install tensorboard warnings.warn("tensorboard package was not found: use pip install tensorboard")
ClipLoss()
In this framework, a signature
can uniquely identify a training process. The solver
loaded through it aggregates components such as the dataset, model, optimizer, and loss. Almost all the experimental information is preserved (as seen in the above cell), including the dataset used (dset.selections) and the number of samples for train, develop, and test. Next, we will directly load the dataloader and infer model using a batch.
datasets = solver.datasets.test.datasets # [bm.dataset.SegmentDataset,bm.dataset.SegmentDataset,bm.dataset.SegmentDataset,...]
print("len(datasets):",len(datasets))
dataset = ConcatDataset(datasets)
loader = DataLoader(dataset, num_workers=0, batch_size=8, collate_fn=bm.dataset.SegmentBatch.collate_fn)
sample_rate = 120
len(datasets): 196
By using the ConcatDataset, we concatenate the test datasets stored in the solver, and then wrap them using the pytorch dataloader.
batch = next(iter(loader))
# see bm.dataset.Segmentbatch
print("batch.meg.shape:",batch.meg.shape)
print("batch.subject_index:",batch.subject_index)
print("len of batch._event_lists:",len(batch._event_lists))
print("batch.features.shape:",batch.features.shape)
batch.meg.shape: torch.Size([8, 208, 361]) batch.subject_index: tensor([0, 0, 0, 0, 0, 0, 0, 0]) len of batch._event_lists: 8 batch.features.shape: torch.Size([8, 1025, 361])
meg
and subject_index
are inputs to the Brain module. The range of subject_index
is [0, 196]
, representing the current position in the dataset. Its length is consistent with the batch_size
. It is somewhat similar to the speaker representation in TTS, representing certain overall characteristics of the entire recording. In this case, the size of meg
is [B, C, T]
. T=361
is actually because each segment is 3 seconds long, with a sampling rate of 120Hz. _event_lists
are also included, which helps in recovering the original information.
Then, you will notice that the shape of the features
is very familiar. Yes, it is the feature extracted by Wav2Vec2.0
(even though it seems to have one extra feature, 1025 vs 1024). I was curious about where it came from. Actually, the authors have cached this part of the computation. If you have run the training program locally, you will find npy files in paths like cache/Wav2VecEmbedding
. These files store pre-computed (during the first run) features. Then, the SegmentDataset
will load it into memory again in the __getitem__()
function. To find how wav2vec2.0 was invoked, I deleted/moved the cache directory and traced the code path.
SegmentDataset.__getitem__() >
self._get_feature(index) >
self.features(start, stop) >
feature_builder.__call__() >
val = feature.get_on_overlap(event, overlap) >
Wav2VecTransformer.get_on_overlap >
self._get_cached_tensor() >
self._compute_hidden_states >
Here, we can see familiar code: reading the speech and forward model.
def _compute_hidden_states(
self, name: str, filepath: Path, start: float, stop: float,
layers: tp.Optional[tp.List[int]] = None) -> torch.Tensor:
input_values = self._preprocess_wav(filepath=filepath, start=start, stop=stop)
self.model.to(self.device)
self.model.eval() # needs to be in eval mode
with torch.no_grad():
outputs = self.model(input_values.to(self.device), output_hidden_states=True)
out: tp.Any = outputs.get(name)
if isinstance(out, tuple):
out = torch.stack(out)
if layers is not None:
out = out[layers].mean(0)
return out.detach().cpu().clone().numpy()
The layers
parameter specifies which layers of the Transformer will be averaged. In practice, Wav2VecTransformer
is considered as a feature extractor in the datasets
. This is a convenient design because the parameters of this part are not updated during training.
feature_builder = solver.datasets.test.datasets[0].features
features = feature_builder.extract_features(batch.features, solver.used_features.keys())
batch = batch.replace(features=features.to(solver.device))
print("feature_builder:\n\t",feature_builder)
print("solver.used_features.keys():\n\t",solver.used_features.keys())
print("batch.features.shape:",batch.features.shape)
print("Transformer layers used for average:",feature_builder["Wav2VecTransformer"].layers) # Paper said they used the last four but here is diff.
feature_builder: FeaturesBuilder([('Wav2VecTransformer', Wav2VecTransformer(120.0)), ('WordHash', WordHash(120.0))]) solver.used_features.keys(): odict_keys(['Wav2VecTransformer']) batch.features.shape: torch.Size([8, 1024, 361]) Transformer layers used for average: [14, 15, 16, 17, 18]
It can be seen that feature_builder
stores both Wav2VecTransformer
and WordHash
, but the code only uses Wav2VecTransformer
. The shape of features
is as expected: 1024 and batch features are also corrected by replace
. However, it seems that the used output layer is inconsistent with the one described in the paper (Paper said they used the last four layers). The code below prints the model of Wav2Vec2.0.
feature_builder["Wav2VecTransformer"].model
Wav2Vec2Model( (feature_extractor): Wav2Vec2FeatureEncoder( (conv_layers): ModuleList( (0): Wav2Vec2LayerNormConvLayer( (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,)) (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (activation): GELUActivation() ) (1-4): 4 x Wav2Vec2LayerNormConvLayer( (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,)) (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (activation): GELUActivation() ) (5-6): 2 x Wav2Vec2LayerNormConvLayer( (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,)) (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (activation): GELUActivation() ) ) ) (feature_projection): Wav2Vec2FeatureProjection( (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True) (projection): Linear(in_features=512, out_features=1024, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): Wav2Vec2EncoderStableLayerNorm( (pos_conv_embed): Wav2Vec2PositionalConvEmbedding( (conv): Conv1d(1024, 1024, kernel_size=(128,), stride=(1,), padding=(64,), groups=16) (padding): Wav2Vec2SamePadLayer() (activation): GELUActivation() ) (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) (layers): ModuleList( (0-23): 24 x Wav2Vec2EncoderLayerStableLayerNorm( (attention): Wav2Vec2Attention( (k_proj): Linear(in_features=1024, out_features=1024, bias=True) (v_proj): Linear(in_features=1024, out_features=1024, bias=True) (q_proj): Linear(in_features=1024, out_features=1024, bias=True) (out_proj): Linear(in_features=1024, out_features=1024, bias=True) ) (dropout): Dropout(p=0.1, inplace=False) (layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) (feed_forward): Wav2Vec2FeedForward( (intermediate_dropout): Dropout(p=0.0, inplace=False) (intermediate_dense): Linear(in_features=1024, out_features=4096, bias=True) (intermediate_act_fn): GELUActivation() (output_dense): Linear(in_features=4096, out_features=1024, bias=True) (output_dropout): Dropout(p=0.1, inplace=False) ) (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) ) ) ) )
M/EEG data can suffer from large artifacts. Baseline correction and sklearn.preprocessing.RobustScaler are used.
Baseline correction: Baseline correction is performed when initializing
mne.Epochs
inbm.dataset._DatasetFactory.apply()
. Baseline correction is applied to each epoch and channel individually. The mean signal of the baseline period is calculated and subtracted from the entire epoch.RobustScaler: This scaler removes the median and scales the data based on the interquartile range (usually IQR). It is commonly used for standardizing datasets in machine learning, especially when outliers can affect mean and variance calculations. Instead of using mean and variance, the median and interquartile range are utilized for better results. This scaling is performed at the sample level, so the number of scalers is equal to the number of recordings (196).
For the Wav2Vec2.0 features, there is no need to worry too much about the data range. Using StandardScaler
is sufficient, as its name suggests.
(See Section 2.4 Preprocessing and Section A.2 Impact of clamping)
# Rescales the input MEG and features. If the MEG after rescaling
# still contains large values (e.g. more than `limit`) rejects the offending item.
batch, reject_mask = solver.scale_reject(batch) # see bm.norm.ScaleReject
print("# of meg_scalers:", len(solver.scale_reject.scaler.meg_scalers))
print(" Type of meg_scalers:", solver.scale_reject.scaler.meg_scalers[0])
print(" Clamp of meg_scalers:",solver.scale_reject.limit)
print("# of feature_scalers:", len(solver.scale_reject.scaler.feature_scalers))
print(" Type of feature_scalers:", solver.scale_reject.scaler.feature_scalers['Wav2VecTransformer'])
# of meg_scalers: 196 Type of meg_scalers: <bm.norm.RobustScaler object at 0x7f25a7e9de20> Clamp of meg_scalers: 20.0 # of feature_scalers: 1 Type of feature_scalers: <bm.norm.StandardScaler object at 0x7f25a50e58b0>
The information about RobustScaler
and StandardScaler
is shown above. It is noted that the value of Clamp of meg_scalers
is 20, which follows the torch.Tensor::clamp_()
method. This means that values greater than 20 standard deviations are clamped. This operation can minimize the impact of large outlier samples. The authors have also carefully studied this value (See Section A.2 Impact of clamping). From the experimental results, it appears that Clamp is a necessary for handling M/EEG data.
Clamping | Brennan (EEG) | Broderick (EEG) | Gwilliams (MEG) | Schoffelen (MEG) |
---|---|---|---|---|
20 | 25.7 ± 2.9 | 17.7 ± 0.6 | 70.7 ± 0.1 | 67.5 ± 0.4 |
100 | 27.1 ± 2.6 | 7.6 ± 0.0 | 70.6 ± 0.3 | 67.2 ± 0.9 |
None | 14.1 ± 1.0 | 0.5 ± 0.0 | 23.6 ± 24.6 | 1.5 ± 0.3 |
meg = batch.meg
features = batch.features
features_mask = torch.ones_like(batch.features_mask)
offset_meg_samples = int(150 / 1000 * sample_rate) # 150ms Residual See Sec 2.2.2 Brain module - Residual dilated convolutions.
meg = meg[..., offset_meg_samples:]
offset_features_samples = offset_meg_samples
features = features[..., :-offset_features_samples]
features_mask = features_mask[..., :-offset_features_samples]
inputs = dict(meg=meg.to(features))
output = features
# estimate = solver.model(inputs, batch)
length = next(iter(inputs.values())).shape[-1] # length of any of the inputs
Given the expected delay between a stimulus and its corresponding brain responses, the code further shift the input brain signal by 150 ms into the future to facilitate the alignment between $Y$ and $Z$.
I'm glad we've reached this point, and now let's take a look at several parts of the Brain module (Part E in the picture). Please don't forget about meg
and subject_index
, as they are the inputs to the module.
Spatial attention¶
# following code is the same as `inputs["meg"] = solver.model.merger(inputs["meg"], batch)`
meg = inputs["meg"]
B, C, T = meg.shape
positions = solver.model.merger.position_getter.get_positions(batch)
embedding = solver.model.merger.embedding(positions)
heads = solver.model.merger.heads[None].expand(B, -1, -1)
scores = torch.einsum("bcd,bod->boc", embedding.to(heads), heads)
weights = torch.softmax(scores, dim=2)
out = torch.einsum("bct,boc->bot", meg, weights)
inputs["meg"] = out
print("meg.shape",meg.shape)
print("positions",positions.shape)
print("solver.model.merger.embedding:", solver.model.merger.embedding)
print("out.shape", out.shape)
meg.shape torch.Size([8, 208, 343]) positions torch.Size([8, 208, 2]) solver.model.merger.embedding: FourierEmb() out.shape torch.Size([8, 270, 343])
First, let's start with the position_getter
. The PositionGetter
class includes a projection function: layout = mne.find_layout(info)
. This function uses a device-dependent surface designed to preserve the channel distances. Then, the 3D sensor locations are first projected onto a 2D plane. Their 2D positions are finally normalized to $[0, 1]$.
After the projection, the brain data will be remapped onto D1 = 270
channels. When viewed on the anatomical plane, these positions seem to be "uniformly" distributed. Obviously, not every position is an important part of speech perception. It is promising to use the important concept of attention in deep learning to automatically capture the correspondence between the sampled channels and the spatially mapped channels.
With the help of spatial location information, attention is used to weight the input channels to obtain each output channel. Specifically, for each channel $i$, it has a position $(x_i, y_i)$. Let's assume that the current processed raw M/EEG time series is $X$, which has $C$ channels. For each output channel $j\in\{1, \ldots, D_1\}$, the spatial attention (SA) is defined as:
$$ \begin{equation} \mathrm{SA}(X)^{(j)} = \frac{1}{\sum_{i=1}^{C} \mathrm{e}^{a_j(x_i, y_i)}}\left( \sum_{i=1}^{C} \mathrm{e}^{a_j(x_i, y_i)} X^{(i)} \right) \end{equation} $$
The features of the output channels are calculated by a softmax attention calculated over of $a_j(x, y)$ at each input position$(x_i,y_i)$:
$$ \begin{equation} a_j(x, y) = \sum_{k=1}^K\sum_{l=1}^K \mathrm{Re}(z_j^{(k, l)}) \cos\left(2 \pi (k x + l y)\right) +\mathrm{Im}(z_j^{(k, l)}) \sin\left(2 \pi (k x + l y)\right). \end{equation} $$
where $z_j \in \mathbb{C}^{K \times K}$ defines the Fourier space which has $K{=}32$ harmonics. Re()
and Im()
respectively represent the real part and imaginary part of a complex number. The following section demonstrates how to implement SA using code:
FourierEmb()
is the implementation of SA. It takes positions
as input and outputs embedding
.
# following code is the same as `embedding = solver.model.merger.embedding(positions)`
import math
n_freqs = (2048 // 2)**0.5 # 32
freqs_y = torch.arange(n_freqs).to(positions) # torch.Size([32])
freqs_x = freqs_y[:, None] # torch.Size([32, 1])
width = 1 + 2 * 0.2
positions = positions + 0.2
p_x = 2 * math.pi * freqs_x / width # torch.Size([32, 1])
p_y = 2 * math.pi * freqs_y / width # torch.Size([32])
positions = positions[..., None, None, :]
loc = (positions[..., 0] * p_x + positions[..., 1] * p_y).view(*O, -1) # torch.Size([8, 208, 32, 32]) -> torch.Size([8, 208, 1024])
embedding = torch.cat([torch.cos(loc),torch.sin(loc)], dim=-1) # torch.Size([8, 208, 2048])
heads
has the shape of [B,O,C]
, where $O=270$. The embedding
is multiplied with the spatial mapping heads to obtain the scores
. These scores
are then passed through torch.softmax
to obtain the weights
. The output, out
, is the weighted sum of the meg
feature map.
This image displays the distribution of attention weights in a two-dimensional space. The red color indicates that, on average, the M/EEG sensors are associated with a higher spatial attention weight. We can see in Gwilliams' image (best performance) that the electrode signals on the left and right sides clearly have more informative content of interest.
inputs["meg"] = solver.model.initial_linear(inputs["meg"])
solver.model.initial_linear
Sequential( (0): Conv1d(270, 270, kernel_size=(1,), stride=(1,)) )
Subject Layer¶
This layer has a parameter representation of size [270, 270, 27]
. The first value, 270, represents the number of input channels, the second value, 270, represents the number of output channels, and the third value (27 for the Gwilliams2022 dataset, as stated in Table 1) represents the number of subject_ids
.
This structure serves as the foundation for achieving the following purpose:
"A single architecture trained across a large cohort of participants."
During the forward computation, the subjects
indicate the speaker index for the current MEG data. The [270, 270]
matrix corresponding to this index transforms the input features. This operation can be seen as incorporating participant-specific information.
subjects = batch.subject_index.to(device=features.device,dtype=torch.int64)
inputs["meg"] = solver.model.subject_layers(inputs["meg"], subjects) # see bm.common.SubjectLayers
solver.model.subject_layers
SubjectLayers(270, 270, 27)
Residual dilated convolutions¶
Residual dilated convolutions are the main component of the Brain Module. In the implementation, residual skip connections, batch normalization, and GELU activation function are combined to extract M/EEG features. One important detail to note is that we need to ensure that the output of this part matches the dimensionality of speech representations for computing the contrastive loss.
Below are the details of the convolutional layer:
encoded = {}
for name, x in inputs.items():
encoded[name] = solver.model.encoders[name](x)
inputs = [x[1] for x in sorted(encoded.items())]
x = torch.cat(inputs, dim=1)
x = solver.model.final(x)
assert x.shape[-1] >= length
estimate = x[:, :, :length]
solver.model.encoders
ModuleDict( (meg): ConvSequence( (sequence): ModuleList( (0): Sequential( (0): Conv1d(270, 320, kernel_size=(3,), stride=(1,), padding=(1,)) (1): BatchNorm1d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): GELU(approximate='none') ) (1): Sequential( (0): Conv1d(320, 320, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,)) (1): BatchNorm1d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): GELU(approximate='none') ) (2): Sequential( (0): Conv1d(320, 320, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(4,)) (1): BatchNorm1d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): GELU(approximate='none') ) (3): Sequential( (0): Conv1d(320, 320, kernel_size=(3,), stride=(1,), padding=(8,), dilation=(8,)) (1): BatchNorm1d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): GELU(approximate='none') ) (4): Sequential( (0): Conv1d(320, 320, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(16,)) (1): BatchNorm1d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): GELU(approximate='none') ) (5): Sequential( (0): Conv1d(320, 320, kernel_size=(3,), stride=(1,), padding=(1,)) (1): BatchNorm1d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): GELU(approximate='none') ) (6): Sequential( (0): Conv1d(320, 320, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,)) (1): BatchNorm1d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): GELU(approximate='none') ) (7): Sequential( (0): Conv1d(320, 320, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(4,)) (1): BatchNorm1d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): GELU(approximate='none') ) (8): Sequential( (0): Conv1d(320, 320, kernel_size=(3,), stride=(1,), padding=(8,), dilation=(8,)) (1): BatchNorm1d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): GELU(approximate='none') ) (9): Sequential( (0): Conv1d(320, 320, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(16,)) (1): BatchNorm1d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): GELU(approximate='none') ) ) (glus): ModuleList( (0): None (1): Sequential( (0): Conv1d(320, 640, kernel_size=(3,), stride=(1,), padding=(1,)) (1): GLU(dim=1) ) (2): None (3): Sequential( (0): Conv1d(320, 640, kernel_size=(3,), stride=(1,), padding=(1,)) (1): GLU(dim=1) ) (4): None (5): Sequential( (0): Conv1d(320, 640, kernel_size=(3,), stride=(1,), padding=(1,)) (1): GLU(dim=1) ) (6): None (7): Sequential( (0): Conv1d(320, 640, kernel_size=(3,), stride=(1,), padding=(1,)) (1): GLU(dim=1) ) (8): None (9): Sequential( (0): Conv1d(320, 640, kernel_size=(3,), stride=(1,), padding=(1,)) (1): GLU(dim=1) ) ) ) )
pred, trues, features_mask, reject_mask = estimate, output, features_mask, reject_mask
CLIP loss¶
After obtaining representations of speech and corresponding brain signals, we can establish a mapping between them. A classical approach is to use regression. However, according to the authors' speculation, regression might be an ineffective loss function as it deviates from the ultimate goal, i.e., maximizing the discrimination between different speech segments. In fact, the design of regression loss implies that all dimensions of the Mel spectrogram are equally important and have similar scales. Similar to the regression loss of mean squared error, it tends to make the model equally good at predicting low and high frequencies, even if (1) certain frequencies, such as very low frequencies, may be irrelevant to speech, and (2) the variation range of certain frequencies may be much lower than others.
In the previous steps, extreme outliers have been removed, and normalization operations have been carefully designed. Finally, the authors chose contrastive loss as the training objective of the model to naturally encourage the model to focus on important information dimensions and appropriately scale them.
Let $X$ be a brain recording segment and $Y\in \mathbb{R}^{F\times T}$ the latent representation of its corresponding sound (a.k.a "positive sample"). We sample $N - 1 $ negative samples $\bar{Y}_{j\in \{1, \ldots, N-1\}}$ over our dataset and we add the positive sample as $\bar{Y}_N = Y$.
We want our model to predict the probabilities $ \forall j \in\{1, \ldots, N\}, p_j = \mathbb{P}{[\bar{Y_j} = Y]}. $ We thus train a model $f_{\mathrm{clip}}$ mapping the brain activity $X$ to a latent representation $Z = f_{\mathrm{clip}}(X) \in \mathbb{R}^{F\times T}$. The estimated probability can then be approximated by the dot product of $Z$ and the candidate speech latent representations $Y_j$, followed by a softmax:
$$ \begin{equation} \hat{p}_{j} = \frac{\mathrm{e}^{\langle Z,\bar{Y}_j \rangle }}{\sum_{j'=1}^N \mathrm{e}^{ \langle Z,\bar{Y}_{j'} \rangle }}, \end{equation} $$
with $\langle \cdot, \cdot \rangle$ the inner product over both dimensions of $Z$ and $\hat{Y}$.
During the inference stage, for a test sample, it is examined along with the collected negative samples. $\hat{p}_{N}$ denotes the probability of being classified successfully. If $\bar{p}_{j\in \{1, \ldots, N\}}$ are ranked the first in terms of value when sorted in descending order, it is referred to as Top-1 accuracy. If it is ranked among the top ten, it contributes to the Top-10 accuracy.
To train $f_{\mathrm{clip}}$ with a cross-entropy between $p_{j}$ and $\hat{p}_{j}$. Note that for a large enough dataset, we can neglect the probability of sampling twice the same segment, so that we have $p_j = 1_{j = N}$ , and the cross-entropy simplifies to
$$ \begin{equation} L_\mathrm{CLIP}(p, \hat{p}) = -\log(\hat{p}_{N}) = -\langle Z,Y\rangle + \log\Big(\sum_{j'=1}^{N}\mathrm{e}^{\langle Z, \bar{Y}_j'\rangle} \Big). \end{equation} $$
The other elements of the batch as negative samples at train time. At inference time, the negative samples correspond to all of the segments of the test set but the positive one.
clip = solver.loss
candidates = trues.cuda() # Setup negatives
probs = clip.get_probabilities(pred.cuda(), candidates).cpu()
candidates.shape, probs.shape
(torch.Size([8, 1024, 343]), torch.Size([8, 8]))
Result¶
Finally, we can iterate through the entire test set to obtain statistical results. It can be seen that we have successfully reproduced a Top-10 accuracy of Top-10 acc: 70.60
on the gwilliams2022 dataset.
evaluator = Evaluator("97d170e1", shuffle_test_data=False)
evaluator.solver.args.num_workers = 20
# Load test data
evaluator.load_test_data(
n_recordings=None, # conf.n_recordings None
batch_size=1000, # conf.load_batch_size 1000
test_study="gwilliams2022")
[11-14 22:39:10][scripts.run_eval_probs][INFO] - Loading solver 97d170e1 [11-14 22:39:10][bm.play][INFO] - Loading solver from XP 97d170e1. Overrides used: ['dset.selections=[gwilliams2022]'] [11-14 22:39:10][bm._env][WARNING] - Hostname gpu226.corp.yodao.com not defined in /conf/study_paths/study_paths.yaml. Using default paths.
[11-14 22:39:10][bm.dataset][WARNING] - Requested 1000 recordings but only found 196 [11-14 22:39:21][bm.dataset][INFO] - Loading Subjects | 39/196 | 6.21 it/sec [11-14 22:39:27][bm.dataset][INFO] - Loading Subjects | 78/196 | 6.09 it/sec [11-14 22:39:34][bm.dataset][INFO] - Loading Subjects | 117/196 | 6.05 it/sec [11-14 22:39:41][bm.dataset][INFO] - Loading Subjects | 156/196 | 5.83 it/sec [11-14 22:39:48][bm.dataset][INFO] - Loading Subjects | 195/196 | 5.84 it/sec [11-14 22:39:48][bm.dataset][INFO] - # Examples (train | valid | test): 203152 | 35156 | 69972 [11-14 22:39:48][bm.train][INFO] - Model hash: ed97b0fbdffe06faf696d2eecaacb57b143fd68d /ssd3/other/penglinkai01/miniconda3/envs/bm/lib/python3.8/site-packages/flashy/loggers/tensorboard.py:47: UserWarning: tensorboard package was not found: use pip install tensorboard warnings.warn("tensorboard package was not found: use pip install tensorboard") [11-14 22:39:48][scripts.run_eval_probs][INFO] - Extracting test data
[11-14 22:40:33][scripts.run_eval_probs][INFO] - extract | 3/70 | 10.8 sec/it [11-14 22:40:41][scripts.run_eval_probs][INFO] - extract | 6/70 | 0.14 it/sec [11-14 22:40:49][scripts.run_eval_probs][INFO] - extract | 9/70 | 0.17 it/sec [11-14 22:40:58][scripts.run_eval_probs][INFO] - extract | 12/70 | 0.19 it/sec [11-14 22:41:07][scripts.run_eval_probs][INFO] - extract | 15/70 | 0.21 it/sec [11-14 22:41:19][scripts.run_eval_probs][INFO] - extract | 18/70 | 0.21 it/sec [11-14 22:41:30][scripts.run_eval_probs][INFO] - extract | 21/70 | 0.22 it/sec [11-14 22:41:39][scripts.run_eval_probs][INFO] - extract | 24/70 | 0.23 it/sec [11-14 22:41:48][scripts.run_eval_probs][INFO] - extract | 27/70 | 0.24 it/sec [11-14 22:41:57][scripts.run_eval_probs][INFO] - extract | 30/70 | 0.24 it/sec [11-14 22:42:07][scripts.run_eval_probs][INFO] - extract | 33/70 | 0.25 it/sec [11-14 22:42:16][scripts.run_eval_probs][INFO] - extract | 36/70 | 0.25 it/sec [11-14 22:42:26][scripts.run_eval_probs][INFO] - extract | 39/70 | 0.26 it/sec [11-14 22:42:36][scripts.run_eval_probs][INFO] - extract | 42/70 | 0.26 it/sec [11-14 22:42:45][scripts.run_eval_probs][INFO] - extract | 45/70 | 0.26 it/sec [11-14 22:42:54][scripts.run_eval_probs][INFO] - extract | 48/70 | 0.27 it/sec [11-14 22:43:04][scripts.run_eval_probs][INFO] - extract | 51/70 | 0.27 it/sec [11-14 22:43:14][scripts.run_eval_probs][INFO] - extract | 54/70 | 0.27 it/sec [11-14 22:43:24][scripts.run_eval_probs][INFO] - extract | 57/70 | 0.27 it/sec [11-14 22:43:34][scripts.run_eval_probs][INFO] - extract | 60/70 | 0.27 it/sec [11-14 22:43:44][scripts.run_eval_probs][INFO] - extract | 63/70 | 0.27 it/sec [11-14 22:43:54][scripts.run_eval_probs][INFO] - extract | 66/70 | 0.27 it/sec [11-14 22:44:04][scripts.run_eval_probs][INFO] - extract | 69/70 | 0.28 it/sec
preds, trues = evaluator.preds, evaluator.trues
clip = evaluator.solver.loss
# Setup negatives
candidates = trues.cuda()
# Loop over samples
loader = DataLoader(TensorDataset(preds, torch.arange(0, len(preds)),),batch_size=8)
probs = torch.zeros(len(preds), len(trues))
lp = flashy.logging.LogProgressBar(logger, loader, updates=20, name='probs')
for preds_, idx_ in lp:
# Compute probabilities
probs_ = clip.get_probabilities(preds_.cuda(), candidates).cpu()
# Update
probs[idx_] = probs_
[11-14 22:44:32][__main__][INFO] - probs | 437/8747 | 60.30 it/sec [11-14 22:44:39][__main__][INFO] - probs | 874/8747 | 61.85 it/sec [11-14 22:44:45][__main__][INFO] - probs | 1311/8747 | 62.29 it/sec [11-14 22:44:54][__main__][INFO] - probs | 1748/8747 | 58.44 it/sec [11-14 22:45:12][__main__][INFO] - probs | 2185/8747 | 45.84 it/sec [11-14 22:45:45][__main__][INFO] - probs | 2622/8747 | 32.59 it/sec [11-14 22:46:06][__main__][INFO] - probs | 3059/8747 | 30.19 it/sec [11-14 22:46:24][__main__][INFO] - probs | 3496/8747 | 29.15 it/sec [11-14 22:46:54][__main__][INFO] - probs | 3933/8747 | 26.34 it/sec [11-14 22:47:10][__main__][INFO] - probs | 4370/8747 | 26.31 it/sec [11-14 22:47:36][__main__][INFO] - probs | 4807/8747 | 25.09 it/sec [11-14 22:48:16][__main__][INFO] - probs | 5244/8747 | 22.61 it/sec [11-14 22:48:55][__main__][INFO] - probs | 5681/8747 | 21.00 it/sec [11-14 22:49:18][__main__][INFO] - probs | 6118/8747 | 20.86 it/sec [11-14 22:49:55][__main__][INFO] - probs | 6555/8747 | 19.85 it/sec [11-14 22:50:23][__main__][INFO] - probs | 6992/8747 | 19.48 it/sec [11-14 22:50:58][__main__][INFO] - probs | 7429/8747 | 18.87 it/sec [11-14 22:51:20][__main__][INFO] - probs | 7866/8747 | 18.94 it/sec [11-14 22:51:37][__main__][INFO] - probs | 8303/8747 | 19.19 it/sec [11-14 22:52:14][__main__][INFO] - probs | 8740/8747 | 18.61 it/sec
# probs, target_labels, vocab_labels
# probs_segment, segment_hashes, vocab_segment
target_labels = evaluator.metadata["segment_hashes"]
vocab_labels = evaluator.trues_segment_hashes
for topk in (1, 5, 10):
# Extract topk indices
idx = probs.topk(topk, dim=1).indices
# Get the corresponding topk labels
whs = vocab_labels[idx.view(-1)].reshape(idx.shape)
# 1 if the labels matches with the targets
correct = ((whs == target_labels[:, None]).any(1)).float()
# Average across samples
acc = correct.mean()
logger.info("Top-%d acc: %.2f", topk, 100 * acc)
[11-14 22:52:14][__main__][INFO] - Top-1 acc: 41.16 [11-14 22:52:14][__main__][INFO] - Top-5 acc: 62.52 [11-14 22:52:14][__main__][INFO] - Top-10 acc: 70.60
Discussion¶
From the above experimental results, the Top-10 accuracy is 70%, which is an exciting result. The experimental results indicate that many of the proposed methods and processing steps have had a positive impact on the results. However, the following facts mean that we still need to make a lot of efforts:
Generalization: Subject embedding is a great attempt, which means that future decoding systems do not require complex adaptation for each individual. However, the differences between datasets indicate that this problem is still challenging.
Supervised EEG segmentation: The EEG signals are segmented in a "supervised" manner using speech annotations. In other words, "we know that this part contains speech" and then classify it.
Precision: Each unit for classification is a 3s segment. Pretrained speech features are compressed, making it difficult to decode continuous speech with low time resolution.
Language models, especially large Language models, may have promising capabilities to alleviate the low time resolution problem to some extent. As mentioned as above, Tang et al. leverage GPT and a creative beam search method to enhance decoding performance with contextual information.