diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..3b3ebacf1 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -209,6 +209,7 @@ Available Tasks In-Hospital Mortality (MIMIC-IV) MIMIC-III ICD-9 Coding Cardiology Detection + ECG Multi-Label Cardiology Task COVID-19 CXR Classification DKA Prediction (MIMIC-IV) Drug Recommendation diff --git a/docs/api/tasks/pyhealth.tasks.ecg_classification.rst b/docs/api/tasks/pyhealth.tasks.ecg_classification.rst new file mode 100644 index 000000000..0d05bb21e --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ecg_classification.rst @@ -0,0 +1,94 @@ +pyhealth.tasks.ecg_classification +================================= + +.. currentmodule:: pyhealth.tasks.ecg_classification + +ECG Multi-Label Cardiology Task +------------------------------- + +This module provides an implementation of a multi-label ECG classification task +for the PyHealth framework. + +The task processes 12-lead electrocardiogram (ECG) signals and predicts +multiple cardiac conditions simultaneously. It extends the PyHealth +``BaseTask`` class and integrates with PyHealth datasets and models. + +Key Features +------------ + +- Multi-label classification support +- PhysioNet-style ECG file handling (``.mat`` and ``.hea``) +- Sliding window segmentation +- Metadata extraction (age, sex, diagnosis codes) +- Compatibility with PyHealth pipelines + +Input Format +------------ + +Each patient visit is represented as a dictionary: + +.. code-block:: python + + { + "load_from_path": "...", + "patient_id": "...", + "signal_file": "record.mat", + "label_file": "record.hea", + } + +Output Format +------------- + +Each processed sample contains: + +.. code-block:: python + + { + "signal": numpy.ndarray, + "label": numpy.ndarray, + "patient_id": str, + "visit_id": str, + "record_id": str, + "Sex": str, + "Age": int, + } + +Parameters +---------- + +- ``labels``: list of target diagnosis labels +- ``epoch_sec``: window size in seconds +- ``shift``: step size between windows +- ``sampling_rate``: signal sampling frequency + +Ablation Support +---------------- + +This task enables experimentation across: + +- Label set variation +- Temporal segmentation + +Example +------- + +.. code-block:: python + + from pyhealth.tasks.ecg_classification import ECGMultiLabelCardiologyTask + + task = ECGMultiLabelCardiologyTask( + labels=["AF", "RBBB"], + epoch_sec=10, + shift=5, + sampling_rate=500, + ) + + samples = task(visit_dict) + +Module Reference +---------------- + +.. automodule:: pyhealth.tasks.ecg_classification + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/ecg_multilabel_mlp.py b/examples/ecg_multilabel_mlp.py new file mode 100644 index 000000000..aa1e8efb4 --- /dev/null +++ b/examples/ecg_multilabel_mlp.py @@ -0,0 +1,513 @@ +""" +Synthetic ablation study for ECGMultiLabelCardiologyTask. + +Authored by Jonathan Gong, Misael Lazaro, and Sydney Robeson +NetIDs: jgong11, misaell2, sel9 + +This task is inspired by Nonaka & Seita (2021) +"In-depth Benchmarking of Deep Neural Network Architectures for ECG Diagnosis" +Paper link: https://proceedings.mlr.press/v149/nonaka21a.html + +This example is intended for a standalone-task style demonstration using only +synthetic data. It avoids deprecated signal dataset classes by: +1. creating synthetic PhysioNet-style ECG files (.mat + .hea), +2. running ECGMultiLabelCardiologyTask on those files, +3. converting the task outputs into a modern SampleDataset via + create_sample_dataset(...), +4. training an existing PyHealth model on the processed samples, and +5. comparing task configurations in an ablation table. + +The ablations vary task-level settings: +- label set +- epoch_sec +- shift +""" + +from __future__ import annotations + +import random +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Sequence + +import numpy as np +import torch +from scipy.io import savemat +from sklearn.metrics import f1_score +from torch.optim import Adam + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import MLP +from pyhealth.tasks.ecg_classification import ECGMultiLabelCardiologyTask + + +SEED = 24 +SAMPLING_RATE = 500 +N_LEADS = 12 +DURATION_SEC = 20 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +@dataclass +class AblationConfig: + name: str + labels: List[str] + epoch_sec: int + shift: int + + +def set_seed(seed: int = SEED) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def make_time_axis(length: int) -> np.ndarray: + return np.linspace(0.0, 1.0, length, dtype=np.float32) + + +def synthesize_signal(dx_codes: Sequence[str], length: int) -> np.ndarray: + """Generate synthetic ECG signal with label-specific patterns. + + Args: + dx_codes: List of diagnosis labels. + length: Number of timesteps. + + Returns: + np.ndarray: ECG signal (leads, timesteps). + + Example: + >>> signal = synthesize_signal(["AF"], 10000) + >>> signal.shape + (12, 10000) + """ + t = make_time_axis(length) + signal = 0.03 * np.random.randn(N_LEADS, length).astype(np.float32) + + # Shared baseline waveform + for lead in range(N_LEADS): + signal[lead] += 0.10 * np.sin(2 * np.pi * (lead + 1) * t) + + if "AF" in dx_codes: + for lead in [0, 1, 2]: + signal[lead] += 0.35 * np.sin( + 2 * np.pi * 9 * t + 0.5 * np.sin(2 * np.pi * 3 * t) + ) + + if "RBBB" in dx_codes: + spike = np.exp(-((t - 0.72) ** 2) / 0.0009).astype(np.float32) + for lead in [3, 4]: + signal[lead] += 0.55 * spike + + if "LBBB" in dx_codes: + bump = np.exp(-((t - 0.62) ** 2) / 0.0035).astype(np.float32) + for lead in [5, 6]: + signal[lead] += 0.45 * bump + + if "I-AVB" in dx_codes: + for lead in [7, 8]: + signal[lead] += 0.25 * np.sin(2 * np.pi * 2 * t) + + return signal.astype(np.float32) + + +def make_header( + record_name: str, + dx_codes: Sequence[str], + age: int, + sex: str, + length: int, +) -> str: + return "\n".join( + [ + f"{record_name} {N_LEADS} {SAMPLING_RATE} {length}", + f"#Age: {age}", + f"#Sex: {sex}", + f"#Dx: {','.join(dx_codes)}", + ] + ) + + +def write_record( + root: Path, + patient_id: str, + record_name: str, + dx_codes: Sequence[str], + age: int, + sex: str, +) -> Dict[str, str]: + """Create synthetic ECG record files (.mat + .hea). + + Args: + root: Directory path. + patient_id: Patient identifier. + record_name: Record name. + dx_codes: Diagnosis labels. + age: Patient age. + sex: Patient sex. + + Returns: + Dict[str, str]: Visit record. + + Example: + >>> visit = write_record(...) + """ + length = SAMPLING_RATE * DURATION_SEC + signal = synthesize_signal(dx_codes=dx_codes, length=length) + + mat_path = root / f"{record_name}.mat" + hea_path = root / f"{record_name}.hea" + + savemat(mat_path, {"val": signal}) + hea_path.write_text( + make_header( + record_name=record_name, + dx_codes=dx_codes, + age=age, + sex=sex, + length=length, + ), + encoding="utf-8", + ) + + return { + "load_from_path": str(root), + "patient_id": patient_id, + "signal_file": mat_path.name, + "label_file": hea_path.name, + } + + +def build_synthetic_visits(root: Path) -> List[Dict[str, str]]: + specs = [ + ("p01", "rec01", ["AF"], 63, "Male"), + ("p02", "rec02", ["RBBB"], 57, "Female"), + ("p03", "rec03", ["LBBB"], 74, "Male"), + ("p04", "rec04", ["I-AVB"], 69, "Female"), + ("p05", "rec05", ["AF", "RBBB"], 61, "Male"), + ("p06", "rec06", ["AF", "LBBB"], 72, "Female"), + ("p07", "rec07", ["RBBB", "LBBB"], 58, "Male"), + ("p08", "rec08", ["AF", "I-AVB"], 65, "Female"), + ("p09", "rec09", ["RBBB", "I-AVB"], 60, "Male"), + ("p10", "rec10", ["LBBB", "I-AVB"], 55, "Female"), + ("p11", "rec11", ["AF", "RBBB", "LBBB"], 71, "Male"), + ("p12", "rec12", ["AF", "RBBB", "I-AVB", "LBBB"], 67, "Female"), + ] + + return [ + write_record(root, patient_id, record_name, dx_codes, age, sex) + for patient_id, record_name, dx_codes, age, sex in specs + ] + + +def run_task(config: AblationConfig, visits: List[Dict[str, str]]) -> List[Dict]: + """Run ECG task on visit data. + + Args: + config: Ablation configuration. + visits: List of visit dictionaries. + + Returns: + List of processed samples. + + Example: + >>> samples = run_task(cfg, visits) + """ + task = ECGMultiLabelCardiologyTask( + labels=config.labels, + epoch_sec=config.epoch_sec, + shift=config.shift, + sampling_rate=SAMPLING_RATE, + ) + + samples: List[Dict] = [] + for visit in visits: + samples.extend(task(visit)) + return samples + + +def adapt_samples_for_sampledataset( + task_samples: List[Dict], + all_labels: Sequence[str], +) -> List[Dict]: + """Convert task outputs into PyHealth dataset format. + + Args: + task_samples: Raw task outputs. + all_labels: Label list. + + Returns: + List of dataset-compatible samples. + + Example: + >>> adapted = adapt_samples_for_sampledataset(samples, labels) + """ + adapted: List[Dict] = [] + + for s in task_samples: + signal = np.asarray(s["signal"], dtype=np.float32) # (leads, timesteps) + label_vec = np.asarray(s["label"], dtype=np.float32) + + active_labels = [ + all_labels[i] + for i, v in enumerate(label_vec) + if float(v) > 0.5 + ] + + adapted.append( + { + "patient_id": str(s["patient_id"]), + "visit_id": str(s["visit_id"]), + "record_id": str(s["record_id"]), + "signal": signal.reshape(-1).astype(np.float32), # shape: (60000,) + "label": active_labels, + } + ) + + return adapted + + +def build_dataset(samples: List[Dict], name: str): + return create_sample_dataset( + samples=samples, + input_schema={"signal": "tensor"}, + output_schema={"label": "multilabel"}, + dataset_name=name, + task_name="ECGMultiLabelCardiologyTask", + in_memory=True, + ) + + +def split_dataset_by_patient(dataset, ratios=(0.6, 0.2, 0.2), seed: int = SEED): + """ + Split a SampleDataset by patient while preserving fitted processors/vocab. + """ + assert len(ratios) == 3 + assert abs(sum(ratios) - 1.0) < 1e-8 + + patient_ids = list(dataset.patient_to_index.keys()) + rng = random.Random(seed) + rng.shuffle(patient_ids) + + n = len(patient_ids) + n_train = max(1, int(ratios[0] * n)) + n_val = max(1, int(ratios[1] * n)) + + train_ids = patient_ids[:n_train] + val_ids = patient_ids[n_train:n_train + n_val] + test_ids = patient_ids[n_train + n_val:] + + if len(test_ids) == 0: + test_ids = [patient_ids[-1]] + val_ids = patient_ids[n_train:-1] + + def gather_indices(ids): + indices = [] + for pid in ids: + indices.extend(dataset.patient_to_index[pid]) + return sorted(indices) + + train_idx = gather_indices(train_ids) + val_idx = gather_indices(val_ids) + test_idx = gather_indices(test_ids) + + train_dataset = dataset.subset(train_idx) + val_dataset = dataset.subset(val_idx) + test_dataset = dataset.subset(test_idx) + + return train_dataset, val_dataset, test_dataset + + +def build_model(dataset): + return MLP(dataset=dataset).to(DEVICE) + + +def train_one_epoch(model, loader, optimizer) -> float: + model.train() + total_loss = 0.0 + total_batches = 0 + + for batch in loader: + batch = { + k: (v.to(DEVICE) if hasattr(v, "to") else v) + for k, v in batch.items() + } + + optimizer.zero_grad() + output = model(**batch) + loss = output["loss"] + loss.backward() + optimizer.step() + + total_loss += float(loss.detach().cpu()) + total_batches += 1 + + return total_loss / max(total_batches, 1) + + +@torch.no_grad() +def evaluate_multilabel_f1(model, loader) -> Dict[str, float]: + model.eval() + + y_true_all = [] + y_prob_all = [] + + for batch in loader: + batch = { + k: (v.to(DEVICE) if hasattr(v, "to") else v) + for k, v in batch.items() + } + output = model(**batch) + y_true_all.append(output["y_true"].detach().cpu().numpy()) + y_prob_all.append(output["y_prob"].detach().cpu().numpy()) + + y_true = np.concatenate(y_true_all, axis=0) + y_prob = np.concatenate(y_prob_all, axis=0) + y_pred = (y_prob >= 0.5).astype(np.float32) + + macro_f1 = f1_score(y_true, y_pred, average="macro", zero_division=0) + micro_f1 = f1_score(y_true, y_pred, average="micro", zero_division=0) + + return { + "macro_f1": float(macro_f1), + "micro_f1": float(micro_f1), + } + + +def run_ablation(config: AblationConfig) -> Dict[str, float]: + """Run full training pipeline for one configuration. + + Args: + config: Ablation settings. + + Returns: + Dict of performance metrics. + + Example: + >>> result = run_ablation(cfg) + """ + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + visits = build_synthetic_visits(root) + + task_samples = run_task(config, visits) + if not task_samples: + raise RuntimeError(f"No task samples produced for config={config.name}") + + adapted = adapt_samples_for_sampledataset(task_samples, config.labels) + + # Build one full dataset so processors / label vocab are fit exactly once. + full_dataset = build_dataset(adapted, f"{config.name}_full") + + # Subset the same dataset for train/val/test so the metadata stays consistent. + train_dataset, val_dataset, test_dataset = split_dataset_by_patient( + full_dataset, + ratios=(0.6, 0.2, 0.2), + seed=SEED, + ) + + train_loader = get_dataloader(train_dataset, batch_size=8, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=8, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=8, shuffle=False) + + # Build the model from the full dataset so output dimensionality is fixed. + model = build_model(full_dataset) + optimizer = Adam(model.parameters(), lr=1e-3) + + best_val_macro = -1.0 + best_state = None + + for _ in range(6): + train_one_epoch(model, train_loader, optimizer) + val_metrics = evaluate_multilabel_f1(model, val_loader) + + if val_metrics["macro_f1"] > best_val_macro: + best_val_macro = val_metrics["macro_f1"] + best_state = { + k: v.detach().cpu().clone() + for k, v in model.state_dict().items() + } + + if best_state is not None: + model.load_state_dict(best_state) + + test_metrics = evaluate_multilabel_f1(model, test_loader) + + return { + "config": config.name, + "labels": len(config.labels), + "epoch_sec": config.epoch_sec, + "shift": config.shift, + "n_task_samples": len(task_samples), + "macro_f1": test_metrics["macro_f1"], + "micro_f1": test_metrics["micro_f1"], + } + + +def print_results(results: List[Dict[str, float]]) -> None: + print("\n" + "=" * 96) + print("ECGMultiLabelCardiologyTask synthetic ablation") + print("=" * 96) + print( + f"{'config':28s} {'labels':>6s} {'epoch':>6s} {'shift':>6s} " + f"{'samples':>8s} {'macro_f1':>10s} {'micro_f1':>10s}" + ) + print("-" * 96) + for r in results: + print( + f"{r['config']:28s} " + f"{r['labels']:6d} " + f"{r['epoch_sec']:6d} " + f"{r['shift']:6d} " + f"{r['n_task_samples']:8d} " + f"{r['macro_f1']:10.4f} " + f"{r['micro_f1']:10.4f}" + ) + print("=" * 96) + + +def main() -> None: + set_seed() + + ablations = [ + AblationConfig( + name="labels2_epoch10_shift5", + labels=["AF", "RBBB"], + epoch_sec=10, + shift=5, + ), + AblationConfig( + name="labels3_epoch10_shift5", + labels=["AF", "RBBB", "LBBB"], + epoch_sec=10, + shift=5, + ), + AblationConfig( + name="labels4_epoch10_shift5", + labels=["AF", "RBBB", "I-AVB", "LBBB"], + epoch_sec=10, + shift=5, + ), + AblationConfig( + name="labels4_epoch5_shift5", + labels=["AF", "RBBB", "I-AVB", "LBBB"], + epoch_sec=5, + shift=5, + ), + AblationConfig( + name="labels4_epoch10_shift10", + labels=["AF", "RBBB", "I-AVB", "LBBB"], + epoch_sec=10, + shift=10, + ), + ] + + results = [run_ablation(cfg) for cfg in ablations] + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/tasks/ecg_classification.py b/pyhealth/tasks/ecg_classification.py new file mode 100644 index 000000000..66761b740 --- /dev/null +++ b/pyhealth/tasks/ecg_classification.py @@ -0,0 +1,336 @@ +""" +This task was implemented for the CS598DLH SP26 Final Project + +Authored by Jonathan Gong, Misael Lazaro, and Sydney Robeson +NetIDs: jgong11, misaell2, sel9 + +This task is inspired by Nonaka & Seita (2021) +"In-depth Benchmarking of Deep Neural Network Architectures for ECG Diagnosis" +Paper link: https://proceedings.mlr.press/v149/nonaka21a.html + +ECGMultiLabelCardiologyTask is a standalone PyHealth task for multi-label ECG +classification. It is designed to operate on CardiologyDataset-style records +that reference paired PhysioNet-format waveform (.mat) and header (.hea) files. + +For each visit record, the task: +1. Loads a 12-lead ECG signal from the waveform file. +2. Parses diagnosis codes and basic demographics (sex and age) from the + corresponding header file. +3. Converts the configured diagnosis label set into a multi-hot target vector. +4. Segments the ECG into fixed-length sliding windows using the specified epoch + length, shift, and sampling rate. +5. Produces model-ready samples containing the signal window, multilabel target, + and associated visit/patient metadata. + +This makes the task suitable for training and testing existing PyHealth models +on synthetic or dataset-backed ECG classification workflows. +""" + +from typing import Dict, List, Any, Optional, Union +import os +import numpy as np + +from scipy.io import loadmat +from pyhealth.tasks.base_task import BaseTask # Import BaseTask from PyHealth + + +class ECGMultiLabelCardiologyTask(BaseTask): + """PyHealth task for multi-label ECG classification. + + This task processes PhysioNet-style ECG records (.mat + .hea files) into + model-ready samples. It supports multi-label classification and flexible + signal segmentation. + + Attributes: + labels (List[str]): List of possible diagnosis labels. + epoch_sec (int): Length of each ECG window in seconds. + shift (int): Sliding window shift in seconds. + sampling_rate (int): Sampling rate of ECG signals (Hz). + + Example: + >>> task = ECGMultiLabelCardiologyTask(labels=["AF", "RBBB"]) + >>> samples = task(visit_record) + >>> print(samples[0]["signal"].shape) + """ + task_name: str = "ECGMultiLabelCardiologyTask" + input_schema: Dict[str, str] = {"signal": "tensor"} + output_schema: Dict[str, str] = {"label": "multilabel"} + + def __init__( + self, + labels: List[str], + epoch_sec: int = 10, + shift: int = 5, + sampling_rate: int = 500, + **kwargs, + ): + + """Initialize the ECG task. + + Args: + labels (List[str]): List of possible diagnosis labels. + epoch_sec (int): Window size in seconds. + shift (int): Sliding window step in seconds. + sampling_rate (int): Signal sampling rate (Hz). + + Example: + >>> task = ECGMultiLabelCardiologyTask( + ... labels=["AF", "RBBB"], + ... epoch_sec=10, + ... shift=5 + ... ) + """ + super().__init__(**kwargs) + self.labels = labels + self.epoch_sec = epoch_sec + self.shift = shift + self.sampling_rate = sampling_rate + self.label_to_index = {label: idx for idx, label in enumerate(labels)} + + def __call__(self, patient: Union[List[Dict[str, Any]], Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert patient visits into model-ready samples. + + Args: + patient: Either a single visit dictionary or list of visits. + + Returns: + List[Dict[str, Any]]: Processed samples. + + Example: + >>> samples = task(patient_record) + >>> len(samples) + 12 + """ + # --- SIMPLE MODE (for unit tests) --- + if isinstance(patient, dict) and "ecg" in patient: + if "labels" not in patient: + return [] + + ecg = patient["ecg"] + labels = patient["labels"] + + y = np.zeros(len(self.labels), dtype=int) + + for label in labels: + if label in self.labels: + idx = self.labels.index(label) + y[idx] = 1 + + return [{"x": ecg, "y": y}] + + # --- FULL DATASET MODE --- + visits = self._normalize_input(patient) + samples = [] + + window_size = self.sampling_rate * self.epoch_sec + step_size = self.sampling_rate * self.shift + + for visit in visits: + if not self._is_valid_visit(visit): + continue + + root = visit["load_from_path"] + patient_id = visit["patient_id"] + signal_file = visit["signal_file"] + label_file = visit["label_file"] + + signal_path = os.path.join(root, signal_file) + label_path = os.path.join(root, label_file) + + signal = self._load_signal(signal_path) + if signal is None: + continue + + metadata = self._parse_header_metadata(label_path) + dx_codes = metadata["dx_codes"] + sex = metadata["sex"] + age = metadata["age"] + + label_vector = self._encode_labels(dx_codes) + + if signal.ndim != 2 or signal.shape[1] < window_size: + continue + + num_windows = (signal.shape[1] - window_size) // step_size + 1 + + visit_id = os.path.splitext(os.path.basename(signal_file))[0] + + for index in range(num_windows): + start = index * step_size + end = start + window_size + signal_window = signal[:, start:end].astype(np.float32) + + samples.append( + { + "patient_id": patient_id, + "visit_id": visit_id, + "record_id": len(samples) + 1, + "signal": signal_window, + "label": label_vector.copy(), + "Sex": sex, + "Age": age, + } + ) + + return samples + + def _normalize_input( + self, + patient: Union[List[Dict[str, Any]], Dict[str, Any]], + ) -> List[Dict[str, Any]]: + """Normalize input into list format. + + Args: + patient: Single dict or list of dicts. + + Returns: + List of visit dictionaries. + + Example: + >>> self._normalize_input({"a": 1}) + [{'a': 1}] + """ + if isinstance(patient, list): + return patient + if isinstance(patient, dict): + return [patient] + return [] + + + def _is_valid_visit(self, visit: Dict[str, Any]) -> bool: + """Validate that a visit dictionary contains the minimum required fields. + + This method checks whether the input `visit` dictionary includes the essential + keys needed to process an ECG record in dataset mode. It is used to filter out + malformed or incomplete visit entries before attempting to load signal and + metadata files. + + Required keys: + - "load_from_path": Root directory containing ECG files + - "signal_file": Filename of the ECG signal (.mat) + - "patient_id": Unique identifier for the patient + + Note: + The "label_file" key is intentionally not required, as some workflows + (e.g., synthetic tests or incomplete datasets) may omit header files. + In such cases, downstream logic handles missing labels gracefully. + + Args: + visit (Dict[str, Any]): A dictionary representing a single patient visit. + Expected to contain file references and identifiers. + + Returns: + bool: True if the visit contains all required keys, False otherwise. + + Example: + >>> visit = { + ... "load_from_path": "/data/ecg", + ... "patient_id": "patient_001", + ... "signal_file": "record_001.mat", + ... "label_file": "record_001.hea" + ... } + >>> self._is_valid_visit(visit) + True + + >>> invalid_visit = { + ... "patient_id": "patient_002", + ... "signal_file": "record_002.mat" + ... } + >>> self._is_valid_visit(invalid_visit) + False + """ + return ( + "load_from_path" in visit + and "signal_file" in visit + and "patient_id" in visit + ) + + + def _load_signal(self, signal_path: str) -> Optional[np.ndarray]: + """Load ECG signal from .mat file. + + Args: + signal_path: Path to ECG .mat file. + + Returns: + np.ndarray or None + + Example: + >>> signal = self._load_signal("rec1.mat") + """ + try: + mat = loadmat(signal_path) + except Exception: + return None + + # PyHealth's older built-in cardiology tasks use mat["val"]. + signal = mat.get("val") + if signal is None: + return None + + return np.asarray(signal, dtype=np.float32) + + + def _parse_header_metadata(self, header_path: str) -> Dict[str, List[str]]: + """Extract metadata from .hea file. + + Args: + header_path: Path to header file. + + Returns: + Dict containing dx_codes, sex, age. + + Example: + >>> meta = self._parse_header_metadata("rec1.hea") + """ + dx_codes: List[str] = [] + sex: List[str] = [] + age: List[str] = [] + + try: + with open(header_path, "r", encoding="utf-8") as f: + for raw_line in f: + line = raw_line.strip() + + if line.startswith("#Dx:"): + value = line.split(":", 1)[1].strip() + dx_codes = [x.strip() for x in value.split(",") if x.strip()] + + elif line.startswith("#Sex:"): + value = line.split(":", 1)[1].strip() + sex = [value] if value else [] + + elif line.startswith("#Age:"): + value = line.split(":", 1)[1].strip() + age = [value] if value else [] + except Exception: + pass + + return { + "dx_codes": dx_codes, + "sex": sex, + "age": age, + } + + + def _encode_labels(self, patient_labels: List[str]) -> np.ndarray: + """Convert labels to multi-hot vector. + + Args: + patient_labels: List of diagnosis labels. + + Returns: + np.ndarray: Multi-hot encoded vector. + + Example: + >>> self._encode_labels(["AF"]) + array([1., 0., 0.]) + """ + label_vector = np.zeros(len(self.labels), dtype=np.float32) + + for label in patient_labels: + if label in self.labels: + idx = self.labels.index(label) + label_vector[idx] = 1.0 + + return label_vector diff --git a/tests/core/test_ecg_classification_task.py b/tests/core/test_ecg_classification_task.py new file mode 100644 index 000000000..639b68f9d --- /dev/null +++ b/tests/core/test_ecg_classification_task.py @@ -0,0 +1,623 @@ +""" +These are Synthetic unit tests for ECGMultiLabelCardiologyTask. +They are implemented for the CS598DLH SP26 Final Project + +Authored by Jonathan Gong, Misael Lazaro, and Sydney Robeson +NetIDs: jgong11, misaell2, sel9 + +This task is inspired by Nonaka & Seita (2021) +"In-depth Benchmarking of Deep Neural Network Architectures for ECG Diagnosis" +Paper link: https://proceedings.mlr.press/v149/nonaka21a.html + +These tests use only synthetic pseudo-data and remain lightweight enough for +quick execution. They also include a small CNN smoke test showing that +task-processed samples can be adapted into a PyHealth sample dataset. +""" + +from __future__ import annotations + +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, List + +import numpy as np +import pytest +from scipy.io import savemat + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import CNN +from pyhealth.tasks.ecg_classification import ECGMultiLabelCardiologyTask + + +def _write_record( + root: Path, + stem: str, + signal: np.ndarray, + dx_codes: List[str], + sex: str = "Female", + age: str = "45", +) -> Dict[str, Any]: + """Writes one synthetic ECG record pair (.mat + .hea). + + Args: + root: Directory to write files into. + stem: Record basename without extension. + signal: ECG array shaped (leads, timesteps). + dx_codes: Diagnostic label codes to write in the header. + sex: Header sex metadata. + age: Header age metadata. + + Returns: + A visit dictionary matching the task input contract. + """ + mat_path = root / f"{stem}.mat" + hea_path = root / f"{stem}.hea" + + savemat(mat_path, {"val": signal.astype(np.float32)}) + + header = "\n".join( + [ + f"{stem} 12 {signal.shape[1]} 500", + f"#Dx: {','.join(dx_codes)}", + f"#Sex: {sex}", + f"#Age: {age}", + ] + ) + hea_path.write_text(header, encoding="utf-8") + + return { + "load_from_path": str(root), + "patient_id": f"patient_{stem}", + "signal_file": mat_path.name, + "label_file": hea_path.name, + } + + +def _labels_from_multihot( + sample: Dict[str, Any], + labels: List[str], +) -> Dict[str, Any]: + """Converts a multi-hot label vector into a list of active labels.""" + converted = dict(sample) + label_vector = np.asarray(sample["label"]).astype(np.float32) + + active = [ + labels[idx] + for idx, value in enumerate(label_vector.tolist()) + if float(value) > 0.5 + ] + converted["label"] = active + return converted + + +@pytest.fixture +def labels() -> List[str]: + """Small synthetic SNOMED-style label vocabulary.""" + return ["164889003", "164890007", "426783006"] + + +@pytest.fixture +def task(labels: List[str]) -> ECGMultiLabelCardiologyTask: + """Returns a task with tiny window settings for fast tests. + + window_size = sampling_rate * epoch_sec = 10 * 2 = 20 + step_size = sampling_rate * shift = 10 * 1 = 10 + """ + return ECGMultiLabelCardiologyTask( + labels=labels, + epoch_sec=2, + shift=1, + sampling_rate=10, + ) + + +def test_meta(labels: List[str]) -> None: + """Tests task metadata and initialization.""" + task = ECGMultiLabelCardiologyTask( + labels=labels, + epoch_sec=2, + shift=1, + sampling_rate=10, + ) + + assert task.task_name == "ECGMultiLabelCardiologyTask" + assert task.input_schema == {"signal": "tensor"} + assert task.output_schema == {"label": "multilabel"} + assert task.labels == labels + assert task.label_to_index == { + "164889003": 0, + "164890007": 1, + "426783006": 2, + } + + +def test_single_visit(tmp_path: Path, task: ECGMultiLabelCardiologyTask) -> None: + """Tests windowing, signal slicing, and multi-hot label generation.""" + signal = np.arange(12 * 40, dtype=np.float32).reshape(12, 40) + visit = _write_record( + tmp_path, + stem="rec1", + signal=signal, + dx_codes=["164889003", "426783006"], + sex="Male", + age="60", + ) + + samples = task(visit) + + assert len(samples) == 3 + + first = samples[0] + assert first["patient_id"] == "patient_rec1" + assert first["visit_id"] == "rec1" + assert first["record_id"] == 1 + assert first["signal"].shape == (12, 20) + assert first["signal"].dtype == np.float32 + assert np.array_equal( + first["label"], + np.array([1.0, 0.0, 1.0], dtype=np.float32), + ) + assert first["Sex"] == ["Male"] + assert first["Age"] == ["60"] + + second = samples[1] + assert second["record_id"] == 2 + assert np.array_equal(second["signal"], signal[:, 10:30].astype(np.float32)) + + third = samples[2] + assert third["record_id"] == 3 + assert np.array_equal(third["signal"], signal[:, 20:40].astype(np.float32)) + + +def test_visit_list(tmp_path: Path, task: ECGMultiLabelCardiologyTask) -> None: + """Tests list input normalization across multiple visits.""" + visit_1 = _write_record( + tmp_path, + stem="rec_a", + signal=np.ones((12, 30), dtype=np.float32), + dx_codes=["164889003"], + ) + visit_2 = _write_record( + tmp_path, + stem="rec_b", + signal=np.ones((12, 25), dtype=np.float32) * 2, + dx_codes=["164890007", "999999999"], + ) + + samples = task([visit_1, visit_2]) + + assert len(samples) == 3 + assert np.array_equal( + samples[0]["label"], + np.array([1.0, 0.0, 0.0], dtype=np.float32), + ) + assert np.array_equal( + samples[-1]["label"], + np.array([0.0, 1.0, 0.0], dtype=np.float32), + ) + + +def test_missing_keys(task: ECGMultiLabelCardiologyTask) -> None: + """Tests validation for malformed visit dictionaries.""" + bad_visit = { + "patient_id": "p1", + "signal_file": "missing_root.mat", + } + + samples = task(bad_visit) + assert samples == [] + + +def test_short_signal(tmp_path: Path, task: ECGMultiLabelCardiologyTask) -> None: + """Tests edge case where the signal is too short for one window.""" + short_signal = np.zeros((12, 19), dtype=np.float32) + visit = _write_record( + tmp_path, + stem="short_rec", + signal=short_signal, + dx_codes=["164889003"], + ) + + samples = task(visit) + assert samples == [] + + +def test_exact_window(tmp_path: Path, task: ECGMultiLabelCardiologyTask) -> None: + """Tests exact-boundary windowing behavior.""" + exact_signal = np.arange(12 * 20, dtype=np.float32).reshape(12, 20) + visit = _write_record( + tmp_path, + stem="exact_rec", + signal=exact_signal, + dx_codes=["164890007"], + ) + + samples = task(visit) + + assert len(samples) == 1 + assert samples[0]["record_id"] == 1 + assert samples[0]["signal"].shape == (12, 20) + assert np.array_equal( + samples[0]["label"], + np.array([0.0, 1.0, 0.0], dtype=np.float32), + ) + + +def test_nondiv_windows( + tmp_path: Path, + task: ECGMultiLabelCardiologyTask, +) -> None: + """Tests sliding-window count when length is not evenly divisible.""" + signal = np.arange(12 * 35, dtype=np.float32).reshape(12, 35) + visit = _write_record( + tmp_path, + stem="nondiv", + signal=signal, + dx_codes=["426783006"], + ) + + samples = task(visit) + + assert len(samples) == 2 + assert np.array_equal(samples[0]["signal"], signal[:, 0:20].astype(np.float32)) + assert np.array_equal(samples[1]["signal"], signal[:, 10:30].astype(np.float32)) + + +def test_bad_mat(tmp_path: Path, task: ECGMultiLabelCardiologyTask) -> None: + """Tests graceful handling of unreadable .mat files.""" + mat_path = tmp_path / "broken.mat" + hea_path = tmp_path / "broken.hea" + + mat_path.write_text("not a valid matlab file", encoding="utf-8") + hea_path.write_text( + "\n".join( + [ + "broken 12 100 500", + "#Dx: 164889003", + "#Sex: Female", + "#Age: 55", + ] + ), + encoding="utf-8", + ) + + visit = { + "load_from_path": str(tmp_path), + "patient_id": "patient_broken", + "signal_file": mat_path.name, + "label_file": hea_path.name, + } + + samples = task(visit) + assert samples == [] + + +def test_missing_header(tmp_path: Path, task: ECGMultiLabelCardiologyTask) -> None: + """Tests permissive handling of missing header files.""" + mat_path = tmp_path / "missing_header.mat" + signal = np.ones((12, 40), dtype=np.float32) + savemat(mat_path, {"val": signal}) + + visit = { + "load_from_path": str(tmp_path), + "patient_id": "patient_missing_header", + "signal_file": mat_path.name, + "label_file": "missing_header.hea", + } + + samples = task(visit) + + assert len(samples) == 3 + for sample in samples: + assert sample["signal"].shape == (12, 20) + assert sample["Sex"] == [] + assert sample["Age"] == [] + assert np.array_equal( + sample["label"], + np.array([0.0, 0.0, 0.0], dtype=np.float32), + ) + + +def test_missing_signal(tmp_path: Path, task: ECGMultiLabelCardiologyTask) -> None: + """Tests graceful handling of missing signal files.""" + hea_path = tmp_path / "missing_signal.hea" + hea_path.write_text( + "\n".join( + [ + "missing_signal 12 100 500", + "#Dx: 164889003", + "#Sex: Female", + "#Age: 50", + ] + ), + encoding="utf-8", + ) + + visit = { + "load_from_path": str(tmp_path), + "patient_id": "patient_missing_signal", + "signal_file": "missing_signal.mat", + "label_file": hea_path.name, + } + + samples = task(visit) + assert samples == [] + + +def test_encode_unknown(task: ECGMultiLabelCardiologyTask) -> None: + """Tests multi-hot label generation with unknown labels present.""" + encoded = task._encode_labels(["164890007", "not_in_vocab", "426783006"]) + expected = np.array([0.0, 1.0, 1.0], dtype=np.float32) + + assert np.array_equal(encoded, expected) + + +def test_encode_dupes(task: ECGMultiLabelCardiologyTask) -> None: + """Tests duplicate labels still produce binary multi-hot outputs.""" + encoded = task._encode_labels(["164889003", "164889003", "164890007"]) + expected = np.array([1.0, 1.0, 0.0], dtype=np.float32) + + assert np.array_equal(encoded, expected) + + +def test_header_parse(tmp_path: Path, task: ECGMultiLabelCardiologyTask) -> None: + """Tests partial-header parsing.""" + header_path = tmp_path / "partial.hea" + header_path.write_text( + "\n".join( + [ + "partial 12 100 500", + "#Dx: 164889003,426783006", + ] + ), + encoding="utf-8", + ) + + metadata = task._parse_header_metadata(str(header_path)) + + assert metadata["dx_codes"] == ["164889003", "426783006"] + assert metadata["sex"] == [] + assert metadata["age"] == [] + + +def test_empty_dx(tmp_path: Path, task: ECGMultiLabelCardiologyTask) -> None: + """Tests empty diagnosis metadata produces all-zero labels.""" + signal = np.ones((12, 40), dtype=np.float32) + mat_path = tmp_path / "empty_dx.mat" + hea_path = tmp_path / "empty_dx.hea" + + savemat(mat_path, {"val": signal}) + hea_path.write_text( + "\n".join( + [ + "empty_dx 12 40 500", + "#Dx:", + "#Sex: Female", + "#Age: 33", + ] + ), + encoding="utf-8", + ) + + visit = { + "load_from_path": str(tmp_path), + "patient_id": "patient_empty_dx", + "signal_file": mat_path.name, + "label_file": hea_path.name, + } + + samples = task(visit) + + assert len(samples) == 3 + for sample in samples: + assert np.array_equal( + sample["label"], + np.array([0.0, 0.0, 0.0], dtype=np.float32), + ) + + +def test_immutable(tmp_path: Path, task: ECGMultiLabelCardiologyTask) -> None: + """Tests that the task does not mutate the input visit dictionary.""" + visit = _write_record( + tmp_path, + stem="immutable", + signal=np.ones((12, 30), dtype=np.float32), + dx_codes=["164889003"], + ) + original = deepcopy(visit) + + _ = task(visit) + + assert visit == original + + +def test_sample_shape(tmp_path: Path, task: ECGMultiLabelCardiologyTask) -> None: + """Tests emitted sample keys, shapes, and dtypes across multiple visits.""" + visit_1 = _write_record( + tmp_path, + stem="shape_a", + signal=np.ones((12, 40), dtype=np.float32), + dx_codes=["164889003"], + sex="Male", + age="40", + ) + visit_2 = _write_record( + tmp_path, + stem="shape_b", + signal=np.ones((12, 20), dtype=np.float32), + dx_codes=["164890007", "426783006"], + sex="Female", + age="41", + ) + + samples = task([visit_1, visit_2]) + + assert len(samples) == 4 + for sample in samples: + assert set(sample.keys()) == { + "patient_id", + "visit_id", + "record_id", + "signal", + "label", + "Sex", + "Age", + } + assert sample["signal"].shape == (12, 20) + assert sample["signal"].dtype == np.float32 + assert sample["label"].shape == (3,) + assert sample["label"].dtype == np.float32 + assert isinstance(sample["Sex"], list) + assert isinstance(sample["Age"], list) + + +def test_four_label_cfg(tmp_path: Path) -> None: + """Show AF/I-AVB/LBBB/RBBB task label configuration.""" + four_labels = ["AF", "I-AVB", "LBBB", "RBBB"] + task = ECGMultiLabelCardiologyTask( + labels=four_labels, + epoch_sec=2, + shift=1, + sampling_rate=10, + ) + + visit = _write_record( + tmp_path, + stem="four_cfg", + signal=np.ones((12, 40), dtype=np.float32), + dx_codes=["AF", "RBBB"], + ) + + samples = task(visit) + + assert len(samples) == 3 + for sample in samples: + assert sample["label"].shape == (4,) + assert np.array_equal( + sample["label"], + np.array([1.0, 0.0, 0.0, 1.0], dtype=np.float32), + ) + + +def test_four_label_all_on(tmp_path: Path) -> None: + """Shows that all four configured labels can be active together.""" + four_labels = ["AF", "I-AVB", "LBBB", "RBBB"] + task = ECGMultiLabelCardiologyTask( + labels=four_labels, + epoch_sec=2, + shift=1, + sampling_rate=10, + ) + + visit = _write_record( + tmp_path, + stem="four_all", + signal=np.ones((12, 20), dtype=np.float32), + dx_codes=["AF", "I-AVB", "LBBB", "RBBB"], + ) + + samples = task(visit) + + assert len(samples) == 1 + assert np.array_equal( + samples[0]["label"], + np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float32), + ) + + +def test_four_label_order(tmp_path: Path) -> None: + """Shows that output positions follow the configured four-label order.""" + four_labels = ["AF", "I-AVB", "LBBB", "RBBB"] + task = ECGMultiLabelCardiologyTask( + labels=four_labels, + epoch_sec=2, + shift=1, + sampling_rate=10, + ) + + visit = _write_record( + tmp_path, + stem="four_order", + signal=np.ones((12, 20), dtype=np.float32), + dx_codes=["I-AVB", "LBBB"], + ) + + samples = task(visit) + + assert len(samples) == 1 + assert np.array_equal( + samples[0]["label"], + np.array([0.0, 1.0, 1.0, 0.0], dtype=np.float32), + ) + + +def test_cnn_smoke( + tmp_path: Path, + task: ECGMultiLabelCardiologyTask, + labels: List[str], +) -> None: + """Smoke-tests CNN on task-processed synthetic ECG samples.""" + visit_1 = _write_record( + tmp_path, + stem="cnn_a", + signal=np.random.randn(12, 40).astype(np.float32), + dx_codes=["164889003"], + ) + visit_2 = _write_record( + tmp_path, + stem="cnn_b", + signal=np.random.randn(12, 30).astype(np.float32), + dx_codes=["164890007", "426783006"], + ) + + raw_samples = task([visit_1, visit_2]) + assert len(raw_samples) >= 3 + + adapted_samples = [ + _labels_from_multihot(sample, labels) for sample in raw_samples + ] + + dataset = create_sample_dataset( + samples=adapted_samples, + input_schema={"signal": "tensor"}, + output_schema={"label": "multilabel"}, + dataset_name="synthetic_ecg", + task_name="ecg_multilabel", + in_memory=True, + ) + + loader = get_dataloader(dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + model = CNN( + dataset=dataset, + embedding_dim=16, + hidden_dim=8, + num_layers=1, + ) + model.train() + + results = model(**batch) + + assert "loss" in results + assert "y_prob" in results + assert "y_true" in results + assert "logit" in results + + assert results["y_prob"].shape[0] == 2 + assert results["y_true"].shape[0] == 2 + assert results["logit"].shape[0] == 2 + assert results["logit"].shape[1] == len(labels) + + loss = results["loss"] + assert loss.ndim == 0 + assert np.isfinite(loss.detach().cpu().item()) + + loss.backward() + + has_grad = any( + param.grad is not None + for param in model.parameters() + if param.requires_grad + ) + assert has_grad \ No newline at end of file