diff --git a/.gitignore b/.gitignore index 9993737db..a73d698a3 100644 --- a/.gitignore +++ b/.gitignore @@ -139,4 +139,13 @@ data/physionet.org/ .vscode/ # Model weight files (large binaries, distributed separately) -weightfiles/ \ No newline at end of file +weightfiles/ + +# Auto-generated dataset metadata CSVs (contain absolute paths from the +# machine that ran prepare_metadata). These are rebuilt on first use by +# the dataset loaders, so checking them in leaks user-specific paths. +test-resources/**/*-pyhealth.csv + +# Per-record preprocessing caches built by dataset loaders with +# preprocess=True (e.g. LUDB/MIT-BIH/BIDMC .npz files). +test-resources/**/processed/ \ No newline at end of file diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8d9a59d21..a3efd3bf9 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -227,7 +227,10 @@ Available Datasets datasets/pyhealth.datasets.MedicalTranscriptionsDataset datasets/pyhealth.datasets.CardiologyDataset datasets/pyhealth.datasets.eICUDataset + datasets/pyhealth.datasets.BIDMCDataset datasets/pyhealth.datasets.ISRUCDataset + datasets/pyhealth.datasets.LUDBDataset + datasets/pyhealth.datasets.MITBIHDataset datasets/pyhealth.datasets.MIMICExtractDataset datasets/pyhealth.datasets.OMOPDataset datasets/pyhealth.datasets.DREAMTDataset diff --git a/docs/api/datasets/pyhealth.datasets.BIDMCDataset.rst b/docs/api/datasets/pyhealth.datasets.BIDMCDataset.rst new file mode 100644 index 000000000..7b97dc424 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.BIDMCDataset.rst @@ -0,0 +1,11 @@ +pyhealth.datasets.BIDMCDataset +=============================== + +BIDMC respiratory signal dataset — 53 ICU patients with 8-minute recordings +of ECG, PPG, and respiratory signals at 125 Hz. Refer to +`PhysioNet `_ for more information. + +.. autoclass:: pyhealth.datasets.BIDMCDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/datasets/pyhealth.datasets.LUDBDataset.rst b/docs/api/datasets/pyhealth.datasets.LUDBDataset.rst new file mode 100644 index 000000000..afb9f4552 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.LUDBDataset.rst @@ -0,0 +1,11 @@ +pyhealth.datasets.LUDBDataset +============================== + +Lobachevsky University Database (LUDB) — 200 subjects with 12-lead ECG at 500 Hz, +manually annotated with P wave, QRS complex, and T wave boundaries. Refer to +`PhysioNet `_ for more information. + +.. autoclass:: pyhealth.datasets.LUDBDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/datasets/pyhealth.datasets.MITBIHDataset.rst b/docs/api/datasets/pyhealth.datasets.MITBIHDataset.rst new file mode 100644 index 000000000..e21ef039b --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.MITBIHDataset.rst @@ -0,0 +1,11 @@ +pyhealth.datasets.MITBIHDataset +================================ + +MIT-BIH Arrhythmia Database — 48 half-hour excerpts of two-channel ambulatory +ECG at 360 Hz. Refer to +`PhysioNet `_ for more information. + +.. autoclass:: pyhealth.datasets.MITBIHDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/models.rst b/docs/api/models.rst index 7c3ac7c4b..b24e6f45f 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -196,6 +196,7 @@ API Reference models/pyhealth.models.Agent models/pyhealth.models.GRASP models/pyhealth.models.MedLink + models/pyhealth.models.MedTsLLM models/pyhealth.models.TCN models/pyhealth.models.TFMTokenizer models/pyhealth.models.GAN diff --git a/docs/api/models/pyhealth.models.MedTsLLM.rst b/docs/api/models/pyhealth.models.MedTsLLM.rst new file mode 100644 index 000000000..ea6a1eea3 --- /dev/null +++ b/docs/api/models/pyhealth.models.MedTsLLM.rst @@ -0,0 +1,11 @@ +pyhealth.models.MedTsLLM +========================= + +MedTsLLM: Leveraging LLMs for Multimodal Medical Time Series Analysis +(Chan et al., MLHC 2024). Refer to the +`paper `_ for more information. + +.. autoclass:: pyhealth.models.MedTsLLM + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..d20cefbbf 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -1,5 +1,4 @@ Tasks -=============== We support various real-world healthcare predictive tasks defined by **function calls**. The following example tasks are collected from top AI/Medical venues, such as: @@ -212,6 +211,10 @@ Available Tasks COVID-19 CXR Classification DKA Prediction (MIMIC-IV) Drug Recommendation + ECG Anomaly Detection (MIT-BIH) + ECG Boundary Detection (MIT-BIH) + ECG Wave Segmentation (LUDB) + Respiratory Boundary Detection (BIDMC) Length of Stay Prediction Medical Transcriptions Classification Mortality Prediction (Next Visit) diff --git a/docs/api/tasks/pyhealth.tasks.ECGAnomalyDetection.rst b/docs/api/tasks/pyhealth.tasks.ECGAnomalyDetection.rst new file mode 100644 index 000000000..d4ee1d97e --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ECGAnomalyDetection.rst @@ -0,0 +1,11 @@ +pyhealth.tasks.ECGAnomalyDetection +==================================== + +Reconstruction-based anomaly detection task for the MIT-BIH dataset. +Trains the model to reconstruct normal 2-channel ECG; at eval time, +elevated reconstruction error flags abnormal-rhythm beats. + +.. autoclass:: pyhealth.tasks.ECGAnomalyDetection + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks/pyhealth.tasks.ECGBoundaryDetection.rst b/docs/api/tasks/pyhealth.tasks.ECGBoundaryDetection.rst new file mode 100644 index 000000000..67bafcc8e --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ECGBoundaryDetection.rst @@ -0,0 +1,10 @@ +pyhealth.tasks.ECGBoundaryDetection +===================================== + +R-peak boundary detection task for the MIT-BIH dataset. Detects beat +boundaries in ECG signals. + +.. autoclass:: pyhealth.tasks.ECGBoundaryDetection + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks/pyhealth.tasks.ECGWaveSegmentation.rst b/docs/api/tasks/pyhealth.tasks.ECGWaveSegmentation.rst new file mode 100644 index 000000000..5b58a7e10 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.ECGWaveSegmentation.rst @@ -0,0 +1,10 @@ +pyhealth.tasks.ECGWaveSegmentation +==================================== + +Per-timestep ECG wave segmentation task for the LUDB dataset. Classifies +each time point as background, P wave, QRS complex, or T wave. + +.. autoclass:: pyhealth.tasks.ECGWaveSegmentation + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks/pyhealth.tasks.RespiratoryBoundaryDetection.rst b/docs/api/tasks/pyhealth.tasks.RespiratoryBoundaryDetection.rst new file mode 100644 index 000000000..49366d54d --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.RespiratoryBoundaryDetection.rst @@ -0,0 +1,10 @@ +pyhealth.tasks.RespiratoryBoundaryDetection +============================================= + +Breath boundary detection task for the BIDMC dataset. Detects breath +boundaries in respiratory impedance signals. + +.. autoclass:: pyhealth.tasks.RespiratoryBoundaryDetection + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/bidmc_respiratory_boundary_medtsllm.py b/examples/bidmc_respiratory_boundary_medtsllm.py new file mode 100644 index 000000000..330c1668d --- /dev/null +++ b/examples/bidmc_respiratory_boundary_medtsllm.py @@ -0,0 +1,177 @@ +"""BIDMC Respiratory Boundary Detection with MedTsLLM. + +Demonstrates the MedTsLLM model (Chan et al., MLHC 2024) on the BIDMC +dataset for per-timestep breath boundary detection using 3-channel +respiratory signals (RESP, PLETH, II). Binary segmentation task +trained with BCE-with-logits loss. + +Paper: https://arxiv.org/abs/2408.07773 +Dataset: https://physionet.org/content/bidmc/1.0.0/ + +Usage: + # Synthetic data, no downloads: + python examples/bidmc_respiratory_boundary_medtsllm.py --synthetic + + # Real BIDMC data with GPT-2: + python examples/bidmc_respiratory_boundary_medtsllm.py \\ + --root /path/to/bidmc --backbone openai-community/gpt2 + +Ablation Study: + The script exposes the paper's two main ablation axes as CLI + flags so each run is a single ablation cell. + + 1. LLM backbone swap -- ``--backbone ``: + compare GPT-2 vs. GPT-2-medium vs. DistilGPT-2 etc. + + 2. Prompt components -- each piece of the text prompt can be + disabled independently: + --no-prompt-dataset drops the dataset description + --no-prompt-task drops the task description + --no-prompt-patient drops the per-patient description + --prompt-stats adds the rolling signal stats + + python examples/bidmc_respiratory_boundary_medtsllm.py \\ + --root /path/to/bidmc --no-prompt-patient + python examples/bidmc_respiratory_boundary_medtsllm.py \\ + --root /path/to/bidmc --no-prompt-dataset --prompt-stats +""" + +import argparse + +import numpy as np +import torch + +from pyhealth.datasets import ( + create_sample_dataset, + get_dataloader, + split_by_patient, +) +from pyhealth.models import MedTsLLM +from pyhealth.trainer import Trainer + + +def make_synthetic_dataset( + n_patients: int = 10, seq_len: int = 256, n_channels: int = 3 +): + """Synthetic 3-channel respiratory data for smoke-testing.""" + samples = [] + for i in range(n_patients): + for w in range(5): + signal = np.random.randn(seq_len, n_channels).astype(np.float32) + label = np.random.randint(0, 2, size=seq_len).astype(np.float32) + samples.append({ + "patient_id": f"p{i}", + "visit_id": f"v{w}", + "signal": signal, + "label": label, + "description": "", + }) + return create_sample_dataset( + samples=samples, + input_schema={"signal": "tensor"}, + output_schema={"label": "tensor"}, + dataset_name="synthetic_bidmc", + ) + + +def make_real_dataset(root: str, seq_len: int = 256, step_size: int = 128): + """Load BIDMC and apply RespiratoryBoundaryDetection.""" + from pyhealth.datasets import BIDMCDataset + from pyhealth.tasks import RespiratoryBoundaryDetection + + dataset = BIDMCDataset(root=root) + task = RespiratoryBoundaryDetection( + window_size=seq_len, step_size=step_size + ) + return dataset.set_task(task) + + +def main(): + parser = argparse.ArgumentParser( + description="BIDMC respiratory boundary detection with MedTsLLM" + ) + parser.add_argument("--root", type=str, default=None) + parser.add_argument("--synthetic", action="store_true") + parser.add_argument("--backbone", type=str, default=None) + parser.add_argument("--epochs", type=int, default=10) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--device", type=str, default="auto") + # Ablation knobs — each flag disables one prompt component + parser.add_argument("--no-prompt-dataset", action="store_true") + parser.add_argument("--no-prompt-task", action="store_true") + parser.add_argument("--no-prompt-patient", action="store_true") + parser.add_argument("--prompt-stats", action="store_true") + args = parser.parse_args() + + device = ( + ("cuda" if torch.cuda.is_available() else "cpu") + if args.device == "auto" + else args.device + ) + + seq_len = 256 + n_features = 3 + + if args.synthetic or args.root is None: + print("Using synthetic data") + sample_dataset = make_synthetic_dataset( + n_patients=10, seq_len=seq_len, n_channels=n_features + ) + word_embeddings = torch.randn(100, 64) + backbone = None + epochs = 3 + else: + print(f"Loading BIDMC from {args.root}") + sample_dataset = make_real_dataset(args.root, seq_len=seq_len) + word_embeddings = None + backbone = args.backbone or "openai-community/gpt2" + epochs = args.epochs + + train_ds, _, test_ds = split_by_patient( + sample_dataset, ratios=[0.8, 0.0, 0.2] + ) + train_loader = get_dataloader( + train_ds, batch_size=args.batch_size, shuffle=True + ) + test_loader = get_dataloader( + test_ds, batch_size=args.batch_size, shuffle=False + ) + + model = MedTsLLM( + dataset=sample_dataset, + task="segmentation", + seq_len=seq_len, + n_features=n_features, + covariate_mode="concat", + d_model=32, + d_ff=64, + n_heads=8, + num_tokens=1024, + patch_len=16, + stride=8, + dataset_description=( + "The BIDMC dataset contains electrocardiogram (ECG), " + "pulse oximetry (PPG), and impedance pneumography " + "respiratory signals from intensive care patients." + ), + backbone=backbone, + word_embeddings=word_embeddings, + prompt_dataset=not args.no_prompt_dataset, + prompt_task=not args.no_prompt_task, + prompt_patient=not args.no_prompt_patient, + prompt_stats=args.prompt_stats, + ) + + trainer = Trainer(model=model, device=device, enable_logging=False) + trainer.train( + train_dataloader=train_loader, + test_dataloader=test_loader, + epochs=epochs, + optimizer_class=torch.optim.Adam, + optimizer_params={"lr": args.lr}, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/ludb_ecg_segmentation_medtsllm.py b/examples/ludb_ecg_segmentation_medtsllm.py new file mode 100644 index 000000000..efae066ce --- /dev/null +++ b/examples/ludb_ecg_segmentation_medtsllm.py @@ -0,0 +1,202 @@ +"""LUDB ECG Wave Segmentation with MedTsLLM. + +Demonstrates the MedTsLLM model (Chan et al., MLHC 2024) on the LUDB +dataset for per-timestep ECG wave segmentation (P wave, QRS complex, +T wave, background). + +Paper: https://arxiv.org/abs/2408.07773 +Dataset: https://physionet.org/content/ludb/1.0.1/ + +Usage: + # With synthetic data (no downloads, fast): + python examples/ludb_ecg_segmentation_medtsllm.py --synthetic + + # With real data + GPT-2: + python examples/ludb_ecg_segmentation_medtsllm.py \\ + --root /path/to/ludb --backbone openai-community/gpt2 + +Ablation Study: + The script exposes the paper's two main ablation axes as CLI + flags so each run is a single ablation cell. + + 1. LLM backbone swap -- ``--backbone ``: + compare GPT-2 vs. GPT-2-medium vs. DistilGPT-2 etc. + + python examples/ludb_ecg_segmentation_medtsllm.py \\ + --root /path/to/ludb --backbone openai-community/gpt2 + python examples/ludb_ecg_segmentation_medtsllm.py \\ + --root /path/to/ludb --backbone distilbert/distilgpt2 + + 2. Prompt components -- each piece of the text prompt can be + disabled independently: + --no-prompt-dataset drops the dataset description + --no-prompt-task drops the task description + --no-prompt-patient drops the per-patient description + --prompt-stats adds the rolling signal stats + + python examples/ludb_ecg_segmentation_medtsllm.py \\ + --root /path/to/ludb --no-prompt-patient + python examples/ludb_ecg_segmentation_medtsllm.py \\ + --root /path/to/ludb --no-prompt-dataset --prompt-stats +""" + +import argparse + +import numpy as np +import torch + +from pyhealth.datasets import ( + create_sample_dataset, + get_dataloader, + split_by_patient, +) +from pyhealth.models import MedTsLLM +from pyhealth.trainer import Trainer + + +def make_synthetic_dataset(n_patients=10, seq_len=512, n_classes=4): + """Create a synthetic dataset for testing without data downloads.""" + samples = [] + for i in range(n_patients): + for w in range(5): + signal = np.random.randn(seq_len).astype(np.float32) + label = np.random.randint(0, n_classes, size=seq_len).astype( + np.int64 + ) + samples.append({ + "patient_id": f"p{i}", + "visit_id": f"v{w}", + "signal": signal, + "label": label, + }) + + return create_sample_dataset( + samples=samples, + input_schema={"signal": "tensor"}, + output_schema={"label": "tensor"}, + dataset_name="synthetic_ludb", + ) + + +def make_real_dataset(root, seq_len=512, step_size=256): + """Load LUDB dataset and apply ECG segmentation task. + + ``preprocess=True`` caches decoded signals to ``{root}/processed/`` + so wfdb decoding runs exactly once; ``trim=True`` drops the head + and tail of each record to match the paper's preprocessing recipe. + """ + from pyhealth.datasets import LUDBDataset + from pyhealth.tasks import ECGWaveSegmentation + + dataset = LUDBDataset(root=root, preprocess=True, trim=True) + task = ECGWaveSegmentation(window_size=seq_len, step_size=step_size) + return dataset.set_task(task) + + +def main(): + parser = argparse.ArgumentParser( + description="LUDB ECG segmentation with MedTsLLM" + ) + parser.add_argument( + "--root", type=str, default=None, help="Path to LUDB data" + ) + parser.add_argument( + "--synthetic", + action="store_true", + help="Use synthetic data (no downloads)", + ) + parser.add_argument( + "--backbone", + type=str, + default=None, + help="HuggingFace model ID (default: openai-community/gpt2)", + ) + parser.add_argument("--epochs", type=int, default=10) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--device", type=str, default="auto") + # Ablation knobs — each flag disables one prompt component + parser.add_argument("--no-prompt-dataset", action="store_true") + parser.add_argument("--no-prompt-task", action="store_true") + parser.add_argument("--no-prompt-patient", action="store_true") + parser.add_argument("--prompt-stats", action="store_true") + args = parser.parse_args() + + if args.device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + else: + device = args.device + + seq_len = 512 + n_classes = 4 + + # Load data + if args.synthetic or args.root is None: + print("Using synthetic data (no downloads)") + sample_dataset = make_synthetic_dataset( + n_patients=10, seq_len=seq_len, n_classes=n_classes + ) + word_embeddings = torch.randn(100, 64) + backbone = None + epochs = 3 + else: + print(f"Loading LUDB from {args.root}") + sample_dataset = make_real_dataset( + args.root, seq_len=seq_len, step_size=256 + ) + word_embeddings = None + backbone = args.backbone or "openai-community/gpt2" + epochs = args.epochs + + # Split by patient + train_ds, _, test_ds = split_by_patient( + sample_dataset, ratios=[0.8, 0.0, 0.2] + ) + train_loader = get_dataloader( + train_ds, batch_size=args.batch_size, shuffle=True + ) + test_loader = get_dataloader( + test_ds, batch_size=args.batch_size, shuffle=False + ) + + model = MedTsLLM( + dataset=sample_dataset, + seq_len=seq_len, + n_features=1, + n_classes=n_classes, + d_model=32, + d_ff=128, + n_heads=8, + num_tokens=1024, + patch_len=16, + stride=8, + dataset_description=( + "LUDB is an ECG signal database collected from subjects " + "with various cardiovascular diseases used for ECG " + "delineation." + ), + backbone=backbone, + word_embeddings=word_embeddings, + prompt_dataset=not args.no_prompt_dataset, + prompt_task=not args.no_prompt_task, + prompt_patient=not args.no_prompt_patient, + prompt_stats=args.prompt_stats, + ) + + # Train + trainer = Trainer( + model=model, + device=device, + enable_logging=False, + ) + trainer.train( + train_dataloader=train_loader, + test_dataloader=test_loader, + epochs=epochs, + optimizer_class=torch.optim.Adam, + optimizer_params={"lr": args.lr}, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/mitbih_ecg_anomaly_medtsllm.py b/examples/mitbih_ecg_anomaly_medtsllm.py new file mode 100644 index 000000000..2704ef27d --- /dev/null +++ b/examples/mitbih_ecg_anomaly_medtsllm.py @@ -0,0 +1,205 @@ +"""MIT-BIH ECG Anomaly Detection with MedTsLLM. + +Demonstrates reconstruction-based anomaly detection (Chan et al., +MLHC 2024) on the MIT-BIH Arrhythmia Database. MedTsLLM is trained +with MSE loss to reconstruct 2-channel ECG; at eval time, elevated +reconstruction error flags abnormal-rhythm beats. + +Paper: https://arxiv.org/abs/2408.07773 +Dataset: https://physionet.org/content/mitdb/1.0.0/ + +Usage: + python examples/mitbih_ecg_anomaly_medtsllm.py --synthetic + python examples/mitbih_ecg_anomaly_medtsllm.py \\ + --root /path/to/mitdb --backbone openai-community/gpt2 + +Ablation Study: + The script exposes the paper's two main ablation axes as CLI + flags so each run is a single ablation cell. + + 1. LLM backbone swap -- ``--backbone ``: + compare GPT-2 vs. GPT-2-medium vs. DistilGPT-2 etc. + + 2. Prompt components -- each piece of the text prompt can be + disabled independently: + --no-prompt-dataset drops the dataset description + --no-prompt-task drops the task description + --no-prompt-patient drops the per-patient description + --prompt-stats adds the rolling signal stats + + python examples/mitbih_ecg_anomaly_medtsllm.py \\ + --root /path/to/mitdb --no-prompt-patient + python examples/mitbih_ecg_anomaly_medtsllm.py \\ + --root /path/to/mitdb --no-prompt-dataset --prompt-stats +""" + +import argparse + +import numpy as np +import torch + +from pyhealth.datasets import ( + create_sample_dataset, + get_dataloader, + split_by_patient, +) +from pyhealth.models import MedTsLLM +from pyhealth.trainer import Trainer + + +def make_synthetic_dataset( + n_patients: int = 10, seq_len: int = 128, n_channels: int = 2 +): + """Synthetic 2-channel ECG with sparse anomaly intervals.""" + samples = [] + for i in range(n_patients): + for w in range(5): + signal = np.random.randn(seq_len, n_channels).astype(np.float32) + label = np.zeros(seq_len, dtype=np.float32) + # 1-2 anomaly intervals per window + for _ in range(np.random.randint(0, 3)): + s = np.random.randint(0, seq_len - 10) + label[s : s + np.random.randint(3, 10)] = 1 + samples.append({ + "patient_id": f"p{i}", + "visit_id": f"v{w}", + "signal": signal, + "label": label, + "description": "", + }) + return create_sample_dataset( + samples=samples, + input_schema={"signal": "tensor"}, + output_schema={"label": "tensor"}, + dataset_name="synthetic_mitbih_anomaly", + ) + + +def make_real_dataset(root: str, seq_len: int = 128, step_size: int = 128): + """Load MIT-BIH and apply ECGAnomalyDetection. + + Uses ``preprocess=True`` so the 30-minute wfdb records are decoded + and downsampled once into ``{root}/processed/*.npz``. Subsequent + runs skip wfdb entirely. ``paper_split="abnormal_sorted"`` matches + the paper's least-/most-abnormal 80/20 train/test split. + """ + from pyhealth.datasets import MITBIHDataset + from pyhealth.tasks import ECGAnomalyDetection + + dataset = MITBIHDataset( + root=root, + preprocess=True, + downsample_factor=3, + trim=True, + paper_split="abnormal_sorted", + ) + task = ECGAnomalyDetection(window_size=seq_len, step_size=step_size) + return dataset.set_task(task) + + +def main(): + parser = argparse.ArgumentParser( + description="MIT-BIH arrhythmia anomaly detection with MedTsLLM" + ) + parser.add_argument("--root", type=str, default=None) + parser.add_argument("--synthetic", action="store_true") + parser.add_argument("--backbone", type=str, default=None) + parser.add_argument("--epochs", type=int, default=10) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--device", type=str, default="auto") + # Ablation knobs — each flag disables one prompt component + parser.add_argument("--no-prompt-dataset", action="store_true") + parser.add_argument("--no-prompt-task", action="store_true") + parser.add_argument("--no-prompt-patient", action="store_true") + parser.add_argument("--prompt-stats", action="store_true") + args = parser.parse_args() + + device = ( + ("cuda" if torch.cuda.is_available() else "cpu") + if args.device == "auto" + else args.device + ) + + seq_len = 128 + n_features = 2 + + if args.synthetic or args.root is None: + print("Using synthetic data") + sample_dataset = make_synthetic_dataset( + n_patients=10, seq_len=seq_len, n_channels=n_features + ) + word_embeddings = torch.randn(100, 64) + backbone = None + epochs = 3 + else: + print(f"Loading MIT-BIH from {args.root}") + sample_dataset = make_real_dataset(args.root, seq_len=seq_len) + word_embeddings = None + backbone = args.backbone or "openai-community/gpt2" + epochs = args.epochs + + # Synthetic path has no paper_split, so fall back to a random + # patient split. On real data, ``make_real_dataset`` sets + # ``paper_split="abnormal_sorted"`` which tags each sample with a + # ``"split"`` field; partition on that below to consume the + # baked-in train/test assignment without reshuffling. + if args.synthetic or args.root is None: + train_ds, _, test_ds = split_by_patient( + sample_dataset, ratios=[0.8, 0.0, 0.2] + ) + else: + train_idx, test_idx = [], [] + for i in range(len(sample_dataset)): + bucket = sample_dataset[i]["split"] + if bucket == "train": + train_idx.append(i) + elif bucket == "test": + test_idx.append(i) + train_ds = sample_dataset.subset(np.array(train_idx)) + test_ds = sample_dataset.subset(np.array(test_idx)) + train_loader = get_dataloader( + train_ds, batch_size=args.batch_size, shuffle=True + ) + test_loader = get_dataloader( + test_ds, batch_size=args.batch_size, shuffle=False + ) + + model = MedTsLLM( + dataset=sample_dataset, + task="anomaly_detection", + seq_len=seq_len, + n_features=n_features, + covariate_mode="concat", + d_model=32, + d_ff=64, + n_heads=8, + num_tokens=1024, + patch_len=16, + stride=8, + dataset_description=( + "The MIT-BIH Arrhythmia Database contains excerpts of " + "two-channel ambulatory ECG from a mixed population of " + "inpatients and outpatients, digitized at 360 samples " + "per second per channel." + ), + backbone=backbone, + word_embeddings=word_embeddings, + prompt_dataset=not args.no_prompt_dataset, + prompt_task=not args.no_prompt_task, + prompt_patient=not args.no_prompt_patient, + prompt_stats=args.prompt_stats, + ) + + trainer = Trainer(model=model, device=device, enable_logging=False) + trainer.train( + train_dataloader=train_loader, + test_dataloader=test_loader, + epochs=epochs, + optimizer_class=torch.optim.Adam, + optimizer_params={"lr": args.lr}, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/mitbih_ecg_boundary_medtsllm.py b/examples/mitbih_ecg_boundary_medtsllm.py new file mode 100644 index 000000000..0d02aee79 --- /dev/null +++ b/examples/mitbih_ecg_boundary_medtsllm.py @@ -0,0 +1,186 @@ +"""MIT-BIH ECG Boundary Detection with MedTsLLM. + +Demonstrates the MedTsLLM model (Chan et al., MLHC 2024) on the +MIT-BIH Arrhythmia Database for R-peak boundary detection using +2-channel ECG downsampled from 360 Hz to 120 Hz. Binary +segmentation task trained with BCE-with-logits loss. + +Paper: https://arxiv.org/abs/2408.07773 +Dataset: https://physionet.org/content/mitdb/1.0.0/ + +Usage: + python examples/mitbih_ecg_boundary_medtsllm.py --synthetic + python examples/mitbih_ecg_boundary_medtsllm.py \\ + --root /path/to/mitdb --backbone openai-community/gpt2 + +Ablation Study: + The script exposes the paper's two main ablation axes as CLI + flags so each run is a single ablation cell. + + 1. LLM backbone swap -- ``--backbone ``: + compare GPT-2 vs. GPT-2-medium vs. DistilGPT-2 etc. + + 2. Prompt components -- each piece of the text prompt can be + disabled independently: + --no-prompt-dataset drops the dataset description + --no-prompt-task drops the task description + --no-prompt-patient drops the per-patient description + --prompt-stats adds the rolling signal stats + + python examples/mitbih_ecg_boundary_medtsllm.py \\ + --root /path/to/mitdb --no-prompt-patient + python examples/mitbih_ecg_boundary_medtsllm.py \\ + --root /path/to/mitdb --no-prompt-dataset --prompt-stats +""" + +import argparse + +import numpy as np +import torch + +from pyhealth.datasets import ( + create_sample_dataset, + get_dataloader, + split_by_patient, +) +from pyhealth.models import MedTsLLM +from pyhealth.trainer import Trainer + + +def make_synthetic_dataset( + n_patients: int = 10, seq_len: int = 256, n_channels: int = 2 +): + """Synthetic 2-channel ECG data.""" + samples = [] + for i in range(n_patients): + for w in range(5): + signal = np.random.randn(seq_len, n_channels).astype(np.float32) + label = np.zeros(seq_len, dtype=np.float32) + # Sparse boundary labels (~3% density) to mimic real R-peaks + idx = np.random.choice(seq_len, size=seq_len // 40, replace=False) + label[idx] = 1 + samples.append({ + "patient_id": f"p{i}", + "visit_id": f"v{w}", + "signal": signal, + "label": label, + "description": "", + }) + return create_sample_dataset( + samples=samples, + input_schema={"signal": "tensor"}, + output_schema={"label": "tensor"}, + dataset_name="synthetic_mitbih", + ) + + +def make_real_dataset(root: str, seq_len: int = 256, step_size: int = 256): + """Load MIT-BIH and apply ECGBoundaryDetection. + + Uses ``preprocess=True`` so the 30-minute wfdb records are decoded + and downsampled once into ``{root}/processed/*.npz``. Subsequent + runs skip wfdb entirely. + """ + from pyhealth.datasets import MITBIHDataset + from pyhealth.tasks import ECGBoundaryDetection + + dataset = MITBIHDataset( + root=root, + preprocess=True, + downsample_factor=3, + trim=True, + ) + task = ECGBoundaryDetection(window_size=seq_len, step_size=step_size) + return dataset.set_task(task) + + +def main(): + parser = argparse.ArgumentParser( + description="MIT-BIH R-peak boundary detection with MedTsLLM" + ) + parser.add_argument("--root", type=str, default=None) + parser.add_argument("--synthetic", action="store_true") + parser.add_argument("--backbone", type=str, default=None) + parser.add_argument("--epochs", type=int, default=10) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--device", type=str, default="auto") + # Ablation knobs — each flag disables one prompt component + parser.add_argument("--no-prompt-dataset", action="store_true") + parser.add_argument("--no-prompt-task", action="store_true") + parser.add_argument("--no-prompt-patient", action="store_true") + parser.add_argument("--prompt-stats", action="store_true") + args = parser.parse_args() + + device = ( + ("cuda" if torch.cuda.is_available() else "cpu") + if args.device == "auto" + else args.device + ) + + seq_len = 256 + n_features = 2 + + if args.synthetic or args.root is None: + print("Using synthetic data") + sample_dataset = make_synthetic_dataset( + n_patients=10, seq_len=seq_len, n_channels=n_features + ) + word_embeddings = torch.randn(100, 64) + backbone = None + epochs = 3 + else: + print(f"Loading MIT-BIH from {args.root}") + sample_dataset = make_real_dataset(args.root, seq_len=seq_len) + word_embeddings = None + backbone = args.backbone or "openai-community/gpt2" + epochs = args.epochs + + train_ds, _, test_ds = split_by_patient( + sample_dataset, ratios=[0.8, 0.0, 0.2] + ) + train_loader = get_dataloader( + train_ds, batch_size=args.batch_size, shuffle=True + ) + test_loader = get_dataloader( + test_ds, batch_size=args.batch_size, shuffle=False + ) + + model = MedTsLLM( + dataset=sample_dataset, + task="segmentation", + seq_len=seq_len, + n_features=n_features, + covariate_mode="concat", + d_model=32, + d_ff=64, + n_heads=8, + num_tokens=1024, + patch_len=16, + stride=8, + dataset_description=( + "The MIT-BIH Arrhythmia Database contains excerpts of " + "two-channel ambulatory ECG from a mixed population of " + "inpatients and outpatients, digitized at 360 samples " + "per second per channel." + ), + backbone=backbone, + word_embeddings=word_embeddings, + prompt_dataset=not args.no_prompt_dataset, + prompt_task=not args.no_prompt_task, + prompt_patient=not args.no_prompt_patient, + prompt_stats=args.prompt_stats, + ) + + trainer = Trainer(model=model, device=device, enable_logging=False) + trainer.train( + train_dataloader=train_loader, + test_dataloader=test_loader, + epochs=epochs, + optimizer_class=torch.optim.Adam, + optimizer_params={"lr": args.lr}, + ) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 50b1b3887..d9c696ef2 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -55,7 +55,10 @@ def __init__(self, *args, **kwargs): from .dreamt import DREAMTDataset from .ehrshot import EHRShotDataset from .eicu import eICUDataset +from .bidmc import BIDMCDataset from .isruc import ISRUCDataset +from .ludb import LUDBDataset +from .mitbih import MITBIHDataset from .medical_transcriptions import MedicalTranscriptionsDataset from .mimic3 import MIMIC3Dataset from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset diff --git a/pyhealth/datasets/_medtsllm_cache.py b/pyhealth/datasets/_medtsllm_cache.py new file mode 100644 index 000000000..289a8f2aa --- /dev/null +++ b/pyhealth/datasets/_medtsllm_cache.py @@ -0,0 +1,105 @@ +"""Fingerprint-based ``.npz`` preprocessing cache for MedTsLLM datasets. + +Internal helper for the LUDB / MIT-BIH / BIDMC loaders. Provides +``load_or_build`` so preprocessing runs the expensive wfdb-decode path +exactly once per raw file. Subsequent calls load from an ``.npz`` +next to the raw file (or wherever the dataset chooses to cache). +Invalidation is keyed on raw-file stats + preprocessing params via +:func:`compute_fingerprint`. + +Author: Anton Barchukov +""" + +import hashlib +import json +import os +from typing import Callable + +import numpy as np + +_FINGERPRINT_KEY = "_cache_fingerprint" + + +def compute_fingerprint(raw_paths: list[str], params: dict) -> str: + """Return a stable SHA-256 fingerprint for cache invalidation. + + Combines raw-file ``(path, mtime_ns, size)`` tuples with the + preprocessing params (JSON-serializable) into a single hash. A + change in any input — including editing a raw file or flipping a + param — flips the fingerprint. + + Args: + raw_paths: Absolute paths to the raw files whose decoded form + is being cached. Order-insensitive. + params: Preprocessing params that affect the cached arrays + (e.g. ``{"trim": True, "downsample_factor": 3}``). + + Returns: + 64-char hex digest string. + """ + h = hashlib.sha256() + for path in sorted(raw_paths): + st = os.stat(path) + h.update(path.encode("utf-8")) + h.update(b"|") + h.update(str(st.st_mtime_ns).encode("ascii")) + h.update(b"|") + h.update(str(st.st_size).encode("ascii")) + h.update(b"\n") + h.update(json.dumps(params, sort_keys=True, default=str).encode("utf-8")) + return h.hexdigest() + + +def load_or_build( + cache_path: str, + fingerprint: str, + builder: Callable[[], dict[str, np.ndarray]], +) -> dict[str, np.ndarray]: + """Load cached arrays if the fingerprint matches, else build + write. + + Args: + cache_path: Target ``.npz`` path. Parent dirs are created. + fingerprint: Expected fingerprint string. Mismatch triggers a + rebuild. + builder: Zero-arg callable returning ``{name: ndarray}``. Only + invoked on cache miss. + + Returns: + Dict of arrays — either freshly built or restored from disk. + """ + cached = _try_load(cache_path, fingerprint) + if cached is not None: + return cached + + arrays = builder() + parent = os.path.dirname(cache_path) + if parent: + os.makedirs(parent, exist_ok=True) + + payload = dict(arrays) + payload[_FINGERPRINT_KEY] = np.array([fingerprint]) + np.savez(cache_path, allow_pickle=False, **payload) + return arrays + + +def _try_load( + cache_path: str, fingerprint: str +) -> dict[str, np.ndarray] | None: + """Return cached arrays iff the file exists, parses, and matches.""" + if not os.path.exists(cache_path): + return None + try: + npz = np.load(cache_path, allow_pickle=False) + except Exception: + return None + try: + stored = str(npz[_FINGERPRINT_KEY][0]) + except (KeyError, IndexError): + return None + if stored != fingerprint: + return None + return { + key: np.array(npz[key]) + for key in npz.files + if key != _FINGERPRINT_KEY + } diff --git a/pyhealth/datasets/bidmc.py b/pyhealth/datasets/bidmc.py new file mode 100644 index 000000000..6dcdc2bbf --- /dev/null +++ b/pyhealth/datasets/bidmc.py @@ -0,0 +1,258 @@ +# Author: Anton Barchukov +# Paper: Chan et al., "MedTsLLM: Leveraging LLMs for Multimodal +# Medical Time Series Analysis", MLHC 2024 +# Paper link: https://arxiv.org/abs/2408.07773 +# Description: BIDMC dataset — 53 ICU patients with 8-minute +# recordings of ECG, PPG, and respiratory signals at 125 Hz. +# Two annotators manually annotated individual breaths. +# Source: https://physionet.org/content/bidmc/1.0.0/ + +import logging +import os +from typing import Optional + +import numpy as np +import pandas as pd + +from pyhealth.datasets import BaseDataset +from pyhealth.datasets._medtsllm_cache import ( + compute_fingerprint, + load_or_build, +) + +logger = logging.getLogger(__name__) + +# Paper's BIDMC split: 85/15 by patient, np.random.RandomState(0). +_PAPER_SPLIT_RATIO = 0.85 +_PAPER_SPLIT_SEED = 0 + +# Subdirectory under ``root`` for preprocessed ``.npz`` caches. +_PROCESSED_SUBDIR = "processed" + +# RESP, PLETH, and ECG lead II: the 3 channels used in the paper. +_TARGET_CHANNELS = ["RESP,", "PLETH,", "II,"] + + +class BIDMCDataset(BaseDataset): + """BIDMC respiratory signal dataset for boundary detection. + + 53 ICU patients with 8-minute recordings of ECG, PPG, and + respiratory impedance signals at 125 Hz. Breath boundaries are + manually annotated by two annotators. + + Dataset is available at https://physionet.org/content/bidmc/1.0.0/ + + Paper: Pimentel, M.A.F. et al. "Towards a Robust Estimation of + Respiratory Rate from Pulse Oximeters." IEEE TBME, 2016. + + Args: + root: Root directory of the raw BIDMC data. Should contain + wfdb record files (bidmc01.dat, bidmc01.hea, etc.). + dataset_name: Name of the dataset. Default is ``"bidmc"``. + config_path: Path to the YAML config file. + dev: Whether to enable dev mode (first 5 patients). + paper_split: If True, populate a ``split`` column with an + 85/15 train/test assignment per patient, using the + legacy NumPy seed=0 RNG from the paper. Otherwise the + ``split`` column is left blank. Default False. + preprocess: If True, decode each record once, extract the + RESP/PLETH/II channels, and cache + ``(signal, ann_sample, ann_aux)`` to + ``{root}/processed/{record}.npz``. Subsequent runs skip + wfdb. Default False. + + Examples: + >>> from pyhealth.datasets import BIDMCDataset + >>> dataset = BIDMCDataset(root="/path/to/bidmc/") + >>> dataset.stat() + """ + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + dev: bool = False, + paper_split: bool = False, + preprocess: bool = False, + ) -> None: + if config_path is None: + config_path = os.path.join( + os.path.dirname(__file__), "configs", "bidmc.yaml" + ) + + metadata_path = os.path.join(root, "bidmc-pyhealth.csv") + if not os.path.exists(metadata_path): + self.prepare_metadata( + root, + dev=dev, + paper_split=paper_split, + preprocess=preprocess, + ) + + super().__init__( + root=root, + tables=["respiratory"], + dataset_name=dataset_name or "bidmc", + config_path=config_path, + dev=dev, + ) + + @staticmethod + def prepare_metadata( + root: str, + dev: bool = False, + paper_split: bool = False, + preprocess: bool = False, + ) -> None: + """Prepare metadata CSV from raw wfdb files. + + Args: + root: Root directory containing wfdb files. + dev: If True, only process first 5 patients. + paper_split: If True, assign records to train/test via the + paper's 85/15 seed=0 split (written to the ``split`` + column). Otherwise the column is blank. + preprocess: If True, build per-record ``.npz`` caches and + populate the ``processed_file`` column. + """ + import wfdb + + # Find record files (exclude numerics *n.hea) + hea_files = sorted( + f.replace(".hea", "") + for f in os.listdir(root) + if f.endswith(".hea") + and not f.endswith("n.hea") + and f.startswith("bidmc") + ) + if dev: + hea_files = hea_files[:5] + + split_by_record = ( + _assign_paper_split(hea_files) if paper_split else {} + ) + processed_dir = os.path.join(root, _PROCESSED_SUBDIR) + + rows = [] + for rec_name in hea_files: + rec_path = os.path.join(root, rec_name) + try: + record = wfdb.rdrecord(rec_path) + except Exception: + continue + + patient_id = rec_name.replace("bidmc", "") + + # Parse header for demographics + age, sex, location = "", "", "" + if record.comments: + comment = " ".join(record.comments) + for field, var in [("age", "age"), ("sex", "sex"), + ("location", "location")]: + tag = f"<{field}>:" + if tag in comment: + idx = comment.index(tag) + len(tag) + end = comment.find("<", idx) + val = (comment[idx:end].strip() + if end > 0 else comment[idx:].strip()) + if field == "age": + age = val + elif field == "sex": + sex = val + elif field == "location": + location = val + + processed_file = "" + if preprocess: + processed_file = _build_record_cache( + processed_dir=processed_dir, + rec_path=os.path.join(root, rec_name), + rec_name=rec_name, + ) + + rows.append({ + "patient_id": patient_id, + "signal_file": os.path.join(root, rec_name), + "annotation_file": "breath", + "age": age, + "sex": sex, + "location": location, + "split": split_by_record.get(rec_name, ""), + "processed_file": processed_file, + }) + + df = pd.DataFrame(rows) + out_path = os.path.join(root, "bidmc-pyhealth.csv") + df.to_csv(out_path, index=False) + logger.info( + "BIDMC metadata: %d patients -> %s", len(df), out_path + ) + + @property + def default_task(self): + """Returns the default task for this dataset.""" + from pyhealth.tasks.respiratory_boundary_detection import ( + RespiratoryBoundaryDetection, + ) + + return RespiratoryBoundaryDetection() + + +def _build_record_cache( + processed_dir: str, + rec_path: str, + rec_name: str, +) -> str: + """Cache (signal, ann_sample, ann_aux) for one BIDMC record. + + Signal is the 3-channel (RESP, PLETH, II) array at native 125 Hz + — no downsampling, no trim, matching the paper's recipe. + ``ann_sample`` / ``ann_aux`` preserve the raw annotation stream + so the task can pick annotator 1 or 2 at windowing time. + """ + cache_path = os.path.join(processed_dir, f"{rec_name}.npz") + raw_paths = [ + rec_path + ".dat", + rec_path + ".hea", + rec_path + ".breath", + ] + fingerprint = compute_fingerprint(raw_paths, {"channels": _TARGET_CHANNELS}) + + def _build() -> dict[str, np.ndarray]: + import wfdb + + record = wfdb.rdrecord(rec_path) + ann = wfdb.rdann(rec_path, extension="breath") + + col_idx = [ + record.sig_name.index(ch) + for ch in _TARGET_CHANNELS + if ch in record.sig_name + ] + signal = record.p_signal[:, col_idx].astype(np.float32) + + return { + "signal": signal, + "ann_sample": np.asarray(ann.sample, dtype=np.int64), + "ann_aux": np.asarray(ann.aux_note).astype("U8"), + } + + load_or_build(cache_path, fingerprint, _build) + return cache_path + + +def _assign_paper_split(records: list[str]) -> dict[str, str]: + """Assign each record to train/test per the paper's 85/15 seed=0 split. + + Mirrors the cs598 BIDMCSegmentationDataset recipe: shuffle record + names with ``np.random.RandomState(0)``, take the first 85% as + train and the remainder as test. + """ + rng = np.random.RandomState(_PAPER_SPLIT_SEED) + order = rng.permutation(len(records)) + cutoff = int(len(records) * _PAPER_SPLIT_RATIO) + assignment: dict[str, str] = {} + for rank, idx in enumerate(order): + assignment[records[idx]] = "train" if rank < cutoff else "test" + return assignment diff --git a/pyhealth/datasets/configs/bidmc.yaml b/pyhealth/datasets/configs/bidmc.yaml new file mode 100644 index 000000000..a0a7ede5c --- /dev/null +++ b/pyhealth/datasets/configs/bidmc.yaml @@ -0,0 +1,14 @@ +version: "1.0.0" +tables: + respiratory: + file_path: "bidmc-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "signal_file" + - "annotation_file" + - "age" + - "sex" + - "location" + - "split" + - "processed_file" diff --git a/pyhealth/datasets/configs/ludb.yaml b/pyhealth/datasets/configs/ludb.yaml new file mode 100644 index 000000000..4ba371f1a --- /dev/null +++ b/pyhealth/datasets/configs/ludb.yaml @@ -0,0 +1,16 @@ +version: "1.0.0" +tables: + ecg: + file_path: "ludb-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "lead" + - "clip_id" + - "signal_file" + - "label_file" + - "age" + - "sex" + - "diagnoses" + - "split" + - "processed_file" diff --git a/pyhealth/datasets/configs/mitbih.yaml b/pyhealth/datasets/configs/mitbih.yaml new file mode 100644 index 000000000..21dd118ba --- /dev/null +++ b/pyhealth/datasets/configs/mitbih.yaml @@ -0,0 +1,16 @@ +version: "1.0.0" +tables: + ecg: + file_path: "mitbih-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "signal_file" + - "annotation_file" + - "age" + - "sex" + - "medications" + - "n_abnormal" + - "n_beats" + - "split" + - "processed_file" diff --git a/pyhealth/datasets/ludb.py b/pyhealth/datasets/ludb.py new file mode 100644 index 000000000..05f7cff7f --- /dev/null +++ b/pyhealth/datasets/ludb.py @@ -0,0 +1,331 @@ +# Author: Anton Barchukov +# Paper: Chan et al., "MedTsLLM: Leveraging LLMs for Multimodal +# Medical Time Series Analysis", MLHC 2024 +# Paper link: https://arxiv.org/abs/2408.07773 +# Description: Lobachevsky University Database (LUDB) — 200 subjects +# with 12-lead ECG at 500 Hz, manually annotated with P wave, +# T wave, QRS complex, and background classes. +# Source: https://physionet.org/content/ludb/1.0.1/ + +import logging +import os +from typing import Optional + +import numpy as np +import pandas as pd + +from pyhealth.datasets import BaseDataset +from pyhealth.datasets._medtsllm_cache import ( + compute_fingerprint, + load_or_build, +) + +logger = logging.getLogger(__name__) + +# Label mapping: wfdb annotation symbol -> class index +# 0 = background, 1 = P wave, 2 = QRS complex, 3 = T wave +WAVE_CLASSES = ["background", "P wave", "QRS complex", "T wave"] +_WAVE_LABELS = {"p": 1, "N": 2, "t": 3} + +# Paper's LUDB split: 80/20 by patient, seeded with NumPy legacy RNG +# to match the cs598 reference implementation exactly. +_PAPER_SPLIT_RATIO = 0.8 +_PAPER_SPLIT_SEED = 0 + +# Subdirectory under ``root`` where preprocessed ``.npz`` files live. +_PROCESSED_SUBDIR = "processed" + + +def _parse_ludb_header(record) -> tuple[str, str, str]: + """Parse age, sex, and diagnoses from LUDB wfdb header comments. + + LUDB headers look like:: + + #: 51 + #: F + #: + #Rhythm: Sinus bradycardia. + #Left ventricular hypertrophy. + + Args: + record: A ``wfdb.Record`` with a ``comments`` attribute. + + Returns: + Tuple of (age, sex, diagnoses). Diagnoses are joined with + ``"; "`` separators. Missing fields return empty strings. + """ + age = "" + sex = "" + diagnoses_lines: list[str] = [] + if not record.comments: + return age, sex, "" + + in_diagnoses = False + for line in record.comments: + line = line.strip() + if line.startswith(":"): + age = line.split(":", 1)[1].strip() + elif line.startswith(":"): + sex = line.split(":", 1)[1].strip() + elif line.startswith(":"): + in_diagnoses = True + elif in_diagnoses and line: + diagnoses_lines.append(line.rstrip(". ")) + + return age, sex, "; ".join(diagnoses_lines) + + +class LUDBDataset(BaseDataset): + """Lobachevsky University Database (LUDB) for ECG delineation. + + Dataset of 200 subjects with 12-lead ECG recordings at 500 Hz + (10 seconds each). Each lead is manually annotated by cardiologists + with P wave, QRS complex, and T wave boundaries. + + Dataset is available at https://physionet.org/content/ludb/1.0.1/ + + Paper: Kalyakulina, A. et al. "LUDB: A New Open-Access Validation + Database for Electrocardiogram Delineation Algorithms." + + Args: + root: Root directory of the raw LUDB data. Should contain a + ``data/`` subdirectory with wfdb record files (.dat, .hea). + dataset_name: Name of the dataset. Default is ``"ludb"``. + config_path: Path to the YAML config file. Default uses the + built-in config. + dev: Whether to enable dev mode (only use first 5 patients). + Default is False. + paper_split: If True, populate a ``split`` column with an + 80/20 train/test assignment per patient, using the + legacy NumPy seed=0 RNG from the paper. Otherwise the + ``split`` column is left blank. Default False. + preprocess: If True, decode each ``(patient, lead)`` wfdb + record once and cache the resulting ``(signal, labels)`` + arrays to ``{root}/processed/{record}_{lead}.npz``. + Subsequent runs skip wfdb entirely. Default False. + trim: Only consulted when ``preprocess=True``. Crop each + lead to the region between the first and last wave + annotation before caching. Matches the paper's + preprocessing. Default True. + + Examples: + >>> from pyhealth.datasets import LUDBDataset + >>> dataset = LUDBDataset(root="/path/to/ludb/") + >>> dataset.stat() + """ + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + dev: bool = False, + paper_split: bool = False, + preprocess: bool = False, + trim: bool = True, + ) -> None: + if config_path is None: + config_path = os.path.join( + os.path.dirname(__file__), "configs", "ludb.yaml" + ) + + metadata_path = os.path.join(root, "ludb-pyhealth.csv") + if not os.path.exists(metadata_path): + self.prepare_metadata( + root, + dev=dev, + paper_split=paper_split, + preprocess=preprocess, + trim=trim, + ) + + super().__init__( + root=root, + tables=["ecg"], + dataset_name=dataset_name or "ludb", + config_path=config_path, + dev=dev, + ) + + @staticmethod + def prepare_metadata( + root: str, + dev: bool = False, + paper_split: bool = False, + preprocess: bool = False, + trim: bool = True, + ) -> None: + """Prepare metadata CSV from raw wfdb files. + + Scans the ``data/`` subdirectory for wfdb records and creates + a CSV with one row per (patient, lead) pair, pointing to the + signal and annotation files. When ``preprocess=True`` also + writes per-lead ``.npz`` caches into ``{root}/processed/``. + + Args: + root: Root directory containing ``data/`` with wfdb files. + dev: If True, only process first 5 patients. + paper_split: If True, assign patients to train/test via the + paper's 80/20 seed=0 split and write values to the + ``split`` column. Otherwise the column is blank. + preprocess: If True, build per-lead ``.npz`` signal+label + caches and point each row's ``processed_file`` at + its cache. + trim: When ``preprocess=True``, crop each lead to the + region between the first and last wave annotation. + """ + import wfdb + + data_dir = os.path.join(root, "data") + if not os.path.isdir(data_dir): + raise FileNotFoundError( + f"LUDB data directory not found at {data_dir}. " + "Download from https://physionet.org/content/ludb/1.0.1/" + ) + + records = sorted( + f.replace(".dat", "") + for f in os.listdir(data_dir) + if f.endswith(".dat") + ) + if dev: + records = records[:5] + + split_by_record = _assign_paper_split(records) if paper_split else {} + processed_dir = os.path.join(root, _PROCESSED_SUBDIR) + + rows = [] + for rec_name in records: + rec_path = os.path.join(data_dir, rec_name) + record = wfdb.rdrecord(rec_path) + patient_id = int(rec_name) + + age, sex, diagnoses = _parse_ludb_header(record) + + for lead_idx in range(record.n_sig): + lead_name = record.sig_name[lead_idx] + # Check if annotation file exists for this lead + ann_ext = lead_name.lower() + ann_path = os.path.join(data_dir, f"{rec_name}.{ann_ext}") + if not os.path.exists(ann_path): + continue + + clip_id = patient_id * 100 + lead_idx + processed_file = "" + if preprocess: + processed_file = _build_lead_cache( + processed_dir=processed_dir, + rec_path=rec_path, + rec_name=rec_name, + lead_name=lead_name, + lead_idx=lead_idx, + ann_path=ann_path, + ann_ext=ann_ext, + trim=trim, + ) + rows.append({ + "patient_id": str(patient_id), + "lead": lead_name, + "clip_id": clip_id, + "signal_file": os.path.join(data_dir, rec_name), + "label_file": ann_ext, + "age": age, + "sex": sex, + "diagnoses": diagnoses, + "split": split_by_record.get(rec_name, ""), + "processed_file": processed_file, + }) + + df = pd.DataFrame(rows) + out_path = os.path.join(root, "ludb-pyhealth.csv") + df.to_csv(out_path, index=False) + logger.info( + "LUDB metadata: %d records from %d patients -> %s", + len(df), + df["patient_id"].nunique(), + out_path, + ) + + @property + def default_task(self): + """Returns the default task for this dataset.""" + from pyhealth.tasks.ecg_wave_segmentation import ECGWaveSegmentation + + return ECGWaveSegmentation() + + +def _build_lead_cache( + processed_dir: str, + rec_path: str, + rec_name: str, + lead_name: str, + lead_idx: int, + ann_path: str, + ann_ext: str, + trim: bool, +) -> str: + """Cache (signal, labels) for one (record, lead) pair. + + Returns the absolute cache path, which is written into the + metadata CSV's ``processed_file`` column. + """ + cache_path = os.path.join( + processed_dir, f"{rec_name}_{lead_name}.npz" + ) + raw_paths = [rec_path + ".dat", rec_path + ".hea", ann_path] + params = {"trim": bool(trim)} + fingerprint = compute_fingerprint(raw_paths, params) + + def _build() -> dict[str, np.ndarray]: + import wfdb + + record = wfdb.rdrecord(rec_path) + signal = record.p_signal[:, lead_idx].astype(np.float32) + ann = wfdb.rdann(rec_path, extension=ann_ext) + + labels = np.zeros(len(signal), dtype=np.int64) + i = 0 + while i < len(ann.symbol): + sym = ann.symbol[i] + if sym == "(" and i + 2 < len(ann.symbol): + wave_type = ann.symbol[i + 1] + onset = ann.sample[i] + offset = ( + ann.sample[i + 2] + if ann.symbol[i + 2] == ")" + else ann.sample[i + 1] + ) + if wave_type in _WAVE_LABELS: + labels[onset : offset + 1] = _WAVE_LABELS[wave_type] + i += 3 + else: + i += 1 + + if trim: + wave_mask = labels > 0 + if wave_mask.any(): + first = int(np.argmax(wave_mask)) + last = len(wave_mask) - 1 - int(np.argmax(wave_mask[::-1])) + signal = signal[first : last + 1] + labels = labels[first : last + 1] + + return {"signal": signal, "labels": labels} + + load_or_build(cache_path, fingerprint, _build) + return cache_path + + +def _assign_paper_split(records: list[str]) -> dict[str, str]: + """Assign each record to train/test per the paper's 80/20 seed=0 split. + + Mirrors the cs598 preprocess_ludb.py recipe: shuffle record names + with ``np.random.RandomState(0)``, take the first 80% as train and + the remainder as test. + """ + rng = np.random.RandomState(_PAPER_SPLIT_SEED) + order = rng.permutation(len(records)) + cutoff = int(len(records) * _PAPER_SPLIT_RATIO) + assignment: dict[str, str] = {} + for rank, idx in enumerate(order): + assignment[records[idx]] = "train" if rank < cutoff else "test" + return assignment diff --git a/pyhealth/datasets/mitbih.py b/pyhealth/datasets/mitbih.py new file mode 100644 index 000000000..d9fc2947d --- /dev/null +++ b/pyhealth/datasets/mitbih.py @@ -0,0 +1,340 @@ +# Author: Anton Barchukov +# Paper: Chan et al., "MedTsLLM: Leveraging LLMs for Multimodal +# Medical Time Series Analysis", MLHC 2024 +# Paper link: https://arxiv.org/abs/2408.07773 +# Description: MIT-BIH Arrhythmia Database — 48 half-hour excerpts +# of two-channel ambulatory ECG, 360 Hz. Used for boundary +# detection (R-peaks) and anomaly detection (arrhythmia). +# Source: https://physionet.org/content/mitdb/1.0.0/ + +import logging +import os +from typing import Optional + +import numpy as np +import pandas as pd + +from pyhealth.datasets import BaseDataset +from pyhealth.datasets._medtsllm_cache import ( + compute_fingerprint, + load_or_build, +) + +logger = logging.getLogger(__name__) + +# Normal beat types (all others are anomalies) +_NORMAL_BEATS = {"N", "L", "R", "e", "j"} + +# Paced rhythm records — excluded per paper +_PACED_RECORDS = {"102", "104", "107", "217"} + +# Paper's MIT-BIH split: 80/20 by patient, seeded with NumPy legacy RNG. +_PAPER_SPLIT_RATIO = 0.8 +_PAPER_SPLIT_SEED = 0 +_VALID_SPLIT_MODES = {None, "random", "abnormal_sorted"} + +# Subdirectory under ``root`` for preprocessed ``.npz`` caches. +_PROCESSED_SUBDIR = "processed" + + +class MITBIHDataset(BaseDataset): + """MIT-BIH Arrhythmia Database for ECG analysis. + + 48 half-hour excerpts of two-channel ambulatory ECG from a mixed + population of inpatients and outpatients, digitized at 360 Hz. + 4 paced-rhythm records are excluded. + + Supports two tasks: + - Boundary detection (R-peak localization) + - Anomaly detection (arrhythmia via reconstruction error) + + Dataset is available at https://physionet.org/content/mitdb/1.0.0/ + + Paper: Moody, G.B. & Mark, R.G. "The impact of the MIT-BIH + Arrhythmia Database." IEEE EMB Magazine, 2001. + + Args: + root: Root directory of the raw MIT-BIH data. Should contain + wfdb record files (100.dat, 100.hea, etc.). + dataset_name: Name of the dataset. Default is ``"mitbih"``. + config_path: Path to the YAML config file. + dev: Whether to enable dev mode (first 5 patients). + paper_split: Split assignment strategy: + + - ``None`` (default): leave the ``split`` column blank. + - ``"random"``: 80/20 patient split via ``RandomState(0)``, + matching the paper's segmentation/boundary task setup. + - ``"abnormal_sorted"``: patients sorted by ``n_abnormal`` + ascending, with all-abnormal patients excluded. The + least-abnormal 80% become train and the most-abnormal + 20% become test. Matches the paper's anomaly setup. + preprocess: If True, decode each record once, downsample, + trim, and cache ``(signal, ann_sample, ann_symbol)`` to + ``{root}/processed/{record}.npz``. Subsequent runs skip + wfdb. Default False. + downsample_factor: Decimation factor applied to the 360 Hz + raw signal when ``preprocess=True``. Default 3 (120 Hz). + trim: When ``preprocess=True``, crop each record to the + region between its first and last beat annotation. + Matches the paper's preprocessing. Default True. + + Examples: + >>> from pyhealth.datasets import MITBIHDataset + >>> dataset = MITBIHDataset(root="/path/to/mitdb/") + >>> dataset.stat() + """ + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + dev: bool = False, + paper_split: Optional[str] = None, + preprocess: bool = False, + downsample_factor: int = 3, + trim: bool = True, + ) -> None: + if paper_split not in _VALID_SPLIT_MODES: + raise ValueError( + f"paper_split must be one of {_VALID_SPLIT_MODES}, " + f"got {paper_split!r}" + ) + if downsample_factor < 1: + raise ValueError( + f"downsample_factor must be >= 1, got {downsample_factor}" + ) + + if config_path is None: + config_path = os.path.join( + os.path.dirname(__file__), "configs", "mitbih.yaml" + ) + + metadata_path = os.path.join(root, "mitbih-pyhealth.csv") + if not os.path.exists(metadata_path): + self.prepare_metadata( + root, + dev=dev, + paper_split=paper_split, + preprocess=preprocess, + downsample_factor=downsample_factor, + trim=trim, + ) + + super().__init__( + root=root, + tables=["ecg"], + dataset_name=dataset_name or "mitbih", + config_path=config_path, + dev=dev, + ) + + @staticmethod + def prepare_metadata( + root: str, + dev: bool = False, + paper_split: Optional[str] = None, + preprocess: bool = False, + downsample_factor: int = 3, + trim: bool = True, + ) -> None: + """Prepare metadata CSV from raw wfdb files. + + Args: + root: Root directory containing wfdb files. + dev: If True, only process first 5 patients. + paper_split: ``None``, ``"random"``, or ``"abnormal_sorted"``. + See class docstring for semantics. + preprocess: If True, build per-record ``.npz`` caches and + populate the ``processed_file`` column. + downsample_factor: Decimation factor when ``preprocess=True``. + trim: When ``preprocess=True``, crop to first/last beat + annotation. + """ + import wfdb + + records = sorted( + f.replace(".dat", "") + for f in os.listdir(root) + if f.endswith(".dat") and f.replace(".dat", "") not in _PACED_RECORDS + ) + if dev: + records = records[:5] + + processed_dir = os.path.join(root, _PROCESSED_SUBDIR) + + rows = [] + for rec_name in records: + rec_path = os.path.join(root, rec_name) + try: + record = wfdb.rdrecord(rec_path) + ann = wfdb.rdann(rec_path, extension="atr") + except Exception: + continue + + patient_id = rec_name + + # Parse header for demographics + age, sex, medications = "", "", "" + if record.comments: + first = record.comments[0].strip() + tokens = first.split() + if len(tokens) >= 2: + age = tokens[0] + sex = tokens[1] + if len(record.comments) > 1: + medications = record.comments[1].strip() + + # Count beats excluding rhythm-change markers ("+") + beat_symbols = [s for s in ann.symbol if s != "+"] + n_beats = len(beat_symbols) + n_abnormal = sum( + 1 for s in beat_symbols if s not in _NORMAL_BEATS + ) + + processed_file = "" + if preprocess: + processed_file = _build_record_cache( + processed_dir=processed_dir, + rec_path=rec_path, + rec_name=rec_name, + downsample_factor=downsample_factor, + trim=trim, + ) + + rows.append({ + "patient_id": patient_id, + "signal_file": os.path.join(root, rec_name), + "annotation_file": "atr", + "age": age, + "sex": sex, + "medications": medications, + "n_abnormal": n_abnormal, + "n_beats": n_beats, + "processed_file": processed_file, + }) + + rows = _apply_paper_split(rows, paper_split) + + df = pd.DataFrame(rows) + out_path = os.path.join(root, "mitbih-pyhealth.csv") + df.to_csv(out_path, index=False) + logger.info( + "MIT-BIH metadata: %d records -> %s", len(df), out_path + ) + + @property + def default_task(self): + """Returns the default task (boundary detection).""" + from pyhealth.tasks.ecg_boundary_detection import ( + ECGBoundaryDetection, + ) + + return ECGBoundaryDetection() + + +def _build_record_cache( + processed_dir: str, + rec_path: str, + rec_name: str, + downsample_factor: int, + trim: bool, +) -> str: + """Cache downsampled + trimmed signal and annotations for one record. + + Returns the absolute cache path, which is written into the + metadata CSV's ``processed_file`` column. The cache stores: + + - ``signal``: ``(n_timesteps, n_channels)`` post-downsample + trim + - ``ann_sample``: beat sample indices **relative to the trimmed + signal** (drops annotations outside the trim range) + - ``ann_symbol``: beat symbols aligned with ``ann_sample`` + """ + cache_path = os.path.join(processed_dir, f"{rec_name}.npz") + raw_paths = [rec_path + ".dat", rec_path + ".hea", rec_path + ".atr"] + params = { + "downsample_factor": int(downsample_factor), + "trim": bool(trim), + } + fingerprint = compute_fingerprint(raw_paths, params) + + def _build() -> dict[str, np.ndarray]: + import wfdb + + record = wfdb.rdrecord(rec_path) + ann = wfdb.rdann(rec_path, extension="atr") + + signal = record.p_signal.astype(np.float32) + if downsample_factor > 1: + signal = signal[::downsample_factor] + + ds_samples = np.asarray(ann.sample) // downsample_factor + ann_symbols = np.asarray(ann.symbol) + + # Drop annotations outside the signal bounds. + in_bounds = (ds_samples >= 0) & (ds_samples < len(signal)) + ds_samples = ds_samples[in_bounds] + ann_symbols = ann_symbols[in_bounds] + + if trim and len(ds_samples) > 0: + first = int(ds_samples[0]) + last = int(ds_samples[-1]) + if first <= last: + signal = signal[first : last + 1] + ds_samples = ds_samples - first + # After shifting, last kept index is last-first. + + return { + "signal": signal, + "ann_sample": ds_samples.astype(np.int64), + "ann_symbol": ann_symbols.astype("U4"), + } + + load_or_build(cache_path, fingerprint, _build) + return cache_path + + +def _apply_paper_split( + rows: list[dict], paper_split: Optional[str] +) -> list[dict]: + """Assign each row to train/test per the paper's split strategy. + + Mutates rows in place by adding a ``split`` key. For + ``"abnormal_sorted"``, patients with every beat marked abnormal + are dropped entirely (the paper excludes them from anomaly + training). Returns the possibly filtered list. + """ + if not rows: + return rows + + if paper_split is None: + for row in rows: + row["split"] = "" + return rows + + if paper_split == "random": + rng = np.random.RandomState(_PAPER_SPLIT_SEED) + order = rng.permutation(len(rows)) + cutoff = int(len(rows) * _PAPER_SPLIT_RATIO) + split_by_rank = [ + "train" if rank < cutoff else "test" + for rank in range(len(rows)) + ] + for rank, idx in enumerate(order): + rows[idx]["split"] = split_by_rank[rank] + return rows + + if paper_split == "abnormal_sorted": + # Drop all-abnormal patients before splitting. + kept = [ + row for row in rows + if row.get("n_beats", 0) == 0 + or row["n_abnormal"] < row["n_beats"] + ] + kept.sort(key=lambda r: r["n_abnormal"]) + cutoff = int(len(kept) * _PAPER_SPLIT_RATIO) + for rank, row in enumerate(kept): + row["split"] = "train" if rank < cutoff else "test" + return kept + + raise ValueError(f"Unknown paper_split mode: {paper_split!r}") diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..528adf549 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -17,6 +17,7 @@ from .graphcare import GraphCare from .grasp import GRASP, GRASPLayer from .medlink import MedLink +from .medtsllm import MedTsLLM from .micron import MICRON, MICRONLayer from .mlp import MLP from .molerec import MoleRec, MoleRecLayer diff --git a/pyhealth/models/_medtsllm/__init__.py b/pyhealth/models/_medtsllm/__init__.py new file mode 100644 index 000000000..13f5b692e --- /dev/null +++ b/pyhealth/models/_medtsllm/__init__.py @@ -0,0 +1,8 @@ +from .layers import ( + FlattenHead, + LinearProjection, + PatchEmbedding, + ReprogrammingLayer, + RevIN, +) +from .prompt import build_prompt, compute_lags, encode_prompts diff --git a/pyhealth/models/_medtsllm/layers.py b/pyhealth/models/_medtsllm/layers.py new file mode 100644 index 000000000..540a67ff8 --- /dev/null +++ b/pyhealth/models/_medtsllm/layers.py @@ -0,0 +1,269 @@ +"""Neural network layers for MedTsLLM. + +Includes: RevIN, PatchEmbedding, ReprogrammingLayer, LinearProjection, +FlattenHead. These are the lightweight trainable components that wrap +around a frozen LLM backbone. +""" + +import math + +import torch +from torch import Tensor, nn + + +class RevIN(nn.Module): + """Reversible instance normalization for time series. + + Normalizes each feature by subtracting the mean and dividing by the + standard deviation. Stores statistics so the operation can be + reversed for reconstruction tasks. + + Args: + num_features: Number of input features/channels. + eps: Small value for numerical stability. + affine: If True, learns per-feature scale and bias. + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + affine: bool = False, + ) -> None: + super().__init__() + self.num_features = num_features + self.eps = eps + self.affine = affine + + if self.affine: + self.affine_weight = nn.Parameter(torch.ones(self.num_features)) + self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) + + self._mean: Tensor | None = None + self._stdev: Tensor | None = None + + def forward(self, x: Tensor, mode: str) -> Tensor: + """Apply normalization or denormalization. + + Args: + x: (batch, seq_len, num_features). + mode: ``"norm"`` or ``"denorm"``. + """ + if mode == "norm": + self._mean = x.mean(dim=1, keepdim=True).detach() + self._stdev = ( + (x.var(dim=1, keepdim=True, unbiased=False) + self.eps) + .sqrt() + .detach() + ) + x = (x - self._mean) / self._stdev + if self.affine: + x = x * self.affine_weight + self.affine_bias + return x + elif mode == "denorm": + if self._mean is None or self._stdev is None: + raise RuntimeError("Call forward(x, 'norm') first.") + if self.affine: + x = (x - self.affine_bias) / self.affine_weight + return x * self._stdev + self._mean + else: + raise ValueError(f"mode must be 'norm' or 'denorm', got '{mode}'") + + +class _TokenEmbedding(nn.Module): + """1D convolution over a single patch (maps patch_len -> d_model). + + Separated from ``PatchEmbedding`` as its own module so the state + dict keys (``value_embedding.conv.weight``) match the TIME-LLM / + original-MedTsLLM upstream naming. This lets you load checkpoints + trained outside of PyHealth without renaming keys. + + Args: + patch_len: Length of each patch (input channels). + d_model: Output embedding dimension. + """ + + def __init__(self, patch_len: int, d_model: int) -> None: + super().__init__() + self.conv = nn.Conv1d( + in_channels=patch_len, + out_channels=d_model, + kernel_size=3, + padding=1, + padding_mode="circular", + bias=False, + ) + nn.init.kaiming_normal_(self.conv.weight) + + def forward(self, x: Tensor) -> Tensor: + """Embed patches via 1D convolution. + + Args: + x: (batch, seq_len, patch_len). + + Returns: + Embedded patches: (batch, seq_len, d_model). + """ + return self.conv(x.permute(0, 2, 1)).transpose(1, 2) + + +class PatchEmbedding(nn.Module): + """Unfolds a time series into overlapping patches and embeds via Conv1d. + + Args: + d_model: Embedding dimension for each patch. + patch_len: Length of each patch in timesteps. + stride: Stride between consecutive patches. + dropout: Dropout probability. + """ + + def __init__( + self, + d_model: int, + patch_len: int, + stride: int, + dropout: float = 0.1, + ) -> None: + super().__init__() + self.patch_len = patch_len + self.stride = stride + + self.padding = nn.ReplicationPad1d((0, stride)) + self.value_embedding = _TokenEmbedding(patch_len, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: Tensor) -> tuple[Tensor, int]: + """Embed time series into patch representations. + + Args: + x: (batch, n_features, seq_len). + + Returns: + Tuple of (patch_embeddings, n_features). + patch_embeddings: (batch * n_features, n_patches, d_model). + """ + n_vars = x.shape[1] + x = self.padding(x) + x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) + x = torch.reshape( + x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]) + ) + x = self.value_embedding(x) + return self.dropout(x), n_vars + + +class ReprogrammingLayer(nn.Module): + """Cross-attention reprogramming layer. + + Projects patch embeddings (queries) against word prototype + embeddings (keys/values) using multi-head attention. This is + the core mechanism of MedTsLLM / Time-LLM. + + Args: + d_model: Dimension of input patch embeddings. + n_heads: Number of attention heads. + d_keys: Dimension per head for keys. + d_llm: Dimension of LLM embeddings. + attention_dropout: Dropout on attention weights. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + d_keys: int, + d_llm: int, + attention_dropout: float = 0.1, + ) -> None: + super().__init__() + self.n_heads = n_heads + self.d_keys = d_keys + + self.query_projection = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_llm, d_keys * n_heads) + self.value_projection = nn.Linear(d_llm, d_keys * n_heads) + self.out_projection = nn.Linear(d_keys * n_heads, d_llm) + self.dropout = nn.Dropout(attention_dropout) + + def forward( + self, + target_embedding: Tensor, + source_embedding: Tensor, + value_embedding: Tensor, + ) -> Tensor: + """Reprogram patch embeddings via cross-attention. + + Args: + target_embedding: Queries, (batch, n_patches, d_model). + source_embedding: Keys, (n_tokens, d_llm). + value_embedding: Values, (n_tokens, d_llm). + + Returns: + Reprogrammed embeddings: (batch, n_patches, d_llm). + """ + b, seq_len, _ = target_embedding.shape + s, _ = source_embedding.shape + h = self.n_heads + + queries = self.query_projection(target_embedding).view( + b, seq_len, h, -1 + ) + keys = self.key_projection(source_embedding).view(s, h, -1) + values = self.value_projection(value_embedding).view(s, h, -1) + + scale = 1.0 / math.sqrt(queries.shape[-1]) + scores = torch.einsum("blhe,she->bhls", queries, keys) + attn = self.dropout(torch.softmax(scale * scores, dim=-1)) + out = torch.einsum("bhls,she->blhe", attn, values) + out = out.reshape(b, seq_len, -1) + return self.out_projection(out) + + +class LinearProjection(nn.Module): + """Simple linear projection as ablation alternative. + + Replaces ReprogrammingLayer with a plain linear map from + d_model to d_llm, ignoring word embeddings. + + Args: + d_model: Input dimension. + d_llm: Output dimension. + """ + + def __init__(self, d_model: int, d_llm: int) -> None: + super().__init__() + self.linear = nn.Linear(d_model, d_llm) + + def forward( + self, + target_embedding: Tensor, + source_embedding: Tensor, + value_embedding: Tensor, + ) -> Tensor: + """Project patch embeddings linearly to LLM dimension.""" + return self.linear(target_embedding) + + +class FlattenHead(nn.Module): + """Flatten patch dimension and project to output size. + + Args: + n_features_in: Total input features (d_ff * n_patches). + n_outputs: Total output size (pred_len * n_outputs_per_step). + head_dropout: Dropout probability. + """ + + def __init__( + self, + n_features_in: int, + n_outputs: int, + head_dropout: float = 0.0, + ) -> None: + super().__init__() + self.flatten = nn.Flatten(start_dim=-2) + self.linear = nn.Linear(n_features_in, n_outputs) + self.dropout = nn.Dropout(head_dropout) + + def forward(self, x: Tensor) -> Tensor: + """Args: x: (batch, d_ff, n_patches). Returns: (batch, n_outputs).""" + return self.dropout(self.linear(self.flatten(x))) diff --git a/pyhealth/models/_medtsllm/prompt.py b/pyhealth/models/_medtsllm/prompt.py new file mode 100644 index 000000000..f923aca78 --- /dev/null +++ b/pyhealth/models/_medtsllm/prompt.py @@ -0,0 +1,204 @@ +"""Text prompt construction and encoding for MedTsLLM. + +Mirrors the prompt structure from the original paper implementation +(Chan et al., MLHC 2024, https://github.com/flixpar/med-ts-llm): + + [BOS] [dataset desc] [clip desc] [input stats] [task desc] [Time series:] + +The four content segments (dataset, clip, stats, task) are toggled +independently by the caller. +""" + +from typing import Any + +import torch +from torch import Tensor + + +def compute_lags(x: Tensor, n_lags: int = 5) -> Tensor: + """Compute top-N autocorrelation lags via FFT. + + Matches the original paper's ``calcute_lags``. Power spectral + density is computed as ``rfft(x) * conj(rfft(x))``, inverse- + transformed back to autocorrelation, averaged across features, + and the top-N lag indices are selected. + + Args: + x: (batch, seq_len) or (batch, seq_len, n_features). + n_lags: Number of top lags to return. + + Returns: + (batch, n_lags) long tensor of lag indices. + """ + if x.ndim == 3: + x = x.permute(0, 2, 1).contiguous() + else: + x = x.unsqueeze(1) + q_fft = torch.fft.rfft(x, dim=-1) + corr = torch.fft.irfft(q_fft * torch.conj(q_fft), dim=-1) + mean_corr = corr.mean(dim=1) + _, lags = torch.topk(mean_corr, n_lags, dim=-1) + return lags + + +def build_prompt( + inputs: dict[str, Any], + *, + dataset_description: str = "", + task_description: str = "", + include_dataset: bool = True, + include_task: bool = True, + include_clip: bool = False, + include_stats: bool = False, + n_lags: int = 5, + bos_token: str = "", +) -> list[list[str]]: + """Build text prompts for a batch of inputs. + + Args: + inputs: Dict with ``"x_enc"`` (batch, seq_len, n_features). May + include ``"descriptions"`` (per-sample strings) when + ``include_clip`` is True. + dataset_description: Dataset-level description string. + task_description: Task-level description string. + include_dataset: Whether to include dataset description. + include_task: Whether to include task description. + include_clip: Whether to include per-sample descriptions. + include_stats: Whether to include per-sample input statistics + (min/max/median/trend/top-N autocorr lags). + n_lags: Number of autocorrelation lags for the stats prompt. + bos_token: Beginning-of-sequence token string. + + Returns: + List of string lists, one per batch element. + """ + x_enc = inputs["x_enc"] + bs = x_enc.shape[0] + + dataset_prompt = ( + f"Dataset: {dataset_description}" if include_dataset else "" + ) + task_prompt = f"Task: {task_description}" if include_task else "" + + clip_prompts = ( + inputs.get("descriptions", [""] * bs) if include_clip else [""] * bs + ) + + if include_stats: + stats_prompts = _build_stats_prompts(x_enc, n_lags) + else: + stats_prompts = [""] * bs + + prompts = [] + for b in range(bs): + parts = [ + bos_token, + dataset_prompt, + clip_prompts[b], + stats_prompts[b], + task_prompt, + "Time series:", + ] + parts = [p for p in parts if p != ""] + parts = [ + (p + " " if isinstance(p, str) and i > 0 else p) + for i, p in enumerate(parts) + ] + prompts.append(parts) + + return prompts + + +def encode_prompts( + prompts: list[list[str]], + tokenizer, + embedding_layer, + device: torch.device, +) -> Tensor: + """Encode text prompts into embedding tensors. + + Args: + prompts: List of string lists from ``build_prompt``. + tokenizer: HuggingFace tokenizer. + embedding_layer: LLM's input embedding layer. + device: Device for tensors. + + Returns: + Padded prompt embeddings: (batch, max_tokens, d_llm). + """ + batch_embeddings = [] + + for parts in prompts: + part_embeddings = [] + for part in parts: + ids = tokenizer( + part, + return_tensors="pt", + padding=False, + truncation=False, + ).input_ids.to(device) + emb = embedding_layer(ids) + part_embeddings.append(emb) + + combined = torch.cat(part_embeddings, dim=1) + batch_embeddings.append(combined) + + max_len = max(e.shape[1] for e in batch_embeddings) + d_llm = batch_embeddings[0].shape[2] + + pad_id = tokenizer.pad_token_id + if pad_id is None: + pad_id = tokenizer.eos_token_id or 0 + pad_emb = embedding_layer(torch.tensor([pad_id], device=device)) + + padded = [] + for emb in batch_embeddings: + if emb.shape[1] < max_len: + pad_len = max_len - emb.shape[1] + pad = pad_emb.expand(1, pad_len, d_llm) + emb = torch.cat([pad, emb], dim=1) + padded.append(emb) + + return torch.cat(padded, dim=0) + + +def _build_stats_prompts(x_enc: Tensor, n_lags: int) -> list[str]: + """Build per-sample input statistics prompt strings. + + Reports per-feature min, max, median, trend direction, and the + top-N autocorrelation lags. Matches the format in the original + paper implementation. + """ + xs = x_enc.detach() + if xs.ndim == 2: + xs = xs.unsqueeze(-1) + + with torch.no_grad(): + mins = torch.min(xs, dim=1).values.tolist() + maxs = torch.max(xs, dim=1).values.tolist() + medians = torch.median(xs.float(), dim=1).values.tolist() + trends = (xs.diff(dim=1).sum(dim=1) > 0).tolist() + lags = compute_lags(xs.float(), n_lags).tolist() + + def fmt(v: Any) -> str: + if isinstance(v, list): + return "[" + ", ".join(fmt(x) for x in v) + "]" + if isinstance(v, bool): + return "upward" if v else "downward" + if isinstance(v, float): + return f"{v:.3f}" + return str(v) + + prompts = [] + for b in range(xs.shape[0]): + prompt = ( + f"Input statistics (per feature): " + f"min value = {fmt(mins[b])}, " + f"max value = {fmt(maxs[b])}, " + f"median value = {fmt(medians[b])}, " + f"the trend of input is {fmt(trends[b])}, " + f"the top {n_lags} lags are {lags[b]}." + ) + prompts.append(prompt) + + return prompts diff --git a/pyhealth/models/medtsllm.py b/pyhealth/models/medtsllm.py new file mode 100644 index 000000000..7afa18730 --- /dev/null +++ b/pyhealth/models/medtsllm.py @@ -0,0 +1,586 @@ +# Author: Anton Barchukov +# Paper: Chan et al., "MedTsLLM: Leveraging LLMs for Multimodal +# Medical Time Series Analysis", MLHC 2024 +# Paper link: https://arxiv.org/abs/2408.07773 +# Original repo: https://github.com/flixpar/med-ts-llm +# Description: Repurposes a frozen pretrained LLM as a feature +# extractor for medical time series. Raw signals are patched, +# projected into the LLM's embedding space via cross-attention +# (reprogramming layer), and decoded by a lightweight task head. + +import warnings +from typing import Dict, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel +from pyhealth.models._medtsllm import ( + FlattenHead, + PatchEmbedding, + ReprogrammingLayer, + RevIN, + build_prompt, + encode_prompts, +) + +_SUPPORTED_TASKS = { + "semantic_segmentation", + "segmentation", + "anomaly_detection", + "reconstruction", + "forecasting", + "pretraining", +} + + +class MedTsLLM(BaseModel): + """MedTsLLM: LLM-based medical time series model. + + Pipeline: RevIN -> PatchEmbedding -> Reprogramming -> LLM -> OutputHead + + Repurposes a frozen pretrained LLM (e.g., GPT-2, Qwen2.5) as a + feature extractor for medical time series tasks. The LLM weights + are never updated — only the reprogramming layer, patch embedding, + and output head are trained (~1-2M parameters). + + Paper: Chan, N. et al. "MedTsLLM: Leveraging LLMs for Multimodal + Medical Time Series Analysis." MLHC 2024. + + Note: forward-pass task branching (binary segmentation, anomaly + reconstruction) is not yet wired — only ``semantic_segmentation`` + is fully supported end-to-end. The ``task`` argument currently + drives the default task-description prompt only. + + Args: + dataset: PyHealth SampleDataset with ``signal`` input and + ``label`` output. + task: One of ``"semantic_segmentation"``, ``"segmentation"``, + ``"anomaly_detection"``, ``"reconstruction"``, + ``"forecasting"``, ``"pretraining"``. Drives the default + task-description prompt. Default + ``"semantic_segmentation"``. + seq_len: Input sequence length. Default 512. + n_features: Number of input channels. Default 1. + n_classes: Number of output classes for segmentation. + Default 4. + backbone: HuggingFace model ID for the LLM backbone. + Set to ``None`` to use a lightweight replacement + (for testing without model downloads). Default + ``"openai-community/gpt2"``. + d_model: Patch embedding dimension. Default 32. + d_ff: Feedforward / output head hidden dimension. Default 128. + n_heads: Attention heads in reprogramming layer. Default 8. + num_tokens: Number of word prototype tokens. Default 1024. + patch_len: Length of each patch. Default 16. + stride: Stride between patches. Default 8. + dropout: Dropout probability. Default 0.1. + covariate_mode: ``"univariate"`` or ``"concat"``. Default + ``"univariate"``. + reprogramming_layer: Optional module to replace the default + ``ReprogrammingLayer``. Must accept ``(target, source, + value)`` and return ``(batch, n_patches, d_llm)``. Use + ``LinearProjection`` for the no-reprogramming ablation. + dataset_description: Text description for prompting. + task_description: Text description for prompting. If empty, + a task-appropriate default is generated. + prompt_dataset: Include dataset prompt. Default True. + prompt_task: Include task prompt. Default True. + prompt_patient: Include per-patient description (age, sex, + diagnoses, medications) in prompt. Requires samples to + include a ``description`` field. Default True. + prompt_stats: Include per-sample input statistics + (min/max/median/trend/top-N autocorr lags). Default False + to match the cs598-pyhealth reference's dtp config; the + paper's ``input_stats`` prompt is an optional extra. + n_lags: Number of autocorrelation lags in the stats prompt. + Default 5. + llm_dtype: Torch dtype for LLM weights. Default float32. + word_embeddings: Pre-loaded word embeddings tensor. Required + when ``backbone`` is None. + + Examples: + >>> from pyhealth.models import MedTsLLM + >>> model = MedTsLLM( + ... dataset=sample_dataset, + ... backbone="openai-community/gpt2", + ... seq_len=512, + ... n_classes=4, + ... ) + """ + + def __init__( + self, + dataset: SampleDataset, + task: str = "semantic_segmentation", + seq_len: int = 512, + n_features: int = 1, + n_classes: int = 4, + backbone: Optional[str] = "openai-community/gpt2", + d_model: int = 32, + d_ff: int = 128, + n_heads: int = 8, + num_tokens: int = 1024, + patch_len: int = 16, + stride: int = 8, + dropout: float = 0.1, + covariate_mode: str = "univariate", + reprogramming_layer: Optional[nn.Module] = None, + dataset_description: str = "", + task_description: str = "", + prompt_dataset: bool = True, + prompt_task: bool = True, + prompt_patient: bool = True, + prompt_stats: bool = False, + n_lags: int = 5, + llm_dtype: torch.dtype = torch.float32, + word_embeddings: Optional[Tensor] = None, + ): + super(MedTsLLM, self).__init__(dataset=dataset) + + if task not in _SUPPORTED_TASKS: + raise ValueError( + f"task must be one of {sorted(_SUPPORTED_TASKS)}, " + f"got {task!r}" + ) + + self.task = task + self.seq_len = seq_len + self.pred_len = seq_len + self.n_features = n_features + self.n_classes = n_classes + self.d_model = d_model + self.d_ff = d_ff + self.covariate_mode = covariate_mode + self.n_lags = n_lags + + # Compute patch count + self.n_patches = (seq_len - patch_len) // stride + 2 + + # Effective d_model for concat covariate mode + d_model_effective = d_model + if covariate_mode == "concat": + d_model_effective = d_model * n_features + + # Setup LLM backbone or replacement + if backbone is not None: + self._setup_llm(backbone, llm_dtype) + elif word_embeddings is not None: + self._setup_replacement(word_embeddings) + else: + raise ValueError( + "Either backbone or word_embeddings must be provided." + ) + + # Trainable layers + self.normalize_layers = RevIN(n_features, affine=False) + self.patch_embedding = PatchEmbedding( + d_model=d_model, + patch_len=patch_len, + stride=stride, + dropout=dropout, + ) + self.mapping_layer = nn.Linear(self.vocab_size, num_tokens) + + if reprogramming_layer is not None: + self.reprogramming_layer = reprogramming_layer + else: + self.reprogramming_layer = ReprogrammingLayer( + d_model=d_model_effective, + n_heads=n_heads, + d_keys=d_ff, + d_llm=self.d_llm, + attention_dropout=dropout, + ) + + self.embedding_downsample = nn.Linear(self.d_llm, d_ff) + + # Output head size depends on task: + # semantic_segmentation => one logit per class per step + # segmentation => a single binary logit per step + # anomaly_detection / + # reconstruction => one value per feature per step + # forecasting / + # pretraining => same as reconstruction + self.n_outputs_per_step = self._compute_n_outputs_per_step( + task, n_classes, n_features + ) + self.output_projection = FlattenHead( + n_features_in=d_ff * self.n_patches, + n_outputs=self.n_outputs_per_step * self.pred_len, + ) + + # Prompting config + self.dataset_description = dataset_description + self.task_description = ( + task_description + or self._default_task_description(task, seq_len, self.pred_len) + ) + self.prompt_config = { + "dataset": prompt_dataset, + "task": prompt_task, + "patient": prompt_patient, + "stats": prompt_stats, + } + + @staticmethod + def _compute_n_outputs_per_step( + task: str, n_classes: int, n_features: int + ) -> int: + """Resolve the per-timestep output dimension from the task.""" + if task == "semantic_segmentation": + return n_classes + if task == "segmentation": + return 1 + if task in ( + "anomaly_detection", + "reconstruction", + "forecasting", + "pretraining", + ): + return n_features + raise ValueError(f"Unsupported task: {task!r}") + + @staticmethod + def _default_task_description( + task: str, seq_len: int, pred_len: int + ) -> str: + """Generate a task-appropriate default task description. + + Mirrors the original paper implementation's + ``get_task_description``. + """ + if task in ("forecasting", "pretraining"): + return ( + f"Forecast the next {pred_len} steps given the " + f"previous {seq_len} steps of data." + ) + if task in ("anomaly_detection", "reconstruction"): + return ( + f"Reconstruct the past {seq_len} steps of data as " + "accurately as possible using the following " + "information." + ) + if task == "semantic_segmentation": + return ( + f"Classify the past {seq_len} steps of data as " + "accurately as possible using the following " + "information." + ) + if task == "segmentation": + return ( + f"Identify the change points in the past {seq_len} " + "steps of data to segment the sequence." + ) + return "" + + def parameters(self, recurse: bool = True): + """Yield only trainable parameters. + + Overrides nn.Module.parameters() to exclude the frozen LLM + backbone. This prevents PyHealth's Trainer from allocating + optimizer state (momentum, variance) for frozen parameters, + saving ~2x the frozen param count in memory. + + For GPT-2 (137M) this saves ~1GB. For Qwen2.5-1.5B it + saves ~12GB. + """ + for p in super().parameters(recurse=recurse): + if p.requires_grad: + yield p + + def named_parameters( + self, + prefix: str = "", + recurse: bool = True, + remove_duplicate: bool = True, + ): + """Yield only trainable named parameters. + + See parameters() for rationale. + """ + for name, p in super().named_parameters( + prefix=prefix, + recurse=recurse, + remove_duplicate=remove_duplicate, + ): + if p.requires_grad: + yield name, p + + def _setup_llm(self, backbone: str, dtype: torch.dtype) -> None: + """Load a frozen HuggingFace LLM as backbone.""" + from transformers import AutoConfig, AutoModel, AutoTokenizer + + llm_config = AutoConfig.from_pretrained(backbone) + llm_config.output_hidden_states = True + + self.llm = AutoModel.from_pretrained( + backbone, + config=llm_config, + torch_dtype=dtype, + attn_implementation="sdpa", + ) + self.tokenizer = AutoTokenizer.from_pretrained(backbone) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Freeze LLM + for param in self.llm.parameters(): + param.requires_grad = False + + # Extract word embeddings + we = self.llm.get_input_embeddings().weight.detach().cpu() + if we.shape[0] > 100_000: + inds = torch.linspace( + 0, we.shape[0] - 1, 100_000, dtype=torch.long + ) + we = we[inds].clone() + self.word_embeddings = nn.Parameter(we, requires_grad=False) + self.vocab_size = self.word_embeddings.shape[0] + self.d_llm = self.word_embeddings.shape[1] + self._use_llm = True + + def _setup_replacement(self, word_embeddings: Tensor) -> None: + """Setup a small feedforward network replacing the LLM.""" + self.word_embeddings = nn.Parameter( + word_embeddings.detach().clone(), requires_grad=False + ) + self.vocab_size = word_embeddings.shape[0] + self.d_llm = word_embeddings.shape[1] + self.llm_replacement = nn.Sequential( + nn.Linear(self.d_llm, self.d_llm), + nn.GELU(), + nn.Linear(self.d_llm, self.d_llm), + ) + self.tokenizer = None + self._use_llm = False + + def forward(self, **kwargs) -> Dict[str, Tensor]: + """Forward pass following PyHealth's BaseModel contract. + + Args: + **kwargs: Must include ``signal`` tensor of shape + ``(batch, seq_len)`` or ``(batch, seq_len, n_features)``. + May include ``label`` tensor for loss computation and + ``description`` list of per-sample strings for the + patient prompt. + + Returns: + Dict with keys: ``logit``, ``y_prob``, and optionally + ``loss``, ``y_true``. + """ + signal = kwargs[self.feature_keys[0]].to(self.device) + if signal.ndim == 2: + signal = signal.unsqueeze(-1) + bs = signal.shape[0] + + # Encode time series + enc_out = self._encode_ts(signal) + + # Build + prepend prompt (only when a real LLM is attached) + if self._use_llm and any(self.prompt_config.values()): + prompt_enc = self._build_prompt_embeddings(signal, bs, kwargs) + enc_out = torch.cat([prompt_enc, enc_out], dim=1) + + # LLM or replacement forward + if self._use_llm: + dec_out = self.llm(inputs_embeds=enc_out).last_hidden_state + dec_out = dec_out.to(device=signal.device, dtype=signal.dtype) + else: + dec_out = self.llm_replacement(enc_out) + + # Keep last n_patches outputs + dec_out = dec_out[:, -self.n_patches :, :] + + # Downsample and project + dec_out = self.embedding_downsample(dec_out) + dec_out = dec_out.permute(0, 2, 1).contiguous() + dec_out = self.output_projection(dec_out) + + # Reshape to (bs, pred_len, n_outputs_per_step) + dec_out = dec_out.view( + bs, self.pred_len, self.n_outputs_per_step + ) + + label_key = self.label_keys[0] if self.label_keys else None + y_true = ( + kwargs[label_key].to(self.device) + if label_key and label_key in kwargs + else None + ) + + return self._task_head(dec_out, signal, y_true) + + def _task_head( + self, + dec_out: Tensor, + signal: Tensor, + y_true: Optional[Tensor], + ) -> Dict[str, Tensor]: + """Apply task-specific post-processing and loss. + + Branches on ``self.task``: + + - ``semantic_segmentation``: softmax probs, cross-entropy loss. + - ``segmentation``: binary logit per step, BCE-with-logits loss. + - ``anomaly_detection`` / ``reconstruction``: RevIN-denormalized + reconstruction in the original signal space, MSE loss against + the input signal. + - ``forecasting`` / ``pretraining``: denormalized prediction, + MSE loss against the input signal (placeholder — true + forecasting needs a future-signal label). + """ + output: Dict[str, Tensor] = {} + if y_true is not None: + output["y_true"] = y_true + + if self.task == "semantic_segmentation": + # dec_out: (bs, pred_len, n_classes) + logit = dec_out + output["logit"] = logit + output["y_prob"] = F.softmax(logit, dim=-1) + if y_true is not None: + output["loss"] = F.cross_entropy( + logit.view(-1, self.n_classes), + y_true.view(-1).long(), + ) + return output + + if self.task == "segmentation": + # dec_out: (bs, pred_len, 1) -> (bs, pred_len) + logit = dec_out.squeeze(-1) + output["logit"] = logit + output["y_prob"] = torch.sigmoid(logit) + if y_true is not None: + output["loss"] = F.binary_cross_entropy_with_logits( + logit, y_true.float() + ) + return output + + if self.task in ( + "anomaly_detection", + "reconstruction", + "forecasting", + "pretraining", + ): + # dec_out: (bs, pred_len, n_features) — denormalize to + # recover original signal space before computing loss. + prediction = self.normalize_layers(dec_out, "denorm") + output["logit"] = prediction + output["y_prob"] = prediction + # Train to reconstruct the input signal. For + # anomaly_detection, labels (anomaly masks) are used at + # eval time for scoring — not during training. + output["loss"] = F.mse_loss(prediction, signal) + return output + + raise ValueError(f"Unsupported task in forward: {self.task!r}") + + def _build_prompt_embeddings( + self, signal: Tensor, bs: int, kwargs: Dict + ) -> Tensor: + """Construct and encode the prompt for the current batch. + + No caching: prompts are rebuilt every forward pass to match + the original paper implementation. Dataset/task/stats prompts + are cheap to re-tokenize; patient prompts depend on batch + contents and can't be cached anyway. + """ + include_patient = self.prompt_config.get("patient", False) + bos = getattr(self.tokenizer, "bos_token", None) or "" + + prompt_inputs: Dict = {"x_enc": signal} + + if include_patient: + description = kwargs.get("description") + if description is None: + warnings.warn( + "prompt_patient=True but no 'description' field " + "provided in batch. Patient prompt will be empty. " + "Ensure your task emits 'description' per sample.", + RuntimeWarning, + stacklevel=3, + ) + descriptions = [""] * bs + else: + descriptions = _coerce_descriptions(description, bs) + prompt_inputs["descriptions"] = descriptions + + prompts = build_prompt( + prompt_inputs, + dataset_description=self.dataset_description, + task_description=self.task_description, + include_dataset=self.prompt_config.get("dataset", False), + include_task=self.prompt_config.get("task", False), + include_clip=include_patient, + include_stats=self.prompt_config.get("stats", False), + n_lags=self.n_lags, + bos_token=bos, + ) + + with torch.no_grad(): + return encode_prompts( + prompts, + self.tokenizer, + self.llm.get_input_embeddings(), + signal.device, + ) + + def _encode_ts(self, x_enc: Tensor) -> Tensor: + """Encode time series: normalize -> patch -> reprogram. + + Args: + x_enc: (batch, seq_len, n_features). + + Returns: + Encoded representation: (batch, n_patches, d_llm). + """ + bs, seq_len, n_features = x_enc.shape + + x_enc = self.normalize_layers(x_enc, "norm") + x_enc = x_enc.permute(0, 2, 1).contiguous() + + enc_out, _ = self.patch_embedding(x_enc) + + # Project word embeddings to prototypes + we = self.word_embeddings.to(self.mapping_layer.weight.dtype) + source_embeddings = self.mapping_layer( + we.permute(1, 0) + ).permute(1, 0) + + # Handle covariate modes + if self.covariate_mode == "concat": + enc_out = enc_out.reshape( + bs, n_features, self.n_patches, self.d_model + ) + enc_out = enc_out.permute(0, 2, 1, 3) + enc_out = enc_out.reshape( + bs, self.n_patches, n_features * self.d_model + ) + + enc_out = self.reprogramming_layer( + enc_out, source_embeddings, source_embeddings + ) + + return enc_out + + +def _coerce_descriptions( + description, bs: int +) -> list[str]: + """Normalize a batch description field into a list[str] of length bs. + + Handles the common shapes produced by PyHealth's default collate + (a list of strings) as well as scalar-string edge cases. + """ + if isinstance(description, str): + return [description] * bs + if isinstance(description, list): + return description + try: + return list(description) + except TypeError: + return [description] * bs diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..fd3682a27 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -22,6 +22,10 @@ drug_recommendation_mimic4_fn, drug_recommendation_omop_fn, ) +from .ecg_anomaly_detection import ECGAnomalyDetection +from .ecg_boundary_detection import ECGBoundaryDetection +from .ecg_wave_segmentation import ECGWaveSegmentation +from .respiratory_boundary_detection import RespiratoryBoundaryDetection from .in_hospital_mortality_mimic4 import InHospitalMortalityMIMIC4 from .length_of_stay_prediction import ( LengthOfStayPredictioneICU, diff --git a/pyhealth/tasks/ecg_anomaly_detection.py b/pyhealth/tasks/ecg_anomaly_detection.py new file mode 100644 index 000000000..52e1e3163 --- /dev/null +++ b/pyhealth/tasks/ecg_anomaly_detection.py @@ -0,0 +1,152 @@ +# Author: Anton Barchukov +# Paper: Chan et al., "MedTsLLM: Leveraging LLMs for Multimodal +# Medical Time Series Analysis", MLHC 2024 +# Description: Beat-level anomaly detection task for the MIT-BIH +# Arrhythmia Database. Each timestep is labeled 1 if it falls +# within a beat interval annotated as an abnormal rhythm type, +# and 0 otherwise (including non-beat gaps). Training is +# reconstruction-style: MedTsLLM (task="anomaly_detection") +# learns to reconstruct the ECG and anomalous beats are flagged +# at eval time by elevated reconstruction error. + +from typing import Any, Dict + +import numpy as np + +from pyhealth.tasks import BaseTask +from pyhealth.tasks.ecg_boundary_detection import ( + _load_signal_and_annotations, +) + +# Rhythm annotation symbols considered "normal" — everything else +# is an abnormal beat in the MedTsLLM paper's setup. +_NORMAL_BEATS = {"N", "L", "R", "e", "j"} + + +class ECGAnomalyDetection(BaseTask): + """Beat-level arrhythmia anomaly detection on MIT-BIH ECG. + + Produces per-timestep binary labels over a downsampled 2-channel + ECG. A timestep is labeled ``1`` when it falls inside the + interval of an abnormal-type beat annotation (all symbols other + than ``{"N", "L", "R", "e", "j"}`` and the rhythm-change marker + ``"+"``) and ``0`` otherwise. + + Signal decoding, downsampling, and optional trimming are handled + at the dataset level via :class:`MITBIHDataset`'s ``preprocess``, + ``downsample_factor``, and ``trim`` kwargs. + + Args: + window_size: Number of time points per window. Default 128. + step_size: Stride between consecutive windows. Default 128. + + Attributes: + task_name (str): ``"ECGAnomalyDetection"`` + input_schema (Dict[str, str]): ``{"signal": "tensor"}`` + output_schema (Dict[str, str]): ``{"label": "tensor"}`` + + Examples: + >>> from pyhealth.datasets import MITBIHDataset + >>> from pyhealth.tasks import ECGAnomalyDetection + >>> dataset = MITBIHDataset( + ... root="/path/to/mitdb/", + ... preprocess=True, + ... paper_split="abnormal_sorted", + ... ) + >>> sample_ds = dataset.set_task(ECGAnomalyDetection()) + """ + + task_name: str = "ECGAnomalyDetection" + input_schema: Dict[str, str] = {"signal": "tensor"} + output_schema: Dict[str, str] = {"label": "tensor"} + + def __init__( + self, + window_size: int = 128, + step_size: int = 128, + ): + self.window_size = window_size + self.step_size = step_size + super().__init__() + + def __call__(self, patient: Any) -> list[dict[str, Any]]: + """Process a single patient into windowed samples.""" + pid = patient.patient_id + events = patient.get_events() + + samples = [] + for event in events: + result = _load_signal_and_annotations(event) + if result is None: + continue + signal, ann_sample, ann_symbol = result + + labels = _build_anomaly_mask( + ann_sample, ann_symbol, len(signal) + ) + + description = _build_description(event) + split = getattr(event, "split", "") or "" + + for start in range( + 0, len(signal) - self.window_size + 1, self.step_size + ): + end = start + self.window_size + samples.append({ + "patient_id": pid, + "signal": signal[start:end], + "label": labels[start:end].astype(np.float32), + "description": description, + "split": split, + }) + + return samples + + +def _build_anomaly_mask( + ann_sample: np.ndarray, + ann_symbol: np.ndarray, + signal_len: int, +) -> np.ndarray: + """Label each timestep 1 if inside an abnormal beat interval. + + Rhythm-change markers (``"+"``) are ignored. For beat ``i`` with + sample ``s_i``, the interval ``[s_i, s_{i+1})`` is marked 1 when + ``symbol_i`` is an abnormal beat type. The last beat extends to + ``signal_len``. + """ + labels = np.zeros(signal_len, dtype=np.int64) + # Filter out rhythm-change markers to match the paper's labeling. + beat_mask = np.array( + [str(s) != "+" for s in ann_symbol], dtype=bool + ) + beats_sample = np.asarray(ann_sample, dtype=np.int64)[beat_mask] + beats_symbol = np.asarray(ann_symbol)[beat_mask] + + for i, symbol in enumerate(beats_symbol): + if str(symbol) in _NORMAL_BEATS: + continue + start = max(0, int(beats_sample[i])) + if i + 1 < len(beats_sample): + end = int(beats_sample[i + 1]) + else: + end = signal_len + end = min(end, signal_len) + if start < end: + labels[start:end] = 1 + + return labels + + +def _build_description(event) -> str: + """Compose per-patient description from event demographics.""" + parts: list[str] = [] + for attr in ("age", "sex", "medications"): + value = getattr(event, attr, "") + if value is None: + continue + text = str(value).strip() + if not text or text.lower() == "nan": + continue + parts.append(f"{attr}: {text}") + return ", ".join(parts) diff --git a/pyhealth/tasks/ecg_boundary_detection.py b/pyhealth/tasks/ecg_boundary_detection.py new file mode 100644 index 000000000..0af98273d --- /dev/null +++ b/pyhealth/tasks/ecg_boundary_detection.py @@ -0,0 +1,151 @@ +# Author: Anton Barchukov +# Paper: Chan et al., "MedTsLLM: Leveraging LLMs for Multimodal +# Medical Time Series Analysis", MLHC 2024 +# Description: R-peak boundary detection task for the MIT-BIH dataset. +# Detects beat boundaries (R-peaks) in ECG signals. + +from typing import Any, Dict + +import numpy as np + +from pyhealth.tasks import BaseTask + +# Normal beat types (all others are anomalies) +_NORMAL_BEATS = {"N", "L", "R", "e", "j"} + + +class ECGBoundaryDetection(BaseTask): + """R-peak boundary detection on MIT-BIH ECG signals. + + Binary classification of each time point as a beat boundary + (R-peak) or not. Signal decoding, downsampling, and optional + trimming are handled at the dataset level via + :class:`MITBIHDataset`'s ``preprocess``, ``downsample_factor``, + and ``trim`` kwargs. + + Args: + window_size: Number of time points per window. Default 256. + step_size: Stride between consecutive windows. Default 256. + + Attributes: + task_name (str): ``"ECGBoundaryDetection"`` + input_schema (Dict[str, str]): ``{"signal": "tensor"}`` + output_schema (Dict[str, str]): ``{"label": "tensor"}`` + + Examples: + >>> from pyhealth.datasets import MITBIHDataset + >>> dataset = MITBIHDataset( + ... root="/path/to/mitdb/", preprocess=True + ... ) + >>> sample_dataset = dataset.set_task(ECGBoundaryDetection()) + """ + + task_name: str = "ECGBoundaryDetection" + input_schema: Dict[str, str] = {"signal": "tensor"} + output_schema: Dict[str, str] = {"label": "tensor"} + + def __init__( + self, + window_size: int = 256, + step_size: int = 256, + ): + self.window_size = window_size + self.step_size = step_size + super().__init__() + + def __call__(self, patient: Any) -> list[dict[str, Any]]: + """Process a single patient into windowed samples. + + Args: + patient: A Patient object from MITBIHDataset. + + Returns: + List of sample dicts with signal and binary label arrays. + """ + pid = patient.patient_id + events = patient.get_events() + + samples = [] + for event in events: + result = _load_signal_and_annotations(event) + if result is None: + continue + signal, ann_sample, _ = result + + # Build binary R-peak mask. + labels = np.zeros(len(signal), dtype=np.int64) + for s in ann_sample: + idx = int(s) + if 0 <= idx < len(signal): + labels[idx] = 1 + + description = _build_description(event) + split = getattr(event, "split", "") or "" + + for start in range( + 0, len(signal) - self.window_size + 1, self.step_size + ): + end = start + self.window_size + samples.append({ + "patient_id": pid, + "signal": signal[start:end], + "label": labels[start:end].astype(np.float32), + "description": description, + "split": split, + }) + + return samples + + +def _load_signal_and_annotations( + event, + default_downsample: int = 3, +) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None: + """Return (signal, ann_sample, ann_symbol) from cache or wfdb. + + Cache path is preferred. Wfdb fallback downsamples by + ``default_downsample`` but does not trim — trim only applies to + cached arrays built via ``MITBIHDataset(preprocess=True, trim=...)``. + """ + import os + + processed_file = getattr(event, "processed_file", "") or "" + if processed_file and os.path.exists(processed_file): + with np.load(processed_file, allow_pickle=False) as npz: + return ( + np.asarray(npz["signal"], dtype=np.float32), + np.asarray(npz["ann_sample"], dtype=np.int64), + np.asarray(npz["ann_symbol"]), + ) + + import wfdb + + try: + record = wfdb.rdrecord(event.signal_file) + ann = wfdb.rdann(event.signal_file, extension=event.annotation_file) + except FileNotFoundError: + return None + + signal = record.p_signal.astype(np.float32) + if default_downsample > 1: + signal = signal[::default_downsample] + ds_samples = np.asarray(ann.sample) // default_downsample + ann_symbols = np.asarray(ann.symbol) + in_bounds = (ds_samples >= 0) & (ds_samples < len(signal)) + return ( + signal, + ds_samples[in_bounds].astype(np.int64), + ann_symbols[in_bounds], + ) + + +def _build_description(event) -> str: + """Compose per-patient description from event demographics.""" + parts: list[str] = [] + if getattr(event, "age", ""): + parts.append(f"age: {event.age}") + if getattr(event, "sex", ""): + parts.append(f"sex: {event.sex}") + if getattr(event, "medications", ""): + parts.append(f"medications: {event.medications}") + return ", ".join(parts) diff --git a/pyhealth/tasks/ecg_wave_segmentation.py b/pyhealth/tasks/ecg_wave_segmentation.py new file mode 100644 index 000000000..0d07dce32 --- /dev/null +++ b/pyhealth/tasks/ecg_wave_segmentation.py @@ -0,0 +1,180 @@ +# Author: Anton Barchukov +# Paper: Chan et al., "MedTsLLM: Leveraging LLMs for Multimodal +# Medical Time Series Analysis", MLHC 2024 +# Paper link: https://arxiv.org/abs/2408.07773 +# Description: Per-timestep ECG wave segmentation task for the LUDB +# dataset. Classifies each sample as background (0), P wave (1), +# QRS complex (2), or T wave (3). + +from typing import Any, Dict + +import numpy as np + +from pyhealth.tasks import BaseTask + + +# wfdb annotation symbol -> class index +_WAVE_LABELS = {"p": 1, "N": 2, "t": 3} + + +class ECGWaveSegmentation(BaseTask): + """Per-timestep ECG wave segmentation on LUDB. + + Classifies each time point in a 12-lead ECG as one of: + background (0), P wave (1), QRS complex (2), or T wave (3). + + The raw signal is windowed into fixed-length chunks. Each chunk + produces a sample with a signal tensor and a per-timestep label + array of the same length. + + Trim/decode are controlled at the dataset level via + :class:`LUDBDataset`'s ``preprocess`` and ``trim`` kwargs; this + task only handles windowing and emission. + + Args: + window_size: Number of time points per window. Default 512. + step_size: Stride between consecutive windows. Default 256. + + Attributes: + task_name (str): ``"ECGWaveSegmentation"`` + input_schema (Dict[str, str]): ``{"signal": "tensor"}`` + output_schema (Dict[str, str]): ``{"label": "tensor"}`` + + Examples: + >>> from pyhealth.datasets import LUDBDataset + >>> dataset = LUDBDataset(root="/path/to/ludb/", preprocess=True) + >>> sample_dataset = dataset.set_task(ECGWaveSegmentation()) + >>> sample_dataset.samples[0].keys() + dict_keys(['patient_id', 'lead', 'signal', 'label', ...]) + """ + + task_name: str = "ECGWaveSegmentation" + input_schema: Dict[str, str] = {"signal": "tensor"} + output_schema: Dict[str, str] = {"label": "tensor"} + + def __init__( + self, + window_size: int = 512, + step_size: int = 256, + ): + self.window_size = window_size + self.step_size = step_size + super().__init__() + + def __call__(self, patient: Any) -> list[dict[str, Any]]: + """Process a single patient into windowed samples. + + Args: + patient: A Patient object from LUDBDataset. Each event + represents one ECG lead with ``signal_file`` (wfdb + record path), ``label_file`` (annotation extension), + and optionally ``processed_file`` (cached ``.npz``). + + Returns: + List of sample dicts, each containing: + - ``patient_id``: str + - ``lead``: str, ECG lead name + - ``signal``: np.ndarray, shape ``(window_size,)`` + - ``label``: np.ndarray of int, shape ``(window_size,)`` + """ + pid = patient.patient_id + events = patient.get_events() + + samples = [] + for event in events: + result = _load_signal_and_labels(event) + if result is None: + continue + signal, labels = result + + description = _build_description(event) + split = getattr(event, "split", "") or "" + lead = event.lead + + # Window into fixed-length chunks + for start in range( + 0, len(signal) - self.window_size + 1, self.step_size + ): + end = start + self.window_size + samples.append({ + "patient_id": pid, + "lead": lead, + "signal": signal[start:end], + "label": labels[start:end], + "description": description, + "split": split, + }) + + return samples + + +def _load_signal_and_labels(event) -> tuple[np.ndarray, np.ndarray] | None: + """Return (signal, labels) for one event, preferring the ``.npz`` cache.""" + processed_file = getattr(event, "processed_file", "") or "" + if processed_file: + import os + + if os.path.exists(processed_file): + with np.load(processed_file, allow_pickle=False) as npz: + return ( + np.asarray(npz["signal"], dtype=np.float32), + np.asarray(npz["labels"], dtype=np.int64), + ) + + # Fallback: decode wfdb on demand. + import wfdb + + try: + record = wfdb.rdrecord(event.signal_file) + except FileNotFoundError: + return None + try: + lead_idx = record.sig_name.index(event.lead) + except ValueError: + return None + signal = record.p_signal[:, lead_idx].astype(np.float32) + + try: + ann = wfdb.rdann(event.signal_file, extension=event.label_file) + except FileNotFoundError: + return None + + labels = np.zeros(len(signal), dtype=np.int64) + i = 0 + while i < len(ann.symbol): + sym = ann.symbol[i] + if sym == "(" and i + 2 < len(ann.symbol): + wave_type = ann.symbol[i + 1] + onset = ann.sample[i] + offset = ( + ann.sample[i + 2] + if ann.symbol[i + 2] == ")" + else ann.sample[i + 1] + ) + if wave_type in _WAVE_LABELS: + labels[onset : offset + 1] = _WAVE_LABELS[wave_type] + i += 3 + else: + i += 1 + + return signal, labels + + +def _build_description(event) -> str: + """Compose a per-patient description string from event attributes. + + Returns a comma-separated ``"age: X, sex: Y, diagnoses: Z"`` + string built from the demographics attached to each LUDB event. + NaN values (from pandas reading empty-string cells) and missing + attributes are skipped. + """ + parts: list[str] = [] + for attr in ("age", "sex", "diagnoses"): + value = getattr(event, attr, "") + if value is None: + continue + text = str(value).strip() + if not text or text.lower() == "nan": + continue + parts.append(f"{attr}: {text}") + return ", ".join(parts) diff --git a/pyhealth/tasks/respiratory_boundary_detection.py b/pyhealth/tasks/respiratory_boundary_detection.py new file mode 100644 index 000000000..0557fdd83 --- /dev/null +++ b/pyhealth/tasks/respiratory_boundary_detection.py @@ -0,0 +1,145 @@ +# Author: Anton Barchukov +# Paper: Chan et al., "MedTsLLM: Leveraging LLMs for Multimodal +# Medical Time Series Analysis", MLHC 2024 +# Description: Breath boundary detection task for the BIDMC dataset. +# Detects breath boundaries in respiratory impedance signals. + +from typing import Any, Dict + +import numpy as np + +from pyhealth.tasks import BaseTask + +# RESP, PLETH, and ECG lead II — the 3 channels used in the paper. +_TARGET_CHANNELS = ["RESP,", "PLETH,", "II,"] + + +class RespiratoryBoundaryDetection(BaseTask): + """Breath boundary detection on BIDMC respiratory signals. + + Binary classification of each time point as a breath boundary + or not. The model is trained on 3 channels (RESP, PLETH, II) + and predicts boundary locations. + + Args: + window_size: Number of time points per window. Default 256. + step_size: Stride between consecutive windows. Default 128. + annotator: Which annotator's labels to use (1 or 2). + Default 1. + + Attributes: + task_name (str): ``"RespiratoryBoundaryDetection"`` + input_schema (Dict[str, str]): ``{"signal": "tensor"}`` + output_schema (Dict[str, str]): ``{"label": "tensor"}`` + + Examples: + >>> from pyhealth.datasets import BIDMCDataset + >>> dataset = BIDMCDataset(root="/path/to/bidmc/") + >>> sample_dataset = dataset.set_task( + ... RespiratoryBoundaryDetection() + ... ) + """ + + task_name: str = "RespiratoryBoundaryDetection" + input_schema: Dict[str, str] = {"signal": "tensor"} + output_schema: Dict[str, str] = {"label": "tensor"} + + def __init__( + self, + window_size: int = 256, + step_size: int = 128, + annotator: int = 1, + ): + self.window_size = window_size + self.step_size = step_size + self.annotator = annotator + super().__init__() + + def __call__(self, patient: Any) -> list[dict[str, Any]]: + """Process a single patient into windowed samples. + + Args: + patient: A Patient object from BIDMCDataset. + + Returns: + List of sample dicts with signal and binary label arrays. + """ + pid = patient.patient_id + events = patient.get_events() + + samples = [] + for event in events: + result = _load_signal_and_annotations(event) + if result is None: + continue + signal, ann_sample, ann_aux = result + if signal.shape[1] != 3: + continue + + labels = np.zeros(len(signal), dtype=np.int64) + ann_tag = f"ann{self.annotator}" + for s, aux in zip(ann_sample, ann_aux): + if str(aux) == ann_tag and 0 <= int(s) < len(signal): + labels[int(s)] = 1 + + # Build patient description + desc_parts = [] + if event.age: + desc_parts.append(f"age: {event.age}") + if event.sex: + desc_parts.append(f"sex: {event.sex}") + if event.location: + desc_parts.append(f"location: {event.location}") + description = ", ".join(desc_parts) + split = getattr(event, "split", "") or "" + + # Window into fixed-length chunks + for start in range( + 0, len(signal) - self.window_size + 1, self.step_size + ): + end = start + self.window_size + samples.append({ + "patient_id": pid, + "signal": signal[start:end], + "label": labels[start:end].astype(np.float32), + "description": description, + "split": split, + }) + + return samples + + +def _load_signal_and_annotations( + event, +) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None: + """Return (signal, ann_sample, ann_aux) from ``.npz`` cache or wfdb.""" + import os + + processed_file = getattr(event, "processed_file", "") or "" + if processed_file and os.path.exists(processed_file): + with np.load(processed_file, allow_pickle=False) as npz: + return ( + np.asarray(npz["signal"], dtype=np.float32), + np.asarray(npz["ann_sample"], dtype=np.int64), + np.asarray(npz["ann_aux"]), + ) + + import wfdb + + try: + record = wfdb.rdrecord(event.signal_file) + ann = wfdb.rdann(event.signal_file, extension=event.annotation_file) + except FileNotFoundError: + return None + + col_idx = [ + record.sig_name.index(ch) + for ch in _TARGET_CHANNELS + if ch in record.sig_name + ] + signal = record.p_signal[:, col_idx].astype(np.float32) + return ( + signal, + np.asarray(ann.sample, dtype=np.int64), + np.asarray(ann.aux_note), + ) diff --git a/pyproject.toml b/pyproject.toml index 98f88d47b..a57d42a16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "pyarrow~=22.0.0", "narwhals~=2.13.0", "more-itertools~=10.8.0", + "wfdb>=4.1.0", "einops>=0.8.0", "linear-attention-transformer>=0.19.1", ] diff --git a/test-resources/core/bidmc/bidmc01.breath b/test-resources/core/bidmc/bidmc01.breath new file mode 100644 index 000000000..5ce0a71c6 Binary files /dev/null and b/test-resources/core/bidmc/bidmc01.breath differ diff --git a/test-resources/core/bidmc/bidmc01.dat b/test-resources/core/bidmc/bidmc01.dat new file mode 100644 index 000000000..6ffff1030 Binary files /dev/null and b/test-resources/core/bidmc/bidmc01.dat differ diff --git a/test-resources/core/bidmc/bidmc01.hea b/test-resources/core/bidmc/bidmc01.hea new file mode 100644 index 000000000..ad3c35859 --- /dev/null +++ b/test-resources/core/bidmc/bidmc01.hea @@ -0,0 +1,8 @@ +bidmc01 5 125 60001 +bidmc01.dat 16 57446.965(18)/pm 16 0 162 53972 0 RESP, +bidmc01.dat 16 88262.76(133)/NU 16 0 42 42240 0 PLETH, +bidmc01.dat 16 68600.47(108)/mV 16 0 -952 6385 0 V, +bidmc01.dat 16 70302.98(-146)/mV 16 0 -566 36056 0 AVR, +bidmc01.dat 16 69816.54(304)/mV 16 0 1866 49819 0 II, +# : 55 +# : M diff --git a/test-resources/core/bidmc/bidmc02.breath b/test-resources/core/bidmc/bidmc02.breath new file mode 100644 index 000000000..5ce0a71c6 Binary files /dev/null and b/test-resources/core/bidmc/bidmc02.breath differ diff --git a/test-resources/core/bidmc/bidmc02.dat b/test-resources/core/bidmc/bidmc02.dat new file mode 100644 index 000000000..004e8b6ba Binary files /dev/null and b/test-resources/core/bidmc/bidmc02.dat differ diff --git a/test-resources/core/bidmc/bidmc02.hea b/test-resources/core/bidmc/bidmc02.hea new file mode 100644 index 000000000..80fc961aa --- /dev/null +++ b/test-resources/core/bidmc/bidmc02.hea @@ -0,0 +1,8 @@ +bidmc02 5 125 60001 +bidmc02.dat 16 57636.965(126)/pm 16 0 297 37315 0 RESP, +bidmc02.dat 16 89899.484(-105)/NU 16 0 -1492 16716 0 PLETH, +bidmc02.dat 16 68559.59(920)/mV 16 0 -187 58542 0 V, +bidmc02.dat 16 69553.68(-494)/mV 16 0 199 253 0 AVR, +bidmc02.dat 16 68740.055(84)/mV 16 0 99 24415 0 II, +# : 58 +# : F diff --git a/test-resources/core/ludb/data/1.avf b/test-resources/core/ludb/data/1.avf new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/1.avf differ diff --git a/test-resources/core/ludb/data/1.avl b/test-resources/core/ludb/data/1.avl new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/1.avl differ diff --git a/test-resources/core/ludb/data/1.avr b/test-resources/core/ludb/data/1.avr new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/1.avr differ diff --git a/test-resources/core/ludb/data/1.dat b/test-resources/core/ludb/data/1.dat new file mode 100644 index 000000000..958126295 Binary files /dev/null and b/test-resources/core/ludb/data/1.dat differ diff --git a/test-resources/core/ludb/data/1.hea b/test-resources/core/ludb/data/1.hea new file mode 100644 index 000000000..dff2c9da4 --- /dev/null +++ b/test-resources/core/ludb/data/1.hea @@ -0,0 +1,19 @@ +1 12 500 5000 +1.dat 16 62239.465(-23190)/mV 16 0 -23033 33766 0 i +1.dat 16 55949.375(-23524)/mV 16 0 -23725 5480 0 ii +1.dat 16 51311.14(-24306)/mV 16 0 -23804 45137 0 iii +1.dat 16 47775.176(-25170)/mV 16 0 -25008 18423 0 avr +1.dat 16 44748.46(-25615)/mV 16 0 -25325 28490 0 avl +1.dat 16 41369.31(-26581)/mV 16 0 -26117 20648 0 avf +1.dat 16 39463.5(-26298)/mV 16 0 -27171 54031 0 v1 +1.dat 16 37131.64(-27143)/mV 16 0 -26799 25262 0 v2 +1.dat 16 35097.883(-27041)/mV 16 0 -26918 37315 0 v3 +1.dat 16 33396.85(-27664)/mV 16 0 -27398 3336 0 v4 +1.dat 16 31306.604(-27419)/mV 16 0 -27598 13770 0 v5 +1.dat 16 30131.424(-27905)/mV 16 0 -27941 56298 0 v6 +# : 40 +# : M +# : +# Rhythm: Sinus rhythm. +# Left ventricular hypertrophy. +# Non-specific repolarization abnormalities. diff --git a/test-resources/core/ludb/data/1.i b/test-resources/core/ludb/data/1.i new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/1.i differ diff --git a/test-resources/core/ludb/data/1.ii b/test-resources/core/ludb/data/1.ii new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/1.ii differ diff --git a/test-resources/core/ludb/data/1.iii b/test-resources/core/ludb/data/1.iii new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/1.iii differ diff --git a/test-resources/core/ludb/data/1.v1 b/test-resources/core/ludb/data/1.v1 new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/1.v1 differ diff --git a/test-resources/core/ludb/data/1.v2 b/test-resources/core/ludb/data/1.v2 new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/1.v2 differ diff --git a/test-resources/core/ludb/data/1.v3 b/test-resources/core/ludb/data/1.v3 new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/1.v3 differ diff --git a/test-resources/core/ludb/data/1.v4 b/test-resources/core/ludb/data/1.v4 new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/1.v4 differ diff --git a/test-resources/core/ludb/data/1.v5 b/test-resources/core/ludb/data/1.v5 new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/1.v5 differ diff --git a/test-resources/core/ludb/data/1.v6 b/test-resources/core/ludb/data/1.v6 new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/1.v6 differ diff --git a/test-resources/core/ludb/data/2.avf b/test-resources/core/ludb/data/2.avf new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/2.avf differ diff --git a/test-resources/core/ludb/data/2.avl b/test-resources/core/ludb/data/2.avl new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/2.avl differ diff --git a/test-resources/core/ludb/data/2.avr b/test-resources/core/ludb/data/2.avr new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/2.avr differ diff --git a/test-resources/core/ludb/data/2.dat b/test-resources/core/ludb/data/2.dat new file mode 100644 index 000000000..a7251e472 Binary files /dev/null and b/test-resources/core/ludb/data/2.dat differ diff --git a/test-resources/core/ludb/data/2.hea b/test-resources/core/ludb/data/2.hea new file mode 100644 index 000000000..cfa3a04c5 --- /dev/null +++ b/test-resources/core/ludb/data/2.hea @@ -0,0 +1,19 @@ +2 12 500 5000 +2.dat 16 61262.926(-23165)/mV 16 0 -24100 9855 0 i +2.dat 16 56641.953(-23737)/mV 16 0 -24570 45290 0 ii +2.dat 16 51504.383(-24241)/mV 16 0 -25406 50166 0 iii +2.dat 16 48346.41(-25206)/mV 16 0 -24033 57309 0 avr +2.dat 16 44399.043(-25462)/mV 16 0 -25828 53239 0 avl +2.dat 16 41577.14(-25641)/mV 16 0 -25734 21095 0 avf +2.dat 16 38740.266(-26479)/mV 16 0 -27922 18472 0 v1 +2.dat 16 37059.348(-26572)/mV 16 0 -26685 21959 0 v2 +2.dat 16 35257.223(-27349)/mV 16 0 -26520 61948 0 v3 +2.dat 16 33735.87(-27532)/mV 16 0 -28077 48552 0 v4 +2.dat 16 31750.809(-27608)/mV 16 0 -28287 56113 0 v5 +2.dat 16 30051.408(-28146)/mV 16 0 -27743 12377 0 v6 +# : 45 +# : F +# : +# Rhythm: Sinus rhythm. +# Left ventricular hypertrophy. +# Non-specific repolarization abnormalities. diff --git a/test-resources/core/ludb/data/2.i b/test-resources/core/ludb/data/2.i new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/2.i differ diff --git a/test-resources/core/ludb/data/2.ii b/test-resources/core/ludb/data/2.ii new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/2.ii differ diff --git a/test-resources/core/ludb/data/2.iii b/test-resources/core/ludb/data/2.iii new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/2.iii differ diff --git a/test-resources/core/ludb/data/2.v1 b/test-resources/core/ludb/data/2.v1 new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/2.v1 differ diff --git a/test-resources/core/ludb/data/2.v2 b/test-resources/core/ludb/data/2.v2 new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/2.v2 differ diff --git a/test-resources/core/ludb/data/2.v3 b/test-resources/core/ludb/data/2.v3 new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/2.v3 differ diff --git a/test-resources/core/ludb/data/2.v4 b/test-resources/core/ludb/data/2.v4 new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/2.v4 differ diff --git a/test-resources/core/ludb/data/2.v5 b/test-resources/core/ludb/data/2.v5 new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/2.v5 differ diff --git a/test-resources/core/ludb/data/2.v6 b/test-resources/core/ludb/data/2.v6 new file mode 100644 index 000000000..9c3a63c8c Binary files /dev/null and b/test-resources/core/ludb/data/2.v6 differ diff --git a/test-resources/core/mitbih/100.atr b/test-resources/core/mitbih/100.atr new file mode 100644 index 000000000..2a63ba95e Binary files /dev/null and b/test-resources/core/mitbih/100.atr differ diff --git a/test-resources/core/mitbih/100.dat b/test-resources/core/mitbih/100.dat new file mode 100644 index 000000000..929240569 Binary files /dev/null and b/test-resources/core/mitbih/100.dat differ diff --git a/test-resources/core/mitbih/100.hea b/test-resources/core/mitbih/100.hea new file mode 100644 index 000000000..406294f87 --- /dev/null +++ b/test-resources/core/mitbih/100.hea @@ -0,0 +1,4 @@ +100 2 360 65000 +100.dat 16 47871.453(-386)/mV 16 0 -266 25690 0 MLII +100.dat 16 57050.91(-246)/mV 16 0 8033 32572 0 V5 +# # 60 M diff --git a/test-resources/core/mitbih/101.atr b/test-resources/core/mitbih/101.atr new file mode 100644 index 000000000..2a63ba95e Binary files /dev/null and b/test-resources/core/mitbih/101.atr differ diff --git a/test-resources/core/mitbih/101.dat b/test-resources/core/mitbih/101.dat new file mode 100644 index 000000000..0cc0b8934 Binary files /dev/null and b/test-resources/core/mitbih/101.dat differ diff --git a/test-resources/core/mitbih/101.hea b/test-resources/core/mitbih/101.hea new file mode 100644 index 000000000..dc06edc52 --- /dev/null +++ b/test-resources/core/mitbih/101.hea @@ -0,0 +1,4 @@ +101 2 360 65000 +101.dat 16 48609.28(109)/mV 16 0 350 28443 0 MLII +101.dat 16 57637.297(92)/mV 16 0 9894 54587 0 V5 +# # 63 F diff --git a/tests/core/_synthetic_wfdb.py b/tests/core/_synthetic_wfdb.py new file mode 100644 index 000000000..d078d5d2f --- /dev/null +++ b/tests/core/_synthetic_wfdb.py @@ -0,0 +1,319 @@ +"""Synthetic wfdb record generators for ECG/respiratory dataset tests. + +Generates minimal in-repo wfdb records for LUDB, MIT-BIH, and BIDMC so +tests never ship real patient data. Each generator writes the same +files the dataset loader expects (``.dat``, ``.hea``, per-lead/``.atr`` +/``.breath`` annotations), with seeded RNG for reproducibility. + +All signals are sine/noise mixes — not realistic ECGs, but valid wfdb +records that satisfy the dataset parsers and task annotations. Keep +these generators in sync with the real record schemas when the +dataset parsers change. +""" + +from __future__ import annotations + +import os +from typing import Iterable + +import numpy as np + + +# --------------------------------------------------------------------- # +# LUDB: 12-lead ECG at 500 Hz, 10 s per record, wave annotations per lead +# --------------------------------------------------------------------- # + +_LUDB_LEADS = ("i", "ii", "iii", "avr", "avl", "avf", + "v1", "v2", "v3", "v4", "v5", "v6") +_LUDB_FS = 500 +_LUDB_LEN = 5000 # 10 s +_LUDB_DIAGNOSES = ( + "Rhythm: Sinus rhythm.", + "Left ventricular hypertrophy.", + "Non-specific repolarization abnormalities.", +) + + +def synthesize_ludb( + dest_root: str, + n_records: int = 2, + seed: int = 0, +) -> None: + """Write ``n_records`` synthetic LUDB records into ``{dest_root}/data/``. + + Each record has: + * 12-lead signal with a weak sinusoidal "ECG" shape + noise + * Header comments with ````, ````, ```` + * Per-lead wave annotation file (``.i``, ``.ii``, ...) with evenly + spaced P / N (QRS) / T wave triplets via the LUDB ``( sym )`` + encoding. + """ + import wfdb + + data_dir = os.path.join(dest_root, "data") + os.makedirs(data_dir, exist_ok=True) + rng = np.random.default_rng(seed) + + for rec_idx in range(n_records): + rec_name = str(rec_idx + 1) + signal = _synthesize_12lead_ecg(rng) + comments = [ + f": {40 + 5 * rec_idx}", + f": {'M' if rec_idx % 2 == 0 else 'F'}", + ":", + *_LUDB_DIAGNOSES, + ] + wfdb.wrsamp( + record_name=rec_name, + fs=_LUDB_FS, + units=["mV"] * len(_LUDB_LEADS), + sig_name=list(_LUDB_LEADS), + p_signal=signal, + fmt=["16"] * len(_LUDB_LEADS), + write_dir=data_dir, + comments=comments, + ) + _write_ludb_annotations(data_dir, rec_name, rng) + + +def _synthesize_12lead_ecg(rng: np.random.Generator) -> np.ndarray: + """Return (samples, 12) float array with ECG-ish waveforms.""" + t = np.arange(_LUDB_LEN) / _LUDB_FS + # Base rhythm at ~1 Hz with QRS-like spikes + base = 0.1 * np.sin(2 * np.pi * 1.0 * t) + spikes = np.zeros_like(t) + for beat_t in np.arange(0.4, 10.0, 0.8): + idx = int(beat_t * _LUDB_FS) + if idx < _LUDB_LEN: + spikes[idx] = 1.0 + # Small gaussian bumps around each spike for QRS shape + kernel = np.exp(-((np.arange(-20, 21)) ** 2) / 40.0) + qrs = np.convolve(spikes, kernel, mode="same") + signal = np.stack([ + base + qrs * (0.8 + 0.1 * i) + rng.normal(0, 0.02, _LUDB_LEN) + for i in range(len(_LUDB_LEADS)) + ], axis=1).astype(np.float32) + return signal + + +def _write_ludb_annotations( + data_dir: str, rec_name: str, rng: np.random.Generator, +) -> None: + """Write one annotation file per lead with evenly spaced P/N/T waves. + + LUDB encodes each wave as a triplet of symbols ``(``, wave-type, + ``)`` at onset, peak, offset sample indices, where wave-type is + ``p`` (P wave), ``N`` (QRS complex), or ``t`` (T wave). + """ + import wfdb + + # Produce ~10 beats across the 10 s clip, each with P/N/T triplet + beat_centers = np.linspace(300, _LUDB_LEN - 300, 10, dtype=int) + samples = [] + symbols = [] + for center in beat_centers: + # P wave triplet (onset, symbol, offset) relative to beat center + for offset_from_center, sym in ( + (-120, "p"), (-90, "p"), (-60, "p"), # P wave + (-20, "N"), (0, "N"), (20, "N"), # QRS complex + (80, "t"), (120, "t"), (160, "t"), # T wave + ): + s = int(center + offset_from_center) + if 0 <= s < _LUDB_LEN: + # LUDB schema uses ( sym ) at the three sample positions + # but wfdb wrann writes one symbol per sample. The parser + # looks for '(' at onset, wave-type in middle, ')' at + # offset. We emit them in triplets here. + samples.append(s) + # Build symbols matching the loop above + # Recompute symbols aligned to samples + samples = [] + symbols = [] + for center in beat_centers: + for delta, sym in ( + (-120, "("), (-90, "p"), (-60, ")"), + (-20, "("), (0, "N"), (20, ")"), + (80, "("), (120, "t"), (160, ")"), + ): + s = int(center + delta) + if 0 <= s < _LUDB_LEN: + samples.append(s) + symbols.append(sym) + + samples_arr = np.array(samples, dtype=np.int64) + # wrann rejects extensions with digits (e.g., "v1", "v2"), but + # LUDB's lead-name extensions include six of those. Write each + # annotation with a letters-only placeholder extension, then + # rename to the real lead extension afterwards. + placeholder_stem = "xyzlead" + for i, lead in enumerate(_LUDB_LEADS): + placeholder = f"{placeholder_stem}{chr(ord('a') + i)}" + wfdb.wrann( + record_name=rec_name, + extension=placeholder, + sample=samples_arr, + symbol=symbols, + write_dir=data_dir, + ) + src = os.path.join(data_dir, f"{rec_name}.{placeholder}") + dst = os.path.join(data_dir, f"{rec_name}.{lead}") + os.replace(src, dst) + + +# --------------------------------------------------------------------- # +# BIDMC: respiratory + ECG, 125 Hz, ~8 min, breath annotations +# --------------------------------------------------------------------- # + +_BIDMC_FS = 125 +_BIDMC_LEN = 60_001 # matches the real format's length +# BIDMC's real headers carry a trailing comma in each signal name +# (e.g., ``RESP,``). The dataset parser matches on the comma-suffixed +# form, so preserve it here. +_BIDMC_SIGS = ("RESP,", "PLETH,", "V,", "AVR,", "II,") + + +def synthesize_bidmc( + dest_root: str, + n_records: int = 2, + seed: int = 0, +) -> None: + """Write ``n_records`` synthetic BIDMC records into ``dest_root/``. + + Each record includes a RESP signal + 2 ECG leads + PLETH + AVR, + plus a ``.breath`` annotation file with periodic breath markers. + """ + import wfdb + + os.makedirs(dest_root, exist_ok=True) + rng = np.random.default_rng(seed) + + for rec_idx in range(n_records): + rec_name = f"bidmc{rec_idx + 1:02d}" + t = np.arange(_BIDMC_LEN) / _BIDMC_FS + resp = 0.5 * np.sin(2 * np.pi * 0.25 * t) # 15 breaths/min + pleth = 0.3 * np.sin(2 * np.pi * 1.2 * t) + ecg = 0.4 * np.sin(2 * np.pi * 1.2 * t) + noise = lambda: rng.normal(0, 0.02, _BIDMC_LEN) + sigs = np.stack([resp + noise(), pleth + noise(), ecg + noise(), + -ecg + noise(), ecg + noise()], axis=1).astype(np.float32) + # Age + sex metadata in the first comment line — BIDMC parser + # reads demographics from the header just like LUDB. + comments = [ + f": {55 + 3 * rec_idx}", + f": {'M' if rec_idx % 2 == 0 else 'F'}", + ] + wfdb.wrsamp( + record_name=rec_name, + fs=_BIDMC_FS, + units=["pm", "NU", "mV", "mV", "mV"], + sig_name=list(_BIDMC_SIGS), + p_signal=sigs, + fmt=["16"] * len(_BIDMC_SIGS), + write_dir=dest_root, + comments=comments, + ) + # Breath annotations: one mark per breath (0.25 Hz => 4 s apart) + breath_samples = np.arange(2 * _BIDMC_FS, _BIDMC_LEN, 4 * _BIDMC_FS, dtype=np.int64) + wfdb.wrann( + record_name=rec_name, + extension="breath", + sample=breath_samples, + symbol=["+"] * len(breath_samples), + write_dir=dest_root, + ) + + +# --------------------------------------------------------------------- # +# MIT-BIH: 2-lead ECG at 360 Hz, beat annotations +# --------------------------------------------------------------------- # + +_MITBIH_FS = 360 +_MITBIH_LEN = 65_000 # short clip ≈ 3 min +_MITBIH_SIGS = ("MLII", "V5") + + +def synthesize_mitbih( + dest_root: str, + record_names: Iterable[str] = ("100", "101"), + seed: int = 0, +) -> None: + """Write synthetic MIT-BIH records with beat annotations. + + Each record has 2 ECG leads + ``.atr`` annotations containing a + mix of normal (``N``) and abnormal (``V``, ``A``) beat symbols so + ECGAnomalyDetection / ECGBoundaryDetection tasks see both classes. + """ + import wfdb + + os.makedirs(dest_root, exist_ok=True) + rng = np.random.default_rng(seed) + + for rec_idx, rec_name in enumerate(record_names): + t = np.arange(_MITBIH_LEN) / _MITBIH_FS + mlii = 0.6 * np.sin(2 * np.pi * 1.2 * t) + v5 = 0.5 * np.sin(2 * np.pi * 1.2 * t + 0.3) + noise = rng.normal(0, 0.02, (_MITBIH_LEN, 2)).astype(np.float32) + sigs = np.stack([mlii, v5], axis=1).astype(np.float32) + noise + # Age/sex encoded the way MIT-BIH headers do it: "# 69 M ..." + comments = [f"# {60 + 3 * rec_idx} {'M' if rec_idx % 2 == 0 else 'F'}"] + wfdb.wrsamp( + record_name=rec_name, + fs=_MITBIH_FS, + units=["mV", "mV"], + sig_name=list(_MITBIH_SIGS), + p_signal=sigs, + fmt=["16", "16"], + write_dir=dest_root, + comments=comments, + ) + # Beat markers every ~1 s — mix N and V so tests see both classes + beat_samples = np.arange(_MITBIH_FS, _MITBIH_LEN, _MITBIH_FS, dtype=np.int64) + symbols = ["N" if i % 3 else "V" for i in range(len(beat_samples))] + wfdb.wrann( + record_name=rec_name, + extension="atr", + sample=beat_samples, + symbol=symbols, + write_dir=dest_root, + ) + + +# --------------------------------------------------------------------- # +# Entry point — regenerate the committed test-resources fixtures +# --------------------------------------------------------------------- # +# +# Run with ``python -m tests.core._synthetic_wfdb`` from the repo root. +# Regenerates the wfdb records under ``test-resources/core/{ludb, +# bidmc,mitbih}/`` from a fixed seed. Checked-in outputs are fully +# synthetic — never ship real patient data in tests. + +_REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) + + +def _regenerate_all(repo_root: str = _REPO_ROOT) -> None: + """Rewrite all three dataset fixtures under ``test-resources/core/``.""" + import shutil + + def _wipe_and_recreate(path: str) -> None: + if os.path.isdir(path): + shutil.rmtree(path) + os.makedirs(path, exist_ok=True) + + ludb_root = os.path.join(repo_root, "test-resources", "core", "ludb") + _wipe_and_recreate(os.path.join(ludb_root, "data")) + synthesize_ludb(ludb_root, n_records=2) + + bidmc_root = os.path.join(repo_root, "test-resources", "core", "bidmc") + _wipe_and_recreate(bidmc_root) + synthesize_bidmc(bidmc_root, n_records=2) + + mitbih_root = os.path.join(repo_root, "test-resources", "core", "mitbih") + _wipe_and_recreate(mitbih_root) + synthesize_mitbih(mitbih_root, record_names=["100", "101"]) + + +if __name__ == "__main__": + _regenerate_all() + print(f"Regenerated synthetic test fixtures under {_REPO_ROOT}/test-resources/core/") diff --git a/tests/core/test_bidmc.py b/tests/core/test_bidmc.py new file mode 100644 index 000000000..e834be825 --- /dev/null +++ b/tests/core/test_bidmc.py @@ -0,0 +1,412 @@ +"""Unit tests for BIDMC dataset and RespiratoryBoundaryDetection task. + +Author: Anton Barchukov + +Tests cover: + - BIDMCDataset instantiation from wfdb files + - Metadata preparation + - Patient header parsing (age, sex, location) + - RespiratoryBoundaryDetection task (windowing, annotator, labels) + - Full pipeline: dataset -> task -> samples + - MedTsLLM model with the BIDMC data structure + +Test data under ``test-resources/core/bidmc/`` is fully synthetic — +records generated by ``tests/core/_synthetic_wfdb.synthesize_bidmc`` +with placeholder respiratory/ECG signals. No real patient records +are committed to the repo. +""" + +import os +import unittest + +import numpy as np +import torch + +TEST_DATA_DIR = os.path.join( + os.path.dirname(__file__), "..", "..", "test-resources", "core", "bidmc" +) +HAS_TEST_DATA = os.path.isdir(TEST_DATA_DIR) and any( + f.endswith(".dat") for f in os.listdir(TEST_DATA_DIR) +) + + +def setUpModule(): + """Drop the committed CSV + cache so tests use the current schema.""" + if HAS_TEST_DATA: + _force_regenerate_bidmc_metadata() + + +@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found") +class TestBIDMCDataset(unittest.TestCase): + """Tests for BIDMCDataset with real wfdb files.""" + + @classmethod + def setUpClass(cls): + from pyhealth.datasets.bidmc import BIDMCDataset + + cls.dataset = BIDMCDataset(root=TEST_DATA_DIR, dev=True) + + def test_dataset_loads(self): + """Dataset initializes without error.""" + self.assertIsNotNone(self.dataset) + + def test_metadata_csv_created(self): + """prepare_metadata creates bidmc-pyhealth.csv.""" + csv_path = os.path.join(TEST_DATA_DIR, "bidmc-pyhealth.csv") + self.assertTrue(os.path.exists(csv_path)) + + def test_has_patients(self): + """Dataset has at least 1 patient.""" + pids = self.dataset.unique_patient_ids + self.assertGreater(len(pids), 0) + + def test_patient_has_events(self): + """Each patient has respiratory events.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + events = patient.get_events() + self.assertGreater(len(events), 0) + + def test_event_has_demographics(self): + """Events contain age and sex attributes.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + event = patient.get_events()[0] + self.assertTrue(hasattr(event, "age")) + self.assertTrue(hasattr(event, "sex")) + self.assertTrue(hasattr(event, "location")) + + +@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found") +class TestBIDMCHeaderParsing(unittest.TestCase): + """Tests for BIDMC patient metadata parsing.""" + + def test_parse_header(self): + """Header parsing extracts age, sex, location.""" + import wfdb + + rec = wfdb.rdrecord(os.path.join(TEST_DATA_DIR, "bidmc01")) + from pyhealth.datasets.bidmc import BIDMCDataset + + # Just verify the record has comments + self.assertIsNotNone(rec.comments) + self.assertGreater(len(rec.comments), 0) + + +@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found") +class TestRespiratoryBoundaryDetectionTask(unittest.TestCase): + """Tests for RespiratoryBoundaryDetection with real data.""" + + @classmethod + def setUpClass(cls): + from pyhealth.datasets.bidmc import BIDMCDataset + from pyhealth.tasks.respiratory_boundary_detection import ( + RespiratoryBoundaryDetection, + ) + + cls.dataset = BIDMCDataset(root=TEST_DATA_DIR, dev=True) + cls.task = RespiratoryBoundaryDetection( + window_size=256, step_size=128 + ) + + def test_task_attributes(self): + """Task has correct name and schema.""" + self.assertEqual( + self.task.task_name, "RespiratoryBoundaryDetection" + ) + self.assertEqual(self.task.input_schema, {"signal": "tensor"}) + self.assertEqual(self.task.output_schema, {"label": "tensor"}) + + def test_task_produces_samples(self): + """Task generates samples from a patient.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + samples = self.task(patient) + self.assertGreater(len(samples), 0) + + def test_signal_is_3_channel(self): + """Signal has 3 channels (RESP, PLETH, II).""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + sample = self.task(patient)[0] + self.assertEqual(sample["signal"].shape, (256, 3)) + + def test_label_is_binary(self): + """Labels are 0 or 1 (boundary or not).""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + sample = self.task(patient)[0] + unique = set(sample["label"].tolist()) + self.assertTrue(unique.issubset({0.0, 1.0})) + + def test_sample_has_description(self): + """Samples include patient description.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + sample = self.task(patient)[0] + self.assertIn("description", sample) + self.assertIn("age:", sample["description"]) + + def test_set_task_produces_sample_dataset(self): + """dataset.set_task() returns a usable SampleDataset.""" + sample_ds = self.dataset.set_task(self.task) + self.assertGreater(len(sample_ds), 0) + + +@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found") +class TestMedTsLLMWithBIDMC(unittest.TestCase): + """Test MedTsLLM model with real BIDMC data structure.""" + + @classmethod + def setUpClass(cls): + from pyhealth.datasets.bidmc import BIDMCDataset + from pyhealth.tasks.respiratory_boundary_detection import ( + RespiratoryBoundaryDetection, + ) + from pyhealth.datasets import get_dataloader + from pyhealth.models.medtsllm import MedTsLLM + + dataset = BIDMCDataset(root=TEST_DATA_DIR, dev=True) + task = RespiratoryBoundaryDetection( + window_size=128, step_size=64 + ) + cls.sample_ds = dataset.set_task(task) + + cls.model = MedTsLLM( + dataset=cls.sample_ds, + seq_len=128, + n_features=3, + n_classes=2, + backbone=None, + word_embeddings=torch.randn(50, 32), + d_model=16, + d_ff=32, + n_heads=4, + num_tokens=50, + covariate_mode="concat", + ) + loader = get_dataloader(cls.sample_ds, batch_size=2, shuffle=False) + cls.batch = next(iter(loader)) + + def test_forward(self): + """Model forward pass works with real BIDMC samples.""" + out = self.model(**self.batch) + self.assertIn("logit", out) + self.assertIn("loss", out) + self.assertTrue(out["loss"].isfinite()) + + def test_backward(self): + """Backward pass works with BIDMC data.""" + out = self.model(**self.batch) + out["loss"].backward() + + +# ------------------------------------------------------------------ # +# Phase 5: paper-match split column (BIDMC 85/15 seed=0) +# ------------------------------------------------------------------ # + + +def _force_regenerate_bidmc_metadata(): + """Reset BIDMC test state: delete CSV and PyHealth cache.""" + import shutil + + import platformdirs + + csv_path = os.path.join(TEST_DATA_DIR, "bidmc-pyhealth.csv") + if os.path.exists(csv_path): + os.remove(csv_path) + + cache_root = platformdirs.user_cache_dir(appname="pyhealth") + if os.path.isdir(cache_root): + shutil.rmtree(cache_root) + + +@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found") +class TestBIDMCPaperSplit(unittest.TestCase): + """paper_split=True writes an 85/15 seed=0 split column.""" + + @classmethod + def setUpClass(cls): + from pyhealth.datasets.bidmc import BIDMCDataset + + _force_regenerate_bidmc_metadata() + cls.dataset = BIDMCDataset( + root=TEST_DATA_DIR, dev=True, paper_split=True + ) + + def test_constructor_accepts_paper_split(self): + """BIDMCDataset accepts a paper_split kwarg.""" + import inspect + + from pyhealth.datasets.bidmc import BIDMCDataset + + sig = inspect.signature(BIDMCDataset.__init__) + self.assertIn("paper_split", sig.parameters) + + def test_csv_has_split_column(self): + """Regenerated CSV contains a ``split`` column.""" + import pandas as pd + + csv_path = os.path.join(TEST_DATA_DIR, "bidmc-pyhealth.csv") + df = pd.read_csv(csv_path) + self.assertIn("split", df.columns) + + def test_split_values_are_train_or_test(self): + """Every split value is either 'train' or 'test'.""" + import pandas as pd + + csv_path = os.path.join(TEST_DATA_DIR, "bidmc-pyhealth.csv") + df = pd.read_csv(csv_path) + unique = set(df["split"].unique()) + self.assertTrue(unique.issubset({"train", "test"})) + + def test_event_has_split_attr(self): + """Events expose a ``split`` attribute.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + event = patient.get_events()[0] + self.assertTrue(hasattr(event, "split")) + self.assertIn(event.split, ("train", "test")) + + def test_task_emits_split_field(self): + """RespiratoryBoundaryDetection propagates event.split to samples.""" + from pyhealth.tasks.respiratory_boundary_detection import ( + RespiratoryBoundaryDetection, + ) + + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + task = RespiratoryBoundaryDetection( + window_size=256, step_size=128 + ) + sample = task(patient)[0] + self.assertIn("split", sample) + self.assertIn(sample["split"], ("train", "test")) + + +@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found") +class TestBIDMCPaperSplitDeterministic(unittest.TestCase): + """Seed=0 BIDMC split is reproducible across runs.""" + + def test_split_is_deterministic(self): + """Regenerating twice yields the same patient -> split mapping.""" + import pandas as pd + from pyhealth.datasets.bidmc import BIDMCDataset + + _force_regenerate_bidmc_metadata() + BIDMCDataset(root=TEST_DATA_DIR, dev=True, paper_split=True) + df1 = pd.read_csv(os.path.join(TEST_DATA_DIR, "bidmc-pyhealth.csv")) + mapping1 = df1.set_index("patient_id")["split"].to_dict() + + _force_regenerate_bidmc_metadata() + BIDMCDataset(root=TEST_DATA_DIR, dev=True, paper_split=True) + df2 = pd.read_csv(os.path.join(TEST_DATA_DIR, "bidmc-pyhealth.csv")) + mapping2 = df2.set_index("patient_id")["split"].to_dict() + + self.assertEqual(mapping1, mapping2) + + +@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found") +class TestBIDMCPaperSplitDisabled(unittest.TestCase): + """paper_split=False leaves the split column blank.""" + + def test_split_column_blank_when_disabled(self): + """Without paper_split, split column is empty.""" + import pandas as pd + from pyhealth.datasets.bidmc import BIDMCDataset + + _force_regenerate_bidmc_metadata() + BIDMCDataset(root=TEST_DATA_DIR, dev=True, paper_split=False) + df = pd.read_csv(os.path.join(TEST_DATA_DIR, "bidmc-pyhealth.csv")) + self.assertIn("split", df.columns) + cleaned = df["split"].fillna("").astype(str).str.strip() + self.assertTrue((cleaned == "").all()) + + +# ------------------------------------------------------------------ # +# Phase 6: preprocess=True -> .npz cache for BIDMC +# ------------------------------------------------------------------ # + + +@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found") +class TestBIDMCPreprocessCache(unittest.TestCase): + """preprocess=True writes per-record .npz with signal + annotations.""" + + @classmethod + def setUpClass(cls): + import shutil + + from pyhealth.datasets.bidmc import BIDMCDataset + + processed_dir = os.path.join(TEST_DATA_DIR, "processed") + if os.path.isdir(processed_dir): + shutil.rmtree(processed_dir) + _force_regenerate_bidmc_metadata() + + cls.dataset = BIDMCDataset( + root=TEST_DATA_DIR, dev=True, preprocess=True + ) + + def test_processed_dir_created(self): + self.assertTrue( + os.path.isdir(os.path.join(TEST_DATA_DIR, "processed")) + ) + + def test_event_has_processed_file_attr(self): + pid = self.dataset.unique_patient_ids[0] + event = self.dataset.get_patient(pid).get_events()[0] + self.assertTrue(hasattr(event, "processed_file")) + self.assertTrue(event.processed_file.endswith(".npz")) + self.assertTrue(os.path.exists(event.processed_file)) + + def test_cached_arrays_have_expected_keys(self): + pid = self.dataset.unique_patient_ids[0] + event = self.dataset.get_patient(pid).get_events()[0] + with np.load(event.processed_file, allow_pickle=False) as npz: + self.assertIn("signal", npz.files) + self.assertIn("ann_sample", npz.files) + self.assertIn("ann_aux", npz.files) + + def test_cached_signal_is_3_channel(self): + pid = self.dataset.unique_patient_ids[0] + event = self.dataset.get_patient(pid).get_events()[0] + with np.load(event.processed_file, allow_pickle=False) as npz: + self.assertEqual(int(npz["signal"].shape[1]), 3) + + def test_task_uses_cache(self): + from pyhealth.tasks.respiratory_boundary_detection import ( + RespiratoryBoundaryDetection, + ) + + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + samples = RespiratoryBoundaryDetection( + window_size=256, step_size=128 + )(patient) + self.assertGreater(len(samples), 0) + self.assertEqual(samples[0]["signal"].shape, (256, 3)) + + +@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found") +class TestBIDMCPreprocessDisabled(unittest.TestCase): + """preprocess=False leaves processed_file blank.""" + + def test_processed_file_empty_when_disabled(self): + import shutil + + from pyhealth.datasets.bidmc import BIDMCDataset + + processed_dir = os.path.join(TEST_DATA_DIR, "processed") + if os.path.isdir(processed_dir): + shutil.rmtree(processed_dir) + _force_regenerate_bidmc_metadata() + + ds = BIDMCDataset(root=TEST_DATA_DIR, dev=True, preprocess=False) + pid = ds.unique_patient_ids[0] + event = ds.get_patient(pid).get_events()[0] + value = getattr(event, "processed_file", "") or "" + self.assertEqual(str(value).strip(), "") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_ludb.py b/tests/core/test_ludb.py new file mode 100644 index 000000000..ce30ecf2a --- /dev/null +++ b/tests/core/test_ludb.py @@ -0,0 +1,628 @@ +"""Unit tests for LUDB dataset and ECGWaveSegmentation task. + +Author: Anton Barchukov + +Tests cover: + - LUDBDataset instantiation from wfdb files + - Metadata preparation (prepare_metadata) + - Patient header parsing (age, sex, diagnoses) + - ECGWaveSegmentation task (windowing, trimming, labels) + - Full pipeline: dataset -> task -> samples + - MedTsLLM model with the LUDB data structure + +Test data under ``test-resources/core/ludb/data/`` is fully synthetic +— records generated by ``tests/core/_synthetic_wfdb.synthesize_ludb`` +with random ECG-shaped signals and placeholder demographics. No real +patient records are committed to the repo. +""" + +import os +import unittest + +import numpy as np +import torch + +TEST_DATA_DIR = os.path.join( + os.path.dirname(__file__), "..", "..", "test-resources", "core", "ludb" +) +HAS_TEST_DATA = os.path.isdir(os.path.join(TEST_DATA_DIR, "data")) + + +def setUpModule(): + """Drop the committed CSV + cache so tests use the current schema.""" + if HAS_TEST_DATA: + _force_regenerate_metadata() + + +@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found") +class TestLUDBDataset(unittest.TestCase): + """Tests for LUDBDataset with real wfdb files.""" + + @classmethod + def setUpClass(cls): + from pyhealth.datasets.ludb import LUDBDataset + + cls.dataset = LUDBDataset(root=TEST_DATA_DIR, dev=True) + + def test_dataset_loads(self): + """Dataset initializes without error.""" + self.assertIsNotNone(self.dataset) + + def test_metadata_csv_created(self): + """prepare_metadata creates ludb-pyhealth.csv.""" + csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv") + self.assertTrue(os.path.exists(csv_path)) + + def test_has_patients(self): + """Dataset has at least 1 patient.""" + pids = self.dataset.unique_patient_ids + self.assertGreater(len(pids), 0) + + def test_patient_has_events(self): + """Each patient has ECG events.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + events = patient.get_events() + self.assertGreater(len(events), 0) + + def test_event_has_signal_file(self): + """Events contain signal_file attribute.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + event = patient.get_events()[0] + self.assertTrue(hasattr(event, "signal_file")) + + def test_event_has_lead(self): + """Events contain lead attribute.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + event = patient.get_events()[0] + self.assertTrue(hasattr(event, "lead")) + + +@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found") +class TestLUDBHeaderParsing(unittest.TestCase): + """Tests for LUDB patient metadata in wfdb headers.""" + + def test_header_has_comments(self): + """LUDB wfdb records contain header comments.""" + import wfdb + + data_dir = os.path.join(TEST_DATA_DIR, "data") + rec = wfdb.rdrecord(os.path.join(data_dir, "1")) + self.assertIsNotNone(rec.comments) + self.assertGreater(len(rec.comments), 0) + + def test_header_has_age_sex(self): + """Header comments contain age and sex fields.""" + import wfdb + + data_dir = os.path.join(TEST_DATA_DIR, "data") + rec = wfdb.rdrecord(os.path.join(data_dir, "1")) + comments = "\n".join(rec.comments) + self.assertIn(":", comments) + self.assertIn(":", comments) + + def test_header_has_diagnoses(self): + """Header comments contain diagnoses section.""" + import wfdb + + data_dir = os.path.join(TEST_DATA_DIR, "data") + rec = wfdb.rdrecord(os.path.join(data_dir, "1")) + comments = "\n".join(rec.comments) + self.assertIn(":", comments) + + +@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found") +class TestECGWaveSegmentationTask(unittest.TestCase): + """Tests for ECGWaveSegmentation with real LUDB data.""" + + @classmethod + def setUpClass(cls): + from pyhealth.datasets.ludb import LUDBDataset + from pyhealth.tasks.ecg_wave_segmentation import ( + ECGWaveSegmentation, + ) + + cls.dataset = LUDBDataset(root=TEST_DATA_DIR, dev=True) + cls.task = ECGWaveSegmentation(window_size=512, step_size=256) + + def test_task_attributes(self): + """Task has correct name and schema.""" + self.assertEqual(self.task.task_name, "ECGWaveSegmentation") + self.assertEqual(self.task.input_schema, {"signal": "tensor"}) + self.assertEqual(self.task.output_schema, {"label": "tensor"}) + + def test_task_produces_samples(self): + """Task generates samples from a patient.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + samples = self.task(patient) + self.assertGreater(len(samples), 0) + + def test_sample_has_required_keys(self): + """Each sample has signal, label, patient_id.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + sample = self.task(patient)[0] + self.assertIn("signal", sample) + self.assertIn("label", sample) + self.assertIn("patient_id", sample) + + def test_signal_shape(self): + """Signal has shape (window_size,).""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + sample = self.task(patient)[0] + self.assertEqual(sample["signal"].shape, (512,)) + + def test_label_shape(self): + """Label has shape (window_size,) matching signal.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + sample = self.task(patient)[0] + self.assertEqual(sample["label"].shape, (512,)) + + def test_labels_are_valid_classes(self): + """Labels are 0 (bg), 1 (P), 2 (QRS), or 3 (T).""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + sample = self.task(patient)[0] + unique = set(sample["label"].tolist()) + self.assertTrue(unique.issubset({0, 1, 2, 3})) + + def test_samples_contain_waves(self): + """Default (no preprocess) produces samples containing wave labels.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + samples = self.task(patient) + has_waves = any((s["label"] > 0).any() for s in samples) + self.assertTrue(has_waves) + + def test_set_task_produces_sample_dataset(self): + """dataset.set_task() returns a usable SampleDataset.""" + sample_ds = self.dataset.set_task(self.task) + self.assertGreater(len(sample_ds), 0) + sample = sample_ds[0] + self.assertIn("signal", sample) + self.assertIn("label", sample) + + +# ------------------------------------------------------------------ # +# Phase 2: patient-prompt plumbing (LUDB headers -> description) +# ------------------------------------------------------------------ # + + +def _force_regenerate_metadata(): + """Reset LUDB test state: delete committed CSV and PyHealth cache. + + The repo ships a pre-built ``ludb-pyhealth.csv`` and PyHealth + caches the parsed event dataframe under ``~/.cache/pyhealth/``. + Both need to be cleared so tests exercise the current + ``prepare_metadata`` implementation and event schema from the + YAML config, not stale artifacts from a prior run. + """ + import shutil + + import platformdirs + + csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv") + if os.path.exists(csv_path): + os.remove(csv_path) + + cache_root = platformdirs.user_cache_dir(appname="pyhealth") + if os.path.isdir(cache_root): + shutil.rmtree(cache_root) + + +@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found") +class TestLUDBMetadataDemographics(unittest.TestCase): + """Metadata CSV + events expose patient demographics.""" + + @classmethod + def setUpClass(cls): + from pyhealth.datasets.ludb import LUDBDataset + + _force_regenerate_metadata() + cls.dataset = LUDBDataset(root=TEST_DATA_DIR, dev=True) + + def test_metadata_csv_has_age_column(self): + """Regenerated metadata CSV includes an ``age`` column.""" + import pandas as pd + + csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv") + df = pd.read_csv(csv_path) + self.assertIn("age", df.columns) + + def test_metadata_csv_has_sex_column(self): + """Regenerated metadata CSV includes a ``sex`` column.""" + import pandas as pd + + csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv") + df = pd.read_csv(csv_path) + self.assertIn("sex", df.columns) + + def test_metadata_csv_has_diagnoses_column(self): + """Regenerated metadata CSV includes a ``diagnoses`` column.""" + import pandas as pd + + csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv") + df = pd.read_csv(csv_path) + self.assertIn("diagnoses", df.columns) + + def test_metadata_age_non_empty(self): + """At least one row has a non-empty age.""" + import pandas as pd + + csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv") + df = pd.read_csv(csv_path) + ages = df["age"].dropna().astype(str).str.strip() + self.assertTrue((ages != "").any()) + + def test_event_has_age_attr(self): + """Events expose an ``age`` attribute.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + event = patient.get_events()[0] + self.assertTrue(hasattr(event, "age")) + + def test_event_has_sex_attr(self): + """Events expose a ``sex`` attribute.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + event = patient.get_events()[0] + self.assertTrue(hasattr(event, "sex")) + + def test_event_has_diagnoses_attr(self): + """Events expose a ``diagnoses`` attribute.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + event = patient.get_events()[0] + self.assertTrue(hasattr(event, "diagnoses")) + + +@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found") +class TestECGWaveSegmentationDescription(unittest.TestCase): + """ECGWaveSegmentation emits per-sample ``description``.""" + + @classmethod + def setUpClass(cls): + from pyhealth.datasets.ludb import LUDBDataset + from pyhealth.tasks.ecg_wave_segmentation import ( + ECGWaveSegmentation, + ) + + _force_regenerate_metadata() + cls.dataset = LUDBDataset(root=TEST_DATA_DIR, dev=True) + cls.task = ECGWaveSegmentation(window_size=512, step_size=256) + + def test_sample_has_description_key(self): + """Samples include a ``description`` field.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + sample = self.task(patient)[0] + self.assertIn("description", sample) + + def test_description_is_string(self): + """Description is a string.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + sample = self.task(patient)[0] + self.assertIsInstance(sample["description"], str) + + def test_description_populated(self): + """Description is non-empty for a patient with header data.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + sample = self.task(patient)[0] + self.assertGreater(len(sample["description"]), 0) + + def test_description_contains_age(self): + """Description mentions age for patient 1.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + sample = self.task(patient)[0] + self.assertIn("age", sample["description"].lower()) + + +@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found") +class TestMedTsLLMWithLUDB(unittest.TestCase): + """Test MedTsLLM model with real LUDB data structure.""" + + @classmethod + def setUpClass(cls): + from pyhealth.datasets.ludb import LUDBDataset + from pyhealth.tasks.ecg_wave_segmentation import ( + ECGWaveSegmentation, + ) + from pyhealth.datasets import get_dataloader + from pyhealth.models.medtsllm import MedTsLLM + + dataset = LUDBDataset(root=TEST_DATA_DIR, dev=True) + task = ECGWaveSegmentation(window_size=128, step_size=64) + cls.sample_ds = dataset.set_task(task) + + cls.model = MedTsLLM( + dataset=cls.sample_ds, + seq_len=128, + n_features=1, + n_classes=4, + backbone=None, + word_embeddings=torch.randn(50, 32), + d_model=16, + d_ff=32, + n_heads=4, + num_tokens=50, + ) + loader = get_dataloader(cls.sample_ds, batch_size=2, shuffle=False) + cls.batch = next(iter(loader)) + + def test_forward(self): + """Model forward pass works with real LUDB samples.""" + out = self.model(**self.batch) + self.assertIn("logit", out) + self.assertIn("loss", out) + self.assertTrue(out["loss"].isfinite()) + + def test_output_shape(self): + """Output has correct shape for 4-class segmentation.""" + out = self.model(**self.batch) + bs = self.batch["label"].shape[0] + self.assertEqual(out["logit"].shape, (bs, 128, 4)) + + +# ------------------------------------------------------------------ # +# Phase 5: paper-match split column (LUDB 80/20 seed=0) +# ------------------------------------------------------------------ # + + +@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found") +class TestLUDBPaperSplit(unittest.TestCase): + """paper_split=True writes an 80/20 seed=0 split column.""" + + @classmethod + def setUpClass(cls): + from pyhealth.datasets.ludb import LUDBDataset + + _force_regenerate_metadata() + cls.dataset = LUDBDataset( + root=TEST_DATA_DIR, dev=True, paper_split=True + ) + + def test_constructor_accepts_paper_split(self): + """LUDBDataset accepts a paper_split kwarg.""" + import inspect + + from pyhealth.datasets.ludb import LUDBDataset + + sig = inspect.signature(LUDBDataset.__init__) + self.assertIn("paper_split", sig.parameters) + + def test_csv_has_split_column(self): + """Regenerated CSV contains a ``split`` column.""" + import pandas as pd + + csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv") + df = pd.read_csv(csv_path) + self.assertIn("split", df.columns) + + def test_split_values_are_train_or_test(self): + """Every split value is either 'train' or 'test'.""" + import pandas as pd + + csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv") + df = pd.read_csv(csv_path) + unique = set(df["split"].unique()) + self.assertTrue(unique.issubset({"train", "test"})) + + def test_split_consistent_per_patient(self): + """Every lead of a patient shares one split assignment.""" + import pandas as pd + + csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv") + df = pd.read_csv(csv_path) + for _, group in df.groupby("patient_id"): + self.assertEqual(group["split"].nunique(), 1) + + def test_event_has_split_attr(self): + """Events expose a ``split`` attribute.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + event = patient.get_events()[0] + self.assertTrue(hasattr(event, "split")) + self.assertIn(event.split, ("train", "test")) + + def test_task_emits_split_field(self): + """ECGWaveSegmentation propagates event.split to each sample.""" + from pyhealth.tasks.ecg_wave_segmentation import ( + ECGWaveSegmentation, + ) + + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + task = ECGWaveSegmentation(window_size=512, step_size=256) + sample = task(patient)[0] + self.assertIn("split", sample) + self.assertIn(sample["split"], ("train", "test")) + + +@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found") +class TestLUDBPaperSplitDeterministic(unittest.TestCase): + """Seed=0 split is reproducible across runs.""" + + def test_split_is_deterministic(self): + """Regenerating twice yields the same patient -> split mapping.""" + import pandas as pd + from pyhealth.datasets.ludb import LUDBDataset + + _force_regenerate_metadata() + LUDBDataset(root=TEST_DATA_DIR, dev=True, paper_split=True) + df1 = pd.read_csv(os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv")) + mapping1 = df1.groupby("patient_id")["split"].first().to_dict() + + _force_regenerate_metadata() + LUDBDataset(root=TEST_DATA_DIR, dev=True, paper_split=True) + df2 = pd.read_csv(os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv")) + mapping2 = df2.groupby("patient_id")["split"].first().to_dict() + + self.assertEqual(mapping1, mapping2) + + +@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found") +class TestLUDBPaperSplitDisabled(unittest.TestCase): + """paper_split=False does not assign train/test.""" + + def test_split_column_blank_when_disabled(self): + """Without paper_split, split column is empty.""" + import pandas as pd + from pyhealth.datasets.ludb import LUDBDataset + + _force_regenerate_metadata() + LUDBDataset(root=TEST_DATA_DIR, dev=True, paper_split=False) + df = pd.read_csv(os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv")) + self.assertIn("split", df.columns) + cleaned = df["split"].fillna("").astype(str).str.strip() + self.assertTrue((cleaned == "").all()) + + +# ------------------------------------------------------------------ # +# Phase 6: preprocess=True -> .npz cache for LUDB +# ------------------------------------------------------------------ # + + +@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found") +class TestLUDBPreprocessCache(unittest.TestCase): + """preprocess=True writes .npz per (patient, lead) and wires events.""" + + @classmethod + def setUpClass(cls): + import shutil + + from pyhealth.datasets.ludb import LUDBDataset + + # Clear any prior run's cache so paths/fingerprints are fresh. + processed_dir = os.path.join(TEST_DATA_DIR, "processed") + if os.path.isdir(processed_dir): + shutil.rmtree(processed_dir) + _force_regenerate_metadata() + + cls.dataset = LUDBDataset( + root=TEST_DATA_DIR, dev=True, preprocess=True + ) + + def test_processed_dir_created(self): + """preprocess=True creates {root}/processed/.""" + self.assertTrue( + os.path.isdir(os.path.join(TEST_DATA_DIR, "processed")) + ) + + def test_npz_written_per_lead(self): + """Every (record, lead) pair has a corresponding .npz.""" + processed_dir = os.path.join(TEST_DATA_DIR, "processed") + files = os.listdir(processed_dir) + npz_files = [f for f in files if f.endswith(".npz")] + # 2 records * 12 leads = 24 expected + self.assertGreaterEqual(len(npz_files), 20) + + def test_event_has_processed_file_attr(self): + """Events expose the cache path via ``processed_file``.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + event = patient.get_events()[0] + self.assertTrue(hasattr(event, "processed_file")) + self.assertTrue(event.processed_file.endswith(".npz")) + self.assertTrue(os.path.exists(event.processed_file)) + + def test_cached_arrays_have_expected_keys(self): + """Each .npz stores signal + labels arrays.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + event = patient.get_events()[0] + with np.load(event.processed_file, allow_pickle=False) as npz: + self.assertIn("signal", npz.files) + self.assertIn("labels", npz.files) + self.assertEqual(npz["signal"].shape, npz["labels"].shape) + + def test_cached_labels_contain_waves(self): + """Cached labels include non-background classes after trim.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + event = patient.get_events()[0] + with np.load(event.processed_file, allow_pickle=False) as npz: + self.assertTrue((npz["labels"] > 0).any()) + + def test_task_loads_from_cache(self): + """ECGWaveSegmentation returns samples using cached arrays.""" + from pyhealth.tasks.ecg_wave_segmentation import ( + ECGWaveSegmentation, + ) + + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + task = ECGWaveSegmentation(window_size=128, step_size=64) + samples = task(patient) + self.assertGreater(len(samples), 0) + self.assertEqual(samples[0]["signal"].shape, (128,)) + + +@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found") +class TestLUDBPreprocessDisabled(unittest.TestCase): + """preprocess=False leaves processed_file blank.""" + + def test_processed_file_empty_when_disabled(self): + import shutil + + from pyhealth.datasets.ludb import LUDBDataset + + processed_dir = os.path.join(TEST_DATA_DIR, "processed") + if os.path.isdir(processed_dir): + shutil.rmtree(processed_dir) + _force_regenerate_metadata() + + ds = LUDBDataset(root=TEST_DATA_DIR, dev=True, preprocess=False) + pid = ds.unique_patient_ids[0] + event = ds.get_patient(pid).get_events()[0] + value = getattr(event, "processed_file", "") or "" + self.assertEqual(str(value).strip(), "") + + +@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found") +class TestLUDBPreprocessTrimOff(unittest.TestCase): + """preprocess=True, trim=False caches the full (untrimmed) lead.""" + + def test_untrimmed_cache_is_longer(self): + import shutil + + from pyhealth.datasets.ludb import LUDBDataset + + processed_dir = os.path.join(TEST_DATA_DIR, "processed") + + # Build trimmed cache first. + if os.path.isdir(processed_dir): + shutil.rmtree(processed_dir) + _force_regenerate_metadata() + ds_trim = LUDBDataset( + root=TEST_DATA_DIR, dev=True, preprocess=True, trim=True + ) + pid_trim = ds_trim.unique_patient_ids[0] + ev_trim = ds_trim.get_patient(pid_trim).get_events()[0] + with np.load(ev_trim.processed_file, allow_pickle=False) as npz: + trimmed_len = int(npz["signal"].shape[0]) + + # Rebuild with trim=False. + if os.path.isdir(processed_dir): + shutil.rmtree(processed_dir) + _force_regenerate_metadata() + ds_full = LUDBDataset( + root=TEST_DATA_DIR, dev=True, preprocess=True, trim=False + ) + pid_full = ds_full.unique_patient_ids[0] + ev_full = ds_full.get_patient(pid_full).get_events()[0] + with np.load(ev_full.processed_file, allow_pickle=False) as npz: + full_len = int(npz["signal"].shape[0]) + + self.assertGreaterEqual(full_len, trimmed_len) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_medtsllm.py b/tests/core/test_medtsllm.py new file mode 100644 index 000000000..1a59d4b28 --- /dev/null +++ b/tests/core/test_medtsllm.py @@ -0,0 +1,1358 @@ +"""Unit tests for MedTsLLM model and its preprocessing cache helper. + +Author: Anton Barchukov + +Tests cover: + - Model instantiation with synthetic data (no LLM download) + - Forward pass output keys and shapes + - Backward pass (gradient flow) + - Loss computation when labels provided + - Parameter filtering (frozen params excluded from optimizer) + - State dict save/load with frozen params + - Internal layers (RevIN, PatchEmbedding, ReprogrammingLayer) + - Different sequence lengths, feature counts, covariate modes + - Edge cases: no labels, invalid inputs + - Compatibility with PyHealth pipeline + - ``_medtsllm_cache`` fingerprint + load_or_build helpers used by + LUDB / MIT-BIH / BIDMC preprocessing +""" + +import os +import tempfile +import unittest + +import torch +import numpy as np + + +def _make_dataset( + n_samples: int = 4, + seq_len: int = 128, + n_classes: int = 4, +): + """Create a minimal SampleDataset with signal + label.""" + from pyhealth.datasets import create_sample_dataset + + samples = [] + for i in range(n_samples): + signal = np.random.randn(seq_len).astype(np.float32) + label = np.random.randint(0, n_classes, size=seq_len).astype( + np.int64 + ) + samples.append({ + "patient_id": f"p{i}", + "visit_id": "v0", + "signal": signal, + "label": label, + }) + + return create_sample_dataset( + samples=samples, + input_schema={"signal": "tensor"}, + output_schema={"label": "tensor"}, + dataset_name="test_medtsllm", + ) + + +def _make_model(dataset, **kwargs): + """Create a MedTsLLM with synthetic word embeddings (no LLM).""" + from pyhealth.models.medtsllm import MedTsLLM + + defaults = dict( + dataset=dataset, + seq_len=128, + n_features=1, + n_classes=4, + backbone=None, + word_embeddings=torch.randn(100, 64), + d_model=16, + d_ff=32, + n_heads=4, + num_tokens=50, + patch_len=16, + stride=8, + dropout=0.1, + ) + defaults.update(kwargs) + return MedTsLLM(**defaults) + + +def _make_batch(dataset, batch_size=2): + """Get a single batch from the dataset.""" + from pyhealth.datasets import get_dataloader + + loader = get_dataloader(dataset, batch_size=batch_size, shuffle=False) + return next(iter(loader)) + + +# ------------------------------------------------------------------ # +# Forward pass tests +# ------------------------------------------------------------------ # + + +class TestMedTsLLMForward(unittest.TestCase): + """Tests for MedTsLLM forward pass.""" + + @classmethod + def setUpClass(cls): + cls.dataset = _make_dataset() + cls.model = _make_model(cls.dataset) + cls.batch = _make_batch(cls.dataset) + + def test_forward_keys(self): + """Output dict has expected keys when labels provided.""" + out = self.model(**self.batch) + self.assertIn("logit", out) + self.assertIn("y_prob", out) + self.assertIn("loss", out) + self.assertIn("y_true", out) + + def test_logit_shape(self): + """Logit has shape (batch, seq_len, n_classes).""" + out = self.model(**self.batch) + bs = self.batch["label"].shape[0] + self.assertEqual(out["logit"].shape, (bs, 128, 4)) + + def test_y_prob_shape(self): + """y_prob has same shape as logit.""" + out = self.model(**self.batch) + self.assertEqual(out["y_prob"].shape, out["logit"].shape) + + def test_y_prob_sums_to_one(self): + """Softmax probabilities sum to ~1 along class dim.""" + out = self.model(**self.batch) + sums = out["y_prob"].sum(dim=-1) + self.assertTrue( + torch.allclose(sums, torch.ones_like(sums), atol=1e-5) + ) + + def test_y_prob_non_negative(self): + """All probabilities are non-negative.""" + out = self.model(**self.batch) + self.assertTrue((out["y_prob"] >= 0).all()) + + def test_loss_is_scalar(self): + """Loss is a scalar tensor.""" + out = self.model(**self.batch) + self.assertEqual(out["loss"].shape, ()) + + def test_loss_is_finite(self): + """Loss is not NaN or Inf.""" + out = self.model(**self.batch) + self.assertTrue(out["loss"].isfinite()) + + def test_2d_signal_input(self): + """Forward works with 2D signal (batch, seq_len) — auto-unsqueeze.""" + signal = torch.randn(2, 128) + label = torch.randint(0, 4, (2, 128)) + out = self.model(signal=signal, label=label) + self.assertEqual(out["logit"].shape, (2, 128, 4)) + + def test_3d_signal_input(self): + """Forward works with 3D signal (batch, seq_len, features).""" + signal = torch.randn(2, 128, 1) + label = torch.randint(0, 4, (2, 128)) + out = self.model(signal=signal, label=label) + self.assertEqual(out["logit"].shape, (2, 128, 4)) + + +# ------------------------------------------------------------------ # +# Backward pass tests +# ------------------------------------------------------------------ # + + +class TestMedTsLLMBackward(unittest.TestCase): + """Tests for MedTsLLM backward pass.""" + + @classmethod + def setUpClass(cls): + cls.dataset = _make_dataset() + cls.model = _make_model(cls.dataset) + cls.batch = _make_batch(cls.dataset) + + def test_backward(self): + """Loss backward pass succeeds.""" + out = self.model(**self.batch) + out["loss"].backward() + + def test_gradients_flow(self): + """Trainable parameters receive gradients.""" + self.model.zero_grad() + out = self.model(**self.batch) + out["loss"].backward() + + has_grad = False + for name, param in self.model.named_parameters(): + if param.requires_grad and param.grad is not None: + if param.grad.abs().sum() > 0: + has_grad = True + break + self.assertTrue(has_grad, "No gradients found in trainable params") + + def test_word_embeddings_frozen(self): + """Word embeddings should not receive gradients.""" + self.assertFalse(self.model.word_embeddings.requires_grad) + + def test_parameters_only_trainable(self): + """parameters() yields only trainable params (no frozen).""" + for p in self.model.parameters(): + self.assertTrue( + p.requires_grad, + "parameters() yielded a frozen parameter", + ) + + def test_named_parameters_only_trainable(self): + """named_parameters() excludes frozen params.""" + for name, p in self.model.named_parameters(): + self.assertTrue( + p.requires_grad, + f"named_parameters() yielded frozen param: {name}", + ) + + def test_frozen_params_excluded_from_count(self): + """Frozen word_embeddings not in parameters() count.""" + param_names = {n for n, _ in self.model.named_parameters()} + self.assertNotIn("word_embeddings", param_names) + + +# ------------------------------------------------------------------ # +# State dict tests +# ------------------------------------------------------------------ # + + +class TestMedTsLLMStateDict(unittest.TestCase): + """Tests for checkpoint save/load with frozen params.""" + + @classmethod + def setUpClass(cls): + cls.dataset = _make_dataset() + cls.model = _make_model(cls.dataset) + + def test_state_dict_includes_all_params(self): + """state_dict() includes frozen params (needed for full reload).""" + sd = self.model.state_dict() + self.assertIn("word_embeddings", sd) + self.assertIn("patch_embedding.value_embedding.conv.weight", sd) + + def test_load_state_dict_roundtrip(self): + """Save and reload state dict without errors.""" + sd = self.model.state_dict() + model2 = _make_model(self.dataset) + model2.load_state_dict(sd) + + # Verify weights match + for (n1, p1), (n2, p2) in zip( + self.model.state_dict().items(), + model2.state_dict().items(), + ): + self.assertEqual(n1, n2) + self.assertTrue(torch.equal(p1, p2), f"Mismatch in {n1}") + + +# ------------------------------------------------------------------ # +# No labels tests +# ------------------------------------------------------------------ # + + +class TestMedTsLLMNoLabels(unittest.TestCase): + """Tests for forward pass without labels.""" + + @classmethod + def setUpClass(cls): + cls.dataset = _make_dataset() + cls.model = _make_model(cls.dataset) + + def test_no_loss_without_labels(self): + """No loss key when labels not in kwargs.""" + signal = torch.randn(2, 128) + out = self.model(signal=signal) + self.assertNotIn("loss", out) + self.assertNotIn("y_true", out) + self.assertIn("logit", out) + self.assertIn("y_prob", out) + + +# ------------------------------------------------------------------ # +# Internal layers tests +# ------------------------------------------------------------------ # + + +class TestRevIN(unittest.TestCase): + """Tests for RevIN normalization layer.""" + + def test_normalize_denormalize_roundtrip(self): + """RevIN(x, 'norm') -> RevIN(result, 'denorm') ≈ x.""" + from pyhealth.models._medtsllm.layers import RevIN + + revin = RevIN(num_features=3) + x = torch.randn(2, 64, 3) * 10 + 5 # non-zero mean, large std + normed = revin(x, "norm") + recovered = revin(normed, "denorm") + self.assertTrue( + torch.allclose(x, recovered, atol=1e-5), + "RevIN roundtrip failed", + ) + + def test_normalize_zero_mean_unit_var(self): + """After normalization, each feature has ~0 mean and ~1 std.""" + from pyhealth.models._medtsllm.layers import RevIN + + revin = RevIN(num_features=2) + x = torch.randn(4, 100, 2) * 5 + 3 + normed = revin(x, "norm") + # Check per-sample, per-feature + means = normed.mean(dim=1) + stds = normed.std(dim=1, unbiased=False) + self.assertTrue( + torch.allclose(means, torch.zeros_like(means), atol=1e-4) + ) + self.assertTrue( + torch.allclose(stds, torch.ones_like(stds), atol=0.1) + ) + + def test_denorm_before_norm_raises(self): + """Calling denorm before norm raises RuntimeError.""" + from pyhealth.models._medtsllm.layers import RevIN + + revin = RevIN(num_features=1) + x = torch.randn(1, 10, 1) + with self.assertRaises(RuntimeError): + revin(x, "denorm") + + def test_invalid_mode_raises(self): + """Invalid mode raises ValueError.""" + from pyhealth.models._medtsllm.layers import RevIN + + revin = RevIN(num_features=1) + x = torch.randn(1, 10, 1) + with self.assertRaises(ValueError): + revin(x, "invalid") + + +class TestPatchEmbedding(unittest.TestCase): + """Tests for PatchEmbedding layer.""" + + def test_output_shape(self): + """Output has correct shape (batch*features, n_patches, d_model).""" + from pyhealth.models._medtsllm.layers import PatchEmbedding + + pe = PatchEmbedding(d_model=32, patch_len=16, stride=8) + x = torch.randn(2, 1, 128) # (batch, features, seq_len) + out, n_vars = pe(x) + self.assertEqual(n_vars, 1) + # n_patches = (128 - 16) // 8 + 2 = 16 + expected_patches = (128 - 16) // 8 + 2 + self.assertEqual(out.shape, (2, expected_patches, 32)) + + def test_multivariate_output_shape(self): + """Multivariate input merges batch and feature dims.""" + from pyhealth.models._medtsllm.layers import PatchEmbedding + + pe = PatchEmbedding(d_model=16, patch_len=8, stride=4) + x = torch.randn(3, 5, 64) # 3 batches, 5 features + out, n_vars = pe(x) + self.assertEqual(n_vars, 5) + self.assertEqual(out.shape[0], 15) # 3 * 5 + self.assertEqual(out.shape[2], 16) # d_model + + def test_gradient_flow(self): + """Gradients flow through patch embedding.""" + from pyhealth.models._medtsllm.layers import PatchEmbedding + + pe = PatchEmbedding(d_model=16, patch_len=8, stride=4) + x = torch.randn(2, 1, 32, requires_grad=True) + out, _ = pe(x) + out.sum().backward() + self.assertIsNotNone(x.grad) + + +class TestReprogrammingLayer(unittest.TestCase): + """Tests for ReprogrammingLayer (cross-attention).""" + + def test_output_shape(self): + """Output shape matches (batch, n_patches, d_llm).""" + from pyhealth.models._medtsllm.layers import ReprogrammingLayer + + layer = ReprogrammingLayer( + d_model=32, n_heads=4, d_keys=16, d_llm=64 + ) + target = torch.randn(2, 10, 32) # patches + source = torch.randn(50, 64) # word prototypes + out = layer(target, source, source) + self.assertEqual(out.shape, (2, 10, 64)) + + def test_gradient_flow(self): + """Gradients flow through reprogramming layer.""" + from pyhealth.models._medtsllm.layers import ReprogrammingLayer + + layer = ReprogrammingLayer( + d_model=16, n_heads=2, d_keys=8, d_llm=32 + ) + target = torch.randn(2, 5, 16, requires_grad=True) + source = torch.randn(20, 32) + out = layer(target, source, source) + out.sum().backward() + self.assertIsNotNone(target.grad) + + +class TestFlattenHead(unittest.TestCase): + """Tests for FlattenHead output projection.""" + + def test_output_shape(self): + """Flattens and projects to correct output size.""" + from pyhealth.models._medtsllm.layers import FlattenHead + + head = FlattenHead(n_features_in=64, n_outputs=128) + x = torch.randn(2, 8, 8) # (batch, d_ff, n_patches) + out = head(x) + self.assertEqual(out.shape, (2, 128)) + + +# ------------------------------------------------------------------ # +# Configuration tests +# ------------------------------------------------------------------ # + + +class TestMedTsLLMConfigs(unittest.TestCase): + """Tests for different model configurations.""" + + @classmethod + def setUpClass(cls): + cls.dataset = _make_dataset(seq_len=64, n_classes=2) + + def test_different_seq_len(self): + """Model works with non-default sequence length.""" + model = _make_model( + self.dataset, + seq_len=64, + n_classes=2, + word_embeddings=torch.randn(50, 32), + ) + signal = torch.randn(2, 64) + label = torch.randint(0, 2, (2, 64)) + out = model(signal=signal, label=label) + self.assertEqual(out["logit"].shape, (2, 64, 2)) + self.assertTrue(out["loss"].isfinite()) + + def test_multivariate_concat(self): + """Model works with concat covariate mode.""" + dataset = _make_dataset(seq_len=64, n_classes=2) + model = _make_model( + dataset, + seq_len=64, + n_features=3, + n_classes=2, + covariate_mode="concat", + word_embeddings=torch.randn(50, 32), + ) + signal = torch.randn(2, 64, 3) + label = torch.randint(0, 2, (2, 64)) + out = model(signal=signal, label=label) + self.assertEqual(out["logit"].shape, (2, 64, 2)) + self.assertTrue(out["loss"].isfinite()) + out["loss"].backward() + + def test_univariate_mode(self): + """Model works with univariate covariate mode (default).""" + dataset = _make_dataset(seq_len=64, n_classes=3) + model = _make_model( + dataset, + seq_len=64, + n_features=1, + n_classes=3, + covariate_mode="univariate", + word_embeddings=torch.randn(50, 32), + ) + signal = torch.randn(2, 64) + label = torch.randint(0, 3, (2, 64)) + out = model(signal=signal, label=label) + self.assertEqual(out["logit"].shape, (2, 64, 3)) + + def test_small_d_model(self): + """Model works with minimal dimensions.""" + dataset = _make_dataset(seq_len=32, n_classes=2) + model = _make_model( + dataset, + seq_len=32, + n_classes=2, + d_model=8, + d_ff=16, + n_heads=2, + num_tokens=10, + word_embeddings=torch.randn(20, 16), + ) + signal = torch.randn(1, 32) + label = torch.randint(0, 2, (1, 32)) + out = model(signal=signal, label=label) + self.assertTrue(out["loss"].isfinite()) + + +# ------------------------------------------------------------------ # +# LinearProjection layer tests +# ------------------------------------------------------------------ # + + +class TestLinearProjection(unittest.TestCase): + """Tests for LinearProjection ablation layer.""" + + def test_output_shape(self): + """Output shape matches (batch, n_patches, d_llm).""" + from pyhealth.models._medtsllm.layers import LinearProjection + + layer = LinearProjection(d_model=32, d_llm=64) + target = torch.randn(2, 10, 32) + source = torch.randn(50, 64) # ignored + out = layer(target, source, source) + self.assertEqual(out.shape, (2, 10, 64)) + + def test_ignores_source(self): + """Output is the same regardless of source embeddings.""" + from pyhealth.models._medtsllm.layers import LinearProjection + + layer = LinearProjection(d_model=16, d_llm=32) + target = torch.randn(2, 5, 16) + source_a = torch.randn(20, 32) + source_b = torch.randn(100, 32) + out_a = layer(target, source_a, source_a) + out_b = layer(target, source_b, source_b) + self.assertTrue(torch.equal(out_a, out_b)) + + def test_gradient_flow(self): + """Gradients flow through linear projection.""" + from pyhealth.models._medtsllm.layers import LinearProjection + + layer = LinearProjection(d_model=16, d_llm=32) + target = torch.randn(2, 5, 16, requires_grad=True) + source = torch.randn(20, 32) + out = layer(target, source, source) + out.sum().backward() + self.assertIsNotNone(target.grad) + + +# ------------------------------------------------------------------ # +# Prompt builder tests +# ------------------------------------------------------------------ # + + +class TestPromptBuilder(unittest.TestCase): + """Tests for build_prompt function.""" + + def test_dataset_task_prompt(self): + """Builds prompt with dataset and task descriptions.""" + from pyhealth.models._medtsllm.prompt import build_prompt + + inputs = {"x_enc": torch.randn(2, 128, 1)} + prompts = build_prompt( + inputs, + dataset_description="Test dataset", + task_description="Test task", + include_dataset=True, + include_task=True, + ) + self.assertEqual(len(prompts), 2) + flat = " ".join(prompts[0]) + self.assertIn("Test dataset", flat) + self.assertIn("Test task", flat) + + def test_no_prompt(self): + """Empty prompt when all flags are False.""" + from pyhealth.models._medtsllm.prompt import build_prompt + + inputs = {"x_enc": torch.randn(1, 64, 1)} + prompts = build_prompt( + inputs, + include_dataset=False, + include_task=False, + include_clip=False, + include_stats=False, + ) + # Should still have "Time series:" at minimum + flat = " ".join(prompts[0]) + self.assertIn("Time series:", flat) + + def test_clip_prompt(self): + """Per-sample descriptions included when clip=True.""" + from pyhealth.models._medtsllm.prompt import build_prompt + + inputs = { + "x_enc": torch.randn(2, 64, 1), + "descriptions": ["age: 51, sex: F", "age: 64, sex: M"], + } + prompts = build_prompt( + inputs, + include_dataset=False, + include_task=False, + include_clip=True, + ) + flat_0 = " ".join(prompts[0]) + flat_1 = " ".join(prompts[1]) + self.assertIn("age: 51", flat_0) + self.assertIn("age: 64", flat_1) + + def test_batch_size_matches(self): + """Number of prompts matches batch size.""" + from pyhealth.models._medtsllm.prompt import build_prompt + + inputs = {"x_enc": torch.randn(5, 32, 1)} + prompts = build_prompt(inputs) + self.assertEqual(len(prompts), 5) + + def test_stats_prompt_included(self): + """Stats prompt appears when include_stats=True.""" + from pyhealth.models._medtsllm.prompt import build_prompt + + inputs = {"x_enc": torch.randn(2, 64, 1)} + prompts = build_prompt( + inputs, + include_dataset=False, + include_task=False, + include_stats=True, + ) + flat = " ".join(prompts[0]) + self.assertIn("Input statistics", flat) + self.assertIn("min value", flat) + self.assertIn("max value", flat) + self.assertIn("median value", flat) + self.assertIn("trend", flat) + self.assertIn("lags", flat) + + def test_stats_prompt_omitted_by_default(self): + """Stats prompt absent when include_stats=False.""" + from pyhealth.models._medtsllm.prompt import build_prompt + + inputs = {"x_enc": torch.randn(2, 64, 1)} + prompts = build_prompt( + inputs, + include_dataset=False, + include_task=False, + include_stats=False, + ) + flat = " ".join(prompts[0]) + self.assertNotIn("Input statistics", flat) + + +# ------------------------------------------------------------------ # +# compute_lags tests +# ------------------------------------------------------------------ # + + +class TestComputeLags(unittest.TestCase): + """Tests for the FFT-based autocorrelation lag helper.""" + + def test_shape_univariate(self): + """Shape is (batch, n_lags) for univariate input.""" + from pyhealth.models._medtsllm.prompt import compute_lags + + x = torch.randn(4, 128) + lags = compute_lags(x, n_lags=5) + self.assertEqual(lags.shape, (4, 5)) + + def test_shape_multivariate(self): + """Shape is (batch, n_lags) for 3D input.""" + from pyhealth.models._medtsllm.prompt import compute_lags + + x = torch.randn(4, 128, 3) + lags = compute_lags(x, n_lags=3) + self.assertEqual(lags.shape, (4, 3)) + + def test_lags_are_long(self): + """Lag indices are integer-valued.""" + from pyhealth.models._medtsllm.prompt import compute_lags + + x = torch.randn(2, 64, 1) + lags = compute_lags(x, n_lags=5) + self.assertEqual(lags.dtype, torch.int64) + + def test_lags_in_range(self): + """Lag indices fall within sequence length.""" + from pyhealth.models._medtsllm.prompt import compute_lags + + seq_len = 128 + x = torch.randn(2, seq_len) + lags = compute_lags(x, n_lags=5) + self.assertTrue((lags >= 0).all()) + self.assertTrue((lags < seq_len).all()) + + +# ------------------------------------------------------------------ # +# Task parameter + default task_description +# ------------------------------------------------------------------ # + + +class TestTaskParam(unittest.TestCase): + """Tests for the ``task`` constructor parameter.""" + + @classmethod + def setUpClass(cls): + cls.dataset = _make_dataset() + + def test_default_task_is_semantic_segmentation(self): + """Default task is semantic_segmentation.""" + model = _make_model(self.dataset) + self.assertEqual(model.task, "semantic_segmentation") + + def test_invalid_task_raises(self): + """Unknown task string raises ValueError.""" + with self.assertRaises(ValueError): + _make_model(self.dataset, task="not_a_real_task") + + def test_semseg_description(self): + """semantic_segmentation task_description says 'Classify'.""" + model = _make_model(self.dataset, task="semantic_segmentation") + self.assertIn("Classify", model.task_description) + + def test_segmentation_description(self): + """segmentation task_description says 'change points'.""" + model = _make_model(self.dataset, task="segmentation") + self.assertIn("change points", model.task_description) + + def test_anomaly_description(self): + """anomaly_detection task_description says 'Reconstruct'.""" + model = _make_model(self.dataset, task="anomaly_detection") + self.assertIn("Reconstruct", model.task_description) + + def test_reconstruction_description(self): + """reconstruction task_description says 'Reconstruct'.""" + model = _make_model(self.dataset, task="reconstruction") + self.assertIn("Reconstruct", model.task_description) + + def test_forecasting_description(self): + """forecasting task_description says 'Forecast'.""" + model = _make_model(self.dataset, task="forecasting") + self.assertIn("Forecast", model.task_description) + + def test_explicit_task_description_wins(self): + """Explicit task_description is not overwritten by default.""" + model = _make_model( + self.dataset, + task="semantic_segmentation", + task_description="custom description here", + ) + self.assertEqual( + model.task_description, "custom description here" + ) + + +# ------------------------------------------------------------------ # +# Prompt config knobs +# ------------------------------------------------------------------ # + + +class TestPromptConfigKnobs(unittest.TestCase): + """All four prompt knobs are exposed and default True.""" + + @classmethod + def setUpClass(cls): + cls.dataset = _make_dataset() + + def test_four_knobs_present(self): + """prompt_config has dataset/task/patient/stats keys.""" + model = _make_model(self.dataset) + for key in ("dataset", "task", "patient", "stats"): + self.assertIn(key, model.prompt_config) + + def test_prompt_defaults(self): + """dataset/task/patient default True; stats defaults False to + match the cs598-pyhealth dtp reference config.""" + model = _make_model(self.dataset) + self.assertTrue(model.prompt_config["dataset"]) + self.assertTrue(model.prompt_config["task"]) + self.assertTrue(model.prompt_config["patient"]) + self.assertFalse(model.prompt_config["stats"]) + + def test_individual_toggles(self): + """Each knob can be toggled independently.""" + model = _make_model( + self.dataset, + prompt_dataset=False, + prompt_task=True, + prompt_patient=False, + prompt_stats=False, + ) + self.assertFalse(model.prompt_config["dataset"]) + self.assertTrue(model.prompt_config["task"]) + self.assertFalse(model.prompt_config["patient"]) + self.assertFalse(model.prompt_config["stats"]) + + +# ------------------------------------------------------------------ # +# reprogramming_layer override (LinearProjection ablation) +# ------------------------------------------------------------------ # + + +class TestReprogrammingOverride(unittest.TestCase): + """Override ReprogrammingLayer with LinearProjection.""" + + @classmethod + def setUpClass(cls): + cls.dataset = _make_dataset() + + def test_linear_projection_override(self): + """LinearProjection slots into the model.""" + from pyhealth.models._medtsllm.layers import LinearProjection + + # d_model=16, d_llm=64 in _make_model defaults + linear = LinearProjection(d_model=16, d_llm=64) + model = _make_model(self.dataset, reprogramming_layer=linear) + self.assertIs(model.reprogramming_layer, linear) + + def test_forward_with_linear_projection(self): + """Forward pass works with LinearProjection ablation.""" + from pyhealth.models._medtsllm.layers import LinearProjection + + linear = LinearProjection(d_model=16, d_llm=64) + model = _make_model(self.dataset, reprogramming_layer=linear) + batch = _make_batch(self.dataset) + out = model(**batch) + self.assertIn("logit", out) + self.assertTrue(out["loss"].isfinite()) + + +# ------------------------------------------------------------------ # +# Description coercion +# ------------------------------------------------------------------ # + + +class TestDescriptionCoercion(unittest.TestCase): + """_coerce_descriptions handles list, tuple, scalar-string.""" + + def test_list_passthrough(self): + from pyhealth.models.medtsllm import _coerce_descriptions + + desc = ["a", "b"] + self.assertEqual(_coerce_descriptions(desc, bs=2), ["a", "b"]) + + def test_tuple_coerced(self): + from pyhealth.models.medtsllm import _coerce_descriptions + + desc = ("a", "b", "c") + self.assertEqual( + _coerce_descriptions(desc, bs=3), ["a", "b", "c"] + ) + + def test_scalar_string_broadcast(self): + """Single string broadcasts to batch size, not list-of-chars.""" + from pyhealth.models.medtsllm import _coerce_descriptions + + desc = "hello" + out = _coerce_descriptions(desc, bs=3) + self.assertEqual(out, ["hello", "hello", "hello"]) + + +# ------------------------------------------------------------------ # +# Patient prompt integration test +# ------------------------------------------------------------------ # + + +class TestMedTsLLMPatientPrompt(unittest.TestCase): + """Tests for patient prompt path in model forward.""" + + @classmethod + def setUpClass(cls): + cls.dataset = _make_dataset() + # Model with prompt_patient=True but no LLM + # (synthetic replacement skips prompting, so we test the config) + cls.model = _make_model(cls.dataset, prompt_patient=True) + + def test_prompt_config_has_patient(self): + """prompt_config includes patient key.""" + self.assertIn("patient", self.model.prompt_config) + self.assertTrue(self.model.prompt_config["patient"]) + + def test_forward_with_description_kwarg(self): + """Forward pass works when description is in kwargs.""" + signal = torch.randn(2, 128) + label = torch.randint(0, 4, (2, 128)) + out = self.model( + signal=signal, + label=label, + description=["age: 51, sex: F", "age: 64, sex: M"], + ) + self.assertIn("logit", out) + self.assertTrue(out["loss"].isfinite()) + + def test_forward_without_description(self): + """Forward pass works even without description kwarg.""" + signal = torch.randn(2, 128) + label = torch.randint(0, 4, (2, 128)) + out = self.model(signal=signal, label=label) + self.assertIn("logit", out) + self.assertTrue(out["loss"].isfinite()) + + +# ------------------------------------------------------------------ # +# Task branching in forward (Phase 3) +# ------------------------------------------------------------------ # + + +def _make_binary_dataset( + n_samples: int = 4, + seq_len: int = 128, + n_features: int = 1, +): + """Dataset with binary float labels (boundary detection style).""" + from pyhealth.datasets import create_sample_dataset + + samples = [] + for i in range(n_samples): + if n_features == 1: + signal = np.random.randn(seq_len).astype(np.float32) + else: + signal = np.random.randn(seq_len, n_features).astype( + np.float32 + ) + label = np.random.randint(0, 2, size=seq_len).astype(np.float32) + samples.append({ + "patient_id": f"p{i}", + "visit_id": "v0", + "signal": signal, + "label": label, + }) + + return create_sample_dataset( + samples=samples, + input_schema={"signal": "tensor"}, + output_schema={"label": "tensor"}, + dataset_name="test_medtsllm_binary", + ) + + +def _make_unlabeled_dataset( + n_samples: int = 4, + seq_len: int = 128, + n_features: int = 2, +): + """Multivariate dataset without labels (reconstruction style).""" + from pyhealth.datasets import create_sample_dataset + + samples = [] + for i in range(n_samples): + signal = np.random.randn(seq_len, n_features).astype(np.float32) + samples.append({ + "patient_id": f"p{i}", + "visit_id": "v0", + "signal": signal, + }) + + return create_sample_dataset( + samples=samples, + input_schema={"signal": "tensor"}, + output_schema={}, + dataset_name="test_medtsllm_unlabeled", + ) + + +class TestTaskOutputShapes(unittest.TestCase): + """Output head shape and n_outputs_per_step adapt to task.""" + + def test_semseg_n_outputs_per_step(self): + """semantic_segmentation => n_outputs_per_step == n_classes.""" + dataset = _make_dataset(n_classes=4) + model = _make_model( + dataset, task="semantic_segmentation", n_classes=4 + ) + self.assertEqual(model.n_outputs_per_step, 4) + + def test_segmentation_n_outputs_per_step(self): + """segmentation => n_outputs_per_step == 1 (binary logit).""" + dataset = _make_binary_dataset() + model = _make_model(dataset, task="segmentation") + self.assertEqual(model.n_outputs_per_step, 1) + + def test_anomaly_n_outputs_per_step(self): + """anomaly_detection => n_outputs_per_step == n_features.""" + dataset = _make_unlabeled_dataset(n_features=2) + model = _make_model( + dataset, + task="anomaly_detection", + n_features=2, + covariate_mode="concat", + ) + self.assertEqual(model.n_outputs_per_step, 2) + + def test_reconstruction_n_outputs_per_step(self): + """reconstruction => n_outputs_per_step == n_features.""" + dataset = _make_unlabeled_dataset(n_features=3) + model = _make_model( + dataset, + task="reconstruction", + n_features=3, + covariate_mode="concat", + ) + self.assertEqual(model.n_outputs_per_step, 3) + + def test_semseg_logit_shape(self): + """semantic_segmentation logit is (bs, pred_len, n_classes).""" + dataset = _make_dataset(n_classes=4) + model = _make_model( + dataset, task="semantic_segmentation", n_classes=4 + ) + batch = _make_batch(dataset) + out = model(**batch) + bs = batch["label"].shape[0] + self.assertEqual(out["logit"].shape, (bs, 128, 4)) + + def test_segmentation_logit_shape(self): + """segmentation logit is (bs, pred_len) after squeeze.""" + dataset = _make_binary_dataset() + model = _make_model(dataset, task="segmentation") + batch = _make_batch(dataset) + out = model(**batch) + bs = batch["label"].shape[0] + self.assertEqual(out["logit"].shape, (bs, 128)) + + def test_anomaly_reconstruction_shape(self): + """anomaly_detection prediction is (bs, pred_len, n_features).""" + dataset = _make_unlabeled_dataset(n_features=2) + model = _make_model( + dataset, + task="anomaly_detection", + n_features=2, + covariate_mode="concat", + ) + batch = _make_batch(dataset) + out = model(**batch) + bs = batch["signal"].shape[0] + self.assertEqual(out["logit"].shape, (bs, 128, 2)) + + +class TestTaskLosses(unittest.TestCase): + """Task-specific loss computation.""" + + def test_semseg_uses_cross_entropy(self): + """semseg loss is finite and bounded for random inputs.""" + dataset = _make_dataset(n_classes=4) + model = _make_model( + dataset, task="semantic_segmentation", n_classes=4 + ) + batch = _make_batch(dataset) + out = model(**batch) + self.assertIn("loss", out) + self.assertTrue(out["loss"].isfinite()) + + def test_segmentation_uses_bce(self): + """segmentation loss is finite under BCE-with-logits.""" + dataset = _make_binary_dataset() + model = _make_model(dataset, task="segmentation") + batch = _make_batch(dataset) + out = model(**batch) + self.assertIn("loss", out) + self.assertTrue(out["loss"].isfinite()) + + def test_anomaly_uses_mse_no_label(self): + """anomaly_detection computes MSE against signal, no label needed.""" + dataset = _make_unlabeled_dataset(n_features=2) + model = _make_model( + dataset, + task="anomaly_detection", + n_features=2, + covariate_mode="concat", + ) + batch = _make_batch(dataset) + out = model(**batch) + self.assertIn("loss", out) + self.assertTrue(out["loss"].isfinite()) + + def test_reconstruction_uses_mse(self): + """reconstruction computes MSE against signal.""" + dataset = _make_unlabeled_dataset(n_features=2) + model = _make_model( + dataset, + task="reconstruction", + n_features=2, + covariate_mode="concat", + ) + batch = _make_batch(dataset) + out = model(**batch) + self.assertIn("loss", out) + self.assertTrue(out["loss"].isfinite()) + + +class TestTaskGradients(unittest.TestCase): + """Backward pass works across all task types.""" + + def _backward_flows(self, model, batch): + out = model(**batch) + loss = out["loss"] + loss.backward() + has_grad = any( + p.grad is not None and p.grad.abs().sum() > 0 + for p in model.parameters() + if p.requires_grad + ) + return has_grad + + def test_semseg_backward(self): + dataset = _make_dataset(n_classes=4) + model = _make_model( + dataset, task="semantic_segmentation", n_classes=4 + ) + self.assertTrue( + self._backward_flows(model, _make_batch(dataset)) + ) + + def test_segmentation_backward(self): + dataset = _make_binary_dataset() + model = _make_model(dataset, task="segmentation") + self.assertTrue( + self._backward_flows(model, _make_batch(dataset)) + ) + + def test_anomaly_backward(self): + dataset = _make_unlabeled_dataset(n_features=2) + model = _make_model( + dataset, + task="anomaly_detection", + n_features=2, + covariate_mode="concat", + ) + self.assertTrue( + self._backward_flows(model, _make_batch(dataset)) + ) + + +# ------------------------------------------------------------------ # +# Synthetic example end-to-end test +# ------------------------------------------------------------------ # + + +class TestSyntheticEndToEnd(unittest.TestCase): + """End-to-end test with synthetic data through PyHealth pipeline.""" + + def test_full_pipeline(self): + """Create dataset -> split -> model -> train step -> eval.""" + from pyhealth.datasets import ( + create_sample_dataset, + get_dataloader, + split_by_patient, + ) + from pyhealth.models.medtsllm import MedTsLLM + + # Create synthetic samples + samples = [] + for i in range(6): + samples.append({ + "patient_id": f"p{i}", + "visit_id": "v0", + "signal": np.random.randn(128).astype(np.float32), + "label": np.random.randint(0, 4, 128).astype(np.int64), + }) + + dataset = create_sample_dataset( + samples=samples, + input_schema={"signal": "tensor"}, + output_schema={"label": "tensor"}, + dataset_name="e2e_test", + ) + + # Split + train_ds, _, test_ds = split_by_patient( + dataset, ratios=[0.7, 0.0, 0.3] + ) + + # Model + model = MedTsLLM( + dataset=dataset, + seq_len=128, + n_classes=4, + backbone=None, + word_embeddings=torch.randn(50, 32), + d_model=8, + d_ff=16, + n_heads=2, + num_tokens=10, + ) + + # Train one step + loader = get_dataloader(train_ds, batch_size=2, shuffle=True) + batch = next(iter(loader)) + out = model(**batch) + self.assertIn("loss", out) + self.assertTrue(out["loss"].isfinite()) + out["loss"].backward() + + # Eval + model.eval() + test_loader = get_dataloader(test_ds, batch_size=2, shuffle=False) + with torch.no_grad(): + for batch in test_loader: + out = model(**batch) + self.assertIn("y_prob", out) + break + + +# ------------------------------------------------------------------ # +# Error handling tests +# ------------------------------------------------------------------ # + + +class TestMedTsLLMInvalidInputs(unittest.TestCase): + """Tests for error handling.""" + + def test_no_backbone_or_embeddings(self): + """Raises ValueError when neither backbone nor embeddings given.""" + from pyhealth.models.medtsllm import MedTsLLM + + dataset = _make_dataset() + with self.assertRaises(ValueError): + MedTsLLM( + dataset=dataset, + backbone=None, + word_embeddings=None, + ) + + +# ------------------------------------------------------------------ # +# Preprocessing cache (_medtsllm_cache) — shared by LUDB / MIT-BIH / +# BIDMC loaders. Lives here rather than a separate file because it +# only exists to serve the MedTsLLM port. +# ------------------------------------------------------------------ # + + +class TestComputeFingerprint(unittest.TestCase): + """Fingerprint is a stable hash of (raw file stats, params).""" + + def test_same_inputs_same_fingerprint(self): + from pyhealth.datasets._medtsllm_cache import compute_fingerprint + + with tempfile.TemporaryDirectory() as tmp: + raw = os.path.join(tmp, "a.dat") + open(raw, "w").write("hi") + fp1 = compute_fingerprint([raw], {"trim": True}) + fp2 = compute_fingerprint([raw], {"trim": True}) + self.assertEqual(fp1, fp2) + + def test_different_params_different_fingerprint(self): + from pyhealth.datasets._medtsllm_cache import compute_fingerprint + + with tempfile.TemporaryDirectory() as tmp: + raw = os.path.join(tmp, "a.dat") + open(raw, "w").write("hi") + fp1 = compute_fingerprint([raw], {"trim": True}) + fp2 = compute_fingerprint([raw], {"trim": False}) + self.assertNotEqual(fp1, fp2) + + def test_changed_file_changes_fingerprint(self): + from pyhealth.datasets._medtsllm_cache import compute_fingerprint + + with tempfile.TemporaryDirectory() as tmp: + raw = os.path.join(tmp, "a.dat") + open(raw, "w").write("hi") + fp1 = compute_fingerprint([raw], {}) + open(raw, "w").write("hello world now longer") + fp2 = compute_fingerprint([raw], {}) + self.assertNotEqual(fp1, fp2) + + def test_returns_string(self): + from pyhealth.datasets._medtsllm_cache import compute_fingerprint + + with tempfile.TemporaryDirectory() as tmp: + raw = os.path.join(tmp, "a.dat") + open(raw, "w").write("hi") + fp = compute_fingerprint([raw], {}) + self.assertIsInstance(fp, str) + self.assertGreater(len(fp), 16) + + +class TestLoadOrBuild(unittest.TestCase): + """load_or_build skips the builder when the cache is warm.""" + + def test_first_call_invokes_builder(self): + from pyhealth.datasets._medtsllm_cache import load_or_build + + with tempfile.TemporaryDirectory() as tmp: + cache = os.path.join(tmp, "c.npz") + calls = {"n": 0} + + def builder(): + calls["n"] += 1 + return {"x": np.arange(5, dtype=np.int64)} + + result = load_or_build(cache, "fp1", builder) + self.assertEqual(calls["n"], 1) + np.testing.assert_array_equal(result["x"], np.arange(5)) + self.assertTrue(os.path.exists(cache)) + + def test_second_call_skips_builder(self): + from pyhealth.datasets._medtsllm_cache import load_or_build + + with tempfile.TemporaryDirectory() as tmp: + cache = os.path.join(tmp, "c.npz") + calls = {"n": 0} + + def builder(): + calls["n"] += 1 + return {"x": np.arange(5, dtype=np.int64)} + + load_or_build(cache, "fp1", builder) + load_or_build(cache, "fp1", builder) + self.assertEqual(calls["n"], 1) + + def test_fingerprint_mismatch_rebuilds(self): + from pyhealth.datasets._medtsllm_cache import load_or_build + + with tempfile.TemporaryDirectory() as tmp: + cache = os.path.join(tmp, "c.npz") + calls = {"n": 0} + + def builder(): + calls["n"] += 1 + return {"x": np.full(3, calls["n"], dtype=np.int64)} + + load_or_build(cache, "fp-old", builder) + second = load_or_build(cache, "fp-new", builder) + self.assertEqual(calls["n"], 2) + np.testing.assert_array_equal(second["x"], np.full(3, 2)) + + def test_creates_parent_dirs(self): + from pyhealth.datasets._medtsllm_cache import load_or_build + + with tempfile.TemporaryDirectory() as tmp: + cache = os.path.join(tmp, "nested", "dir", "c.npz") + + def builder(): + return {"x": np.zeros(1, dtype=np.int64)} + + load_or_build(cache, "fp", builder) + self.assertTrue(os.path.exists(cache)) + + def test_preserves_string_arrays(self): + """Cache round-trips unicode arrays (wfdb annotation symbols).""" + from pyhealth.datasets._medtsllm_cache import load_or_build + + with tempfile.TemporaryDirectory() as tmp: + cache = os.path.join(tmp, "c.npz") + + def builder(): + return { + "signal": np.zeros((4, 2), dtype=np.float32), + "symbols": np.array(["N", "V", "L", "A"]), + } + + load_or_build(cache, "fp", builder) + + # Second call hits cache + result = load_or_build(cache, "fp", lambda: None) # type: ignore[arg-type] + self.assertEqual(list(result["symbols"]), ["N", "V", "L", "A"]) + + def test_corrupt_cache_rebuilds(self): + from pyhealth.datasets._medtsllm_cache import load_or_build + + with tempfile.TemporaryDirectory() as tmp: + cache = os.path.join(tmp, "c.npz") + with open(cache, "w") as f: + f.write("not a real npz") + + calls = {"n": 0} + + def builder(): + calls["n"] += 1 + return {"x": np.arange(2, dtype=np.int64)} + + result = load_or_build(cache, "fp", builder) + self.assertEqual(calls["n"], 1) + np.testing.assert_array_equal(result["x"], np.arange(2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_mitbih.py b/tests/core/test_mitbih.py new file mode 100644 index 000000000..4413ec893 --- /dev/null +++ b/tests/core/test_mitbih.py @@ -0,0 +1,668 @@ +"""Unit tests for MIT-BIH dataset and ECGBoundaryDetection task. + +Author: Anton Barchukov + +Tests cover: + - MITBIHDataset instantiation from real wfdb files + - Metadata preparation (paced record exclusion) + - Patient header parsing (age, sex, medications) + - ECGBoundaryDetection task (windowing, downsampling, labels) + - Full pipeline: dataset -> task -> samples + - MedTsLLM model with real MIT-BIH data structure +""" + +import os +import unittest + +import numpy as np +import torch + +TEST_DATA_DIR = os.path.join( + os.path.dirname(__file__), "..", "..", "test-resources", "core", "mitbih" +) +HAS_TEST_DATA = os.path.isdir(TEST_DATA_DIR) and any( + f.endswith(".dat") for f in os.listdir(TEST_DATA_DIR) +) + + +def setUpModule(): + """Drop the committed CSV + cache so tests use the current schema.""" + if HAS_TEST_DATA: + _force_regenerate_mitbih_metadata() + + +@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found") +class TestMITBIHDataset(unittest.TestCase): + """Tests for MITBIHDataset with real wfdb files.""" + + @classmethod + def setUpClass(cls): + from pyhealth.datasets.mitbih import MITBIHDataset + + cls.dataset = MITBIHDataset(root=TEST_DATA_DIR, dev=True) + + def test_dataset_loads(self): + """Dataset initializes without error.""" + self.assertIsNotNone(self.dataset) + + def test_metadata_csv_created(self): + """prepare_metadata creates mitbih-pyhealth.csv.""" + csv_path = os.path.join(TEST_DATA_DIR, "mitbih-pyhealth.csv") + self.assertTrue(os.path.exists(csv_path)) + + def test_has_patients(self): + """Dataset has at least 1 patient.""" + pids = self.dataset.unique_patient_ids + self.assertGreater(len(pids), 0) + + def test_paced_records_excluded(self): + """Paced records (102, 104, 107, 217) are not in dataset.""" + from pyhealth.datasets.mitbih import _PACED_RECORDS + + pids = set(self.dataset.unique_patient_ids) + for paced in _PACED_RECORDS: + self.assertNotIn(paced, pids) + + def test_patient_has_events(self): + """Each patient has ECG events.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + events = patient.get_events() + self.assertGreater(len(events), 0) + + def test_event_has_demographics(self): + """Events contain age, sex, medications.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + event = patient.get_events()[0] + self.assertTrue(hasattr(event, "age")) + self.assertTrue(hasattr(event, "sex")) + self.assertTrue(hasattr(event, "medications")) + + def test_event_has_abnormal_count(self): + """Events track number of abnormal beats.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + event = patient.get_events()[0] + self.assertTrue(hasattr(event, "n_abnormal")) + + +@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found") +class TestMITBIHHeaderParsing(unittest.TestCase): + """Tests for MIT-BIH patient metadata parsing.""" + + def test_parse_header(self): + """Header contains age and sex info.""" + import wfdb + + rec = wfdb.rdrecord(os.path.join(TEST_DATA_DIR, "100")) + self.assertIsNotNone(rec.comments) + self.assertGreater(len(rec.comments), 0) + # First comment line should have age and sex + first = rec.comments[0].strip() + tokens = first.split() + self.assertGreaterEqual(len(tokens), 2) + + +@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found") +class TestECGBoundaryDetectionTask(unittest.TestCase): + """Tests for ECGBoundaryDetection with real MIT-BIH data.""" + + @classmethod + def setUpClass(cls): + from pyhealth.datasets.mitbih import MITBIHDataset + from pyhealth.tasks.ecg_boundary_detection import ( + ECGBoundaryDetection, + ) + + cls.dataset = MITBIHDataset(root=TEST_DATA_DIR, dev=True) + cls.task = ECGBoundaryDetection(window_size=256, step_size=256) + + def test_task_attributes(self): + """Task has correct name and schema.""" + self.assertEqual(self.task.task_name, "ECGBoundaryDetection") + self.assertEqual(self.task.input_schema, {"signal": "tensor"}) + self.assertEqual(self.task.output_schema, {"label": "tensor"}) + + def test_task_produces_samples(self): + """Task generates samples from a patient.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + samples = self.task(patient) + self.assertGreater(len(samples), 0) + + def test_signal_is_2_channel(self): + """Signal has 2 channels (MLII, V1 or similar).""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + sample = self.task(patient)[0] + self.assertEqual(sample["signal"].shape, (256, 2)) + + def test_label_is_binary(self): + """Labels are 0 or 1 (R-peak or not).""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + sample = self.task(patient)[0] + unique = set(sample["label"].tolist()) + self.assertTrue(unique.issubset({0.0, 1.0})) + + def test_sample_has_description(self): + """Samples include patient description with medications.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + sample = self.task(patient)[0] + self.assertIn("description", sample) + self.assertIn("age:", sample["description"]) + + def test_set_task_produces_sample_dataset(self): + """dataset.set_task() returns a usable SampleDataset.""" + sample_ds = self.dataset.set_task(self.task) + self.assertGreater(len(sample_ds), 0) + + +# ------------------------------------------------------------------ # +# Phase 4: anomaly-detection task +# ------------------------------------------------------------------ # + + +@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found") +class TestECGAnomalyDetectionTask(unittest.TestCase): + """Tests for ECGAnomalyDetection with real MIT-BIH data.""" + + @classmethod + def setUpClass(cls): + from pyhealth.datasets.mitbih import MITBIHDataset + from pyhealth.tasks.ecg_anomaly_detection import ( + ECGAnomalyDetection, + ) + + cls.dataset = MITBIHDataset(root=TEST_DATA_DIR, dev=True) + cls.task = ECGAnomalyDetection(window_size=128, step_size=128) + + def test_task_attributes(self): + """Task has correct name and schema.""" + self.assertEqual(self.task.task_name, "ECGAnomalyDetection") + self.assertEqual(self.task.input_schema, {"signal": "tensor"}) + self.assertEqual(self.task.output_schema, {"label": "tensor"}) + + def test_task_produces_samples(self): + """Task generates samples from a patient.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + samples = self.task(patient) + self.assertGreater(len(samples), 0) + + def test_signal_is_2_channel(self): + """Signal has 2 channels (MLII, V1 or similar).""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + sample = self.task(patient)[0] + self.assertEqual(sample["signal"].shape, (128, 2)) + + def test_label_is_binary_mask(self): + """Labels are a per-timestep 0/1 anomaly mask.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + sample = self.task(patient)[0] + self.assertEqual(sample["label"].shape, (128,)) + unique = set(sample["label"].tolist()) + self.assertTrue(unique.issubset({0.0, 1.0})) + + def test_sample_has_description(self): + """Samples include per-patient description.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + sample = self.task(patient)[0] + self.assertIn("description", sample) + + def test_set_task_produces_sample_dataset(self): + """dataset.set_task() works with anomaly detection task.""" + sample_ds = self.dataset.set_task(self.task) + self.assertGreater(len(sample_ds), 0) + + +@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found") +class TestAnomalyTaskExportedFromPackage(unittest.TestCase): + """ECGAnomalyDetection is importable from pyhealth.tasks.""" + + def test_import_from_tasks(self): + from pyhealth.tasks import ECGAnomalyDetection # noqa: F401 + + +@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found") +class TestMedTsLLMAnomalyOnMITBIH(unittest.TestCase): + """End-to-end anomaly-detection training step on MIT-BIH.""" + + @classmethod + def setUpClass(cls): + from pyhealth.datasets import get_dataloader + from pyhealth.datasets.mitbih import MITBIHDataset + from pyhealth.models.medtsllm import MedTsLLM + from pyhealth.tasks.ecg_anomaly_detection import ( + ECGAnomalyDetection, + ) + + dataset = MITBIHDataset(root=TEST_DATA_DIR, dev=True) + task = ECGAnomalyDetection(window_size=128, step_size=128) + cls.sample_ds = dataset.set_task(task) + + cls.model = MedTsLLM( + dataset=cls.sample_ds, + task="anomaly_detection", + seq_len=128, + n_features=2, + backbone=None, + word_embeddings=torch.randn(50, 32), + d_model=16, + d_ff=32, + n_heads=4, + num_tokens=50, + covariate_mode="concat", + ) + loader = get_dataloader(cls.sample_ds, batch_size=2, shuffle=False) + cls.batch = next(iter(loader)) + + def test_forward(self): + """Anomaly-detection forward returns finite MSE loss.""" + out = self.model(**self.batch) + self.assertIn("logit", out) + self.assertIn("loss", out) + self.assertTrue(out["loss"].isfinite()) + + def test_prediction_shape(self): + """Reconstruction shape matches (bs, seq_len, n_features).""" + out = self.model(**self.batch) + bs = self.batch["signal"].shape[0] + self.assertEqual(out["logit"].shape, (bs, 128, 2)) + + def test_backward(self): + """Backward pass works for anomaly detection on MIT-BIH.""" + out = self.model(**self.batch) + out["loss"].backward() + + +@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found") +class TestMedTsLLMWithMITBIH(unittest.TestCase): + """Test MedTsLLM model with real MIT-BIH data structure.""" + + @classmethod + def setUpClass(cls): + from pyhealth.datasets.mitbih import MITBIHDataset + from pyhealth.tasks.ecg_boundary_detection import ( + ECGBoundaryDetection, + ) + from pyhealth.datasets import get_dataloader + from pyhealth.models.medtsllm import MedTsLLM + + dataset = MITBIHDataset(root=TEST_DATA_DIR, dev=True) + task = ECGBoundaryDetection(window_size=128, step_size=64) + cls.sample_ds = dataset.set_task(task) + + cls.model = MedTsLLM( + dataset=cls.sample_ds, + seq_len=128, + n_features=2, + n_classes=2, + backbone=None, + word_embeddings=torch.randn(50, 32), + d_model=16, + d_ff=32, + n_heads=4, + num_tokens=50, + covariate_mode="concat", + ) + loader = get_dataloader(cls.sample_ds, batch_size=2, shuffle=False) + cls.batch = next(iter(loader)) + + def test_forward(self): + """Model forward pass works with real MIT-BIH samples.""" + out = self.model(**self.batch) + self.assertIn("logit", out) + self.assertIn("loss", out) + self.assertTrue(out["loss"].isfinite()) + + def test_backward(self): + """Backward pass works with MIT-BIH data.""" + out = self.model(**self.batch) + out["loss"].backward() + + +# ------------------------------------------------------------------ # +# Phase 6: preprocess=True -> .npz cache for MIT-BIH +# ------------------------------------------------------------------ # + + +@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found") +class TestMITBIHPreprocessCache(unittest.TestCase): + """preprocess=True writes per-record .npz with signal + annotations.""" + + @classmethod + def setUpClass(cls): + import shutil + + from pyhealth.datasets.mitbih import MITBIHDataset + + processed_dir = os.path.join(TEST_DATA_DIR, "processed") + if os.path.isdir(processed_dir): + shutil.rmtree(processed_dir) + _force_regenerate_mitbih_metadata() + + cls.dataset = MITBIHDataset( + root=TEST_DATA_DIR, + dev=True, + preprocess=True, + downsample_factor=3, + trim=True, + ) + + def test_processed_dir_created(self): + self.assertTrue( + os.path.isdir(os.path.join(TEST_DATA_DIR, "processed")) + ) + + def test_event_has_processed_file_attr(self): + pid = self.dataset.unique_patient_ids[0] + event = self.dataset.get_patient(pid).get_events()[0] + self.assertTrue(hasattr(event, "processed_file")) + self.assertTrue(event.processed_file.endswith(".npz")) + self.assertTrue(os.path.exists(event.processed_file)) + + def test_cached_arrays_have_expected_keys(self): + pid = self.dataset.unique_patient_ids[0] + event = self.dataset.get_patient(pid).get_events()[0] + with np.load(event.processed_file, allow_pickle=False) as npz: + self.assertIn("signal", npz.files) + self.assertIn("ann_sample", npz.files) + self.assertIn("ann_symbol", npz.files) + + def test_trim_applied_to_cache(self): + """After trim, first and last ann_sample are within [0, len-1].""" + pid = self.dataset.unique_patient_ids[0] + event = self.dataset.get_patient(pid).get_events()[0] + with np.load(event.processed_file, allow_pickle=False) as npz: + signal_len = int(npz["signal"].shape[0]) + ann_sample = np.asarray(npz["ann_sample"]) + self.assertGreaterEqual(int(ann_sample[0]), 0) + self.assertLess(int(ann_sample[-1]), signal_len) + # Trim: first annotation should be at or near sample 0. + self.assertLessEqual(int(ann_sample[0]), 5) + + def test_boundary_task_uses_cache(self): + from pyhealth.tasks.ecg_boundary_detection import ( + ECGBoundaryDetection, + ) + + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + samples = ECGBoundaryDetection( + window_size=128, step_size=128 + )(patient) + self.assertGreater(len(samples), 0) + self.assertEqual(samples[0]["signal"].shape, (128, 2)) + + def test_anomaly_task_uses_cache(self): + from pyhealth.tasks.ecg_anomaly_detection import ( + ECGAnomalyDetection, + ) + + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + samples = ECGAnomalyDetection( + window_size=128, step_size=128 + )(patient) + self.assertGreater(len(samples), 0) + self.assertEqual(samples[0]["signal"].shape, (128, 2)) + + +@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found") +class TestMITBIHPreprocessTrimOff(unittest.TestCase): + """trim=False caches the full downsampled signal.""" + + def test_untrimmed_cache_is_longer(self): + import shutil + + from pyhealth.datasets.mitbih import MITBIHDataset + + processed_dir = os.path.join(TEST_DATA_DIR, "processed") + + if os.path.isdir(processed_dir): + shutil.rmtree(processed_dir) + _force_regenerate_mitbih_metadata() + ds_trim = MITBIHDataset( + root=TEST_DATA_DIR, + dev=True, + preprocess=True, + trim=True, + ) + ev_trim = ds_trim.get_patient( + ds_trim.unique_patient_ids[0] + ).get_events()[0] + with np.load(ev_trim.processed_file, allow_pickle=False) as npz: + trimmed_len = int(npz["signal"].shape[0]) + + if os.path.isdir(processed_dir): + shutil.rmtree(processed_dir) + _force_regenerate_mitbih_metadata() + ds_full = MITBIHDataset( + root=TEST_DATA_DIR, + dev=True, + preprocess=True, + trim=False, + ) + ev_full = ds_full.get_patient( + ds_full.unique_patient_ids[0] + ).get_events()[0] + with np.load(ev_full.processed_file, allow_pickle=False) as npz: + full_len = int(npz["signal"].shape[0]) + + self.assertGreaterEqual(full_len, trimmed_len) + + +@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found") +class TestMITBIHPreprocessDisabled(unittest.TestCase): + """preprocess=False leaves processed_file blank.""" + + def test_processed_file_empty_when_disabled(self): + import shutil + + from pyhealth.datasets.mitbih import MITBIHDataset + + processed_dir = os.path.join(TEST_DATA_DIR, "processed") + if os.path.isdir(processed_dir): + shutil.rmtree(processed_dir) + _force_regenerate_mitbih_metadata() + + ds = MITBIHDataset(root=TEST_DATA_DIR, dev=True, preprocess=False) + pid = ds.unique_patient_ids[0] + event = ds.get_patient(pid).get_events()[0] + value = getattr(event, "processed_file", "") or "" + self.assertEqual(str(value).strip(), "") + + +@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found") +class TestMITBIHInvalidDownsample(unittest.TestCase): + """downsample_factor < 1 raises ValueError.""" + + def test_invalid_downsample_raises(self): + from pyhealth.datasets.mitbih import MITBIHDataset + + _force_regenerate_mitbih_metadata() + with self.assertRaises(ValueError): + MITBIHDataset(root=TEST_DATA_DIR, dev=True, downsample_factor=0) + + +# ------------------------------------------------------------------ # +# Phase 5: paper-match split column (MIT-BIH 80/20 seed=0) +# ------------------------------------------------------------------ # + + +def _reset_mitbih_csv(): + """Delete the committed metadata CSV so the next ``MITBIHDataset`` + rebuilds it with the current ``prepare_metadata`` implementation. + + Leaves the PyHealth cache warm — safe when the next dataset uses + the same fingerprint as a prior class. For config changes (e.g. + paper_split), use ``_force_regenerate_mitbih_metadata`` instead. + """ + csv_path = os.path.join(TEST_DATA_DIR, "mitbih-pyhealth.csv") + if os.path.exists(csv_path): + os.remove(csv_path) + + +def _force_regenerate_mitbih_metadata(): + """Nuclear reset: CSV + entire PyHealth cache. Use in setUpModule + once per file, or when a test wipes ``processed_dir`` / changes + config and needs the event-dataframe cache to be rebuilt. + """ + import shutil + + import platformdirs + + _reset_mitbih_csv() + cache_root = platformdirs.user_cache_dir(appname="pyhealth") + if os.path.isdir(cache_root): + shutil.rmtree(cache_root) + + +@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found") +class TestMITBIHPaperSplitRandom(unittest.TestCase): + """paper_split='random' writes an 80/20 seed=0 split column.""" + + @classmethod + def setUpClass(cls): + from pyhealth.datasets.mitbih import MITBIHDataset + + _force_regenerate_mitbih_metadata() + cls.dataset = MITBIHDataset( + root=TEST_DATA_DIR, dev=True, paper_split="random" + ) + + def test_constructor_accepts_paper_split(self): + """MITBIHDataset accepts a paper_split kwarg.""" + import inspect + + from pyhealth.datasets.mitbih import MITBIHDataset + + sig = inspect.signature(MITBIHDataset.__init__) + self.assertIn("paper_split", sig.parameters) + + def test_csv_has_split_column(self): + """Regenerated CSV contains a ``split`` column.""" + import pandas as pd + + csv_path = os.path.join(TEST_DATA_DIR, "mitbih-pyhealth.csv") + df = pd.read_csv(csv_path) + self.assertIn("split", df.columns) + + def test_split_values_are_train_or_test(self): + """Every split value is either 'train' or 'test'.""" + import pandas as pd + + csv_path = os.path.join(TEST_DATA_DIR, "mitbih-pyhealth.csv") + df = pd.read_csv(csv_path) + unique = set(df["split"].unique()) + self.assertTrue(unique.issubset({"train", "test"})) + + def test_event_has_split_attr(self): + """Events expose a ``split`` attribute.""" + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + event = patient.get_events()[0] + self.assertTrue(hasattr(event, "split")) + self.assertIn(event.split, ("train", "test")) + + def test_boundary_task_emits_split_field(self): + """ECGBoundaryDetection propagates event.split to each sample.""" + from pyhealth.tasks.ecg_boundary_detection import ( + ECGBoundaryDetection, + ) + + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + task = ECGBoundaryDetection(window_size=256, step_size=256) + sample = task(patient)[0] + self.assertIn("split", sample) + self.assertIn(sample["split"], ("train", "test")) + + def test_anomaly_task_emits_split_field(self): + """ECGAnomalyDetection propagates event.split to each sample.""" + from pyhealth.tasks.ecg_anomaly_detection import ( + ECGAnomalyDetection, + ) + + pid = self.dataset.unique_patient_ids[0] + patient = self.dataset.get_patient(pid) + task = ECGAnomalyDetection(window_size=128, step_size=128) + sample = task(patient)[0] + self.assertIn("split", sample) + self.assertIn(sample["split"], ("train", "test")) + + +@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found") +class TestMITBIHPaperSplitAbnormalSorted(unittest.TestCase): + """paper_split='abnormal_sorted' orders patients by n_abnormal asc.""" + + @classmethod + def setUpClass(cls): + from pyhealth.datasets.mitbih import MITBIHDataset + + _force_regenerate_mitbih_metadata() + cls.dataset = MITBIHDataset( + root=TEST_DATA_DIR, dev=True, paper_split="abnormal_sorted" + ) + + def test_csv_has_split_column(self): + """Regenerated CSV contains a populated ``split`` column.""" + import pandas as pd + + csv_path = os.path.join(TEST_DATA_DIR, "mitbih-pyhealth.csv") + df = pd.read_csv(csv_path) + self.assertIn("split", df.columns) + unique = set(df["split"].unique()) + self.assertTrue(unique.issubset({"train", "test"})) + + def test_train_has_lower_or_equal_n_abnormal(self): + """All train patients have n_abnormal <= all test patients.""" + import pandas as pd + + csv_path = os.path.join(TEST_DATA_DIR, "mitbih-pyhealth.csv") + df = pd.read_csv(csv_path) + train = df[df["split"] == "train"]["n_abnormal"] + test = df[df["split"] == "test"]["n_abnormal"] + if len(train) > 0 and len(test) > 0: + self.assertLessEqual(train.max(), test.min()) + + +@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found") +class TestMITBIHPaperSplitDisabled(unittest.TestCase): + """paper_split=None leaves the split column blank.""" + + def test_split_column_blank_when_disabled(self): + """Without paper_split, split column is empty.""" + import pandas as pd + from pyhealth.datasets.mitbih import MITBIHDataset + + _force_regenerate_mitbih_metadata() + MITBIHDataset(root=TEST_DATA_DIR, dev=True, paper_split=None) + df = pd.read_csv(os.path.join(TEST_DATA_DIR, "mitbih-pyhealth.csv")) + self.assertIn("split", df.columns) + cleaned = df["split"].fillna("").astype(str).str.strip() + self.assertTrue((cleaned == "").all()) + + +@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found") +class TestMITBIHPaperSplitInvalid(unittest.TestCase): + """Unknown paper_split mode raises ValueError.""" + + def test_invalid_mode_raises(self): + from pyhealth.datasets.mitbih import MITBIHDataset + + _force_regenerate_mitbih_metadata() + with self.assertRaises(ValueError): + MITBIHDataset( + root=TEST_DATA_DIR, dev=True, paper_split="nonsense" + ) + + +if __name__ == "__main__": + unittest.main()