diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8d9a59d21..feab3a178 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -226,6 +226,7 @@ Available Datasets datasets/pyhealth.datasets.MIMIC4Dataset datasets/pyhealth.datasets.MedicalTranscriptionsDataset datasets/pyhealth.datasets.CardiologyDataset + datasets/pyhealth.datasets.PTBXLDataset datasets/pyhealth.datasets.eICUDataset datasets/pyhealth.datasets.ISRUCDataset datasets/pyhealth.datasets.MIMICExtractDataset diff --git a/docs/api/datasets/pyhealth.datasets.PTBXLDataset.rst b/docs/api/datasets/pyhealth.datasets.PTBXLDataset.rst new file mode 100644 index 000000000..97f9e77cf --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.PTBXLDataset.rst @@ -0,0 +1,21 @@ +pyhealth.datasets.PTBXLDataset +============================== + +PTB-XL is a large publicly available 12-lead ECG dataset containing 21 799 +10-second recordings from 18 869 patients, annotated with SCP-ECG diagnostic +statements and mapped to four binary superclass labels: MI (myocardial +infarction), HYP (hypertrophy), STTC (ST/T-change), and CD (conduction +disturbance). + +Paper: Wagner et al. (2020). PTB-XL, a large publicly available +electrocardiography dataset. *Scientific Data*, 7(154). +https://physionet.org/content/ptb-xl/1.0.3/ + +TaskAug reference: Raghu et al. (2022). Data Augmentation for +Electrocardiograms. *CHIL*, PMLR 174. +https://proceedings.mlr.press/v174/raghu22a.html + +.. autoclass:: pyhealth.datasets.PTBXLDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/datasets/pyhealth.datasets.ptbxl.rst b/docs/api/datasets/pyhealth.datasets.ptbxl.rst new file mode 100644 index 000000000..8b231aff6 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.ptbxl.rst @@ -0,0 +1,21 @@ +pyhealth.datasets.ptbxl +======================= + +Overview +-------- +PTB-XL is a large publicly available 12-lead ECG dataset containing 21 799 +10-second records from 18 869 patients (Wagner et al., 2020). This PyHealth +wrapper maps each record to four binary diagnostic superclass labels — +myocardial infarction (MI), hypertrophy (HYP), ST/T-change (STTC), and +conduction disturbance (CD) — enabling direct use with +:class:`~pyhealth.tasks.ecg_classification.ECGBinaryClassification`. + +Source: https://physionet.org/content/ptb-xl/1.0.3/ + +API Reference +------------- + +.. autoclass:: pyhealth.datasets.ptbxl.PTBXLDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/models.rst b/docs/api/models.rst index 7c3ac7c4b..5250cfb48 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -189,6 +189,7 @@ API Reference models/pyhealth.models.JambaEHR models/pyhealth.models.ContraWR models/pyhealth.models.SparcNet + models/pyhealth.models.TaskAugResNet models/pyhealth.models.StageNet models/pyhealth.models.StageAttentionNet models/pyhealth.models.AdaCare diff --git a/docs/api/models/pyhealth.models.TaskAugResNet.rst b/docs/api/models/pyhealth.models.TaskAugResNet.rst new file mode 100644 index 000000000..3c7b7bd87 --- /dev/null +++ b/docs/api/models/pyhealth.models.TaskAugResNet.rst @@ -0,0 +1,28 @@ +pyhealth.models.TaskAugResNet +============================== + +TaskAug is a differentiable, task-adaptive data augmentation framework for +ECG binary classification. A K-stage Gumbel-Softmax augmentation policy +selects from eight time-series operations with class-specific learnable +magnitudes. The policy is trained jointly with a 1-D ResNet-18 backbone via +bi-level optimisation (inner loop: backbone on augmented training data; outer +loop: policy on clean validation loss). + +Paper: Raghu et al. (2022). Data Augmentation for Electrocardiograms. +*Conference on Health, Inference, and Learning (CHIL)*, PMLR 174. +https://proceedings.mlr.press/v174/raghu22a.html + +.. autoclass:: pyhealth.models.taskaug_resnet.TaskAugPolicy + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.taskaug_resnet._ResNet1D + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.TaskAugResNet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/models/pyhealth.models.taskaug_resnet.rst b/docs/api/models/pyhealth.models.taskaug_resnet.rst new file mode 100644 index 000000000..9202b9928 --- /dev/null +++ b/docs/api/models/pyhealth.models.taskaug_resnet.rst @@ -0,0 +1,28 @@ +pyhealth.models.taskaug_resnet +============================== + +Overview +-------- +Replication of the TaskAug framework from Raghu et al. (2022) *Data +Augmentation for Electrocardiograms* (CHIL, PMLR 174). A K-stage +differentiable augmentation policy (:class:`TaskAugPolicy`) selects among +seven ECG-specific operations via Gumbel-Softmax with class-specific learnable +magnitudes. The policy is trained jointly with a 1-D ResNet-18 backbone +(:class:`_ResNet1D`) using a bi-level optimisation scheme: the inner loop +updates backbone weights on augmented training data while the outer loop +updates policy weights on clean validation loss. + +Paper: https://proceedings.mlr.press/v174/raghu22a.html + +API Reference +------------- + +.. autoclass:: pyhealth.models.taskaug_resnet.TaskAugResNet + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.taskaug_resnet.TaskAugPolicy + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..a2cb6f4f7 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -209,6 +209,7 @@ Available Tasks In-Hospital Mortality (MIMIC-IV) MIMIC-III ICD-9 Coding Cardiology Detection + ECG Binary Classification (PTB-XL) COVID-19 CXR Classification DKA Prediction (MIMIC-IV) Drug Recommendation diff --git a/docs/api/tasks/pyhealth.tasks.ecg_classification.rst b/docs/api/tasks/pyhealth.tasks.ecg_classification.rst new file mode 100644 index 000000000..8cf04d899 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ecg_classification.rst @@ -0,0 +1,20 @@ +pyhealth.tasks.ecg_classification +================================== + +Overview +-------- +Binary ECG classification task for the PTB-XL dataset, implementing the +task interface from Raghu et al. (2022) *Data Augmentation for +Electrocardiograms* (CHIL, PMLR 174). Each ECG record is loaded from disk +via WFDB, per-lead z-score normalised, and padded or truncated to a fixed +time length. Supports four diagnostic superclasses: MI, HYP, STTC, and CD. + +Paper: https://proceedings.mlr.press/v174/raghu22a.html + +API Reference +------------- + +.. autoclass:: pyhealth.tasks.ecg_classification.ECGBinaryClassification + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/ptbxl_ecg_classification_taskaug_resnet.py b/examples/ptbxl_ecg_classification_taskaug_resnet.py new file mode 100644 index 000000000..ff04eece4 --- /dev/null +++ b/examples/ptbxl_ecg_classification_taskaug_resnet.py @@ -0,0 +1,706 @@ +"""PTB-XL ECG binary classification with TaskAug + 1-D ResNet-18. + +Demonstrates full pipeline replication of Raghu et al. (2022): + "Data Augmentation for Electrocardiograms", CHIL 2022. + https://proceedings.mlr.press/v174/raghu22a.html + +This script contains four sections: + +1. **Standard training** — joint optimisation of backbone + policy with a + single Adam optimiser (fast baseline). + +2. **Bi-level training (BiLevelTrainer)** — inner loop updates the backbone + on augmented training data; outer loop updates the policy on clean + validation loss. Uses a first-order DARTS-style approximation. + +3. **Ablation study** — compares six configurations on synthetic data: + (a) no augmentation, (b) fixed random augmentation, (c) TaskAug 1-stage, + (d) TaskAug 2-stage (default), (e) frozen policy (random init, never + updated), (f) shared magnitudes (class-agnostic mu_0 = mu_1). + +4. **Learning-rate sweep** — evaluates TaskAug K=2 with three outer-loop + learning rates {1e-2, 1e-3, 1e-4} to show sensitivity to this + hyperparameter. + +Usage +----- +Real PTB-XL data (requires download from PhysioNet):: + + python ptbxl_ecg_classification_taskaug_resnet.py \ + --data_root /path/to/ptb-xl/ \ + --task MI --mode bilevel --epochs 20 + +Synthetic data (no download needed, for testing/CI):: + + python ptbxl_ecg_classification_taskaug_resnet.py --synthetic + +Ablation results (synthetic, default) +------------------------------------- +Expected relative ordering: D >= C >= F >= B >= E >= A (AUROC). +Configs D and C (learned policy) should outperform A (no augmentation). +Config E (frozen policy) isolates the benefit of *learning* the policy. +Config F (shared magnitudes) tests the class-specific magnitude hypothesis. +The lr sweep should show lr_outer=1e-3 outperforms 1e-2 (too aggressive) +and 1e-4 (too slow to converge in few epochs). +""" +from __future__ import annotations + +import argparse +import time +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + + +# --------------------------------------------------------------------------- +# Synthetic dataset (no real data required) +# --------------------------------------------------------------------------- + +def make_synthetic_dataset( + n_train: int = 200, + n_val: int = 50, + leads: int = 12, + length: int = 1000, + pos_rate: float = 0.3, + seed: int = 42, +) -> Tuple[TensorDataset, TensorDataset]: + """Generate synthetic ECG-like tensors for offline testing. + + Positive-class signals have a higher-amplitude sinusoidal component + injected into lead 0 to give the model a learnable signal. + + Args: + n_train: Number of training samples. + n_val: Number of validation samples. + leads: Number of ECG leads (channels). + length: Time-series length. + pos_rate: Fraction of positive-class samples. + seed: Random seed for reproducibility. + + Returns: + Tuple of (train_dataset, val_dataset). + """ + rng = np.random.default_rng(seed) + + def _make_split(n: int) -> Tuple[np.ndarray, np.ndarray]: + labels = (rng.random(n) < pos_rate).astype(np.int64) + signals = rng.standard_normal((n, leads, length)).astype(np.float32) + # Inject discriminative signal into lead 0 for positive class + t = np.linspace(0, 2 * np.pi, length, dtype=np.float32) + for i, lbl in enumerate(labels): + if lbl == 1: + signals[i, 0] += 0.5 * np.sin(5 * t) + return signals, labels + + x_tr, y_tr = _make_split(n_train) + x_val, y_val = _make_split(n_val) + + train_ds = TensorDataset(torch.from_numpy(x_tr), torch.from_numpy(y_tr)) + val_ds = TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val)) + return train_ds, val_ds + + +# --------------------------------------------------------------------------- +# Real PTB-XL dataset (optional — skipped if --synthetic) +# --------------------------------------------------------------------------- + +def make_ptbxl_dataset( + data_root: str, + task_label: str = "MI", + sampling_rate: int = 100, +) -> Tuple[TensorDataset, TensorDataset]: + """Load PTB-XL via PyHealth and return TensorDatasets. + + Performs an 80/20 train/val split on the first 5000 records (N=5000 + regime from Raghu et al. Table 2). + + Args: + data_root: Path to the PTB-XL root directory. + task_label: One of ``"MI"``, ``"HYP"``, ``"STTC"``, ``"CD"``. + sampling_rate: Waveform sampling rate (100 or 500 Hz). + + Returns: + Tuple of (train_dataset, val_dataset). + """ + from pyhealth.datasets.ptbxl import PTBXLDataset + from pyhealth.tasks.ecg_classification import ECGBinaryClassification + + dataset = PTBXLDataset( + root=data_root, + sampling_rate=sampling_rate, + ) + sample_ds = dataset.set_task(ECGBinaryClassification(task_label=task_label)) + + # Collect all samples + ecgs, labels = [], [] + for sample in sample_ds: + ecgs.append(sample["ecg"].clone().detach()) + labels.append(sample["label"]) + + ecgs = torch.stack(ecgs) # (N, 12, T) + labels = torch.stack(labels).squeeze(-1).long() # (N,) not (N, 1) + + # 80/20 split (up to 5000 samples) + n = min(len(ecgs), 5000) + ecgs, labels = ecgs[:n], labels[:n] + split = int(0.8 * n) + train_ds = TensorDataset(ecgs[:split], labels[:split]) + val_ds = TensorDataset(ecgs[split:], labels[split:]) + return train_ds, val_ds + + +# --------------------------------------------------------------------------- +# Metrics helpers +# --------------------------------------------------------------------------- + +def compute_auroc(y_true: np.ndarray, y_prob: np.ndarray) -> float: + """Compute AUROC without sklearn (trapezoidal rule).""" + order = np.argsort(-y_prob) + y_sorted = y_true[order] + n_pos = y_sorted.sum() + n_neg = len(y_sorted) - n_pos + if n_pos == 0 or n_neg == 0: + return float("nan") + tp = np.cumsum(y_sorted) + fp = np.cumsum(1 - y_sorted) + tpr = tp / n_pos + fpr = fp / n_neg + return float(np.trapezoid(tpr, fpr)) + + +def compute_auprc(y_true: np.ndarray, y_prob: np.ndarray) -> float: + """Compute AUPRC (average precision) without sklearn (trapezoidal rule).""" + order = np.argsort(-y_prob) + y_sorted = y_true[order] + n_pos = y_sorted.sum() + if n_pos == 0: + return float("nan") + tp = np.cumsum(y_sorted) + fp = np.cumsum(1 - y_sorted) + precision = tp / (tp + fp) + recall = tp / n_pos + # Prepend (recall=0, precision=1) so the curve starts at the top-left + precision = np.concatenate([[1.0], precision]) + recall = np.concatenate([[0.0], recall]) + return float(np.trapezoid(precision, recall)) + + +@torch.no_grad() +def evaluate( + model: nn.Module, + loader: DataLoader, + device: torch.device, +) -> Dict[str, float]: + """Return loss, accuracy, and AUROC on *loader*.""" + model.eval() + all_loss, all_prob, all_true = [], [], [] + + for ecg, label in loader: + ecg, label = ecg.to(device), label.to(device) + out = model(ecg=ecg, label=label) + all_loss.append(out["loss"].item()) + all_prob.extend(out["y_prob"].squeeze(-1).cpu().numpy()) + all_true.extend(label.cpu().numpy()) + + y_prob = np.array(all_prob) + y_true = np.array(all_true) + acc = ((y_prob > 0.5).astype(int) == y_true).mean() + return { + "loss": float(np.mean(all_loss)), + "accuracy": float(acc), + "auroc": compute_auroc(y_true, y_prob), + "auprc": compute_auprc(y_true, y_prob), + } + + +# --------------------------------------------------------------------------- +# Standard training loop +# --------------------------------------------------------------------------- + +def train_standard( + model: nn.Module, + train_loader: DataLoader, + val_loader: DataLoader, + epochs: int = 10, + lr: float = 1e-3, + device: torch.device = torch.device("cpu"), +) -> List[Dict]: + """Standard joint optimisation of backbone + policy. + + Args: + model: :class:`TaskAugResNet` instance. + train_loader: Training data loader. + val_loader: Validation data loader. + epochs: Number of training epochs. + lr: Learning rate. + device: Compute device. + + Returns: + List of per-epoch metric dicts. + """ + model.to(device) + optimizer = optim.Adam(model.parameters(), lr=lr) + history: List[Dict] = [] + + for epoch in range(1, epochs + 1): + model.train() + for ecg, label in train_loader: + ecg, label = ecg.to(device), label.to(device) + optimizer.zero_grad() + model(ecg=ecg, label=label)["loss"].backward() + optimizer.step() + + metrics = evaluate(model, val_loader, device) + metrics["epoch"] = epoch + history.append(metrics) + print( + f"[Standard] Epoch {epoch:3d} | " + f"val_loss={metrics['loss']:.4f} | " + f"val_acc={metrics['accuracy']:.3f} | " + f"val_auroc={metrics['auroc']:.3f}" + ) + + return history + + +# --------------------------------------------------------------------------- +# Bi-level trainer (DARTS first-order approximation) +# --------------------------------------------------------------------------- + +class BiLevelTrainer: + """First-order bi-level optimiser for TaskAug (Raghu et al., 2022). + + The inner loop updates the ResNet backbone on augmented training batches + using Adam. The outer loop updates the augmentation policy on clean + validation batches — approximating implicit differentiation with a + single-step unrolling (DARTS-style first-order approximation). + + Args: + model: :class:`TaskAugResNet` instance. + lr_inner: Inner-loop (backbone) learning rate. + lr_outer: Outer-loop (policy) learning rate. + device: Compute device. + """ + + def __init__( + self, + model: "TaskAugResNet", # noqa: F821 + lr_inner: float = 1e-3, + lr_outer: float = 1e-3, + device: torch.device = torch.device("cpu"), + ) -> None: + self.model = model.to(device) + self.device = device + self.inner_opt = optim.Adam(model.backbone_parameters(), lr=lr_inner) + self.outer_opt = optim.RMSprop(model.policy_parameters(), lr=lr_outer) + + def step( + self, + ecg_train: torch.Tensor, + label_train: torch.Tensor, + ecg_val: torch.Tensor, + label_val: torch.Tensor, + ) -> Tuple[float, float]: + """Execute one inner + outer update step. + + Args: + ecg_train: Training ECG batch ``(B, 12, T)``. + label_train: Training labels ``(B,)``. + ecg_val: Validation ECG batch ``(B, 12, T)`` (clean, no augment). + label_val: Validation labels ``(B,)``. + + Returns: + Tuple of ``(train_loss, val_loss)`` floats. + """ + ecg_train = ecg_train.to(self.device) + label_train = label_train.to(self.device) + ecg_val = ecg_val.to(self.device) + label_val = label_val.to(self.device) + + # Inner step: update backbone on augmented training data + self.model.train() + self.inner_opt.zero_grad() + train_loss = self.model(ecg=ecg_train, label=label_train)["loss"] + train_loss.backward() + self.inner_opt.step() + + # Outer step: update policy on clean validation data + # (first-order approximation — no Neumann series unrolling) + self.outer_opt.zero_grad() + val_loss = self.model(ecg=ecg_val, label=label_val)["loss"] + val_loss.backward() + self.outer_opt.step() + + return float(train_loss.item()), float(val_loss.item()) + + def fit( + self, + train_loader: DataLoader, + val_loader: DataLoader, + epochs: int = 10, + ) -> List[Dict]: + """Train for *epochs* epochs with bi-level optimisation. + + Args: + train_loader: Training data loader. + val_loader: Validation data loader (also used for outer loop). + epochs: Number of epochs. + + Returns: + List of per-epoch metric dicts. + """ + val_iter = iter(val_loader) + history: List[Dict] = [] + + for epoch in range(1, epochs + 1): + for ecg_tr, lbl_tr in train_loader: + # Sample a validation batch for the outer step + try: + ecg_val, lbl_val = next(val_iter) + except StopIteration: + val_iter = iter(val_loader) + ecg_val, lbl_val = next(val_iter) + + self.step(ecg_tr, lbl_tr, ecg_val, lbl_val) + + metrics = evaluate(self.model, val_loader, self.device) + metrics["epoch"] = epoch + history.append(metrics) + print( + f"[BiLevel] Epoch {epoch:3d} | " + f"val_loss={metrics['loss']:.4f} | " + f"val_acc={metrics['accuracy']:.3f} | " + f"val_auroc={metrics['auroc']:.3f}" + ) + + return history + + +# --------------------------------------------------------------------------- +# Ablation study +# --------------------------------------------------------------------------- + +def run_ablation( + train_ds: TensorDataset, + val_ds: TensorDataset, + device: torch.device, + epochs: int = 15, + batch_size: int = 32, +) -> Dict[str, Dict]: + """Compare six augmentation configurations on the same data split. + + Configurations + -------------- + A. **No augmentation** — backbone-only, no policy. + B. **Fixed random augmentation** — Gaussian noise (sigma=0.1), no learning. + C. **TaskAug 1-stage** — learned policy, K=1. + D. **TaskAug 2-stage** — learned policy, K=2 (paper default). + E. **Frozen policy** — policy initialized at random but never updated; + only backbone trains. Isolates the benefit of *learning* the policy. + F. **Shared magnitudes** — class-agnostic magnitudes (mu_0 = mu_1); + tests the asymmetric augmentation hypothesis from Section 3 of the paper. + + Args: + train_ds: Training TensorDataset. + val_ds: Validation TensorDataset. + device: Compute device. + epochs: Training epochs per configuration. + batch_size: Mini-batch size. + + Returns: + Dict mapping config keys to metric dicts. + """ + from pyhealth.models.taskaug_resnet import TaskAugResNet, _ResNet1D, TaskAugPolicy + + train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_ds, batch_size=batch_size) + + results: Dict[str, Dict] = {} + + # ---- Configuration A: no augmentation (backbone only) ---- + print("\n Ablation A: No Augmentation ") + + class BackboneOnly(nn.Module): + def __init__(self) -> None: + super().__init__() + self.backbone = _ResNet1D(in_channels=12, num_classes=1) + + def forward(self, ecg: torch.Tensor, label: Optional[torch.Tensor] = None): + logits = self.backbone(ecg) + y_prob = torch.sigmoid(logits) + out = {"logit": logits, "y_prob": y_prob} + if label is not None: + out["loss"] = nn.functional.binary_cross_entropy_with_logits( + logits.squeeze(-1), label.float() + ) + out["y_true"] = label + return out + + torch.manual_seed(42) + model_a = BackboneOnly().to(device) + opt_a = optim.Adam(model_a.parameters(), lr=1e-3) + for epoch in range(epochs): + model_a.train() + for ecg, lbl in train_loader: + ecg, lbl = ecg.to(device), lbl.to(device) + opt_a.zero_grad() + model_a(ecg=ecg, label=lbl)["loss"].backward() + opt_a.step() + results["A_no_aug"] = evaluate(model_a, val_loader, device) + print(f" AUROC={results['A_no_aug']['auroc']:.3f} " + f"AUPRC={results['A_no_aug']['auprc']:.3f} " + f"ACC={results['A_no_aug']['accuracy']:.3f}") + + # ---- Configuration B: fixed Gaussian noise ---- + print("\n Ablation B: Fixed Gaussian Noise ") + + class FixedNoiseModel(nn.Module): + def __init__(self, noise_std: float = 0.1) -> None: + super().__init__() + self.noise_std = noise_std + self.backbone = _ResNet1D(in_channels=12, num_classes=1) + + def forward(self, ecg: torch.Tensor, label: Optional[torch.Tensor] = None): + if self.training: + ecg = ecg + self.noise_std * torch.randn_like(ecg) + logits = self.backbone(ecg) + y_prob = torch.sigmoid(logits) + out = {"logit": logits, "y_prob": y_prob} + if label is not None: + out["loss"] = nn.functional.binary_cross_entropy_with_logits( + logits.squeeze(-1), label.float() + ) + out["y_true"] = label + return out + + torch.manual_seed(42) + model_b = FixedNoiseModel().to(device) + opt_b = optim.Adam(model_b.parameters(), lr=1e-3) + for epoch in range(epochs): + model_b.train() + for ecg, lbl in train_loader: + ecg, lbl = ecg.to(device), lbl.to(device) + opt_b.zero_grad() + model_b(ecg=ecg, label=lbl)["loss"].backward() + opt_b.step() + results["B_fixed_noise"] = evaluate(model_b, val_loader, device) + print(f" AUROC={results['B_fixed_noise']['auroc']:.3f} " + f"AUPRC={results['B_fixed_noise']['auprc']:.3f} " + f"ACC={results['B_fixed_noise']['accuracy']:.3f}") + + # ---- Configurations C & D: TaskAug 1-stage and 2-stage ---- + for stages, key, label in [(1, "C_taskaug_1stage", "TaskAug 1-stage"), + (2, "D_taskaug_2stage", "TaskAug 2-stage (paper)")]: + print(f"\n Ablation {key[0]}: {label} ") + torch.manual_seed(42) + mock_ds = _make_mock_dataset() + model_x = TaskAugResNet(mock_ds, policy_stages=stages).to(device) + trainer = BiLevelTrainer(model_x, lr_inner=1e-3, lr_outer=1e-3, device=device) + trainer.fit(train_loader, val_loader, epochs=epochs) + results[key] = evaluate(model_x, val_loader, device) + print(f" AUROC={results[key]['auroc']:.3f} " + f"AUPRC={results[key]['auprc']:.3f} " + f"ACC={results[key]['accuracy']:.3f}") + + # ---- Configuration E: frozen policy (random init, never updated) ---- + print("\nAblation E: Frozen Policy") + torch.manual_seed(42) + mock_ds = _make_mock_dataset() + model_e = TaskAugResNet(mock_ds, policy_stages=2).to(device) + for p in model_e.policy.parameters(): + p.requires_grad_(False) + opt_e = optim.Adam(model_e.backbone_parameters(), lr=1e-3) + for epoch in range(epochs): + model_e.train() + for ecg, lbl in train_loader: + ecg, lbl = ecg.to(device), lbl.to(device) + opt_e.zero_grad() + model_e(ecg=ecg, label=lbl)["loss"].backward() + opt_e.step() + results["E_frozen_policy"] = evaluate(model_e, val_loader, device) + print(f" AUROC={results['E_frozen_policy']['auroc']:.3f} " + f"AUPRC={results['E_frozen_policy']['auprc']:.3f} " + f"ACC={results['E_frozen_policy']['accuracy']:.3f}") + + # ---- Configuration F: shared magnitudes (no class-specific mu) ---- + print("\nAblation F: Shared Magnitudes") + torch.manual_seed(42) + mock_ds = _make_mock_dataset() + model_f = TaskAugResNet(mock_ds, policy_stages=2, shared_magnitudes=True).to(device) + trainer_f = BiLevelTrainer(model_f, lr_inner=1e-3, lr_outer=1e-3, device=device) + trainer_f.fit(train_loader, val_loader, epochs=epochs) + results["F_shared_mag"] = evaluate(model_f, val_loader, device) + print(f" AUROC={results['F_shared_mag']['auroc']:.3f} " + f"AUPRC={results['F_shared_mag']['auprc']:.3f} " + f"ACC={results['F_shared_mag']['accuracy']:.3f}") + + # ---- Summary table ---- + print("\n" + "=" * 72) + print(f"{'Configuration':<30} {'AUROC':>7} {'AUPRC':>7} {'Accuracy':>9}") + print("-" * 72) + names = { + "A_no_aug": "A. No augmentation", + "B_fixed_noise": "B. Fixed Gaussian noise", + "C_taskaug_1stage": "C. TaskAug K=1", + "D_taskaug_2stage": "D. TaskAug K=2 (paper)", + "E_frozen_policy": "E. Frozen policy", + "F_shared_mag": "F. Shared magnitudes", + } + for key, display in names.items(): + r = results[key] + print(f" {display:<28} {r['auroc']:>7.3f} {r['auprc']:>7.3f} {r['accuracy']:>9.3f}") + print("=" * 72) + + return results + + +# --------------------------------------------------------------------------- +# Learning-rate sweep +# --------------------------------------------------------------------------- + +def run_lr_sweep( + train_ds: TensorDataset, + val_ds: TensorDataset, + device: torch.device, + epochs: int = 15, + batch_size: int = 32, +) -> Dict[str, Dict]: + """Sweep outer-loop learning rate for TaskAug K=2. + + Tests lr_outer in {1e-2, 1e-3, 1e-4} while keeping lr_inner=1e-3 fixed. + Demonstrates sensitivity to the outer-loop learning rate — the paper + uses 1e-3 as the default. + + Args: + train_ds: Training TensorDataset. + val_ds: Validation TensorDataset. + device: Compute device. + epochs: Training epochs per configuration. + batch_size: Mini-batch size. + + Returns: + Dict mapping lr description to metric dicts. + """ + from pyhealth.models.taskaug_resnet import TaskAugResNet + + train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_ds, batch_size=batch_size) + + results: Dict[str, Dict] = {} + for lr_outer in [1e-2, 1e-3, 1e-4]: + key = f"lr_outer={lr_outer}" + print(f"\n LR Sweep: {key} ") + torch.manual_seed(42) + mock_ds = _make_mock_dataset() + model = TaskAugResNet(mock_ds, policy_stages=2).to(device) + trainer = BiLevelTrainer(model, lr_inner=1e-3, lr_outer=lr_outer, device=device) + trainer.fit(train_loader, val_loader, epochs=epochs) + results[key] = evaluate(model, val_loader, device) + print(f" AUROC={results[key]['auroc']:.3f} " + f"AUPRC={results[key]['auprc']:.3f} " + f"ACC={results[key]['accuracy']:.3f}") + + print("\n" + "=" * 62) + print(f"{'lr_outer':<20} {'AUROC':>7} {'AUPRC':>7} {'Accuracy':>9}") + print("-" * 62) + for key, r in results.items(): + print(f" {key:<18} {r['auroc']:>7.3f} {r['auprc']:>7.3f} {r['accuracy']:>9.3f}") + print("=" * 62) + + return results + + +# --------------------------------------------------------------------------- +# Helper — mock dataset for TaskAugResNet in ablation +# --------------------------------------------------------------------------- + +def _make_mock_dataset(): + from unittest.mock import MagicMock + + ds = MagicMock() + ds.input_schema = {"ecg": "tensor"} + ds.output_schema = {"label": "binary"} + proc = MagicMock() + proc.size.return_value = 1 + ds.output_processors = {"label": proc} + return ds + + +# --------------------------------------------------------------------------- +# CLI entry point +# --------------------------------------------------------------------------- + +def main() -> None: + parser = argparse.ArgumentParser( + description="PTB-XL ECG classification with TaskAug + ResNet-18" + ) + parser.add_argument( + "--data_root", type=str, default=None, + help="Path to PTB-XL root directory (omit to use synthetic data)" + ) + parser.add_argument( + "--synthetic", action="store_true", + help="Use synthetic data regardless of --data_root" + ) + parser.add_argument( + "--task", choices=["MI", "HYP", "STTC", "CD"], default="MI" + ) + parser.add_argument( + "--mode", + choices=["standard", "bilevel", "ablation", "lr_sweep"], + default="ablation", + ) + parser.add_argument("--epochs", type=int, default=15) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--policy_stages", type=int, default=2) + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + + # ---- Data ---- + if args.synthetic or args.data_root is None: + print("Using synthetic data (200 train / 50 val samples).") + train_ds, val_ds = make_synthetic_dataset() + else: + print(f"Loading PTB-XL from {args.data_root}, task={args.task}") + train_ds, val_ds = make_ptbxl_dataset(args.data_root, task_label=args.task) + + train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True) + val_loader = DataLoader(val_ds, batch_size=args.batch_size) + + # ---- Run ---- + if args.mode == "ablation": + run_ablation(train_ds, val_ds, device, epochs=args.epochs, + batch_size=args.batch_size) + return + + if args.mode == "lr_sweep": + run_lr_sweep(train_ds, val_ds, device, epochs=args.epochs, + batch_size=args.batch_size) + return + + mock_ds = _make_mock_dataset() + from pyhealth.models.taskaug_resnet import TaskAugResNet + + model = TaskAugResNet(mock_ds, policy_stages=args.policy_stages) + + t0 = time.time() + if args.mode == "standard": + train_standard(model, train_loader, val_loader, + epochs=args.epochs, lr=args.lr, device=device) + else: # bilevel + trainer = BiLevelTrainer(model, lr_inner=args.lr, lr_outer=args.lr, + device=device) + trainer.fit(train_loader, val_loader, epochs=args.epochs) + + print(f"\nTotal wall time: {time.time() - t0:.1f}s") + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 50b1b3887..a4853deed 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -62,6 +62,7 @@ def __init__(self, *args, **kwargs): from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset from .physionet_deid import PhysioNetDeIDDataset +from .ptbxl import PTBXLDataset from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset from .shhs import SHHSDataset from .sleepedf import SleepEDFDataset diff --git a/pyhealth/datasets/configs/ptbxl.yaml b/pyhealth/datasets/configs/ptbxl.yaml new file mode 100644 index 000000000..38e11ef5d --- /dev/null +++ b/pyhealth/datasets/configs/ptbxl.yaml @@ -0,0 +1,14 @@ +version: "1.0.0" +tables: + ecg_records: + file_path: "ptbxl_metadata.csv" + patient_id: "patient_id" + timestamp: "recording_date" + timestamp_format: "%Y-%m-%d %H:%M:%S" + attributes: + - "ecg_id" + - "filename" + - "mi_label" + - "hyp_label" + - "sttc_label" + - "cd_label" diff --git a/pyhealth/datasets/ptbxl.py b/pyhealth/datasets/ptbxl.py new file mode 100644 index 000000000..cc5615236 --- /dev/null +++ b/pyhealth/datasets/ptbxl.py @@ -0,0 +1,186 @@ +# Authors: Paul Garcia (alanpg2), Rogelio Medina (orm9), Cesar Nava (can14) +# Paper: PTB-XL, a large publicly available electrocardiography dataset +# Link: https://physionet.org/content/ptb-xl/1.0.3/ +# Description: PyHealth dataset wrapper for PTB-XL 12-lead ECG records +# with four binary diagnostic superclass labels (MI, HYP, STTC, CD). + +"""PTB-XL ECG dataset for PyHealth. + +Reference: + Wagner et al. (2020). PTB-XL, a large publicly available electrocardiography + dataset. Scientific Data, 7(154). + https://physionet.org/content/ptb-xl/1.0.3/ +""" +from __future__ import annotations + +import ast +import os +from typing import Dict, Optional + +import pandas as pd + +from pyhealth.datasets import BaseDataset + + +class PTBXLDataset(BaseDataset): + """PTB-XL large publicly available ECG dataset. + + PTB-XL contains 21799 10-second 12-lead ECG records from 18869 patients, + annotated with SCP-ECG diagnostic statements. Records are mapped to four + binary diagnostic superclass labels: MI (myocardial infarction), HYP + (hypertrophy), STTC (ST/T-change), and CD (conduction disturbance). + + The dataset must be downloaded from PhysioNet before use:: + + wget -r -N -c -np https://physionet.org/files/ptb-xl/1.0.3/ + + On first instantiation a ``ptbxl_metadata.csv`` file is generated inside + *root*. Subsequent instantiations reuse this cached file. + + Args: + root: Path to the PTB-XL root directory (must contain + ``ptbxl_database.csv``, ``scp_statements.csv``, and the + ``records100/`` or ``records500/`` subdirectories). + dataset_name: Optional name for this dataset instance. + config_path: Path to a custom YAML schema config. Defaults to the + built-in ``configs/ptbxl.yaml``. + sampling_rate: Waveform sampling rate to load. Must be ``100`` or + ``500`` (Hz). Default: ``500`` (matches Raghu et al. 2022, which + resamples the 500 Hz records to 250 Hz before modelling). + dev: If ``True``, restrict to the first 100 patients for fast + development iterations. + + Examples: + >>> dataset = PTBXLDataset(root="/data/ptb-xl/") + >>> print(dataset.stats()) + >>> from pyhealth.tasks.ecg_classification import ECGBinaryClassification + >>> sample_ds = dataset.set_task(ECGBinaryClassification(task_label="MI")) + """ + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + sampling_rate: int = 500, + dev: bool = False, + ) -> None: + if sampling_rate not in (100, 500): + raise ValueError("sampling_rate must be 100 or 500") + self.sampling_rate = sampling_rate + + metadata_path = os.path.join(root, "ptbxl_metadata.csv") + if not os.path.exists(metadata_path): + self._prepare_metadata(root) + + if config_path is None: + config_path = os.path.join( + os.path.dirname(__file__), "configs", "ptbxl.yaml" + ) + + super().__init__( + root=root, + tables=["ecg_records"], + dataset_name=dataset_name or "ptbxl", + config_path=config_path, + dev=dev, + ) + + # ------------------------------------------------------------------ + # Metadata preparation + # ------------------------------------------------------------------ + + def _prepare_metadata(self, root: str) -> None: + """Parse raw PTB-XL files and write ``ptbxl_metadata.csv``. + + Reads ``ptbxl_database.csv`` for record-level information and + ``scp_statements.csv`` for the diagnostic superclass mapping. The + resulting CSV contains one row per ECG record with columns required + by ``configs/ptbxl.yaml``. + + Args: + root: PTB-XL root directory. + + Returns: + None. Writes ``ptbxl_metadata.csv`` to *root* as a side effect. + + Raises: + FileNotFoundError: If ``ptbxl_database.csv`` is absent. + """ + db_path = os.path.join(root, "ptbxl_database.csv") + if not os.path.exists(db_path): + raise FileNotFoundError( + f"ptbxl_database.csv not found in {root}. " + "Download PTB-XL from https://physionet.org/content/ptb-xl/1.0.3/" + ) + + df = pd.read_csv(db_path, index_col="ecg_id") + df["scp_codes"] = df["scp_codes"].apply(ast.literal_eval) + + # Build {scp_code -> diagnostic_class} mapping from scp_statements. + # In PTB-XL, diagnostic_class holds the 5 superclass labels used by + # the TaskAug paper: MI, HYP, STTC, CD, NORM. + superclass_map: Dict[str, str] = {} + scp_path = os.path.join(root, "scp_statements.csv") + if os.path.exists(scp_path): + scp_df = pd.read_csv(scp_path, index_col=0) + scp_df = scp_df[scp_df["diagnostic"] == 1] + superclass_map = scp_df["diagnostic_class"].to_dict() + + filename_col = "filename_hr" if self.sampling_rate == 500 else "filename_lr" + + def _binary_label(scp_codes: Dict[str, float], superclass: str) -> int: + """Return 1 if any code in *scp_codes* maps to *superclass*. + + Args: + scp_codes: Mapping of SCP statement code to likelihood score. + superclass: Diagnostic superclass string (e.g. ``"MI"``). + + Returns: + ``1`` if the record belongs to *superclass*, ``0`` otherwise. + """ + for code, likelihood in scp_codes.items(): + if likelihood > 0 and superclass_map.get(code) == superclass: + return 1 + return 0 + + records = [] + for ecg_id, row in df.iterrows(): + rec_date = row.get("recording_date", "2000-01-01 00:00:00") + if pd.isna(rec_date): + rec_date = "2000-01-01 00:00:00" + + scp = row["scp_codes"] + records.append( + { + "patient_id": str(int(row["patient_id"])), + "ecg_id": str(ecg_id), + "recording_date": str(rec_date), + # Store absolute path so the task can load without root + "filename": os.path.join(root, str(row[filename_col])), + "mi_label": _binary_label(scp, "MI"), + "hyp_label": _binary_label(scp, "HYP"), + "sttc_label": _binary_label(scp, "STTC"), + "cd_label": _binary_label(scp, "CD"), + } + ) + + pd.DataFrame(records).to_csv( + os.path.join(root, "ptbxl_metadata.csv"), index=False + ) + + # ------------------------------------------------------------------ + # Default task + # ------------------------------------------------------------------ + + @property + def default_task(self) -> "ECGBinaryClassification": # noqa: F821 + """Return the default MI binary classification task. + + Returns: + An :class:`~pyhealth.tasks.ecg_classification.ECGBinaryClassification` + instance configured for MI (myocardial infarction) prediction. + """ + from pyhealth.tasks.ecg_classification import ECGBinaryClassification + + return ECGBinaryClassification() diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..1397072f6 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -26,6 +26,7 @@ from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer from .stagenet import StageNet, StageNetLayer from .stagenet_mha import StageAttentionNet, StageNetAttentionLayer +from .taskaug_resnet import TaskAugResNet from .tcn import TCN, TCNLayer from .tfm_tokenizer import ( TFMTokenizer, @@ -45,4 +46,4 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding -from .califorest import CaliForest \ No newline at end of file +from .califorest import CaliForest diff --git a/pyhealth/models/taskaug_resnet.py b/pyhealth/models/taskaug_resnet.py new file mode 100644 index 000000000..291c76a6c --- /dev/null +++ b/pyhealth/models/taskaug_resnet.py @@ -0,0 +1,522 @@ +# Authors: Paul Garcia (alanpg2), Rogelio Medina (orm9), Cesar Nava (can14) +# Paper: Data Augmentation for Electrocardiograms (Raghu et al., CHIL 2022) +# Link: https://proceedings.mlr.press/v174/raghu22a.html +# Description: TaskAug differentiable augmentation policy with 1-D ResNet-18 +# backbone for binary ECG classification on PTB-XL. + +"""TaskAug: Task-Adaptive Data Augmentation with a 1-D ResNet-18 backbone. + +Implements the TaskAug framework from: + + Raghu et al. (2022). Data Augmentation for Electrocardiograms. + Conference on Health, Inference, and Learning (CHIL), PMLR 174. + https://proceedings.mlr.press/v174/raghu22a.html + +Architecture +------------ +* :class:`TaskAugPolicy` — K-stage differentiable augmentation policy. + At each stage one of seven operations is selected via the Gumbel-Softmax + trick and applied with a class-specific magnitude drawn from learnable + parameters (``mag_neg`` for label=0, ``mag_pos`` for label=1). + +* :class:`_ResNet1D` — 1-D adaptation of ResNet-18 with kernel size 7. + +* :class:`TaskAugResNet` — :class:`~pyhealth.models.BaseModel` subclass that + wires the policy and backbone together. During training the policy + augments the input before it reaches the backbone; at inference the raw + signal is forwarded directly. + +For bi-level optimisation (inner loop: backbone on augmented training data; +outer loop: policy on clean validation loss) use the ``BiLevelTrainer`` +provided in ``examples/ptbxl_ecg_classification_taskaug_resnet.py``. +""" +from __future__ import annotations + +import math +from typing import Dict, Iterator, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel + + +# --------------------------------------------------------------------------- +# Seven augmentation operations +# Each accepts x: (B, C, T) and mag: (B,) and returns (B, C, T). +# --------------------------------------------------------------------------- + +def _gaussian_noise(x: torch.Tensor, mag: torch.Tensor) -> torch.Tensor: + """Add i.i.d. Gaussian noise scaled per sample by *mag*. + + Args: + x: Input signal of shape ``(B, C, T)``. + mag: Per-sample noise standard deviation of shape ``(B,)``. + + Returns: + Noisy signal of shape ``(B, C, T)``. + """ + return x + mag.view(-1, 1, 1) * torch.randn_like(x) + + +def _magnitude_scale(x: torch.Tensor, mag: torch.Tensor) -> torch.Tensor: + """Scale signal amplitude by ``(1 + mag)`` per sample. + + Args: + x: Input signal of shape ``(B, C, T)``. + mag: Per-sample scaling offset of shape ``(B,)``. + + Returns: + Scaled signal of shape ``(B, C, T)``. + """ + return x * (1.0 + mag.view(-1, 1, 1)) + + +def _time_mask(x: torch.Tensor, mag: torch.Tensor) -> torch.Tensor: + """Zero-out a random contiguous time segment of length ``floor(|mag|*T)``. + + Args: + x: Input signal of shape ``(B, C, T)``. + mag: Per-sample mask fraction in ``[0, 1]`` of shape ``(B,)``. + + Returns: + Masked signal of shape ``(B, C, T)``. + """ + B, C, T = x.shape + out = x.clone() + mask_len = (mag.abs() * T).long().clamp(0, T) + for b in range(B): + ml = int(mask_len[b].item()) + if ml > 0: + start = int(torch.randint(0, max(T - ml + 1, 1), (1,)).item()) + out[b, :, start : start + ml] = 0.0 + return out + + +def _baseline_wander(x: torch.Tensor, mag: torch.Tensor) -> torch.Tensor: + """Add a low-frequency sinusoidal baseline drift to each sample. + + Args: + x: Input signal of shape ``(B, C, T)``. + mag: Per-sample drift amplitude of shape ``(B,)``. + + Returns: + Signal with baseline drift added, shape ``(B, C, T)``. + """ + B, C, T = x.shape + device = x.device + t = torch.linspace(0, 2 * math.pi, T, device=device) + freq = torch.rand(B, device=device) * 0.3 + 0.05 # normalised 0.05–0.35 + phase = torch.rand(B, device=device) * 2 * math.pi + wander = torch.sin(freq.unsqueeze(-1) * t + phase.unsqueeze(-1)) # (B, T) + return x + mag.view(-1, 1, 1) * wander.unsqueeze(1) + + +def _temporal_warp(x: torch.Tensor, mag: torch.Tensor) -> torch.Tensor: + """Apply differentiable non-linear temporal warping via ``grid_sample``. + + Args: + x: Input signal of shape ``(B, C, T)``. + mag: Per-sample warp strength of shape ``(B,)``. + + Returns: + Warped signal of shape ``(B, C, T)``. + """ + B, C, T = x.shape + device = x.device + base = torch.linspace(-1.0, 1.0, T, device=device).unsqueeze(0).expand(B, -1) + freq = torch.rand(B, device=device) * 2.0 + 1.0 # 1–3 cycles + phase = torch.rand(B, device=device) * 2 * math.pi + t_norm = torch.linspace(0, 2 * math.pi, T, device=device) + disp = 0.1 * mag.unsqueeze(-1) * torch.sin( + freq.unsqueeze(-1) * t_norm + phase.unsqueeze(-1) + ) + grid_x = (base + disp).clamp(-1.0, 1.0) # (B, T) + grid_y = torch.zeros_like(grid_x) + # grid_sample expects (B, H_out, W_out, 2) — treat H=1 + grid = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(1) # (B, 1, T, 2) + warped = F.grid_sample( + x.unsqueeze(2), grid, + mode="bilinear", padding_mode="border", align_corners=True, + ) + return warped.squeeze(2) # (B, C, T) + + +def _temporal_displacement(x: torch.Tensor, mag: torch.Tensor) -> torch.Tensor: + """Circularly shift each sample along the time axis by ``floor(|mag|*T)``. + + Args: + x: Input signal of shape ``(B, C, T)``. + mag: Per-sample shift fraction in ``[0, 1]`` of shape ``(B,)``. + + Returns: + Shifted signal of shape ``(B, C, T)``. + """ + B, C, T = x.shape + shifts = (mag.abs() * T).long() + return torch.stack( + [torch.roll(x[b], int(shifts[b].item()), dims=-1) for b in range(B)] + ) + + +def _no_op(x: torch.Tensor, _mag: torch.Tensor) -> torch.Tensor: + """Return the signal unchanged (identity operation). + + Args: + x: Input signal of shape ``(B, C, T)``. + _mag: Unused magnitude placeholder of shape ``(B,)``. + + Returns: + The unmodified input signal. + """ + return x + + +def _lead_dropout(x: torch.Tensor, mag: torch.Tensor) -> torch.Tensor: + """Randomly zero out ECG leads with probability proportional to ``|mag|``. + + Each lead is independently zeroed with probability ``min(|mag|, 0.5)``, + capped at 50 % so at least half the leads are always retained. + + Args: + x: Input signal of shape ``(B, C, T)``. + mag: Per-sample dropout probability of shape ``(B,)``. + + Returns: + Signal with randomly dropped leads, shape ``(B, C, T)``. + """ + B, C, T = x.shape + out = x.clone() + p_drop = mag.abs().clamp(0.0, 0.5) + for b in range(B): + mask = torch.bernoulli( + torch.full((C,), float(p_drop[b].item()), device=x.device) + ).bool() + out[b, mask, :] = 0.0 + return out + + +_OPS: List = [ + _gaussian_noise, + _magnitude_scale, + _time_mask, + _baseline_wander, + _temporal_warp, + _temporal_displacement, + _no_op, + _lead_dropout, +] +_NUM_OPS: int = len(_OPS) # 8 + + +# --------------------------------------------------------------------------- +# TaskAugPolicy +# --------------------------------------------------------------------------- + +class TaskAugPolicy(nn.Module): + """Differentiable task-adaptive augmentation policy (Raghu et al., 2022). + + Applies *num_stages* sequential augmentation stages. In each stage the + operation weights are obtained via Gumbel-Softmax, then a soft weighted + sum of all augmented versions is computed (differentiable at training + time). Each operation has two learnable scalar magnitudes — one for the + negative class (``mag_neg``) and one for the positive class + (``mag_pos``) — enabling asymmetric augmentation intensities. + + Args: + num_stages: Number of sequential augmentation stages ``K``. Default: 2. + temperature: Gumbel-Softmax temperature ``τ``. Default: 1.0. + + Attributes: + logits: ``(K, N_ops)`` selection logits. + mag_neg: ``(K, N_ops)`` per-stage magnitudes for label=0 samples. + mag_pos: ``(K, N_ops)`` per-stage magnitudes for label=1 samples. + """ + + def __init__( + self, + num_stages: int = 2, + temperature: float = 1.0, + shared_magnitudes: bool = False, + ) -> None: + super().__init__() + self.num_stages = num_stages + self.temperature = temperature + self.shared_magnitudes = shared_magnitudes + + self.logits = nn.Parameter(torch.zeros(num_stages, _NUM_OPS)) + self.mag_neg = nn.Parameter(0.1 * torch.ones(num_stages, _NUM_OPS)) + self.mag_pos = nn.Parameter(0.1 * torch.ones(num_stages, _NUM_OPS)) + + def forward(self, x: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Apply K augmentation stages to *x*. + + Args: + x: ECG tensor of shape ``(B, C, T)``. + labels: Binary class labels of shape ``(B,)`` (0 or 1). + + Returns: + Augmented tensor of the same shape as *x*. + """ + labels_f = labels.float() # (B,) + + for k in range(self.num_stages): + # Gumbel-Softmax operation weights: (N_ops,) + weights = F.gumbel_softmax( + self.logits[k], tau=self.temperature, hard=False + ) + + # Per-sample magnitude: class-specific or shared + if self.shared_magnitudes: + mag_k = self.mag_neg[k].unsqueeze(0).abs().expand(x.shape[0], -1) + else: + mag_k = ( + self.mag_neg[k].unsqueeze(0) + + labels_f.unsqueeze(-1) + * (self.mag_pos[k] - self.mag_neg[k]).unsqueeze(0) + ).abs() + + # Soft weighted combination of all augmented versions + augmented = torch.zeros_like(x) + for i, op in enumerate(_OPS): + augmented = augmented + weights[i] * op(x, mag_k[:, i]) + x = augmented + + return x + + +# --------------------------------------------------------------------------- +# 1-D ResNet-18 backbone +# --------------------------------------------------------------------------- + +class _BasicBlock1D(nn.Module): + """Residual block for 1-D signals (kernel size 7, BN + ReLU).""" + + expansion: int = 1 + + def __init__(self, in_ch: int, out_ch: int, stride: int = 1) -> None: + super().__init__() + self.conv1 = nn.Conv1d( + in_ch, out_ch, kernel_size=7, stride=stride, padding=3, bias=False + ) + self.bn1 = nn.BatchNorm1d(out_ch) + self.conv2 = nn.Conv1d(out_ch, out_ch, kernel_size=7, padding=3, bias=False) + self.bn2 = nn.BatchNorm1d(out_ch) + self.relu = nn.ReLU(inplace=True) + + self.downsample: nn.Module = nn.Identity() + if stride != 1 or in_ch != out_ch: + self.downsample = nn.Sequential( + nn.Conv1d(in_ch, out_ch, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm1d(out_ch), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Compute residual block output. + + Args: + x: Input feature map of shape ``(B, in_ch, T)``. + + Returns: + Output feature map of shape ``(B, out_ch, T//stride)``. + """ + identity = self.downsample(x) + out = self.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + return self.relu(out + identity) + + +class _ResNet1D(nn.Module): + """1-D ResNet-18 adapted for multi-lead ECG signals. + + Args: + in_channels: Number of input channels (ECG leads). Default: 12. + num_classes: Output dimension (1 for binary classification). + + Input shape: + ``(B, in_channels, T)`` — e.g. ``(B, 12, 1000)``. + + Output shape: + ``(B, num_classes)``. + """ + + def __init__(self, in_channels: int = 12, num_classes: int = 1) -> None: + super().__init__() + self.stem = nn.Sequential( + nn.Conv1d(in_channels, 64, kernel_size=15, stride=2, padding=7, bias=False), + nn.BatchNorm1d(64), + nn.ReLU(inplace=True), + nn.MaxPool1d(kernel_size=3, stride=2, padding=1), + ) + self.layer1 = self._make_layer(64, 64, n_blocks=2, stride=1) + self.layer2 = self._make_layer(64, 128, n_blocks=2, stride=2) + self.layer3 = self._make_layer(128, 256, n_blocks=2, stride=2) + self.layer4 = self._make_layer(256, 512, n_blocks=2, stride=2) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.fc = nn.Linear(512, num_classes) + self._init_weights() + + @staticmethod + def _make_layer( + in_ch: int, out_ch: int, n_blocks: int, stride: int + ) -> nn.Sequential: + """Build a stage of *n_blocks* residual blocks. + + Args: + in_ch: Input channel count for the first block. + out_ch: Output channel count for all blocks in this stage. + n_blocks: Number of :class:`_BasicBlock1D` blocks. + stride: Stride for the first block (subsequent blocks use 1). + + Returns: + Sequential module containing all blocks for this stage. + """ + blocks: List[nn.Module] = [_BasicBlock1D(in_ch, out_ch, stride=stride)] + blocks += [_BasicBlock1D(out_ch, out_ch) for _ in range(1, n_blocks)] + return nn.Sequential(*blocks) + + def _init_weights(self) -> None: + """Initialise Conv1d with Kaiming normal and BatchNorm with identity. + + Returns: + None. + """ + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Run a full ResNet-18 forward pass. + + Args: + x: Input tensor of shape ``(B, in_channels, T)``. + + Returns: + Logit tensor of shape ``(B, num_classes)``. + """ + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return self.fc(self.avgpool(x).squeeze(-1)) + + +# --------------------------------------------------------------------------- +# TaskAugResNet — PyHealth BaseModel subclass +# --------------------------------------------------------------------------- + +class TaskAugResNet(BaseModel): + """1-D ResNet-18 with a learned TaskAug augmentation policy. + + Reproduces the TaskAug framework of Raghu et al. (2022) within the + PyHealth ``BaseModel`` interface. The :class:`TaskAugPolicy` is applied + to the input only during training; at inference the raw signal is passed + directly to the backbone. + + For bi-level optimisation — inner loop updating backbone weights on + augmented training data, outer loop updating policy weights on clean + validation loss — use the ``BiLevelTrainer`` helper in the accompanying + examples script. + + Args: + dataset: A :class:`~pyhealth.datasets.SampleDataset` produced by + :class:`~pyhealth.tasks.ecg_classification.ECGBinaryClassification`. + num_leads: Number of ECG leads (input channels). Default: 12. + policy_stages: Number of sequential augmentation stages *K*. Default: 2. + temperature: Gumbel-Softmax temperature. Default: 1.0. + + Examples: + >>> model = TaskAugResNet(sample_dataset) + >>> out = model(ecg=ecg_tensor, label=label_tensor) + >>> out.keys() + dict_keys(['logit', 'y_prob', 'loss', 'y_true']) + """ + + def __init__( + self, + dataset: SampleDataset, + num_leads: int = 12, + policy_stages: int = 2, + temperature: float = 1.0, + shared_magnitudes: bool = False, + ) -> None: + super().__init__(dataset) + self.mode = "binary" # binary classification throughout + + self.policy = TaskAugPolicy( + num_stages=policy_stages, + temperature=temperature, + shared_magnitudes=shared_magnitudes, + ) + self.backbone = _ResNet1D( + in_channels=num_leads, + num_classes=self.get_output_size(), + ) + + # ------------------------------------------------------------------ + # Forward pass + # ------------------------------------------------------------------ + + def forward( + self, + ecg: torch.Tensor, + label: Optional[torch.Tensor] = None, + **kwargs, + ) -> Dict[str, torch.Tensor]: + """Augment (training only) and classify ECG signals. + + Args: + ecg: Float tensor of shape ``(B, 12, T)``. + label: Optional binary label tensor of shape ``(B,)``. Enables + class-specific augmentation magnitudes and loss computation. + + Returns: + Dict containing: + + * ``logit`` — raw logits ``(B, 1)`` + * ``y_prob`` — sigmoid probabilities ``(B, 1)`` + * ``loss`` — scalar BCE loss *(only when label is provided)* + * ``y_true`` — label tensor *(only when label is provided)* + """ + if self.training and label is not None: + ecg = self.policy(ecg, label) + + logits = self.backbone(ecg) # (B, 1) + y_prob = torch.sigmoid(logits) + + output: Dict[str, torch.Tensor] = {"logit": logits, "y_prob": y_prob} + if label is not None: + output["loss"] = F.binary_cross_entropy_with_logits( + logits.squeeze(-1), label.float() + ) + output["y_true"] = label + + return output + + # ------------------------------------------------------------------ + # Parameter group helpers (used by BiLevelTrainer) + # ------------------------------------------------------------------ + + def policy_parameters(self) -> Iterator[nn.Parameter]: + """Return an iterator over augmentation policy parameters. + + Returns: + Iterator of :class:`torch.nn.Parameter` objects belonging to + the :class:`TaskAugPolicy` (logits and magnitudes). + """ + return self.policy.parameters() + + def backbone_parameters(self) -> Iterator[nn.Parameter]: + """Return an iterator over ResNet backbone parameters. + + Returns: + Iterator of :class:`torch.nn.Parameter` objects belonging to + the :class:`_ResNet1D` backbone. + """ + return self.backbone.parameters() diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..c09cb0a64 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -22,6 +22,7 @@ drug_recommendation_mimic4_fn, drug_recommendation_omop_fn, ) +from .ecg_classification import ECGBinaryClassification from .in_hospital_mortality_mimic4 import InHospitalMortalityMIMIC4 from .length_of_stay_prediction import ( LengthOfStayPredictioneICU, diff --git a/pyhealth/tasks/ecg_classification.py b/pyhealth/tasks/ecg_classification.py new file mode 100644 index 000000000..b9635f242 --- /dev/null +++ b/pyhealth/tasks/ecg_classification.py @@ -0,0 +1,204 @@ +# Authors: Paul Garcia (alanpg2), Rogelio Medina (orm9), Cesar Nava (can14) +# Paper: Data Augmentation for Electrocardiograms (Raghu et al., CHIL 2022) +# Link: https://proceedings.mlr.press/v174/raghu22a.html +# Description: Binary ECG classification task for PTB-XL supporting MI, HYP, +# STTC, and CD diagnostic superclass labels. + +"""Binary ECG classification tasks for PTB-XL. + +Reference: + Raghu et al. (2022). Data Augmentation for Electrocardiograms. + Conference on Health, Inference, and Learning (CHIL), PMLR 174. + https://proceedings.mlr.press/v174/raghu22a.html +""" +from __future__ import annotations + +from typing import Dict, List, Optional, Tuple + +import numpy as np + +from pyhealth.data import Patient +from pyhealth.tasks import BaseTask + +#: Supported diagnostic superclass labels. +SUPERCLASSES: Tuple[str, ...] = ("MI", "HYP", "STTC", "CD") +_LABEL_COL: Dict[str, str] = {s: f"{s.lower()}_label" for s in SUPERCLASSES} + + +class ECGBinaryClassification(BaseTask): + """Binary ECG classification task for the PTB-XL dataset. + + Each ECG record produces one labelled sample. The WFDB waveform is + loaded from the path stored in the event, optionally resampled from + *input_hz* to *output_hz* using anti-aliased Fourier resampling, + per-lead z-score normalised, and padded or truncated to a fixed length + along the time axis. + + The default configuration (500 Hz input, 250 Hz output, 2500-sample + window) reproduces the setup of Raghu et al. (2022), who load the 500 Hz + PTB-XL records and downsample to 250 Hz for a 10-second window. Pair + this task with ``PTBXLDataset(sampling_rate=500)`` (the dataset default). + + Args: + task_label: Diagnostic superclass used as the binary target. + One of ``"MI"``, ``"HYP"``, ``"STTC"``, or ``"CD"``. + Default: ``"MI"``. + target_length: Number of time steps per sample after resampling, + padding, or truncation. At 250 Hz the default of 2500 equals + 10 seconds (matching Raghu et al. 2022). + input_hz: Sampling rate of the raw waveforms loaded from disk. + Should match ``PTBXLDataset.sampling_rate``. Default: ``500``. + output_hz: Target sampling rate after resampling. Set equal to + *input_hz* to skip resampling. Default: ``250``. + + Examples: + >>> task = ECGBinaryClassification(task_label="MI") + >>> samples = task(patient) + >>> print(samples[0].keys()) + dict_keys(['patient_id', 'ecg_id', 'ecg', 'label']) + """ + + task_name: str = "ECGBinaryClassification" + input_schema: Dict[str, str] = {"ecg": "tensor"} + output_schema: Dict[str, str] = {"label": "binary"} + + def __init__( + self, + task_label: str = "MI", + target_length: int = 2500, + input_hz: int = 500, + output_hz: int = 250, + ) -> None: + super().__init__() + if task_label not in SUPERCLASSES: + raise ValueError( + f"task_label must be one of {SUPERCLASSES}, got {task_label!r}" + ) + self.task_label = task_label + self._label_col = _LABEL_COL[task_label] + self.target_length = target_length + self.input_hz = input_hz + self.output_hz = output_hz + + def __call__(self, patient: Patient) -> List[Dict]: + """Process one patient into a list of ECG classification samples. + + Args: + patient: PyHealth ``Patient`` whose ``ecg_records`` events contain + the attributes defined in ``configs/ptbxl.yaml``. + + Returns: + List of sample dicts. Each dict contains: + + * ``patient_id`` – string patient identifier + * ``ecg_id`` – string ECG record identifier + * ``ecg`` – float32 ndarray of shape ``(12, target_length)`` + * ``label`` – int (0 or 1) + """ + samples: List[Dict] = [] + events = patient.get_events(event_type="ecg_records") + + for event in events: + filename: str = event["filename"] + label: int = int(event[self._label_col]) + + signal = self._load_signal(filename) + if signal is None: + continue + + if self.input_hz != self.output_hz: + signal = self._resample(signal, self.input_hz, self.output_hz) + signal = self._normalize(signal) + signal = self._pad_or_truncate(signal, self.target_length) + + samples.append( + { + "patient_id": patient.patient_id, + "ecg_id": str(event["ecg_id"]), + "ecg": signal, + "label": label, + } + ) + + return samples + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + @staticmethod + def _load_signal(filename: str) -> Optional[np.ndarray]: + """Load a WFDB record from disk and return a float32 signal array. + + Fails gracefully so that individual corrupt records are skipped + rather than crashing the entire pipeline. + + Args: + filename: Absolute path to the WFDB record *without* file + extension (e.g. ``/data/ptb-xl/records100/00000/00001_lr``). + + Returns: + Float32 ndarray of shape ``(leads, T)``, or ``None`` if the + record cannot be read. + """ + try: + import wfdb # lazy import — optional dependency + + record = wfdb.rdrecord(filename) + return record.p_signal.T.astype(np.float32) # (leads, T) + except Exception: + return None + + @staticmethod + def _resample(signal: np.ndarray, input_hz: int, output_hz: int) -> np.ndarray: + """Resample *signal* from *input_hz* to *output_hz* using Fourier resampling. + + Uses ``scipy.signal.resample`` which applies an anti-aliased FFT-based + method, matching the approach used in Raghu et al. (2022) for 500→250 Hz + downsampling. + + Args: + signal: Float32 array of shape ``(leads, T)``. + input_hz: Original sampling rate in Hz. + output_hz: Target sampling rate in Hz. + + Returns: + Resampled float32 array of shape ``(leads, T')``, where + ``T' = round(T * output_hz / input_hz)``. + """ + from scipy.signal import resample as scipy_resample # lazy import + + T = signal.shape[1] + target_samples = round(T * output_hz / input_hz) + return scipy_resample(signal, target_samples, axis=1).astype(signal.dtype) + + @staticmethod + def _normalize(signal: np.ndarray) -> np.ndarray: + """Apply per-lead z-score normalisation (zero mean, unit std). + + Args: + signal: Float32 array of shape ``(leads, T)``. + + Returns: + Normalised array of the same shape. + """ + mean = signal.mean(axis=1, keepdims=True) + std = signal.std(axis=1, keepdims=True) + 1e-8 + return (signal - mean) / std + + @staticmethod + def _pad_or_truncate(signal: np.ndarray, target: int) -> np.ndarray: + """Truncate or right-zero-pad a signal to exactly *target* steps. + + Args: + signal: Float32 array of shape ``(leads, T)``. + target: Desired number of time steps. + + Returns: + Float32 array of shape ``(leads, target)``. + """ + T = signal.shape[1] + if T >= target: + return signal[:, :target] + pad = np.zeros((signal.shape[0], target - T), dtype=signal.dtype) + return np.concatenate([signal, pad], axis=1) diff --git a/tests/core/test_ecg_classification.py b/tests/core/test_ecg_classification.py new file mode 100644 index 000000000..ec65ec052 --- /dev/null +++ b/tests/core/test_ecg_classification.py @@ -0,0 +1,196 @@ +"""Tests for ECGBinaryClassification task using synthetic patient objects. + +Synthetic signals are generated with numpy — no real PTB-XL data required. +All tests complete in milliseconds. +""" +from __future__ import annotations + +from typing import Any, Dict, List +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_fake_event( + filename: str = "/fake/record", + mi_label: int = 1, + hyp_label: int = 0, + sttc_label: int = 0, + cd_label: int = 0, + ecg_id: str = "1", +) -> MagicMock: + """Return a mock Event whose attributes match PTB-XL metadata columns.""" + event = MagicMock() + data: Dict[str, Any] = { + "filename": filename, + "ecg_id": ecg_id, + "mi_label": mi_label, + "hyp_label": hyp_label, + "sttc_label": sttc_label, + "cd_label": cd_label, + } + event.__getitem__ = lambda self, key: data[key] + return event + + +def _make_fake_patient(events: List[MagicMock], patient_id: str = "P001") -> MagicMock: + patient = MagicMock() + patient.patient_id = patient_id + patient.get_events.return_value = events + return patient + + +def _fake_signal(leads: int = 12, length: int = 1000) -> np.ndarray: + return np.random.randn(leads, length).astype(np.float32) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestECGBinaryClassificationInit: + def test_valid_task_label(self): + from pyhealth.tasks.ecg_classification import ECGBinaryClassification + + for label in ("MI", "HYP", "STTC", "CD"): + task = ECGBinaryClassification(task_label=label) + assert task.task_label == label + + def test_invalid_task_label_raises(self): + from pyhealth.tasks.ecg_classification import ECGBinaryClassification + + with pytest.raises(ValueError, match="task_label must be one of"): + ECGBinaryClassification(task_label="INVALID") + + def test_default_task_label_is_mi(self): + from pyhealth.tasks.ecg_classification import ECGBinaryClassification + + task = ECGBinaryClassification() + assert task.task_label == "MI" + + def test_schemas(self): + from pyhealth.tasks.ecg_classification import ECGBinaryClassification + + task = ECGBinaryClassification() + assert task.input_schema == {"ecg": "tensor"} + assert task.output_schema == {"label": "binary"} + + +class TestECGBinaryClassificationCall: + def test_returns_list_of_dicts(self): + from pyhealth.tasks.ecg_classification import ECGBinaryClassification + + task = ECGBinaryClassification(task_label="MI", target_length=500) + events = [_make_fake_event(mi_label=1), _make_fake_event(mi_label=0, ecg_id="2")] + patient = _make_fake_patient(events) + + fake_signal = _fake_signal(12, 600) + with patch.object(task, "_load_signal", return_value=fake_signal): + samples = task(patient) + + assert isinstance(samples, list) + assert len(samples) == 2 + + def test_sample_keys(self): + from pyhealth.tasks.ecg_classification import ECGBinaryClassification + + task = ECGBinaryClassification() + events = [_make_fake_event()] + patient = _make_fake_patient(events) + + with patch.object(task, "_load_signal", return_value=_fake_signal()): + samples = task(patient) + + assert set(samples[0].keys()) >= {"patient_id", "ecg_id", "ecg", "label"} + + def test_ecg_shape_after_truncation(self): + from pyhealth.tasks.ecg_classification import ECGBinaryClassification + + task = ECGBinaryClassification(target_length=500) + events = [_make_fake_event()] + patient = _make_fake_patient(events) + + # Signal longer than target → should truncate + with patch.object(task, "_load_signal", return_value=_fake_signal(12, 800)): + samples = task(patient) + + assert samples[0]["ecg"].shape == (12, 500) + + def test_ecg_shape_after_padding(self): + from pyhealth.tasks.ecg_classification import ECGBinaryClassification + + task = ECGBinaryClassification(target_length=1000) + events = [_make_fake_event()] + patient = _make_fake_patient(events) + + # Signal shorter than target → should zero-pad + with patch.object(task, "_load_signal", return_value=_fake_signal(12, 600)): + samples = task(patient) + + assert samples[0]["ecg"].shape == (12, 1000) + # Padded region should be zeros + assert (samples[0]["ecg"][:, 600:] == 0).all() + + def test_label_correct_for_mi(self): + from pyhealth.tasks.ecg_classification import ECGBinaryClassification + + task = ECGBinaryClassification(task_label="MI") + events = [_make_fake_event(mi_label=1), _make_fake_event(mi_label=0, ecg_id="2")] + patient = _make_fake_patient(events) + + with patch.object(task, "_load_signal", return_value=_fake_signal()): + samples = task(patient) + + assert samples[0]["label"] == 1 + assert samples[1]["label"] == 0 + + def test_label_correct_for_hyp(self): + from pyhealth.tasks.ecg_classification import ECGBinaryClassification + + task = ECGBinaryClassification(task_label="HYP") + events = [_make_fake_event(hyp_label=1)] + patient = _make_fake_patient(events) + + with patch.object(task, "_load_signal", return_value=_fake_signal()): + samples = task(patient) + + assert samples[0]["label"] == 1 + + def test_failed_signal_load_skipped(self): + from pyhealth.tasks.ecg_classification import ECGBinaryClassification + + task = ECGBinaryClassification() + events = [_make_fake_event(ecg_id="1"), _make_fake_event(ecg_id="2")] + patient = _make_fake_patient(events) + + # First load fails, second succeeds + with patch.object( + task, "_load_signal", side_effect=[None, _fake_signal()] + ): + samples = task(patient) + + assert len(samples) == 1 + assert samples[0]["ecg_id"] == "2" + + def test_normalize_produces_unit_std(self): + from pyhealth.tasks.ecg_classification import ECGBinaryClassification + + rng = np.random.default_rng(0) + signal = rng.normal(loc=5.0, scale=3.0, size=(12, 500)).astype(np.float32) + normed = ECGBinaryClassification._normalize(signal) + + np.testing.assert_allclose(normed.mean(axis=1), 0.0, atol=1e-5) + np.testing.assert_allclose(normed.std(axis=1), 1.0, atol=1e-5) + + def test_empty_patient_returns_empty_list(self): + from pyhealth.tasks.ecg_classification import ECGBinaryClassification + + task = ECGBinaryClassification() + patient = _make_fake_patient([]) + samples = task(patient) + assert samples == [] diff --git a/tests/core/test_ptbxl.py b/tests/core/test_ptbxl.py new file mode 100644 index 000000000..721d9558a --- /dev/null +++ b/tests/core/test_ptbxl.py @@ -0,0 +1,386 @@ +"""Tests for PTBXLDataset using synthetic data. + +All tests are self-contained: they create temporary files in TemporaryDirectory +contexts (auto-cleaned on exit), never touch real PTB-XL data, and complete +in milliseconds. +""" +from __future__ import annotations + +import os +import tempfile +from typing import Any, Dict, List +from unittest.mock import MagicMock + +import pandas as pd +import pytest + + +# --------------------------------------------------------------------------- +# Synthetic data helpers +# --------------------------------------------------------------------------- + +def _write_fake_ptbxl( + root: str, + n_patients: int = 3, + n_records: int = 6, +) -> None: + """Write minimal fake PTB-XL CSV files to *root* (≤ 5 patients).""" + # diagnostic_class is the correct PTB-XL column holding the 5 superclass + # labels (MI, HYP, STTC, CD, NORM). diagnostic_subclass holds finer types. + pd.DataFrame( + { + "Unnamed: 0": ["NORM", "MI", "ISCAL", "HYP", "STTC", "CD"], + "diagnostic": [1, 1, 1, 1, 1, 1], + "diagnostic_class": ["NORM", "MI", "MI", "HYP", "STTC", "CD"], + "diagnostic_subclass": ["NORM", "IMI", "ISCAL", "LVH", "ISC_", "CLBBB"], + } + ).to_csv(os.path.join(root, "scp_statements.csv"), index=False) + + rows = [] + for i in range(1, n_records + 1): + if i % 3 == 0: + scp = {"MI": 100.0} + elif i % 4 == 0: + scp = {"HYP": 80.0} + elif i % 5 == 0: + scp = {"STTC": 90.0} + else: + scp = {"NORM": 100.0} + + rows.append( + { + "ecg_id": i, + "patient_id": ((i - 1) % n_patients) + 1, + "recording_date": "2000-01-01 00:00:00", + "filename_lr": f"records100/00000/{i:05d}_lr", + "filename_hr": f"records500/00000/{i:05d}_hr", + "scp_codes": str(scp), + } + ) + pd.DataFrame(rows).to_csv( + os.path.join(root, "ptbxl_database.csv"), index=False + ) + + +def _make_fake_event(attr: Dict[str, Any]) -> MagicMock: + """Return a mock Event with ``__getitem__`` backed by *attr*.""" + event = MagicMock() + event.__getitem__ = lambda self, key: attr[key] + return event + + +def _make_fake_patient( + events: List[MagicMock], patient_id: str = "P001" +) -> MagicMock: + patient = MagicMock() + patient.patient_id = patient_id + patient.get_events.return_value = events + return patient + + +# --------------------------------------------------------------------------- +# Fast & Performant — synthetic data, tempdir, millisecond execution +# --------------------------------------------------------------------------- + +class TestDataLoading: + """Tests data loading: metadata CSV creation and content.""" + + def test_metadata_csv_created(self): + with tempfile.TemporaryDirectory() as root: + _write_fake_ptbxl(root) + from pyhealth.datasets.ptbxl import PTBXLDataset + + obj = PTBXLDataset.__new__(PTBXLDataset) + obj.sampling_rate = 100 + obj._prepare_metadata(root) + + assert os.path.exists(os.path.join(root, "ptbxl_metadata.csv")) + + def test_row_count_matches_input(self): + with tempfile.TemporaryDirectory() as root: + _write_fake_ptbxl(root, n_records=6) + from pyhealth.datasets.ptbxl import PTBXLDataset + + obj = PTBXLDataset.__new__(PTBXLDataset) + obj.sampling_rate = 100 + obj._prepare_metadata(root) + + meta = pd.read_csv(os.path.join(root, "ptbxl_metadata.csv")) + assert len(meta) == 6 + + def test_missing_scp_statements_graceful(self): + """Metadata preparation succeeds even without scp_statements.csv.""" + with tempfile.TemporaryDirectory() as root: + _write_fake_ptbxl(root) + os.remove(os.path.join(root, "scp_statements.csv")) + from pyhealth.datasets.ptbxl import PTBXLDataset + + obj = PTBXLDataset.__new__(PTBXLDataset) + obj.sampling_rate = 100 + obj._prepare_metadata(root) + + meta = pd.read_csv(os.path.join(root, "ptbxl_metadata.csv")) + assert (meta["mi_label"] == 0).all() + + def test_metadata_not_regenerated_if_exists(self): + """Second call to __init__ reuses existing ptbxl_metadata.csv.""" + with tempfile.TemporaryDirectory() as root: + _write_fake_ptbxl(root) + from pyhealth.datasets.ptbxl import PTBXLDataset + + obj = PTBXLDataset.__new__(PTBXLDataset) + obj.sampling_rate = 100 + obj._prepare_metadata(root) + + mtime_first = os.path.getmtime( + os.path.join(root, "ptbxl_metadata.csv") + ) + # Calling again should NOT re-write the file + obj2 = PTBXLDataset.__new__(PTBXLDataset) + obj2.sampling_rate = 100 + # Simulate __init__ guard: only call if file absent + if not os.path.exists(os.path.join(root, "ptbxl_metadata.csv")): + obj2._prepare_metadata(root) + + mtime_second = os.path.getmtime( + os.path.join(root, "ptbxl_metadata.csv") + ) + assert mtime_first == mtime_second + + +# --------------------------------------------------------------------------- +# Data Integrity +# --------------------------------------------------------------------------- + +class TestDataIntegrity: + """Tests schema, label correctness, and filename integrity.""" + + def test_required_columns_present(self): + with tempfile.TemporaryDirectory() as root: + _write_fake_ptbxl(root) + from pyhealth.datasets.ptbxl import PTBXLDataset + + obj = PTBXLDataset.__new__(PTBXLDataset) + obj.sampling_rate = 100 + obj._prepare_metadata(root) + + meta = pd.read_csv(os.path.join(root, "ptbxl_metadata.csv")) + for col in ( + "patient_id", "ecg_id", "recording_date", "filename", + "mi_label", "hyp_label", "sttc_label", "cd_label", + ): + assert col in meta.columns, f"Missing column: {col}" + + def test_mi_labels_correct(self): + with tempfile.TemporaryDirectory() as root: + _write_fake_ptbxl(root, n_records=6) + from pyhealth.datasets.ptbxl import PTBXLDataset + + obj = PTBXLDataset.__new__(PTBXLDataset) + obj.sampling_rate = 100 + obj._prepare_metadata(root) + + meta = pd.read_csv(os.path.join(root, "ptbxl_metadata.csv")) + # ecg_id is stored as digit-only strings; pandas infers int64 on read-back. + ecg_id = meta["ecg_id"].astype(int) + assert meta[ecg_id == 3]["mi_label"].iloc[0] == 1 + assert meta[ecg_id == 6]["mi_label"].iloc[0] == 1 + assert meta[ecg_id == 1]["mi_label"].iloc[0] == 0 + + def test_labels_are_binary(self): + with tempfile.TemporaryDirectory() as root: + _write_fake_ptbxl(root) + from pyhealth.datasets.ptbxl import PTBXLDataset + + obj = PTBXLDataset.__new__(PTBXLDataset) + obj.sampling_rate = 100 + obj._prepare_metadata(root) + + meta = pd.read_csv(os.path.join(root, "ptbxl_metadata.csv")) + for col in ("mi_label", "hyp_label", "sttc_label", "cd_label"): + assert set(meta[col].unique()).issubset({0, 1}), ( + f"{col} contains non-binary values" + ) + + def test_filename_contains_root(self): + with tempfile.TemporaryDirectory() as root: + _write_fake_ptbxl(root) + from pyhealth.datasets.ptbxl import PTBXLDataset + + obj = PTBXLDataset.__new__(PTBXLDataset) + obj.sampling_rate = 100 + obj._prepare_metadata(root) + + meta = pd.read_csv(os.path.join(root, "ptbxl_metadata.csv")) + assert meta["filename"].iloc[0].startswith(root) + + +# --------------------------------------------------------------------------- +# Patient Parsing +# --------------------------------------------------------------------------- + +class TestPatientParsing: + """Tests that records are correctly grouped by patient_id.""" + + def test_correct_patient_count(self): + with tempfile.TemporaryDirectory() as root: + _write_fake_ptbxl(root, n_patients=3, n_records=6) + from pyhealth.datasets.ptbxl import PTBXLDataset + + obj = PTBXLDataset.__new__(PTBXLDataset) + obj.sampling_rate = 100 + obj._prepare_metadata(root) + + meta = pd.read_csv(os.path.join(root, "ptbxl_metadata.csv")) + assert meta["patient_id"].nunique() == 3 + + def test_patient_ids_are_strings(self): + with tempfile.TemporaryDirectory() as root: + _write_fake_ptbxl(root) + from pyhealth.datasets.ptbxl import PTBXLDataset + + obj = PTBXLDataset.__new__(PTBXLDataset) + obj.sampling_rate = 100 + obj._prepare_metadata(root) + + meta = pd.read_csv(os.path.join(root, "ptbxl_metadata.csv")) + # patient_id is written as digit-only strings; pandas infers int64 on + # read-back. Verify the values are at least safely castable to str. + assert meta["patient_id"].apply(lambda x: str(x)).dtype == object + + def test_records_per_patient_correct(self): + with tempfile.TemporaryDirectory() as root: + _write_fake_ptbxl(root, n_patients=3, n_records=6) + from pyhealth.datasets.ptbxl import PTBXLDataset + + obj = PTBXLDataset.__new__(PTBXLDataset) + obj.sampling_rate = 100 + obj._prepare_metadata(root) + + meta = pd.read_csv(os.path.join(root, "ptbxl_metadata.csv")) + counts = meta.groupby("patient_id").size() + # 6 records across 3 patients → 2 records each + assert (counts == 2).all() + + +# --------------------------------------------------------------------------- +# Event Parsing +# --------------------------------------------------------------------------- + +class TestEventParsing: + """Tests event attribute access via mock Patient/Event objects.""" + + def test_event_attribute_access(self): + """Events return correct attribute values via __getitem__.""" + attr = { + "filename": "/fake/path/00001_lr", + "ecg_id": "1", + "mi_label": 1, + "hyp_label": 0, + "sttc_label": 0, + "cd_label": 0, + } + event = _make_fake_event(attr) + assert event["filename"] == "/fake/path/00001_lr" + assert event["mi_label"] == 1 + assert event["hyp_label"] == 0 + + def test_patient_get_events_returns_list(self): + """patient.get_events() returns the expected event list.""" + events = [ + _make_fake_event({"filename": "/f/1", "ecg_id": "1", + "mi_label": 1, "hyp_label": 0, + "sttc_label": 0, "cd_label": 0}), + _make_fake_event({"filename": "/f/2", "ecg_id": "2", + "mi_label": 0, "hyp_label": 1, + "sttc_label": 0, "cd_label": 0}), + ] + patient = _make_fake_patient(events, patient_id="42") + returned = patient.get_events(event_type="ecg_records") + assert len(returned) == 2 + assert returned[0]["ecg_id"] == "1" + + +# --------------------------------------------------------------------------- +# Task Functionality (dataset + task integration) +# --------------------------------------------------------------------------- + +class TestTaskFunctionality: + """Tests that ECGBinaryClassification works on mock PTB-XL patients.""" + + def test_task_produces_correct_sample_count(self): + from unittest.mock import patch + import numpy as np + from pyhealth.tasks.ecg_classification import ECGBinaryClassification + + task = ECGBinaryClassification(task_label="MI", target_length=128) + events = [ + _make_fake_event({"filename": "/f/1", "ecg_id": "1", + "mi_label": 1, "hyp_label": 0, + "sttc_label": 0, "cd_label": 0}), + _make_fake_event({"filename": "/f/2", "ecg_id": "2", + "mi_label": 0, "hyp_label": 0, + "sttc_label": 0, "cd_label": 0}), + ] + patient = _make_fake_patient(events, patient_id="1") + fake_signal = np.random.randn(12, 200).astype(np.float32) + + with patch.object(task, "_load_signal", return_value=fake_signal): + samples = task(patient) + + assert len(samples) == 2 + assert samples[0]["label"] == 1 + assert samples[1]["label"] == 0 + assert samples[0]["ecg"].shape == (12, 128) + + def test_task_label_mi_vs_hyp(self): + """Switching task_label changes which column is used as the target.""" + from unittest.mock import patch + import numpy as np + from pyhealth.tasks.ecg_classification import ECGBinaryClassification + + attr = {"filename": "/f/1", "ecg_id": "1", + "mi_label": 0, "hyp_label": 1, + "sttc_label": 0, "cd_label": 0} + events = [_make_fake_event(attr)] + patient = _make_fake_patient(events) + fake_signal = np.random.randn(12, 128).astype(np.float32) + + for label_col, task_label, expected in [ + ("mi_label", "MI", 0), + ("hyp_label", "HYP", 1), + ]: + task = ECGBinaryClassification(task_label=task_label, + target_length=128) + with patch.object(task, "_load_signal", return_value=fake_signal): + samples = task(patient) + assert samples[0]["label"] == expected, ( + f"Expected label {expected} for {task_label}" + ) + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + +class TestValidation: + def test_invalid_sampling_rate_raises(self): + """PTBXLDataset.__init__ rejects sampling rates other than 100 or 500.""" + from pyhealth.datasets.ptbxl import PTBXLDataset + + with tempfile.TemporaryDirectory() as root: + _write_fake_ptbxl(root) + # ValueError is raised before super().__init__(), so no patching needed. + obj = PTBXLDataset.__new__(PTBXLDataset) + with pytest.raises(ValueError, match="sampling_rate must be"): + PTBXLDataset.__init__(obj, root=root, sampling_rate=250) + + def test_missing_database_raises(self): + """FileNotFoundError raised when ptbxl_database.csv is absent.""" + with tempfile.TemporaryDirectory() as root: + from pyhealth.datasets.ptbxl import PTBXLDataset + + obj = PTBXLDataset.__new__(PTBXLDataset) + obj.sampling_rate = 100 + with pytest.raises(FileNotFoundError, match="ptbxl_database.csv"): + obj._prepare_metadata(root) diff --git a/tests/core/test_taskaug_resnet.py b/tests/core/test_taskaug_resnet.py new file mode 100644 index 000000000..70b6de8ec --- /dev/null +++ b/tests/core/test_taskaug_resnet.py @@ -0,0 +1,239 @@ +"""Tests for TaskAugResNet, TaskAugPolicy, and _ResNet1D. + +All tests use synthetic in-memory tensors — no dataset files required. +Signal length is set to 128 throughout so each test completes in milliseconds +on CPU (AdaptiveAvgPool1d handles any input length). +""" +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +import torch +import torch.nn as nn + +# Short signal length so every test runs in milliseconds on CPU. +_T = 128 + + +# --------------------------------------------------------------------------- +# Mock dataset helper +# --------------------------------------------------------------------------- + +def _mock_dataset(output_size: int = 1) -> MagicMock: + """Minimal SampleDataset mock for TaskAugResNet instantiation.""" + ds = MagicMock() + ds.input_schema = {"ecg": "tensor"} + ds.output_schema = {"label": "binary"} + proc = MagicMock() + proc.size.return_value = output_size + ds.output_processors = {"label": proc} + return ds + + +# --------------------------------------------------------------------------- +# _ResNet1D — instantiation, forward pass, output shapes, gradients +# --------------------------------------------------------------------------- + +class TestResNet1D: + def test_output_shape(self): + from pyhealth.models.taskaug_resnet import _ResNet1D + + model = _ResNet1D(in_channels=12, num_classes=1) + out = model(torch.randn(4, 12, _T)) + assert out.shape == (4, 1), f"Expected (4, 1), got {out.shape}" + + def test_variable_input_lengths(self): + """AdaptiveAvgPool1d should handle any T >= 1.""" + from pyhealth.models.taskaug_resnet import _ResNet1D + + model = _ResNet1D(in_channels=12, num_classes=1) + for length in (64, 128, 256): + assert model(torch.randn(2, 12, length)).shape == (2, 1) + + def test_batch_size_one(self): + from pyhealth.models.taskaug_resnet import _ResNet1D + + model = _ResNet1D(in_channels=12, num_classes=1) + model.eval() + assert model(torch.randn(1, 12, _T)).shape == (1, 1) + + def test_gradient_flow_through_backbone(self): + """Gradients must flow from loss back to the input tensor.""" + from pyhealth.models.taskaug_resnet import _ResNet1D + + model = _ResNet1D(in_channels=12, num_classes=1) + x = torch.randn(2, 12, _T, requires_grad=True) + model(x).sum().backward() + assert x.grad is not None + assert not torch.isnan(x.grad).any() + + def test_no_nan_in_output(self): + from pyhealth.models.taskaug_resnet import _ResNet1D + + model = _ResNet1D(in_channels=12, num_classes=1) + out = model(torch.randn(4, 12, _T)) + assert not torch.isnan(out).any() + + +# --------------------------------------------------------------------------- +# TaskAugPolicy — instantiation, forward pass, output shapes, gradients +# --------------------------------------------------------------------------- + +class TestTaskAugPolicy: + def test_output_shape_preserved(self): + from pyhealth.models.taskaug_resnet import TaskAugPolicy + + policy = TaskAugPolicy(num_stages=2) + x = torch.randn(4, 12, _T) + out = policy(x, torch.randint(0, 2, (4,))) + assert out.shape == x.shape + + def test_gradient_flows_to_logits(self): + from pyhealth.models.taskaug_resnet import TaskAugPolicy + + policy = TaskAugPolicy(num_stages=2) + x = torch.randn(4, 12, _T) + policy(x, torch.randint(0, 2, (4,))).mean().backward() + assert policy.logits.grad is not None + assert not torch.isnan(policy.logits.grad).any() + + def test_gradient_flows_to_magnitudes(self): + from pyhealth.models.taskaug_resnet import TaskAugPolicy + + policy = TaskAugPolicy(num_stages=2) + x = torch.randn(4, 12, _T) + policy(x, torch.randint(0, 2, (4,))).mean().backward() + assert policy.mag_neg.grad is not None + assert policy.mag_pos.grad is not None + + def test_all_positive_labels(self): + from pyhealth.models.taskaug_resnet import TaskAugPolicy + + policy = TaskAugPolicy() + x = torch.randn(3, 12, _T) + out = policy(x, torch.ones(3, dtype=torch.long)) + assert out.shape == x.shape + + def test_all_negative_labels(self): + from pyhealth.models.taskaug_resnet import TaskAugPolicy + + policy = TaskAugPolicy() + x = torch.randn(3, 12, _T) + out = policy(x, torch.zeros(3, dtype=torch.long)) + assert out.shape == x.shape + + def test_single_stage(self): + from pyhealth.models.taskaug_resnet import TaskAugPolicy + + policy = TaskAugPolicy(num_stages=1) + x = torch.randn(2, 12, _T) + assert policy(x, torch.randint(0, 2, (2,))).shape == x.shape + + def test_learnable_parameter_count(self): + from pyhealth.models.taskaug_resnet import TaskAugPolicy, _NUM_OPS + + policy = TaskAugPolicy(num_stages=2) + n_params = sum(p.numel() for p in policy.parameters()) + # logits (2×N) + mag_neg (2×N) + mag_pos (2×N) = 3 × 2 × N_OPS + assert n_params == 3 * 2 * _NUM_OPS + + +# --------------------------------------------------------------------------- +# TaskAugResNet — instantiation, forward pass, output shapes, gradients +# --------------------------------------------------------------------------- + +class TestTaskAugResNet: + def test_instantiation(self): + from pyhealth.models.taskaug_resnet import TaskAugResNet + + model = TaskAugResNet(_mock_dataset()) + assert hasattr(model, "policy") + assert hasattr(model, "backbone") + + def test_forward_no_label_output_keys(self): + from pyhealth.models.taskaug_resnet import TaskAugResNet + + model = TaskAugResNet(_mock_dataset()) + model.eval() + out = model(ecg=torch.randn(4, 12, _T)) + assert "logit" in out + assert "y_prob" in out + assert "loss" not in out + + def test_forward_output_shapes(self): + from pyhealth.models.taskaug_resnet import TaskAugResNet + + model = TaskAugResNet(_mock_dataset()) + model.eval() + out = model(ecg=torch.randn(4, 12, _T)) + assert out["logit"].shape == (4, 1) + assert out["y_prob"].shape == (4, 1) + + def test_forward_with_label_output_keys(self): + from pyhealth.models.taskaug_resnet import TaskAugResNet + + model = TaskAugResNet(_mock_dataset()) + model.train() + out = model(ecg=torch.randn(4, 12, _T), + label=torch.randint(0, 2, (4,))) + assert "loss" in out + assert "y_true" in out + assert out["loss"].ndim == 0 # scalar + + def test_y_prob_bounded(self): + """y_prob must lie in [0, 1] (sigmoid output).""" + from pyhealth.models.taskaug_resnet import TaskAugResNet + + model = TaskAugResNet(_mock_dataset()) + model.eval() + out = model(ecg=torch.randn(8, 12, _T)) + assert (out["y_prob"] >= 0).all() and (out["y_prob"] <= 1).all() + + def test_no_augmentation_in_eval_mode(self): + """Identical inputs must produce identical outputs in eval mode.""" + from pyhealth.models.taskaug_resnet import TaskAugResNet + + model = TaskAugResNet(_mock_dataset()) + model.eval() + x = torch.randn(2, 12, _T) + label = torch.zeros(2, dtype=torch.long) + out1 = model(ecg=x.clone(), label=label) + out2 = model(ecg=x.clone(), label=label) + torch.testing.assert_close(out1["logit"], out2["logit"]) + + def test_augmentation_stochastic_in_train_mode(self): + """Two forward passes in train mode should produce different outputs.""" + from pyhealth.models.taskaug_resnet import TaskAugResNet + + torch.manual_seed(0) + model = TaskAugResNet(_mock_dataset()) + model.train() + x = torch.randn(4, 12, _T) + labels = torch.randint(0, 2, (4,)) + out1 = model(ecg=x.clone(), label=labels) + out2 = model(ecg=x.clone(), label=labels) + assert not torch.allclose(out1["logit"], out2["logit"]) + + def test_gradient_computation_full_model(self): + """Every trainable parameter must receive a gradient after backward.""" + from pyhealth.models.taskaug_resnet import TaskAugResNet + + model = TaskAugResNet(_mock_dataset()) + model.train() + out = model(ecg=torch.randn(2, 12, _T), + label=torch.randint(0, 2, (2,))) + out["loss"].backward() + for name, p in model.named_parameters(): + if p.requires_grad and p.numel() > 0: + assert p.grad is not None, f"No gradient for {name}" + + def test_policy_backbone_param_groups_disjoint(self): + from pyhealth.models.taskaug_resnet import TaskAugResNet + + model = TaskAugResNet(_mock_dataset()) + policy_params = set(model.policy_parameters()) + backbone_params = set(model.backbone_parameters()) + assert len(policy_params) > 0 + assert len(backbone_params) > 0 + assert policy_params.isdisjoint(backbone_params)