TaskAug ECG replication: PTBXLDataset, ECGBinaryClassification, TaskAugResNet#1123
Open
RogeCS wants to merge 3 commits intosunlabuiuc:masterfrom
Open
TaskAug ECG replication: PTBXLDataset, ECGBinaryClassification, TaskAugResNet#1123RogeCS wants to merge 3 commits intosunlabuiuc:masterfrom
RogeCS wants to merge 3 commits intosunlabuiuc:masterfrom
Conversation
…to 500→250 Hz (Raghu et al. 2022) (#4) * fix bugs * fix bugs * fix bugs * more enhancements * more enhancements * ablation --------- Co-authored-by: Rogelio Medina <rogelio.medina@c3.ai>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Contributors
Type of Contribution
Full pipeline reproducibility — Dataset + Task + Model
Paper
Raghu, A., Raghu, M., Kornblith, S., Duvenaud, D., & Ghahramani, Z. (2022).
Data Augmentation for Electrocardiograms.
Conference on Health, Inference, and Learning (CHIL), PMLR 174.
https://proceedings.mlr.press/v174/raghu22a.html
Description
This PR implements the TaskAug framework from Raghu et al. (2022), which
learns a task-adaptive, differentiable augmentation policy jointly with an ECG
classifier via bi-level optimisation. Three new PyHealth components are
contributed:
1.
PTBXLDataset— Dataset wrapperLoads the PTB-XL 12-lead ECG corpus (21,799 records, 18,869 patients).
Parses
ptbxl_database.csvandscp_statements.csvto assign four binarydiagnostic superclass labels: MI (myocardial infarction), HYP
(hypertrophy), STTC (ST/T-wave change), and CD (conduction
disturbance).
2.
ECGBinaryClassification— Task classLazy-loads WFDB waveform files, applies per-lead z-score normalisation, and
pads/truncates signals to a fixed length (default 2500 samples @ 500 Hz).
Supports all four superclasses via a
task_labelargument.3.
TaskAugResNet— ModelBaseModelsubclass combining:TaskAugPolicy— K-stage differentiable augmentation policy usingGumbel-Softmax over 8 ECG-specific operations (Gaussian noise, magnitude
scale, time mask, baseline wander, temporal warp, temporal displacement,
no-op, and a novel LeadDropout extension) with class-specific learnable
magnitudes (
mag_neg,mag_pos)._ResNet1D— 1-D ResNet-18 backbone adapted for multi-lead ECG(kernel size 7, 12-channel input).
policy_parameters()/backbone_parameters()helpers enabling theinner/outer loop split required for bi-level optimisation.
Extensions beyond the paper
probability proportional to the learned magnitude, simulating electrode
failure in clinical settings.
shared_magnitudesflag: ablation option that forces identicalaugmentation strength across classes, directly testing the paper's
class-asymmetric magnitude hypothesis.
Example script
examples/ptbxl_ecg_classification_taskaug_resnet.pyprovides:BiLevelTrainer(first-order DARTS approximation of the outer loop)TaskAug K=1/K=2, frozen policy, and shared magnitudes
--syntheticflag for dependency-free testing with no download requiredFile Guide
pyhealth/datasets/ptbxl.pyPTBXLDataset— metadata parsing, label generationpyhealth/tasks/ecg_classification.pyECGBinaryClassification— waveform loading, normalisation, task schemapyhealth/models/taskaug_resnet.pyTaskAugPolicy,_ResNet1D,TaskAugResNet,_lead_dropoutexamples/ptbxl_ecg_classification_taskaug_resnet.pyBiLevelTrainer, ablation study, CLITaskAug ECG DLH.ipynbdocs/api/datasets/pyhealth.datasets.PTBXLDataset.rstdocs/api/models/pyhealth.models.TaskAugResNet.rstTesting
All components can be exercised without downloading PTB-XL:
python examples/ptbxl_ecg_classification_taskaug_resnet.py \ --synthetic --mode ablation --epochs 5