diff --git a/.gitignore b/.gitignore
index 9993737db..a73d698a3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -139,4 +139,13 @@ data/physionet.org/
.vscode/
# Model weight files (large binaries, distributed separately)
-weightfiles/
\ No newline at end of file
+weightfiles/
+
+# Auto-generated dataset metadata CSVs (contain absolute paths from the
+# machine that ran prepare_metadata). These are rebuilt on first use by
+# the dataset loaders, so checking them in leaks user-specific paths.
+test-resources/**/*-pyhealth.csv
+
+# Per-record preprocessing caches built by dataset loaders with
+# preprocess=True (e.g. LUDB/MIT-BIH/BIDMC .npz files).
+test-resources/**/processed/
\ No newline at end of file
diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index 8d9a59d21..a3efd3bf9 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -227,7 +227,10 @@ Available Datasets
datasets/pyhealth.datasets.MedicalTranscriptionsDataset
datasets/pyhealth.datasets.CardiologyDataset
datasets/pyhealth.datasets.eICUDataset
+ datasets/pyhealth.datasets.BIDMCDataset
datasets/pyhealth.datasets.ISRUCDataset
+ datasets/pyhealth.datasets.LUDBDataset
+ datasets/pyhealth.datasets.MITBIHDataset
datasets/pyhealth.datasets.MIMICExtractDataset
datasets/pyhealth.datasets.OMOPDataset
datasets/pyhealth.datasets.DREAMTDataset
diff --git a/docs/api/datasets/pyhealth.datasets.BIDMCDataset.rst b/docs/api/datasets/pyhealth.datasets.BIDMCDataset.rst
new file mode 100644
index 000000000..7b97dc424
--- /dev/null
+++ b/docs/api/datasets/pyhealth.datasets.BIDMCDataset.rst
@@ -0,0 +1,11 @@
+pyhealth.datasets.BIDMCDataset
+===============================
+
+BIDMC respiratory signal dataset — 53 ICU patients with 8-minute recordings
+of ECG, PPG, and respiratory signals at 125 Hz. Refer to
+`PhysioNet `_ for more information.
+
+.. autoclass:: pyhealth.datasets.BIDMCDataset
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/datasets/pyhealth.datasets.LUDBDataset.rst b/docs/api/datasets/pyhealth.datasets.LUDBDataset.rst
new file mode 100644
index 000000000..afb9f4552
--- /dev/null
+++ b/docs/api/datasets/pyhealth.datasets.LUDBDataset.rst
@@ -0,0 +1,11 @@
+pyhealth.datasets.LUDBDataset
+==============================
+
+Lobachevsky University Database (LUDB) — 200 subjects with 12-lead ECG at 500 Hz,
+manually annotated with P wave, QRS complex, and T wave boundaries. Refer to
+`PhysioNet `_ for more information.
+
+.. autoclass:: pyhealth.datasets.LUDBDataset
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/datasets/pyhealth.datasets.MITBIHDataset.rst b/docs/api/datasets/pyhealth.datasets.MITBIHDataset.rst
new file mode 100644
index 000000000..e21ef039b
--- /dev/null
+++ b/docs/api/datasets/pyhealth.datasets.MITBIHDataset.rst
@@ -0,0 +1,11 @@
+pyhealth.datasets.MITBIHDataset
+================================
+
+MIT-BIH Arrhythmia Database — 48 half-hour excerpts of two-channel ambulatory
+ECG at 360 Hz. Refer to
+`PhysioNet `_ for more information.
+
+.. autoclass:: pyhealth.datasets.MITBIHDataset
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/models.rst b/docs/api/models.rst
index 7c3ac7c4b..b24e6f45f 100644
--- a/docs/api/models.rst
+++ b/docs/api/models.rst
@@ -196,6 +196,7 @@ API Reference
models/pyhealth.models.Agent
models/pyhealth.models.GRASP
models/pyhealth.models.MedLink
+ models/pyhealth.models.MedTsLLM
models/pyhealth.models.TCN
models/pyhealth.models.TFMTokenizer
models/pyhealth.models.GAN
diff --git a/docs/api/models/pyhealth.models.MedTsLLM.rst b/docs/api/models/pyhealth.models.MedTsLLM.rst
new file mode 100644
index 000000000..ea6a1eea3
--- /dev/null
+++ b/docs/api/models/pyhealth.models.MedTsLLM.rst
@@ -0,0 +1,11 @@
+pyhealth.models.MedTsLLM
+=========================
+
+MedTsLLM: Leveraging LLMs for Multimodal Medical Time Series Analysis
+(Chan et al., MLHC 2024). Refer to the
+`paper `_ for more information.
+
+.. autoclass:: pyhealth.models.MedTsLLM
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst
index 23a4e06e5..d20cefbbf 100644
--- a/docs/api/tasks.rst
+++ b/docs/api/tasks.rst
@@ -1,5 +1,4 @@
Tasks
-===============
We support various real-world healthcare predictive tasks defined by **function calls**. The following example tasks are collected from top AI/Medical venues, such as:
@@ -212,6 +211,10 @@ Available Tasks
COVID-19 CXR Classification
DKA Prediction (MIMIC-IV)
Drug Recommendation
+ ECG Anomaly Detection (MIT-BIH)
+ ECG Boundary Detection (MIT-BIH)
+ ECG Wave Segmentation (LUDB)
+ Respiratory Boundary Detection (BIDMC)
Length of Stay Prediction
Medical Transcriptions Classification
Mortality Prediction (Next Visit)
diff --git a/docs/api/tasks/pyhealth.tasks.ECGAnomalyDetection.rst b/docs/api/tasks/pyhealth.tasks.ECGAnomalyDetection.rst
new file mode 100644
index 000000000..d4ee1d97e
--- /dev/null
+++ b/docs/api/tasks/pyhealth.tasks.ECGAnomalyDetection.rst
@@ -0,0 +1,11 @@
+pyhealth.tasks.ECGAnomalyDetection
+====================================
+
+Reconstruction-based anomaly detection task for the MIT-BIH dataset.
+Trains the model to reconstruct normal 2-channel ECG; at eval time,
+elevated reconstruction error flags abnormal-rhythm beats.
+
+.. autoclass:: pyhealth.tasks.ECGAnomalyDetection
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/tasks/pyhealth.tasks.ECGBoundaryDetection.rst b/docs/api/tasks/pyhealth.tasks.ECGBoundaryDetection.rst
new file mode 100644
index 000000000..67bafcc8e
--- /dev/null
+++ b/docs/api/tasks/pyhealth.tasks.ECGBoundaryDetection.rst
@@ -0,0 +1,10 @@
+pyhealth.tasks.ECGBoundaryDetection
+=====================================
+
+R-peak boundary detection task for the MIT-BIH dataset. Detects beat
+boundaries in ECG signals.
+
+.. autoclass:: pyhealth.tasks.ECGBoundaryDetection
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/tasks/pyhealth.tasks.ECGWaveSegmentation.rst b/docs/api/tasks/pyhealth.tasks.ECGWaveSegmentation.rst
new file mode 100644
index 000000000..5b58a7e10
--- /dev/null
+++ b/docs/api/tasks/pyhealth.tasks.ECGWaveSegmentation.rst
@@ -0,0 +1,10 @@
+pyhealth.tasks.ECGWaveSegmentation
+====================================
+
+Per-timestep ECG wave segmentation task for the LUDB dataset. Classifies
+each time point as background, P wave, QRS complex, or T wave.
+
+.. autoclass:: pyhealth.tasks.ECGWaveSegmentation
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/tasks/pyhealth.tasks.RespiratoryBoundaryDetection.rst b/docs/api/tasks/pyhealth.tasks.RespiratoryBoundaryDetection.rst
new file mode 100644
index 000000000..49366d54d
--- /dev/null
+++ b/docs/api/tasks/pyhealth.tasks.RespiratoryBoundaryDetection.rst
@@ -0,0 +1,10 @@
+pyhealth.tasks.RespiratoryBoundaryDetection
+=============================================
+
+Breath boundary detection task for the BIDMC dataset. Detects breath
+boundaries in respiratory impedance signals.
+
+.. autoclass:: pyhealth.tasks.RespiratoryBoundaryDetection
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/bidmc_respiratory_boundary_medtsllm.py b/examples/bidmc_respiratory_boundary_medtsllm.py
new file mode 100644
index 000000000..330c1668d
--- /dev/null
+++ b/examples/bidmc_respiratory_boundary_medtsllm.py
@@ -0,0 +1,177 @@
+"""BIDMC Respiratory Boundary Detection with MedTsLLM.
+
+Demonstrates the MedTsLLM model (Chan et al., MLHC 2024) on the BIDMC
+dataset for per-timestep breath boundary detection using 3-channel
+respiratory signals (RESP, PLETH, II). Binary segmentation task
+trained with BCE-with-logits loss.
+
+Paper: https://arxiv.org/abs/2408.07773
+Dataset: https://physionet.org/content/bidmc/1.0.0/
+
+Usage:
+ # Synthetic data, no downloads:
+ python examples/bidmc_respiratory_boundary_medtsllm.py --synthetic
+
+ # Real BIDMC data with GPT-2:
+ python examples/bidmc_respiratory_boundary_medtsllm.py \\
+ --root /path/to/bidmc --backbone openai-community/gpt2
+
+Ablation Study:
+ The script exposes the paper's two main ablation axes as CLI
+ flags so each run is a single ablation cell.
+
+ 1. LLM backbone swap -- ``--backbone ``:
+ compare GPT-2 vs. GPT-2-medium vs. DistilGPT-2 etc.
+
+ 2. Prompt components -- each piece of the text prompt can be
+ disabled independently:
+ --no-prompt-dataset drops the dataset description
+ --no-prompt-task drops the task description
+ --no-prompt-patient drops the per-patient description
+ --prompt-stats adds the rolling signal stats
+
+ python examples/bidmc_respiratory_boundary_medtsllm.py \\
+ --root /path/to/bidmc --no-prompt-patient
+ python examples/bidmc_respiratory_boundary_medtsllm.py \\
+ --root /path/to/bidmc --no-prompt-dataset --prompt-stats
+"""
+
+import argparse
+
+import numpy as np
+import torch
+
+from pyhealth.datasets import (
+ create_sample_dataset,
+ get_dataloader,
+ split_by_patient,
+)
+from pyhealth.models import MedTsLLM
+from pyhealth.trainer import Trainer
+
+
+def make_synthetic_dataset(
+ n_patients: int = 10, seq_len: int = 256, n_channels: int = 3
+):
+ """Synthetic 3-channel respiratory data for smoke-testing."""
+ samples = []
+ for i in range(n_patients):
+ for w in range(5):
+ signal = np.random.randn(seq_len, n_channels).astype(np.float32)
+ label = np.random.randint(0, 2, size=seq_len).astype(np.float32)
+ samples.append({
+ "patient_id": f"p{i}",
+ "visit_id": f"v{w}",
+ "signal": signal,
+ "label": label,
+ "description": "",
+ })
+ return create_sample_dataset(
+ samples=samples,
+ input_schema={"signal": "tensor"},
+ output_schema={"label": "tensor"},
+ dataset_name="synthetic_bidmc",
+ )
+
+
+def make_real_dataset(root: str, seq_len: int = 256, step_size: int = 128):
+ """Load BIDMC and apply RespiratoryBoundaryDetection."""
+ from pyhealth.datasets import BIDMCDataset
+ from pyhealth.tasks import RespiratoryBoundaryDetection
+
+ dataset = BIDMCDataset(root=root)
+ task = RespiratoryBoundaryDetection(
+ window_size=seq_len, step_size=step_size
+ )
+ return dataset.set_task(task)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="BIDMC respiratory boundary detection with MedTsLLM"
+ )
+ parser.add_argument("--root", type=str, default=None)
+ parser.add_argument("--synthetic", action="store_true")
+ parser.add_argument("--backbone", type=str, default=None)
+ parser.add_argument("--epochs", type=int, default=10)
+ parser.add_argument("--batch-size", type=int, default=16)
+ parser.add_argument("--lr", type=float, default=1e-4)
+ parser.add_argument("--device", type=str, default="auto")
+ # Ablation knobs — each flag disables one prompt component
+ parser.add_argument("--no-prompt-dataset", action="store_true")
+ parser.add_argument("--no-prompt-task", action="store_true")
+ parser.add_argument("--no-prompt-patient", action="store_true")
+ parser.add_argument("--prompt-stats", action="store_true")
+ args = parser.parse_args()
+
+ device = (
+ ("cuda" if torch.cuda.is_available() else "cpu")
+ if args.device == "auto"
+ else args.device
+ )
+
+ seq_len = 256
+ n_features = 3
+
+ if args.synthetic or args.root is None:
+ print("Using synthetic data")
+ sample_dataset = make_synthetic_dataset(
+ n_patients=10, seq_len=seq_len, n_channels=n_features
+ )
+ word_embeddings = torch.randn(100, 64)
+ backbone = None
+ epochs = 3
+ else:
+ print(f"Loading BIDMC from {args.root}")
+ sample_dataset = make_real_dataset(args.root, seq_len=seq_len)
+ word_embeddings = None
+ backbone = args.backbone or "openai-community/gpt2"
+ epochs = args.epochs
+
+ train_ds, _, test_ds = split_by_patient(
+ sample_dataset, ratios=[0.8, 0.0, 0.2]
+ )
+ train_loader = get_dataloader(
+ train_ds, batch_size=args.batch_size, shuffle=True
+ )
+ test_loader = get_dataloader(
+ test_ds, batch_size=args.batch_size, shuffle=False
+ )
+
+ model = MedTsLLM(
+ dataset=sample_dataset,
+ task="segmentation",
+ seq_len=seq_len,
+ n_features=n_features,
+ covariate_mode="concat",
+ d_model=32,
+ d_ff=64,
+ n_heads=8,
+ num_tokens=1024,
+ patch_len=16,
+ stride=8,
+ dataset_description=(
+ "The BIDMC dataset contains electrocardiogram (ECG), "
+ "pulse oximetry (PPG), and impedance pneumography "
+ "respiratory signals from intensive care patients."
+ ),
+ backbone=backbone,
+ word_embeddings=word_embeddings,
+ prompt_dataset=not args.no_prompt_dataset,
+ prompt_task=not args.no_prompt_task,
+ prompt_patient=not args.no_prompt_patient,
+ prompt_stats=args.prompt_stats,
+ )
+
+ trainer = Trainer(model=model, device=device, enable_logging=False)
+ trainer.train(
+ train_dataloader=train_loader,
+ test_dataloader=test_loader,
+ epochs=epochs,
+ optimizer_class=torch.optim.Adam,
+ optimizer_params={"lr": args.lr},
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/ludb_ecg_segmentation_medtsllm.py b/examples/ludb_ecg_segmentation_medtsllm.py
new file mode 100644
index 000000000..efae066ce
--- /dev/null
+++ b/examples/ludb_ecg_segmentation_medtsllm.py
@@ -0,0 +1,202 @@
+"""LUDB ECG Wave Segmentation with MedTsLLM.
+
+Demonstrates the MedTsLLM model (Chan et al., MLHC 2024) on the LUDB
+dataset for per-timestep ECG wave segmentation (P wave, QRS complex,
+T wave, background).
+
+Paper: https://arxiv.org/abs/2408.07773
+Dataset: https://physionet.org/content/ludb/1.0.1/
+
+Usage:
+ # With synthetic data (no downloads, fast):
+ python examples/ludb_ecg_segmentation_medtsllm.py --synthetic
+
+ # With real data + GPT-2:
+ python examples/ludb_ecg_segmentation_medtsllm.py \\
+ --root /path/to/ludb --backbone openai-community/gpt2
+
+Ablation Study:
+ The script exposes the paper's two main ablation axes as CLI
+ flags so each run is a single ablation cell.
+
+ 1. LLM backbone swap -- ``--backbone ``:
+ compare GPT-2 vs. GPT-2-medium vs. DistilGPT-2 etc.
+
+ python examples/ludb_ecg_segmentation_medtsllm.py \\
+ --root /path/to/ludb --backbone openai-community/gpt2
+ python examples/ludb_ecg_segmentation_medtsllm.py \\
+ --root /path/to/ludb --backbone distilbert/distilgpt2
+
+ 2. Prompt components -- each piece of the text prompt can be
+ disabled independently:
+ --no-prompt-dataset drops the dataset description
+ --no-prompt-task drops the task description
+ --no-prompt-patient drops the per-patient description
+ --prompt-stats adds the rolling signal stats
+
+ python examples/ludb_ecg_segmentation_medtsllm.py \\
+ --root /path/to/ludb --no-prompt-patient
+ python examples/ludb_ecg_segmentation_medtsllm.py \\
+ --root /path/to/ludb --no-prompt-dataset --prompt-stats
+"""
+
+import argparse
+
+import numpy as np
+import torch
+
+from pyhealth.datasets import (
+ create_sample_dataset,
+ get_dataloader,
+ split_by_patient,
+)
+from pyhealth.models import MedTsLLM
+from pyhealth.trainer import Trainer
+
+
+def make_synthetic_dataset(n_patients=10, seq_len=512, n_classes=4):
+ """Create a synthetic dataset for testing without data downloads."""
+ samples = []
+ for i in range(n_patients):
+ for w in range(5):
+ signal = np.random.randn(seq_len).astype(np.float32)
+ label = np.random.randint(0, n_classes, size=seq_len).astype(
+ np.int64
+ )
+ samples.append({
+ "patient_id": f"p{i}",
+ "visit_id": f"v{w}",
+ "signal": signal,
+ "label": label,
+ })
+
+ return create_sample_dataset(
+ samples=samples,
+ input_schema={"signal": "tensor"},
+ output_schema={"label": "tensor"},
+ dataset_name="synthetic_ludb",
+ )
+
+
+def make_real_dataset(root, seq_len=512, step_size=256):
+ """Load LUDB dataset and apply ECG segmentation task.
+
+ ``preprocess=True`` caches decoded signals to ``{root}/processed/``
+ so wfdb decoding runs exactly once; ``trim=True`` drops the head
+ and tail of each record to match the paper's preprocessing recipe.
+ """
+ from pyhealth.datasets import LUDBDataset
+ from pyhealth.tasks import ECGWaveSegmentation
+
+ dataset = LUDBDataset(root=root, preprocess=True, trim=True)
+ task = ECGWaveSegmentation(window_size=seq_len, step_size=step_size)
+ return dataset.set_task(task)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="LUDB ECG segmentation with MedTsLLM"
+ )
+ parser.add_argument(
+ "--root", type=str, default=None, help="Path to LUDB data"
+ )
+ parser.add_argument(
+ "--synthetic",
+ action="store_true",
+ help="Use synthetic data (no downloads)",
+ )
+ parser.add_argument(
+ "--backbone",
+ type=str,
+ default=None,
+ help="HuggingFace model ID (default: openai-community/gpt2)",
+ )
+ parser.add_argument("--epochs", type=int, default=10)
+ parser.add_argument("--batch-size", type=int, default=16)
+ parser.add_argument("--lr", type=float, default=1e-4)
+ parser.add_argument("--device", type=str, default="auto")
+ # Ablation knobs — each flag disables one prompt component
+ parser.add_argument("--no-prompt-dataset", action="store_true")
+ parser.add_argument("--no-prompt-task", action="store_true")
+ parser.add_argument("--no-prompt-patient", action="store_true")
+ parser.add_argument("--prompt-stats", action="store_true")
+ args = parser.parse_args()
+
+ if args.device == "auto":
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ else:
+ device = args.device
+
+ seq_len = 512
+ n_classes = 4
+
+ # Load data
+ if args.synthetic or args.root is None:
+ print("Using synthetic data (no downloads)")
+ sample_dataset = make_synthetic_dataset(
+ n_patients=10, seq_len=seq_len, n_classes=n_classes
+ )
+ word_embeddings = torch.randn(100, 64)
+ backbone = None
+ epochs = 3
+ else:
+ print(f"Loading LUDB from {args.root}")
+ sample_dataset = make_real_dataset(
+ args.root, seq_len=seq_len, step_size=256
+ )
+ word_embeddings = None
+ backbone = args.backbone or "openai-community/gpt2"
+ epochs = args.epochs
+
+ # Split by patient
+ train_ds, _, test_ds = split_by_patient(
+ sample_dataset, ratios=[0.8, 0.0, 0.2]
+ )
+ train_loader = get_dataloader(
+ train_ds, batch_size=args.batch_size, shuffle=True
+ )
+ test_loader = get_dataloader(
+ test_ds, batch_size=args.batch_size, shuffle=False
+ )
+
+ model = MedTsLLM(
+ dataset=sample_dataset,
+ seq_len=seq_len,
+ n_features=1,
+ n_classes=n_classes,
+ d_model=32,
+ d_ff=128,
+ n_heads=8,
+ num_tokens=1024,
+ patch_len=16,
+ stride=8,
+ dataset_description=(
+ "LUDB is an ECG signal database collected from subjects "
+ "with various cardiovascular diseases used for ECG "
+ "delineation."
+ ),
+ backbone=backbone,
+ word_embeddings=word_embeddings,
+ prompt_dataset=not args.no_prompt_dataset,
+ prompt_task=not args.no_prompt_task,
+ prompt_patient=not args.no_prompt_patient,
+ prompt_stats=args.prompt_stats,
+ )
+
+ # Train
+ trainer = Trainer(
+ model=model,
+ device=device,
+ enable_logging=False,
+ )
+ trainer.train(
+ train_dataloader=train_loader,
+ test_dataloader=test_loader,
+ epochs=epochs,
+ optimizer_class=torch.optim.Adam,
+ optimizer_params={"lr": args.lr},
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/mitbih_ecg_anomaly_medtsllm.py b/examples/mitbih_ecg_anomaly_medtsllm.py
new file mode 100644
index 000000000..2704ef27d
--- /dev/null
+++ b/examples/mitbih_ecg_anomaly_medtsllm.py
@@ -0,0 +1,205 @@
+"""MIT-BIH ECG Anomaly Detection with MedTsLLM.
+
+Demonstrates reconstruction-based anomaly detection (Chan et al.,
+MLHC 2024) on the MIT-BIH Arrhythmia Database. MedTsLLM is trained
+with MSE loss to reconstruct 2-channel ECG; at eval time, elevated
+reconstruction error flags abnormal-rhythm beats.
+
+Paper: https://arxiv.org/abs/2408.07773
+Dataset: https://physionet.org/content/mitdb/1.0.0/
+
+Usage:
+ python examples/mitbih_ecg_anomaly_medtsllm.py --synthetic
+ python examples/mitbih_ecg_anomaly_medtsllm.py \\
+ --root /path/to/mitdb --backbone openai-community/gpt2
+
+Ablation Study:
+ The script exposes the paper's two main ablation axes as CLI
+ flags so each run is a single ablation cell.
+
+ 1. LLM backbone swap -- ``--backbone ``:
+ compare GPT-2 vs. GPT-2-medium vs. DistilGPT-2 etc.
+
+ 2. Prompt components -- each piece of the text prompt can be
+ disabled independently:
+ --no-prompt-dataset drops the dataset description
+ --no-prompt-task drops the task description
+ --no-prompt-patient drops the per-patient description
+ --prompt-stats adds the rolling signal stats
+
+ python examples/mitbih_ecg_anomaly_medtsllm.py \\
+ --root /path/to/mitdb --no-prompt-patient
+ python examples/mitbih_ecg_anomaly_medtsllm.py \\
+ --root /path/to/mitdb --no-prompt-dataset --prompt-stats
+"""
+
+import argparse
+
+import numpy as np
+import torch
+
+from pyhealth.datasets import (
+ create_sample_dataset,
+ get_dataloader,
+ split_by_patient,
+)
+from pyhealth.models import MedTsLLM
+from pyhealth.trainer import Trainer
+
+
+def make_synthetic_dataset(
+ n_patients: int = 10, seq_len: int = 128, n_channels: int = 2
+):
+ """Synthetic 2-channel ECG with sparse anomaly intervals."""
+ samples = []
+ for i in range(n_patients):
+ for w in range(5):
+ signal = np.random.randn(seq_len, n_channels).astype(np.float32)
+ label = np.zeros(seq_len, dtype=np.float32)
+ # 1-2 anomaly intervals per window
+ for _ in range(np.random.randint(0, 3)):
+ s = np.random.randint(0, seq_len - 10)
+ label[s : s + np.random.randint(3, 10)] = 1
+ samples.append({
+ "patient_id": f"p{i}",
+ "visit_id": f"v{w}",
+ "signal": signal,
+ "label": label,
+ "description": "",
+ })
+ return create_sample_dataset(
+ samples=samples,
+ input_schema={"signal": "tensor"},
+ output_schema={"label": "tensor"},
+ dataset_name="synthetic_mitbih_anomaly",
+ )
+
+
+def make_real_dataset(root: str, seq_len: int = 128, step_size: int = 128):
+ """Load MIT-BIH and apply ECGAnomalyDetection.
+
+ Uses ``preprocess=True`` so the 30-minute wfdb records are decoded
+ and downsampled once into ``{root}/processed/*.npz``. Subsequent
+ runs skip wfdb entirely. ``paper_split="abnormal_sorted"`` matches
+ the paper's least-/most-abnormal 80/20 train/test split.
+ """
+ from pyhealth.datasets import MITBIHDataset
+ from pyhealth.tasks import ECGAnomalyDetection
+
+ dataset = MITBIHDataset(
+ root=root,
+ preprocess=True,
+ downsample_factor=3,
+ trim=True,
+ paper_split="abnormal_sorted",
+ )
+ task = ECGAnomalyDetection(window_size=seq_len, step_size=step_size)
+ return dataset.set_task(task)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="MIT-BIH arrhythmia anomaly detection with MedTsLLM"
+ )
+ parser.add_argument("--root", type=str, default=None)
+ parser.add_argument("--synthetic", action="store_true")
+ parser.add_argument("--backbone", type=str, default=None)
+ parser.add_argument("--epochs", type=int, default=10)
+ parser.add_argument("--batch-size", type=int, default=16)
+ parser.add_argument("--lr", type=float, default=1e-4)
+ parser.add_argument("--device", type=str, default="auto")
+ # Ablation knobs — each flag disables one prompt component
+ parser.add_argument("--no-prompt-dataset", action="store_true")
+ parser.add_argument("--no-prompt-task", action="store_true")
+ parser.add_argument("--no-prompt-patient", action="store_true")
+ parser.add_argument("--prompt-stats", action="store_true")
+ args = parser.parse_args()
+
+ device = (
+ ("cuda" if torch.cuda.is_available() else "cpu")
+ if args.device == "auto"
+ else args.device
+ )
+
+ seq_len = 128
+ n_features = 2
+
+ if args.synthetic or args.root is None:
+ print("Using synthetic data")
+ sample_dataset = make_synthetic_dataset(
+ n_patients=10, seq_len=seq_len, n_channels=n_features
+ )
+ word_embeddings = torch.randn(100, 64)
+ backbone = None
+ epochs = 3
+ else:
+ print(f"Loading MIT-BIH from {args.root}")
+ sample_dataset = make_real_dataset(args.root, seq_len=seq_len)
+ word_embeddings = None
+ backbone = args.backbone or "openai-community/gpt2"
+ epochs = args.epochs
+
+ # Synthetic path has no paper_split, so fall back to a random
+ # patient split. On real data, ``make_real_dataset`` sets
+ # ``paper_split="abnormal_sorted"`` which tags each sample with a
+ # ``"split"`` field; partition on that below to consume the
+ # baked-in train/test assignment without reshuffling.
+ if args.synthetic or args.root is None:
+ train_ds, _, test_ds = split_by_patient(
+ sample_dataset, ratios=[0.8, 0.0, 0.2]
+ )
+ else:
+ train_idx, test_idx = [], []
+ for i in range(len(sample_dataset)):
+ bucket = sample_dataset[i]["split"]
+ if bucket == "train":
+ train_idx.append(i)
+ elif bucket == "test":
+ test_idx.append(i)
+ train_ds = sample_dataset.subset(np.array(train_idx))
+ test_ds = sample_dataset.subset(np.array(test_idx))
+ train_loader = get_dataloader(
+ train_ds, batch_size=args.batch_size, shuffle=True
+ )
+ test_loader = get_dataloader(
+ test_ds, batch_size=args.batch_size, shuffle=False
+ )
+
+ model = MedTsLLM(
+ dataset=sample_dataset,
+ task="anomaly_detection",
+ seq_len=seq_len,
+ n_features=n_features,
+ covariate_mode="concat",
+ d_model=32,
+ d_ff=64,
+ n_heads=8,
+ num_tokens=1024,
+ patch_len=16,
+ stride=8,
+ dataset_description=(
+ "The MIT-BIH Arrhythmia Database contains excerpts of "
+ "two-channel ambulatory ECG from a mixed population of "
+ "inpatients and outpatients, digitized at 360 samples "
+ "per second per channel."
+ ),
+ backbone=backbone,
+ word_embeddings=word_embeddings,
+ prompt_dataset=not args.no_prompt_dataset,
+ prompt_task=not args.no_prompt_task,
+ prompt_patient=not args.no_prompt_patient,
+ prompt_stats=args.prompt_stats,
+ )
+
+ trainer = Trainer(model=model, device=device, enable_logging=False)
+ trainer.train(
+ train_dataloader=train_loader,
+ test_dataloader=test_loader,
+ epochs=epochs,
+ optimizer_class=torch.optim.Adam,
+ optimizer_params={"lr": args.lr},
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/mitbih_ecg_boundary_medtsllm.py b/examples/mitbih_ecg_boundary_medtsllm.py
new file mode 100644
index 000000000..0d02aee79
--- /dev/null
+++ b/examples/mitbih_ecg_boundary_medtsllm.py
@@ -0,0 +1,186 @@
+"""MIT-BIH ECG Boundary Detection with MedTsLLM.
+
+Demonstrates the MedTsLLM model (Chan et al., MLHC 2024) on the
+MIT-BIH Arrhythmia Database for R-peak boundary detection using
+2-channel ECG downsampled from 360 Hz to 120 Hz. Binary
+segmentation task trained with BCE-with-logits loss.
+
+Paper: https://arxiv.org/abs/2408.07773
+Dataset: https://physionet.org/content/mitdb/1.0.0/
+
+Usage:
+ python examples/mitbih_ecg_boundary_medtsllm.py --synthetic
+ python examples/mitbih_ecg_boundary_medtsllm.py \\
+ --root /path/to/mitdb --backbone openai-community/gpt2
+
+Ablation Study:
+ The script exposes the paper's two main ablation axes as CLI
+ flags so each run is a single ablation cell.
+
+ 1. LLM backbone swap -- ``--backbone ``:
+ compare GPT-2 vs. GPT-2-medium vs. DistilGPT-2 etc.
+
+ 2. Prompt components -- each piece of the text prompt can be
+ disabled independently:
+ --no-prompt-dataset drops the dataset description
+ --no-prompt-task drops the task description
+ --no-prompt-patient drops the per-patient description
+ --prompt-stats adds the rolling signal stats
+
+ python examples/mitbih_ecg_boundary_medtsllm.py \\
+ --root /path/to/mitdb --no-prompt-patient
+ python examples/mitbih_ecg_boundary_medtsllm.py \\
+ --root /path/to/mitdb --no-prompt-dataset --prompt-stats
+"""
+
+import argparse
+
+import numpy as np
+import torch
+
+from pyhealth.datasets import (
+ create_sample_dataset,
+ get_dataloader,
+ split_by_patient,
+)
+from pyhealth.models import MedTsLLM
+from pyhealth.trainer import Trainer
+
+
+def make_synthetic_dataset(
+ n_patients: int = 10, seq_len: int = 256, n_channels: int = 2
+):
+ """Synthetic 2-channel ECG data."""
+ samples = []
+ for i in range(n_patients):
+ for w in range(5):
+ signal = np.random.randn(seq_len, n_channels).astype(np.float32)
+ label = np.zeros(seq_len, dtype=np.float32)
+ # Sparse boundary labels (~3% density) to mimic real R-peaks
+ idx = np.random.choice(seq_len, size=seq_len // 40, replace=False)
+ label[idx] = 1
+ samples.append({
+ "patient_id": f"p{i}",
+ "visit_id": f"v{w}",
+ "signal": signal,
+ "label": label,
+ "description": "",
+ })
+ return create_sample_dataset(
+ samples=samples,
+ input_schema={"signal": "tensor"},
+ output_schema={"label": "tensor"},
+ dataset_name="synthetic_mitbih",
+ )
+
+
+def make_real_dataset(root: str, seq_len: int = 256, step_size: int = 256):
+ """Load MIT-BIH and apply ECGBoundaryDetection.
+
+ Uses ``preprocess=True`` so the 30-minute wfdb records are decoded
+ and downsampled once into ``{root}/processed/*.npz``. Subsequent
+ runs skip wfdb entirely.
+ """
+ from pyhealth.datasets import MITBIHDataset
+ from pyhealth.tasks import ECGBoundaryDetection
+
+ dataset = MITBIHDataset(
+ root=root,
+ preprocess=True,
+ downsample_factor=3,
+ trim=True,
+ )
+ task = ECGBoundaryDetection(window_size=seq_len, step_size=step_size)
+ return dataset.set_task(task)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="MIT-BIH R-peak boundary detection with MedTsLLM"
+ )
+ parser.add_argument("--root", type=str, default=None)
+ parser.add_argument("--synthetic", action="store_true")
+ parser.add_argument("--backbone", type=str, default=None)
+ parser.add_argument("--epochs", type=int, default=10)
+ parser.add_argument("--batch-size", type=int, default=16)
+ parser.add_argument("--lr", type=float, default=1e-4)
+ parser.add_argument("--device", type=str, default="auto")
+ # Ablation knobs — each flag disables one prompt component
+ parser.add_argument("--no-prompt-dataset", action="store_true")
+ parser.add_argument("--no-prompt-task", action="store_true")
+ parser.add_argument("--no-prompt-patient", action="store_true")
+ parser.add_argument("--prompt-stats", action="store_true")
+ args = parser.parse_args()
+
+ device = (
+ ("cuda" if torch.cuda.is_available() else "cpu")
+ if args.device == "auto"
+ else args.device
+ )
+
+ seq_len = 256
+ n_features = 2
+
+ if args.synthetic or args.root is None:
+ print("Using synthetic data")
+ sample_dataset = make_synthetic_dataset(
+ n_patients=10, seq_len=seq_len, n_channels=n_features
+ )
+ word_embeddings = torch.randn(100, 64)
+ backbone = None
+ epochs = 3
+ else:
+ print(f"Loading MIT-BIH from {args.root}")
+ sample_dataset = make_real_dataset(args.root, seq_len=seq_len)
+ word_embeddings = None
+ backbone = args.backbone or "openai-community/gpt2"
+ epochs = args.epochs
+
+ train_ds, _, test_ds = split_by_patient(
+ sample_dataset, ratios=[0.8, 0.0, 0.2]
+ )
+ train_loader = get_dataloader(
+ train_ds, batch_size=args.batch_size, shuffle=True
+ )
+ test_loader = get_dataloader(
+ test_ds, batch_size=args.batch_size, shuffle=False
+ )
+
+ model = MedTsLLM(
+ dataset=sample_dataset,
+ task="segmentation",
+ seq_len=seq_len,
+ n_features=n_features,
+ covariate_mode="concat",
+ d_model=32,
+ d_ff=64,
+ n_heads=8,
+ num_tokens=1024,
+ patch_len=16,
+ stride=8,
+ dataset_description=(
+ "The MIT-BIH Arrhythmia Database contains excerpts of "
+ "two-channel ambulatory ECG from a mixed population of "
+ "inpatients and outpatients, digitized at 360 samples "
+ "per second per channel."
+ ),
+ backbone=backbone,
+ word_embeddings=word_embeddings,
+ prompt_dataset=not args.no_prompt_dataset,
+ prompt_task=not args.no_prompt_task,
+ prompt_patient=not args.no_prompt_patient,
+ prompt_stats=args.prompt_stats,
+ )
+
+ trainer = Trainer(model=model, device=device, enable_logging=False)
+ trainer.train(
+ train_dataloader=train_loader,
+ test_dataloader=test_loader,
+ epochs=epochs,
+ optimizer_class=torch.optim.Adam,
+ optimizer_params={"lr": args.lr},
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py
index 50b1b3887..d9c696ef2 100644
--- a/pyhealth/datasets/__init__.py
+++ b/pyhealth/datasets/__init__.py
@@ -55,7 +55,10 @@ def __init__(self, *args, **kwargs):
from .dreamt import DREAMTDataset
from .ehrshot import EHRShotDataset
from .eicu import eICUDataset
+from .bidmc import BIDMCDataset
from .isruc import ISRUCDataset
+from .ludb import LUDBDataset
+from .mitbih import MITBIHDataset
from .medical_transcriptions import MedicalTranscriptionsDataset
from .mimic3 import MIMIC3Dataset
from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset
diff --git a/pyhealth/datasets/_medtsllm_cache.py b/pyhealth/datasets/_medtsllm_cache.py
new file mode 100644
index 000000000..289a8f2aa
--- /dev/null
+++ b/pyhealth/datasets/_medtsllm_cache.py
@@ -0,0 +1,105 @@
+"""Fingerprint-based ``.npz`` preprocessing cache for MedTsLLM datasets.
+
+Internal helper for the LUDB / MIT-BIH / BIDMC loaders. Provides
+``load_or_build`` so preprocessing runs the expensive wfdb-decode path
+exactly once per raw file. Subsequent calls load from an ``.npz``
+next to the raw file (or wherever the dataset chooses to cache).
+Invalidation is keyed on raw-file stats + preprocessing params via
+:func:`compute_fingerprint`.
+
+Author: Anton Barchukov
+"""
+
+import hashlib
+import json
+import os
+from typing import Callable
+
+import numpy as np
+
+_FINGERPRINT_KEY = "_cache_fingerprint"
+
+
+def compute_fingerprint(raw_paths: list[str], params: dict) -> str:
+ """Return a stable SHA-256 fingerprint for cache invalidation.
+
+ Combines raw-file ``(path, mtime_ns, size)`` tuples with the
+ preprocessing params (JSON-serializable) into a single hash. A
+ change in any input — including editing a raw file or flipping a
+ param — flips the fingerprint.
+
+ Args:
+ raw_paths: Absolute paths to the raw files whose decoded form
+ is being cached. Order-insensitive.
+ params: Preprocessing params that affect the cached arrays
+ (e.g. ``{"trim": True, "downsample_factor": 3}``).
+
+ Returns:
+ 64-char hex digest string.
+ """
+ h = hashlib.sha256()
+ for path in sorted(raw_paths):
+ st = os.stat(path)
+ h.update(path.encode("utf-8"))
+ h.update(b"|")
+ h.update(str(st.st_mtime_ns).encode("ascii"))
+ h.update(b"|")
+ h.update(str(st.st_size).encode("ascii"))
+ h.update(b"\n")
+ h.update(json.dumps(params, sort_keys=True, default=str).encode("utf-8"))
+ return h.hexdigest()
+
+
+def load_or_build(
+ cache_path: str,
+ fingerprint: str,
+ builder: Callable[[], dict[str, np.ndarray]],
+) -> dict[str, np.ndarray]:
+ """Load cached arrays if the fingerprint matches, else build + write.
+
+ Args:
+ cache_path: Target ``.npz`` path. Parent dirs are created.
+ fingerprint: Expected fingerprint string. Mismatch triggers a
+ rebuild.
+ builder: Zero-arg callable returning ``{name: ndarray}``. Only
+ invoked on cache miss.
+
+ Returns:
+ Dict of arrays — either freshly built or restored from disk.
+ """
+ cached = _try_load(cache_path, fingerprint)
+ if cached is not None:
+ return cached
+
+ arrays = builder()
+ parent = os.path.dirname(cache_path)
+ if parent:
+ os.makedirs(parent, exist_ok=True)
+
+ payload = dict(arrays)
+ payload[_FINGERPRINT_KEY] = np.array([fingerprint])
+ np.savez(cache_path, allow_pickle=False, **payload)
+ return arrays
+
+
+def _try_load(
+ cache_path: str, fingerprint: str
+) -> dict[str, np.ndarray] | None:
+ """Return cached arrays iff the file exists, parses, and matches."""
+ if not os.path.exists(cache_path):
+ return None
+ try:
+ npz = np.load(cache_path, allow_pickle=False)
+ except Exception:
+ return None
+ try:
+ stored = str(npz[_FINGERPRINT_KEY][0])
+ except (KeyError, IndexError):
+ return None
+ if stored != fingerprint:
+ return None
+ return {
+ key: np.array(npz[key])
+ for key in npz.files
+ if key != _FINGERPRINT_KEY
+ }
diff --git a/pyhealth/datasets/bidmc.py b/pyhealth/datasets/bidmc.py
new file mode 100644
index 000000000..6dcdc2bbf
--- /dev/null
+++ b/pyhealth/datasets/bidmc.py
@@ -0,0 +1,258 @@
+# Author: Anton Barchukov
+# Paper: Chan et al., "MedTsLLM: Leveraging LLMs for Multimodal
+# Medical Time Series Analysis", MLHC 2024
+# Paper link: https://arxiv.org/abs/2408.07773
+# Description: BIDMC dataset — 53 ICU patients with 8-minute
+# recordings of ECG, PPG, and respiratory signals at 125 Hz.
+# Two annotators manually annotated individual breaths.
+# Source: https://physionet.org/content/bidmc/1.0.0/
+
+import logging
+import os
+from typing import Optional
+
+import numpy as np
+import pandas as pd
+
+from pyhealth.datasets import BaseDataset
+from pyhealth.datasets._medtsllm_cache import (
+ compute_fingerprint,
+ load_or_build,
+)
+
+logger = logging.getLogger(__name__)
+
+# Paper's BIDMC split: 85/15 by patient, np.random.RandomState(0).
+_PAPER_SPLIT_RATIO = 0.85
+_PAPER_SPLIT_SEED = 0
+
+# Subdirectory under ``root`` for preprocessed ``.npz`` caches.
+_PROCESSED_SUBDIR = "processed"
+
+# RESP, PLETH, and ECG lead II: the 3 channels used in the paper.
+_TARGET_CHANNELS = ["RESP,", "PLETH,", "II,"]
+
+
+class BIDMCDataset(BaseDataset):
+ """BIDMC respiratory signal dataset for boundary detection.
+
+ 53 ICU patients with 8-minute recordings of ECG, PPG, and
+ respiratory impedance signals at 125 Hz. Breath boundaries are
+ manually annotated by two annotators.
+
+ Dataset is available at https://physionet.org/content/bidmc/1.0.0/
+
+ Paper: Pimentel, M.A.F. et al. "Towards a Robust Estimation of
+ Respiratory Rate from Pulse Oximeters." IEEE TBME, 2016.
+
+ Args:
+ root: Root directory of the raw BIDMC data. Should contain
+ wfdb record files (bidmc01.dat, bidmc01.hea, etc.).
+ dataset_name: Name of the dataset. Default is ``"bidmc"``.
+ config_path: Path to the YAML config file.
+ dev: Whether to enable dev mode (first 5 patients).
+ paper_split: If True, populate a ``split`` column with an
+ 85/15 train/test assignment per patient, using the
+ legacy NumPy seed=0 RNG from the paper. Otherwise the
+ ``split`` column is left blank. Default False.
+ preprocess: If True, decode each record once, extract the
+ RESP/PLETH/II channels, and cache
+ ``(signal, ann_sample, ann_aux)`` to
+ ``{root}/processed/{record}.npz``. Subsequent runs skip
+ wfdb. Default False.
+
+ Examples:
+ >>> from pyhealth.datasets import BIDMCDataset
+ >>> dataset = BIDMCDataset(root="/path/to/bidmc/")
+ >>> dataset.stat()
+ """
+
+ def __init__(
+ self,
+ root: str,
+ dataset_name: Optional[str] = None,
+ config_path: Optional[str] = None,
+ dev: bool = False,
+ paper_split: bool = False,
+ preprocess: bool = False,
+ ) -> None:
+ if config_path is None:
+ config_path = os.path.join(
+ os.path.dirname(__file__), "configs", "bidmc.yaml"
+ )
+
+ metadata_path = os.path.join(root, "bidmc-pyhealth.csv")
+ if not os.path.exists(metadata_path):
+ self.prepare_metadata(
+ root,
+ dev=dev,
+ paper_split=paper_split,
+ preprocess=preprocess,
+ )
+
+ super().__init__(
+ root=root,
+ tables=["respiratory"],
+ dataset_name=dataset_name or "bidmc",
+ config_path=config_path,
+ dev=dev,
+ )
+
+ @staticmethod
+ def prepare_metadata(
+ root: str,
+ dev: bool = False,
+ paper_split: bool = False,
+ preprocess: bool = False,
+ ) -> None:
+ """Prepare metadata CSV from raw wfdb files.
+
+ Args:
+ root: Root directory containing wfdb files.
+ dev: If True, only process first 5 patients.
+ paper_split: If True, assign records to train/test via the
+ paper's 85/15 seed=0 split (written to the ``split``
+ column). Otherwise the column is blank.
+ preprocess: If True, build per-record ``.npz`` caches and
+ populate the ``processed_file`` column.
+ """
+ import wfdb
+
+ # Find record files (exclude numerics *n.hea)
+ hea_files = sorted(
+ f.replace(".hea", "")
+ for f in os.listdir(root)
+ if f.endswith(".hea")
+ and not f.endswith("n.hea")
+ and f.startswith("bidmc")
+ )
+ if dev:
+ hea_files = hea_files[:5]
+
+ split_by_record = (
+ _assign_paper_split(hea_files) if paper_split else {}
+ )
+ processed_dir = os.path.join(root, _PROCESSED_SUBDIR)
+
+ rows = []
+ for rec_name in hea_files:
+ rec_path = os.path.join(root, rec_name)
+ try:
+ record = wfdb.rdrecord(rec_path)
+ except Exception:
+ continue
+
+ patient_id = rec_name.replace("bidmc", "")
+
+ # Parse header for demographics
+ age, sex, location = "", "", ""
+ if record.comments:
+ comment = " ".join(record.comments)
+ for field, var in [("age", "age"), ("sex", "sex"),
+ ("location", "location")]:
+ tag = f"<{field}>:"
+ if tag in comment:
+ idx = comment.index(tag) + len(tag)
+ end = comment.find("<", idx)
+ val = (comment[idx:end].strip()
+ if end > 0 else comment[idx:].strip())
+ if field == "age":
+ age = val
+ elif field == "sex":
+ sex = val
+ elif field == "location":
+ location = val
+
+ processed_file = ""
+ if preprocess:
+ processed_file = _build_record_cache(
+ processed_dir=processed_dir,
+ rec_path=os.path.join(root, rec_name),
+ rec_name=rec_name,
+ )
+
+ rows.append({
+ "patient_id": patient_id,
+ "signal_file": os.path.join(root, rec_name),
+ "annotation_file": "breath",
+ "age": age,
+ "sex": sex,
+ "location": location,
+ "split": split_by_record.get(rec_name, ""),
+ "processed_file": processed_file,
+ })
+
+ df = pd.DataFrame(rows)
+ out_path = os.path.join(root, "bidmc-pyhealth.csv")
+ df.to_csv(out_path, index=False)
+ logger.info(
+ "BIDMC metadata: %d patients -> %s", len(df), out_path
+ )
+
+ @property
+ def default_task(self):
+ """Returns the default task for this dataset."""
+ from pyhealth.tasks.respiratory_boundary_detection import (
+ RespiratoryBoundaryDetection,
+ )
+
+ return RespiratoryBoundaryDetection()
+
+
+def _build_record_cache(
+ processed_dir: str,
+ rec_path: str,
+ rec_name: str,
+) -> str:
+ """Cache (signal, ann_sample, ann_aux) for one BIDMC record.
+
+ Signal is the 3-channel (RESP, PLETH, II) array at native 125 Hz
+ — no downsampling, no trim, matching the paper's recipe.
+ ``ann_sample`` / ``ann_aux`` preserve the raw annotation stream
+ so the task can pick annotator 1 or 2 at windowing time.
+ """
+ cache_path = os.path.join(processed_dir, f"{rec_name}.npz")
+ raw_paths = [
+ rec_path + ".dat",
+ rec_path + ".hea",
+ rec_path + ".breath",
+ ]
+ fingerprint = compute_fingerprint(raw_paths, {"channels": _TARGET_CHANNELS})
+
+ def _build() -> dict[str, np.ndarray]:
+ import wfdb
+
+ record = wfdb.rdrecord(rec_path)
+ ann = wfdb.rdann(rec_path, extension="breath")
+
+ col_idx = [
+ record.sig_name.index(ch)
+ for ch in _TARGET_CHANNELS
+ if ch in record.sig_name
+ ]
+ signal = record.p_signal[:, col_idx].astype(np.float32)
+
+ return {
+ "signal": signal,
+ "ann_sample": np.asarray(ann.sample, dtype=np.int64),
+ "ann_aux": np.asarray(ann.aux_note).astype("U8"),
+ }
+
+ load_or_build(cache_path, fingerprint, _build)
+ return cache_path
+
+
+def _assign_paper_split(records: list[str]) -> dict[str, str]:
+ """Assign each record to train/test per the paper's 85/15 seed=0 split.
+
+ Mirrors the cs598 BIDMCSegmentationDataset recipe: shuffle record
+ names with ``np.random.RandomState(0)``, take the first 85% as
+ train and the remainder as test.
+ """
+ rng = np.random.RandomState(_PAPER_SPLIT_SEED)
+ order = rng.permutation(len(records))
+ cutoff = int(len(records) * _PAPER_SPLIT_RATIO)
+ assignment: dict[str, str] = {}
+ for rank, idx in enumerate(order):
+ assignment[records[idx]] = "train" if rank < cutoff else "test"
+ return assignment
diff --git a/pyhealth/datasets/configs/bidmc.yaml b/pyhealth/datasets/configs/bidmc.yaml
new file mode 100644
index 000000000..a0a7ede5c
--- /dev/null
+++ b/pyhealth/datasets/configs/bidmc.yaml
@@ -0,0 +1,14 @@
+version: "1.0.0"
+tables:
+ respiratory:
+ file_path: "bidmc-pyhealth.csv"
+ patient_id: "patient_id"
+ timestamp: null
+ attributes:
+ - "signal_file"
+ - "annotation_file"
+ - "age"
+ - "sex"
+ - "location"
+ - "split"
+ - "processed_file"
diff --git a/pyhealth/datasets/configs/ludb.yaml b/pyhealth/datasets/configs/ludb.yaml
new file mode 100644
index 000000000..4ba371f1a
--- /dev/null
+++ b/pyhealth/datasets/configs/ludb.yaml
@@ -0,0 +1,16 @@
+version: "1.0.0"
+tables:
+ ecg:
+ file_path: "ludb-pyhealth.csv"
+ patient_id: "patient_id"
+ timestamp: null
+ attributes:
+ - "lead"
+ - "clip_id"
+ - "signal_file"
+ - "label_file"
+ - "age"
+ - "sex"
+ - "diagnoses"
+ - "split"
+ - "processed_file"
diff --git a/pyhealth/datasets/configs/mitbih.yaml b/pyhealth/datasets/configs/mitbih.yaml
new file mode 100644
index 000000000..21dd118ba
--- /dev/null
+++ b/pyhealth/datasets/configs/mitbih.yaml
@@ -0,0 +1,16 @@
+version: "1.0.0"
+tables:
+ ecg:
+ file_path: "mitbih-pyhealth.csv"
+ patient_id: "patient_id"
+ timestamp: null
+ attributes:
+ - "signal_file"
+ - "annotation_file"
+ - "age"
+ - "sex"
+ - "medications"
+ - "n_abnormal"
+ - "n_beats"
+ - "split"
+ - "processed_file"
diff --git a/pyhealth/datasets/ludb.py b/pyhealth/datasets/ludb.py
new file mode 100644
index 000000000..05f7cff7f
--- /dev/null
+++ b/pyhealth/datasets/ludb.py
@@ -0,0 +1,331 @@
+# Author: Anton Barchukov
+# Paper: Chan et al., "MedTsLLM: Leveraging LLMs for Multimodal
+# Medical Time Series Analysis", MLHC 2024
+# Paper link: https://arxiv.org/abs/2408.07773
+# Description: Lobachevsky University Database (LUDB) — 200 subjects
+# with 12-lead ECG at 500 Hz, manually annotated with P wave,
+# T wave, QRS complex, and background classes.
+# Source: https://physionet.org/content/ludb/1.0.1/
+
+import logging
+import os
+from typing import Optional
+
+import numpy as np
+import pandas as pd
+
+from pyhealth.datasets import BaseDataset
+from pyhealth.datasets._medtsllm_cache import (
+ compute_fingerprint,
+ load_or_build,
+)
+
+logger = logging.getLogger(__name__)
+
+# Label mapping: wfdb annotation symbol -> class index
+# 0 = background, 1 = P wave, 2 = QRS complex, 3 = T wave
+WAVE_CLASSES = ["background", "P wave", "QRS complex", "T wave"]
+_WAVE_LABELS = {"p": 1, "N": 2, "t": 3}
+
+# Paper's LUDB split: 80/20 by patient, seeded with NumPy legacy RNG
+# to match the cs598 reference implementation exactly.
+_PAPER_SPLIT_RATIO = 0.8
+_PAPER_SPLIT_SEED = 0
+
+# Subdirectory under ``root`` where preprocessed ``.npz`` files live.
+_PROCESSED_SUBDIR = "processed"
+
+
+def _parse_ludb_header(record) -> tuple[str, str, str]:
+ """Parse age, sex, and diagnoses from LUDB wfdb header comments.
+
+ LUDB headers look like::
+
+ #: 51
+ #: F
+ #:
+ #Rhythm: Sinus bradycardia.
+ #Left ventricular hypertrophy.
+
+ Args:
+ record: A ``wfdb.Record`` with a ``comments`` attribute.
+
+ Returns:
+ Tuple of (age, sex, diagnoses). Diagnoses are joined with
+ ``"; "`` separators. Missing fields return empty strings.
+ """
+ age = ""
+ sex = ""
+ diagnoses_lines: list[str] = []
+ if not record.comments:
+ return age, sex, ""
+
+ in_diagnoses = False
+ for line in record.comments:
+ line = line.strip()
+ if line.startswith(":"):
+ age = line.split(":", 1)[1].strip()
+ elif line.startswith(":"):
+ sex = line.split(":", 1)[1].strip()
+ elif line.startswith(":"):
+ in_diagnoses = True
+ elif in_diagnoses and line:
+ diagnoses_lines.append(line.rstrip(". "))
+
+ return age, sex, "; ".join(diagnoses_lines)
+
+
+class LUDBDataset(BaseDataset):
+ """Lobachevsky University Database (LUDB) for ECG delineation.
+
+ Dataset of 200 subjects with 12-lead ECG recordings at 500 Hz
+ (10 seconds each). Each lead is manually annotated by cardiologists
+ with P wave, QRS complex, and T wave boundaries.
+
+ Dataset is available at https://physionet.org/content/ludb/1.0.1/
+
+ Paper: Kalyakulina, A. et al. "LUDB: A New Open-Access Validation
+ Database for Electrocardiogram Delineation Algorithms."
+
+ Args:
+ root: Root directory of the raw LUDB data. Should contain a
+ ``data/`` subdirectory with wfdb record files (.dat, .hea).
+ dataset_name: Name of the dataset. Default is ``"ludb"``.
+ config_path: Path to the YAML config file. Default uses the
+ built-in config.
+ dev: Whether to enable dev mode (only use first 5 patients).
+ Default is False.
+ paper_split: If True, populate a ``split`` column with an
+ 80/20 train/test assignment per patient, using the
+ legacy NumPy seed=0 RNG from the paper. Otherwise the
+ ``split`` column is left blank. Default False.
+ preprocess: If True, decode each ``(patient, lead)`` wfdb
+ record once and cache the resulting ``(signal, labels)``
+ arrays to ``{root}/processed/{record}_{lead}.npz``.
+ Subsequent runs skip wfdb entirely. Default False.
+ trim: Only consulted when ``preprocess=True``. Crop each
+ lead to the region between the first and last wave
+ annotation before caching. Matches the paper's
+ preprocessing. Default True.
+
+ Examples:
+ >>> from pyhealth.datasets import LUDBDataset
+ >>> dataset = LUDBDataset(root="/path/to/ludb/")
+ >>> dataset.stat()
+ """
+
+ def __init__(
+ self,
+ root: str,
+ dataset_name: Optional[str] = None,
+ config_path: Optional[str] = None,
+ dev: bool = False,
+ paper_split: bool = False,
+ preprocess: bool = False,
+ trim: bool = True,
+ ) -> None:
+ if config_path is None:
+ config_path = os.path.join(
+ os.path.dirname(__file__), "configs", "ludb.yaml"
+ )
+
+ metadata_path = os.path.join(root, "ludb-pyhealth.csv")
+ if not os.path.exists(metadata_path):
+ self.prepare_metadata(
+ root,
+ dev=dev,
+ paper_split=paper_split,
+ preprocess=preprocess,
+ trim=trim,
+ )
+
+ super().__init__(
+ root=root,
+ tables=["ecg"],
+ dataset_name=dataset_name or "ludb",
+ config_path=config_path,
+ dev=dev,
+ )
+
+ @staticmethod
+ def prepare_metadata(
+ root: str,
+ dev: bool = False,
+ paper_split: bool = False,
+ preprocess: bool = False,
+ trim: bool = True,
+ ) -> None:
+ """Prepare metadata CSV from raw wfdb files.
+
+ Scans the ``data/`` subdirectory for wfdb records and creates
+ a CSV with one row per (patient, lead) pair, pointing to the
+ signal and annotation files. When ``preprocess=True`` also
+ writes per-lead ``.npz`` caches into ``{root}/processed/``.
+
+ Args:
+ root: Root directory containing ``data/`` with wfdb files.
+ dev: If True, only process first 5 patients.
+ paper_split: If True, assign patients to train/test via the
+ paper's 80/20 seed=0 split and write values to the
+ ``split`` column. Otherwise the column is blank.
+ preprocess: If True, build per-lead ``.npz`` signal+label
+ caches and point each row's ``processed_file`` at
+ its cache.
+ trim: When ``preprocess=True``, crop each lead to the
+ region between the first and last wave annotation.
+ """
+ import wfdb
+
+ data_dir = os.path.join(root, "data")
+ if not os.path.isdir(data_dir):
+ raise FileNotFoundError(
+ f"LUDB data directory not found at {data_dir}. "
+ "Download from https://physionet.org/content/ludb/1.0.1/"
+ )
+
+ records = sorted(
+ f.replace(".dat", "")
+ for f in os.listdir(data_dir)
+ if f.endswith(".dat")
+ )
+ if dev:
+ records = records[:5]
+
+ split_by_record = _assign_paper_split(records) if paper_split else {}
+ processed_dir = os.path.join(root, _PROCESSED_SUBDIR)
+
+ rows = []
+ for rec_name in records:
+ rec_path = os.path.join(data_dir, rec_name)
+ record = wfdb.rdrecord(rec_path)
+ patient_id = int(rec_name)
+
+ age, sex, diagnoses = _parse_ludb_header(record)
+
+ for lead_idx in range(record.n_sig):
+ lead_name = record.sig_name[lead_idx]
+ # Check if annotation file exists for this lead
+ ann_ext = lead_name.lower()
+ ann_path = os.path.join(data_dir, f"{rec_name}.{ann_ext}")
+ if not os.path.exists(ann_path):
+ continue
+
+ clip_id = patient_id * 100 + lead_idx
+ processed_file = ""
+ if preprocess:
+ processed_file = _build_lead_cache(
+ processed_dir=processed_dir,
+ rec_path=rec_path,
+ rec_name=rec_name,
+ lead_name=lead_name,
+ lead_idx=lead_idx,
+ ann_path=ann_path,
+ ann_ext=ann_ext,
+ trim=trim,
+ )
+ rows.append({
+ "patient_id": str(patient_id),
+ "lead": lead_name,
+ "clip_id": clip_id,
+ "signal_file": os.path.join(data_dir, rec_name),
+ "label_file": ann_ext,
+ "age": age,
+ "sex": sex,
+ "diagnoses": diagnoses,
+ "split": split_by_record.get(rec_name, ""),
+ "processed_file": processed_file,
+ })
+
+ df = pd.DataFrame(rows)
+ out_path = os.path.join(root, "ludb-pyhealth.csv")
+ df.to_csv(out_path, index=False)
+ logger.info(
+ "LUDB metadata: %d records from %d patients -> %s",
+ len(df),
+ df["patient_id"].nunique(),
+ out_path,
+ )
+
+ @property
+ def default_task(self):
+ """Returns the default task for this dataset."""
+ from pyhealth.tasks.ecg_wave_segmentation import ECGWaveSegmentation
+
+ return ECGWaveSegmentation()
+
+
+def _build_lead_cache(
+ processed_dir: str,
+ rec_path: str,
+ rec_name: str,
+ lead_name: str,
+ lead_idx: int,
+ ann_path: str,
+ ann_ext: str,
+ trim: bool,
+) -> str:
+ """Cache (signal, labels) for one (record, lead) pair.
+
+ Returns the absolute cache path, which is written into the
+ metadata CSV's ``processed_file`` column.
+ """
+ cache_path = os.path.join(
+ processed_dir, f"{rec_name}_{lead_name}.npz"
+ )
+ raw_paths = [rec_path + ".dat", rec_path + ".hea", ann_path]
+ params = {"trim": bool(trim)}
+ fingerprint = compute_fingerprint(raw_paths, params)
+
+ def _build() -> dict[str, np.ndarray]:
+ import wfdb
+
+ record = wfdb.rdrecord(rec_path)
+ signal = record.p_signal[:, lead_idx].astype(np.float32)
+ ann = wfdb.rdann(rec_path, extension=ann_ext)
+
+ labels = np.zeros(len(signal), dtype=np.int64)
+ i = 0
+ while i < len(ann.symbol):
+ sym = ann.symbol[i]
+ if sym == "(" and i + 2 < len(ann.symbol):
+ wave_type = ann.symbol[i + 1]
+ onset = ann.sample[i]
+ offset = (
+ ann.sample[i + 2]
+ if ann.symbol[i + 2] == ")"
+ else ann.sample[i + 1]
+ )
+ if wave_type in _WAVE_LABELS:
+ labels[onset : offset + 1] = _WAVE_LABELS[wave_type]
+ i += 3
+ else:
+ i += 1
+
+ if trim:
+ wave_mask = labels > 0
+ if wave_mask.any():
+ first = int(np.argmax(wave_mask))
+ last = len(wave_mask) - 1 - int(np.argmax(wave_mask[::-1]))
+ signal = signal[first : last + 1]
+ labels = labels[first : last + 1]
+
+ return {"signal": signal, "labels": labels}
+
+ load_or_build(cache_path, fingerprint, _build)
+ return cache_path
+
+
+def _assign_paper_split(records: list[str]) -> dict[str, str]:
+ """Assign each record to train/test per the paper's 80/20 seed=0 split.
+
+ Mirrors the cs598 preprocess_ludb.py recipe: shuffle record names
+ with ``np.random.RandomState(0)``, take the first 80% as train and
+ the remainder as test.
+ """
+ rng = np.random.RandomState(_PAPER_SPLIT_SEED)
+ order = rng.permutation(len(records))
+ cutoff = int(len(records) * _PAPER_SPLIT_RATIO)
+ assignment: dict[str, str] = {}
+ for rank, idx in enumerate(order):
+ assignment[records[idx]] = "train" if rank < cutoff else "test"
+ return assignment
diff --git a/pyhealth/datasets/mitbih.py b/pyhealth/datasets/mitbih.py
new file mode 100644
index 000000000..d9fc2947d
--- /dev/null
+++ b/pyhealth/datasets/mitbih.py
@@ -0,0 +1,340 @@
+# Author: Anton Barchukov
+# Paper: Chan et al., "MedTsLLM: Leveraging LLMs for Multimodal
+# Medical Time Series Analysis", MLHC 2024
+# Paper link: https://arxiv.org/abs/2408.07773
+# Description: MIT-BIH Arrhythmia Database — 48 half-hour excerpts
+# of two-channel ambulatory ECG, 360 Hz. Used for boundary
+# detection (R-peaks) and anomaly detection (arrhythmia).
+# Source: https://physionet.org/content/mitdb/1.0.0/
+
+import logging
+import os
+from typing import Optional
+
+import numpy as np
+import pandas as pd
+
+from pyhealth.datasets import BaseDataset
+from pyhealth.datasets._medtsllm_cache import (
+ compute_fingerprint,
+ load_or_build,
+)
+
+logger = logging.getLogger(__name__)
+
+# Normal beat types (all others are anomalies)
+_NORMAL_BEATS = {"N", "L", "R", "e", "j"}
+
+# Paced rhythm records — excluded per paper
+_PACED_RECORDS = {"102", "104", "107", "217"}
+
+# Paper's MIT-BIH split: 80/20 by patient, seeded with NumPy legacy RNG.
+_PAPER_SPLIT_RATIO = 0.8
+_PAPER_SPLIT_SEED = 0
+_VALID_SPLIT_MODES = {None, "random", "abnormal_sorted"}
+
+# Subdirectory under ``root`` for preprocessed ``.npz`` caches.
+_PROCESSED_SUBDIR = "processed"
+
+
+class MITBIHDataset(BaseDataset):
+ """MIT-BIH Arrhythmia Database for ECG analysis.
+
+ 48 half-hour excerpts of two-channel ambulatory ECG from a mixed
+ population of inpatients and outpatients, digitized at 360 Hz.
+ 4 paced-rhythm records are excluded.
+
+ Supports two tasks:
+ - Boundary detection (R-peak localization)
+ - Anomaly detection (arrhythmia via reconstruction error)
+
+ Dataset is available at https://physionet.org/content/mitdb/1.0.0/
+
+ Paper: Moody, G.B. & Mark, R.G. "The impact of the MIT-BIH
+ Arrhythmia Database." IEEE EMB Magazine, 2001.
+
+ Args:
+ root: Root directory of the raw MIT-BIH data. Should contain
+ wfdb record files (100.dat, 100.hea, etc.).
+ dataset_name: Name of the dataset. Default is ``"mitbih"``.
+ config_path: Path to the YAML config file.
+ dev: Whether to enable dev mode (first 5 patients).
+ paper_split: Split assignment strategy:
+
+ - ``None`` (default): leave the ``split`` column blank.
+ - ``"random"``: 80/20 patient split via ``RandomState(0)``,
+ matching the paper's segmentation/boundary task setup.
+ - ``"abnormal_sorted"``: patients sorted by ``n_abnormal``
+ ascending, with all-abnormal patients excluded. The
+ least-abnormal 80% become train and the most-abnormal
+ 20% become test. Matches the paper's anomaly setup.
+ preprocess: If True, decode each record once, downsample,
+ trim, and cache ``(signal, ann_sample, ann_symbol)`` to
+ ``{root}/processed/{record}.npz``. Subsequent runs skip
+ wfdb. Default False.
+ downsample_factor: Decimation factor applied to the 360 Hz
+ raw signal when ``preprocess=True``. Default 3 (120 Hz).
+ trim: When ``preprocess=True``, crop each record to the
+ region between its first and last beat annotation.
+ Matches the paper's preprocessing. Default True.
+
+ Examples:
+ >>> from pyhealth.datasets import MITBIHDataset
+ >>> dataset = MITBIHDataset(root="/path/to/mitdb/")
+ >>> dataset.stat()
+ """
+
+ def __init__(
+ self,
+ root: str,
+ dataset_name: Optional[str] = None,
+ config_path: Optional[str] = None,
+ dev: bool = False,
+ paper_split: Optional[str] = None,
+ preprocess: bool = False,
+ downsample_factor: int = 3,
+ trim: bool = True,
+ ) -> None:
+ if paper_split not in _VALID_SPLIT_MODES:
+ raise ValueError(
+ f"paper_split must be one of {_VALID_SPLIT_MODES}, "
+ f"got {paper_split!r}"
+ )
+ if downsample_factor < 1:
+ raise ValueError(
+ f"downsample_factor must be >= 1, got {downsample_factor}"
+ )
+
+ if config_path is None:
+ config_path = os.path.join(
+ os.path.dirname(__file__), "configs", "mitbih.yaml"
+ )
+
+ metadata_path = os.path.join(root, "mitbih-pyhealth.csv")
+ if not os.path.exists(metadata_path):
+ self.prepare_metadata(
+ root,
+ dev=dev,
+ paper_split=paper_split,
+ preprocess=preprocess,
+ downsample_factor=downsample_factor,
+ trim=trim,
+ )
+
+ super().__init__(
+ root=root,
+ tables=["ecg"],
+ dataset_name=dataset_name or "mitbih",
+ config_path=config_path,
+ dev=dev,
+ )
+
+ @staticmethod
+ def prepare_metadata(
+ root: str,
+ dev: bool = False,
+ paper_split: Optional[str] = None,
+ preprocess: bool = False,
+ downsample_factor: int = 3,
+ trim: bool = True,
+ ) -> None:
+ """Prepare metadata CSV from raw wfdb files.
+
+ Args:
+ root: Root directory containing wfdb files.
+ dev: If True, only process first 5 patients.
+ paper_split: ``None``, ``"random"``, or ``"abnormal_sorted"``.
+ See class docstring for semantics.
+ preprocess: If True, build per-record ``.npz`` caches and
+ populate the ``processed_file`` column.
+ downsample_factor: Decimation factor when ``preprocess=True``.
+ trim: When ``preprocess=True``, crop to first/last beat
+ annotation.
+ """
+ import wfdb
+
+ records = sorted(
+ f.replace(".dat", "")
+ for f in os.listdir(root)
+ if f.endswith(".dat") and f.replace(".dat", "") not in _PACED_RECORDS
+ )
+ if dev:
+ records = records[:5]
+
+ processed_dir = os.path.join(root, _PROCESSED_SUBDIR)
+
+ rows = []
+ for rec_name in records:
+ rec_path = os.path.join(root, rec_name)
+ try:
+ record = wfdb.rdrecord(rec_path)
+ ann = wfdb.rdann(rec_path, extension="atr")
+ except Exception:
+ continue
+
+ patient_id = rec_name
+
+ # Parse header for demographics
+ age, sex, medications = "", "", ""
+ if record.comments:
+ first = record.comments[0].strip()
+ tokens = first.split()
+ if len(tokens) >= 2:
+ age = tokens[0]
+ sex = tokens[1]
+ if len(record.comments) > 1:
+ medications = record.comments[1].strip()
+
+ # Count beats excluding rhythm-change markers ("+")
+ beat_symbols = [s for s in ann.symbol if s != "+"]
+ n_beats = len(beat_symbols)
+ n_abnormal = sum(
+ 1 for s in beat_symbols if s not in _NORMAL_BEATS
+ )
+
+ processed_file = ""
+ if preprocess:
+ processed_file = _build_record_cache(
+ processed_dir=processed_dir,
+ rec_path=rec_path,
+ rec_name=rec_name,
+ downsample_factor=downsample_factor,
+ trim=trim,
+ )
+
+ rows.append({
+ "patient_id": patient_id,
+ "signal_file": os.path.join(root, rec_name),
+ "annotation_file": "atr",
+ "age": age,
+ "sex": sex,
+ "medications": medications,
+ "n_abnormal": n_abnormal,
+ "n_beats": n_beats,
+ "processed_file": processed_file,
+ })
+
+ rows = _apply_paper_split(rows, paper_split)
+
+ df = pd.DataFrame(rows)
+ out_path = os.path.join(root, "mitbih-pyhealth.csv")
+ df.to_csv(out_path, index=False)
+ logger.info(
+ "MIT-BIH metadata: %d records -> %s", len(df), out_path
+ )
+
+ @property
+ def default_task(self):
+ """Returns the default task (boundary detection)."""
+ from pyhealth.tasks.ecg_boundary_detection import (
+ ECGBoundaryDetection,
+ )
+
+ return ECGBoundaryDetection()
+
+
+def _build_record_cache(
+ processed_dir: str,
+ rec_path: str,
+ rec_name: str,
+ downsample_factor: int,
+ trim: bool,
+) -> str:
+ """Cache downsampled + trimmed signal and annotations for one record.
+
+ Returns the absolute cache path, which is written into the
+ metadata CSV's ``processed_file`` column. The cache stores:
+
+ - ``signal``: ``(n_timesteps, n_channels)`` post-downsample + trim
+ - ``ann_sample``: beat sample indices **relative to the trimmed
+ signal** (drops annotations outside the trim range)
+ - ``ann_symbol``: beat symbols aligned with ``ann_sample``
+ """
+ cache_path = os.path.join(processed_dir, f"{rec_name}.npz")
+ raw_paths = [rec_path + ".dat", rec_path + ".hea", rec_path + ".atr"]
+ params = {
+ "downsample_factor": int(downsample_factor),
+ "trim": bool(trim),
+ }
+ fingerprint = compute_fingerprint(raw_paths, params)
+
+ def _build() -> dict[str, np.ndarray]:
+ import wfdb
+
+ record = wfdb.rdrecord(rec_path)
+ ann = wfdb.rdann(rec_path, extension="atr")
+
+ signal = record.p_signal.astype(np.float32)
+ if downsample_factor > 1:
+ signal = signal[::downsample_factor]
+
+ ds_samples = np.asarray(ann.sample) // downsample_factor
+ ann_symbols = np.asarray(ann.symbol)
+
+ # Drop annotations outside the signal bounds.
+ in_bounds = (ds_samples >= 0) & (ds_samples < len(signal))
+ ds_samples = ds_samples[in_bounds]
+ ann_symbols = ann_symbols[in_bounds]
+
+ if trim and len(ds_samples) > 0:
+ first = int(ds_samples[0])
+ last = int(ds_samples[-1])
+ if first <= last:
+ signal = signal[first : last + 1]
+ ds_samples = ds_samples - first
+ # After shifting, last kept index is last-first.
+
+ return {
+ "signal": signal,
+ "ann_sample": ds_samples.astype(np.int64),
+ "ann_symbol": ann_symbols.astype("U4"),
+ }
+
+ load_or_build(cache_path, fingerprint, _build)
+ return cache_path
+
+
+def _apply_paper_split(
+ rows: list[dict], paper_split: Optional[str]
+) -> list[dict]:
+ """Assign each row to train/test per the paper's split strategy.
+
+ Mutates rows in place by adding a ``split`` key. For
+ ``"abnormal_sorted"``, patients with every beat marked abnormal
+ are dropped entirely (the paper excludes them from anomaly
+ training). Returns the possibly filtered list.
+ """
+ if not rows:
+ return rows
+
+ if paper_split is None:
+ for row in rows:
+ row["split"] = ""
+ return rows
+
+ if paper_split == "random":
+ rng = np.random.RandomState(_PAPER_SPLIT_SEED)
+ order = rng.permutation(len(rows))
+ cutoff = int(len(rows) * _PAPER_SPLIT_RATIO)
+ split_by_rank = [
+ "train" if rank < cutoff else "test"
+ for rank in range(len(rows))
+ ]
+ for rank, idx in enumerate(order):
+ rows[idx]["split"] = split_by_rank[rank]
+ return rows
+
+ if paper_split == "abnormal_sorted":
+ # Drop all-abnormal patients before splitting.
+ kept = [
+ row for row in rows
+ if row.get("n_beats", 0) == 0
+ or row["n_abnormal"] < row["n_beats"]
+ ]
+ kept.sort(key=lambda r: r["n_abnormal"])
+ cutoff = int(len(kept) * _PAPER_SPLIT_RATIO)
+ for rank, row in enumerate(kept):
+ row["split"] = "train" if rank < cutoff else "test"
+ return kept
+
+ raise ValueError(f"Unknown paper_split mode: {paper_split!r}")
diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py
index 4c168d3e3..528adf549 100644
--- a/pyhealth/models/__init__.py
+++ b/pyhealth/models/__init__.py
@@ -17,6 +17,7 @@
from .graphcare import GraphCare
from .grasp import GRASP, GRASPLayer
from .medlink import MedLink
+from .medtsllm import MedTsLLM
from .micron import MICRON, MICRONLayer
from .mlp import MLP
from .molerec import MoleRec, MoleRecLayer
diff --git a/pyhealth/models/_medtsllm/__init__.py b/pyhealth/models/_medtsllm/__init__.py
new file mode 100644
index 000000000..13f5b692e
--- /dev/null
+++ b/pyhealth/models/_medtsllm/__init__.py
@@ -0,0 +1,8 @@
+from .layers import (
+ FlattenHead,
+ LinearProjection,
+ PatchEmbedding,
+ ReprogrammingLayer,
+ RevIN,
+)
+from .prompt import build_prompt, compute_lags, encode_prompts
diff --git a/pyhealth/models/_medtsllm/layers.py b/pyhealth/models/_medtsllm/layers.py
new file mode 100644
index 000000000..540a67ff8
--- /dev/null
+++ b/pyhealth/models/_medtsllm/layers.py
@@ -0,0 +1,269 @@
+"""Neural network layers for MedTsLLM.
+
+Includes: RevIN, PatchEmbedding, ReprogrammingLayer, LinearProjection,
+FlattenHead. These are the lightweight trainable components that wrap
+around a frozen LLM backbone.
+"""
+
+import math
+
+import torch
+from torch import Tensor, nn
+
+
+class RevIN(nn.Module):
+ """Reversible instance normalization for time series.
+
+ Normalizes each feature by subtracting the mean and dividing by the
+ standard deviation. Stores statistics so the operation can be
+ reversed for reconstruction tasks.
+
+ Args:
+ num_features: Number of input features/channels.
+ eps: Small value for numerical stability.
+ affine: If True, learns per-feature scale and bias.
+ """
+
+ def __init__(
+ self,
+ num_features: int,
+ eps: float = 1e-5,
+ affine: bool = False,
+ ) -> None:
+ super().__init__()
+ self.num_features = num_features
+ self.eps = eps
+ self.affine = affine
+
+ if self.affine:
+ self.affine_weight = nn.Parameter(torch.ones(self.num_features))
+ self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
+
+ self._mean: Tensor | None = None
+ self._stdev: Tensor | None = None
+
+ def forward(self, x: Tensor, mode: str) -> Tensor:
+ """Apply normalization or denormalization.
+
+ Args:
+ x: (batch, seq_len, num_features).
+ mode: ``"norm"`` or ``"denorm"``.
+ """
+ if mode == "norm":
+ self._mean = x.mean(dim=1, keepdim=True).detach()
+ self._stdev = (
+ (x.var(dim=1, keepdim=True, unbiased=False) + self.eps)
+ .sqrt()
+ .detach()
+ )
+ x = (x - self._mean) / self._stdev
+ if self.affine:
+ x = x * self.affine_weight + self.affine_bias
+ return x
+ elif mode == "denorm":
+ if self._mean is None or self._stdev is None:
+ raise RuntimeError("Call forward(x, 'norm') first.")
+ if self.affine:
+ x = (x - self.affine_bias) / self.affine_weight
+ return x * self._stdev + self._mean
+ else:
+ raise ValueError(f"mode must be 'norm' or 'denorm', got '{mode}'")
+
+
+class _TokenEmbedding(nn.Module):
+ """1D convolution over a single patch (maps patch_len -> d_model).
+
+ Separated from ``PatchEmbedding`` as its own module so the state
+ dict keys (``value_embedding.conv.weight``) match the TIME-LLM /
+ original-MedTsLLM upstream naming. This lets you load checkpoints
+ trained outside of PyHealth without renaming keys.
+
+ Args:
+ patch_len: Length of each patch (input channels).
+ d_model: Output embedding dimension.
+ """
+
+ def __init__(self, patch_len: int, d_model: int) -> None:
+ super().__init__()
+ self.conv = nn.Conv1d(
+ in_channels=patch_len,
+ out_channels=d_model,
+ kernel_size=3,
+ padding=1,
+ padding_mode="circular",
+ bias=False,
+ )
+ nn.init.kaiming_normal_(self.conv.weight)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Embed patches via 1D convolution.
+
+ Args:
+ x: (batch, seq_len, patch_len).
+
+ Returns:
+ Embedded patches: (batch, seq_len, d_model).
+ """
+ return self.conv(x.permute(0, 2, 1)).transpose(1, 2)
+
+
+class PatchEmbedding(nn.Module):
+ """Unfolds a time series into overlapping patches and embeds via Conv1d.
+
+ Args:
+ d_model: Embedding dimension for each patch.
+ patch_len: Length of each patch in timesteps.
+ stride: Stride between consecutive patches.
+ dropout: Dropout probability.
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ patch_len: int,
+ stride: int,
+ dropout: float = 0.1,
+ ) -> None:
+ super().__init__()
+ self.patch_len = patch_len
+ self.stride = stride
+
+ self.padding = nn.ReplicationPad1d((0, stride))
+ self.value_embedding = _TokenEmbedding(patch_len, d_model)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x: Tensor) -> tuple[Tensor, int]:
+ """Embed time series into patch representations.
+
+ Args:
+ x: (batch, n_features, seq_len).
+
+ Returns:
+ Tuple of (patch_embeddings, n_features).
+ patch_embeddings: (batch * n_features, n_patches, d_model).
+ """
+ n_vars = x.shape[1]
+ x = self.padding(x)
+ x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
+ x = torch.reshape(
+ x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])
+ )
+ x = self.value_embedding(x)
+ return self.dropout(x), n_vars
+
+
+class ReprogrammingLayer(nn.Module):
+ """Cross-attention reprogramming layer.
+
+ Projects patch embeddings (queries) against word prototype
+ embeddings (keys/values) using multi-head attention. This is
+ the core mechanism of MedTsLLM / Time-LLM.
+
+ Args:
+ d_model: Dimension of input patch embeddings.
+ n_heads: Number of attention heads.
+ d_keys: Dimension per head for keys.
+ d_llm: Dimension of LLM embeddings.
+ attention_dropout: Dropout on attention weights.
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ n_heads: int,
+ d_keys: int,
+ d_llm: int,
+ attention_dropout: float = 0.1,
+ ) -> None:
+ super().__init__()
+ self.n_heads = n_heads
+ self.d_keys = d_keys
+
+ self.query_projection = nn.Linear(d_model, d_keys * n_heads)
+ self.key_projection = nn.Linear(d_llm, d_keys * n_heads)
+ self.value_projection = nn.Linear(d_llm, d_keys * n_heads)
+ self.out_projection = nn.Linear(d_keys * n_heads, d_llm)
+ self.dropout = nn.Dropout(attention_dropout)
+
+ def forward(
+ self,
+ target_embedding: Tensor,
+ source_embedding: Tensor,
+ value_embedding: Tensor,
+ ) -> Tensor:
+ """Reprogram patch embeddings via cross-attention.
+
+ Args:
+ target_embedding: Queries, (batch, n_patches, d_model).
+ source_embedding: Keys, (n_tokens, d_llm).
+ value_embedding: Values, (n_tokens, d_llm).
+
+ Returns:
+ Reprogrammed embeddings: (batch, n_patches, d_llm).
+ """
+ b, seq_len, _ = target_embedding.shape
+ s, _ = source_embedding.shape
+ h = self.n_heads
+
+ queries = self.query_projection(target_embedding).view(
+ b, seq_len, h, -1
+ )
+ keys = self.key_projection(source_embedding).view(s, h, -1)
+ values = self.value_projection(value_embedding).view(s, h, -1)
+
+ scale = 1.0 / math.sqrt(queries.shape[-1])
+ scores = torch.einsum("blhe,she->bhls", queries, keys)
+ attn = self.dropout(torch.softmax(scale * scores, dim=-1))
+ out = torch.einsum("bhls,she->blhe", attn, values)
+ out = out.reshape(b, seq_len, -1)
+ return self.out_projection(out)
+
+
+class LinearProjection(nn.Module):
+ """Simple linear projection as ablation alternative.
+
+ Replaces ReprogrammingLayer with a plain linear map from
+ d_model to d_llm, ignoring word embeddings.
+
+ Args:
+ d_model: Input dimension.
+ d_llm: Output dimension.
+ """
+
+ def __init__(self, d_model: int, d_llm: int) -> None:
+ super().__init__()
+ self.linear = nn.Linear(d_model, d_llm)
+
+ def forward(
+ self,
+ target_embedding: Tensor,
+ source_embedding: Tensor,
+ value_embedding: Tensor,
+ ) -> Tensor:
+ """Project patch embeddings linearly to LLM dimension."""
+ return self.linear(target_embedding)
+
+
+class FlattenHead(nn.Module):
+ """Flatten patch dimension and project to output size.
+
+ Args:
+ n_features_in: Total input features (d_ff * n_patches).
+ n_outputs: Total output size (pred_len * n_outputs_per_step).
+ head_dropout: Dropout probability.
+ """
+
+ def __init__(
+ self,
+ n_features_in: int,
+ n_outputs: int,
+ head_dropout: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.flatten = nn.Flatten(start_dim=-2)
+ self.linear = nn.Linear(n_features_in, n_outputs)
+ self.dropout = nn.Dropout(head_dropout)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Args: x: (batch, d_ff, n_patches). Returns: (batch, n_outputs)."""
+ return self.dropout(self.linear(self.flatten(x)))
diff --git a/pyhealth/models/_medtsllm/prompt.py b/pyhealth/models/_medtsllm/prompt.py
new file mode 100644
index 000000000..f923aca78
--- /dev/null
+++ b/pyhealth/models/_medtsllm/prompt.py
@@ -0,0 +1,204 @@
+"""Text prompt construction and encoding for MedTsLLM.
+
+Mirrors the prompt structure from the original paper implementation
+(Chan et al., MLHC 2024, https://github.com/flixpar/med-ts-llm):
+
+ [BOS] [dataset desc] [clip desc] [input stats] [task desc] [Time series:]
+
+The four content segments (dataset, clip, stats, task) are toggled
+independently by the caller.
+"""
+
+from typing import Any
+
+import torch
+from torch import Tensor
+
+
+def compute_lags(x: Tensor, n_lags: int = 5) -> Tensor:
+ """Compute top-N autocorrelation lags via FFT.
+
+ Matches the original paper's ``calcute_lags``. Power spectral
+ density is computed as ``rfft(x) * conj(rfft(x))``, inverse-
+ transformed back to autocorrelation, averaged across features,
+ and the top-N lag indices are selected.
+
+ Args:
+ x: (batch, seq_len) or (batch, seq_len, n_features).
+ n_lags: Number of top lags to return.
+
+ Returns:
+ (batch, n_lags) long tensor of lag indices.
+ """
+ if x.ndim == 3:
+ x = x.permute(0, 2, 1).contiguous()
+ else:
+ x = x.unsqueeze(1)
+ q_fft = torch.fft.rfft(x, dim=-1)
+ corr = torch.fft.irfft(q_fft * torch.conj(q_fft), dim=-1)
+ mean_corr = corr.mean(dim=1)
+ _, lags = torch.topk(mean_corr, n_lags, dim=-1)
+ return lags
+
+
+def build_prompt(
+ inputs: dict[str, Any],
+ *,
+ dataset_description: str = "",
+ task_description: str = "",
+ include_dataset: bool = True,
+ include_task: bool = True,
+ include_clip: bool = False,
+ include_stats: bool = False,
+ n_lags: int = 5,
+ bos_token: str = "",
+) -> list[list[str]]:
+ """Build text prompts for a batch of inputs.
+
+ Args:
+ inputs: Dict with ``"x_enc"`` (batch, seq_len, n_features). May
+ include ``"descriptions"`` (per-sample strings) when
+ ``include_clip`` is True.
+ dataset_description: Dataset-level description string.
+ task_description: Task-level description string.
+ include_dataset: Whether to include dataset description.
+ include_task: Whether to include task description.
+ include_clip: Whether to include per-sample descriptions.
+ include_stats: Whether to include per-sample input statistics
+ (min/max/median/trend/top-N autocorr lags).
+ n_lags: Number of autocorrelation lags for the stats prompt.
+ bos_token: Beginning-of-sequence token string.
+
+ Returns:
+ List of string lists, one per batch element.
+ """
+ x_enc = inputs["x_enc"]
+ bs = x_enc.shape[0]
+
+ dataset_prompt = (
+ f"Dataset: {dataset_description}" if include_dataset else ""
+ )
+ task_prompt = f"Task: {task_description}" if include_task else ""
+
+ clip_prompts = (
+ inputs.get("descriptions", [""] * bs) if include_clip else [""] * bs
+ )
+
+ if include_stats:
+ stats_prompts = _build_stats_prompts(x_enc, n_lags)
+ else:
+ stats_prompts = [""] * bs
+
+ prompts = []
+ for b in range(bs):
+ parts = [
+ bos_token,
+ dataset_prompt,
+ clip_prompts[b],
+ stats_prompts[b],
+ task_prompt,
+ "Time series:",
+ ]
+ parts = [p for p in parts if p != ""]
+ parts = [
+ (p + " " if isinstance(p, str) and i > 0 else p)
+ for i, p in enumerate(parts)
+ ]
+ prompts.append(parts)
+
+ return prompts
+
+
+def encode_prompts(
+ prompts: list[list[str]],
+ tokenizer,
+ embedding_layer,
+ device: torch.device,
+) -> Tensor:
+ """Encode text prompts into embedding tensors.
+
+ Args:
+ prompts: List of string lists from ``build_prompt``.
+ tokenizer: HuggingFace tokenizer.
+ embedding_layer: LLM's input embedding layer.
+ device: Device for tensors.
+
+ Returns:
+ Padded prompt embeddings: (batch, max_tokens, d_llm).
+ """
+ batch_embeddings = []
+
+ for parts in prompts:
+ part_embeddings = []
+ for part in parts:
+ ids = tokenizer(
+ part,
+ return_tensors="pt",
+ padding=False,
+ truncation=False,
+ ).input_ids.to(device)
+ emb = embedding_layer(ids)
+ part_embeddings.append(emb)
+
+ combined = torch.cat(part_embeddings, dim=1)
+ batch_embeddings.append(combined)
+
+ max_len = max(e.shape[1] for e in batch_embeddings)
+ d_llm = batch_embeddings[0].shape[2]
+
+ pad_id = tokenizer.pad_token_id
+ if pad_id is None:
+ pad_id = tokenizer.eos_token_id or 0
+ pad_emb = embedding_layer(torch.tensor([pad_id], device=device))
+
+ padded = []
+ for emb in batch_embeddings:
+ if emb.shape[1] < max_len:
+ pad_len = max_len - emb.shape[1]
+ pad = pad_emb.expand(1, pad_len, d_llm)
+ emb = torch.cat([pad, emb], dim=1)
+ padded.append(emb)
+
+ return torch.cat(padded, dim=0)
+
+
+def _build_stats_prompts(x_enc: Tensor, n_lags: int) -> list[str]:
+ """Build per-sample input statistics prompt strings.
+
+ Reports per-feature min, max, median, trend direction, and the
+ top-N autocorrelation lags. Matches the format in the original
+ paper implementation.
+ """
+ xs = x_enc.detach()
+ if xs.ndim == 2:
+ xs = xs.unsqueeze(-1)
+
+ with torch.no_grad():
+ mins = torch.min(xs, dim=1).values.tolist()
+ maxs = torch.max(xs, dim=1).values.tolist()
+ medians = torch.median(xs.float(), dim=1).values.tolist()
+ trends = (xs.diff(dim=1).sum(dim=1) > 0).tolist()
+ lags = compute_lags(xs.float(), n_lags).tolist()
+
+ def fmt(v: Any) -> str:
+ if isinstance(v, list):
+ return "[" + ", ".join(fmt(x) for x in v) + "]"
+ if isinstance(v, bool):
+ return "upward" if v else "downward"
+ if isinstance(v, float):
+ return f"{v:.3f}"
+ return str(v)
+
+ prompts = []
+ for b in range(xs.shape[0]):
+ prompt = (
+ f"Input statistics (per feature): "
+ f"min value = {fmt(mins[b])}, "
+ f"max value = {fmt(maxs[b])}, "
+ f"median value = {fmt(medians[b])}, "
+ f"the trend of input is {fmt(trends[b])}, "
+ f"the top {n_lags} lags are {lags[b]}."
+ )
+ prompts.append(prompt)
+
+ return prompts
diff --git a/pyhealth/models/medtsllm.py b/pyhealth/models/medtsllm.py
new file mode 100644
index 000000000..7afa18730
--- /dev/null
+++ b/pyhealth/models/medtsllm.py
@@ -0,0 +1,586 @@
+# Author: Anton Barchukov
+# Paper: Chan et al., "MedTsLLM: Leveraging LLMs for Multimodal
+# Medical Time Series Analysis", MLHC 2024
+# Paper link: https://arxiv.org/abs/2408.07773
+# Original repo: https://github.com/flixpar/med-ts-llm
+# Description: Repurposes a frozen pretrained LLM as a feature
+# extractor for medical time series. Raw signals are patched,
+# projected into the LLM's embedding space via cross-attention
+# (reprogramming layer), and decoded by a lightweight task head.
+
+import warnings
+from typing import Dict, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+from pyhealth.datasets import SampleDataset
+from pyhealth.models import BaseModel
+from pyhealth.models._medtsllm import (
+ FlattenHead,
+ PatchEmbedding,
+ ReprogrammingLayer,
+ RevIN,
+ build_prompt,
+ encode_prompts,
+)
+
+_SUPPORTED_TASKS = {
+ "semantic_segmentation",
+ "segmentation",
+ "anomaly_detection",
+ "reconstruction",
+ "forecasting",
+ "pretraining",
+}
+
+
+class MedTsLLM(BaseModel):
+ """MedTsLLM: LLM-based medical time series model.
+
+ Pipeline: RevIN -> PatchEmbedding -> Reprogramming -> LLM -> OutputHead
+
+ Repurposes a frozen pretrained LLM (e.g., GPT-2, Qwen2.5) as a
+ feature extractor for medical time series tasks. The LLM weights
+ are never updated — only the reprogramming layer, patch embedding,
+ and output head are trained (~1-2M parameters).
+
+ Paper: Chan, N. et al. "MedTsLLM: Leveraging LLMs for Multimodal
+ Medical Time Series Analysis." MLHC 2024.
+
+ Note: forward-pass task branching (binary segmentation, anomaly
+ reconstruction) is not yet wired — only ``semantic_segmentation``
+ is fully supported end-to-end. The ``task`` argument currently
+ drives the default task-description prompt only.
+
+ Args:
+ dataset: PyHealth SampleDataset with ``signal`` input and
+ ``label`` output.
+ task: One of ``"semantic_segmentation"``, ``"segmentation"``,
+ ``"anomaly_detection"``, ``"reconstruction"``,
+ ``"forecasting"``, ``"pretraining"``. Drives the default
+ task-description prompt. Default
+ ``"semantic_segmentation"``.
+ seq_len: Input sequence length. Default 512.
+ n_features: Number of input channels. Default 1.
+ n_classes: Number of output classes for segmentation.
+ Default 4.
+ backbone: HuggingFace model ID for the LLM backbone.
+ Set to ``None`` to use a lightweight replacement
+ (for testing without model downloads). Default
+ ``"openai-community/gpt2"``.
+ d_model: Patch embedding dimension. Default 32.
+ d_ff: Feedforward / output head hidden dimension. Default 128.
+ n_heads: Attention heads in reprogramming layer. Default 8.
+ num_tokens: Number of word prototype tokens. Default 1024.
+ patch_len: Length of each patch. Default 16.
+ stride: Stride between patches. Default 8.
+ dropout: Dropout probability. Default 0.1.
+ covariate_mode: ``"univariate"`` or ``"concat"``. Default
+ ``"univariate"``.
+ reprogramming_layer: Optional module to replace the default
+ ``ReprogrammingLayer``. Must accept ``(target, source,
+ value)`` and return ``(batch, n_patches, d_llm)``. Use
+ ``LinearProjection`` for the no-reprogramming ablation.
+ dataset_description: Text description for prompting.
+ task_description: Text description for prompting. If empty,
+ a task-appropriate default is generated.
+ prompt_dataset: Include dataset prompt. Default True.
+ prompt_task: Include task prompt. Default True.
+ prompt_patient: Include per-patient description (age, sex,
+ diagnoses, medications) in prompt. Requires samples to
+ include a ``description`` field. Default True.
+ prompt_stats: Include per-sample input statistics
+ (min/max/median/trend/top-N autocorr lags). Default False
+ to match the cs598-pyhealth reference's dtp config; the
+ paper's ``input_stats`` prompt is an optional extra.
+ n_lags: Number of autocorrelation lags in the stats prompt.
+ Default 5.
+ llm_dtype: Torch dtype for LLM weights. Default float32.
+ word_embeddings: Pre-loaded word embeddings tensor. Required
+ when ``backbone`` is None.
+
+ Examples:
+ >>> from pyhealth.models import MedTsLLM
+ >>> model = MedTsLLM(
+ ... dataset=sample_dataset,
+ ... backbone="openai-community/gpt2",
+ ... seq_len=512,
+ ... n_classes=4,
+ ... )
+ """
+
+ def __init__(
+ self,
+ dataset: SampleDataset,
+ task: str = "semantic_segmentation",
+ seq_len: int = 512,
+ n_features: int = 1,
+ n_classes: int = 4,
+ backbone: Optional[str] = "openai-community/gpt2",
+ d_model: int = 32,
+ d_ff: int = 128,
+ n_heads: int = 8,
+ num_tokens: int = 1024,
+ patch_len: int = 16,
+ stride: int = 8,
+ dropout: float = 0.1,
+ covariate_mode: str = "univariate",
+ reprogramming_layer: Optional[nn.Module] = None,
+ dataset_description: str = "",
+ task_description: str = "",
+ prompt_dataset: bool = True,
+ prompt_task: bool = True,
+ prompt_patient: bool = True,
+ prompt_stats: bool = False,
+ n_lags: int = 5,
+ llm_dtype: torch.dtype = torch.float32,
+ word_embeddings: Optional[Tensor] = None,
+ ):
+ super(MedTsLLM, self).__init__(dataset=dataset)
+
+ if task not in _SUPPORTED_TASKS:
+ raise ValueError(
+ f"task must be one of {sorted(_SUPPORTED_TASKS)}, "
+ f"got {task!r}"
+ )
+
+ self.task = task
+ self.seq_len = seq_len
+ self.pred_len = seq_len
+ self.n_features = n_features
+ self.n_classes = n_classes
+ self.d_model = d_model
+ self.d_ff = d_ff
+ self.covariate_mode = covariate_mode
+ self.n_lags = n_lags
+
+ # Compute patch count
+ self.n_patches = (seq_len - patch_len) // stride + 2
+
+ # Effective d_model for concat covariate mode
+ d_model_effective = d_model
+ if covariate_mode == "concat":
+ d_model_effective = d_model * n_features
+
+ # Setup LLM backbone or replacement
+ if backbone is not None:
+ self._setup_llm(backbone, llm_dtype)
+ elif word_embeddings is not None:
+ self._setup_replacement(word_embeddings)
+ else:
+ raise ValueError(
+ "Either backbone or word_embeddings must be provided."
+ )
+
+ # Trainable layers
+ self.normalize_layers = RevIN(n_features, affine=False)
+ self.patch_embedding = PatchEmbedding(
+ d_model=d_model,
+ patch_len=patch_len,
+ stride=stride,
+ dropout=dropout,
+ )
+ self.mapping_layer = nn.Linear(self.vocab_size, num_tokens)
+
+ if reprogramming_layer is not None:
+ self.reprogramming_layer = reprogramming_layer
+ else:
+ self.reprogramming_layer = ReprogrammingLayer(
+ d_model=d_model_effective,
+ n_heads=n_heads,
+ d_keys=d_ff,
+ d_llm=self.d_llm,
+ attention_dropout=dropout,
+ )
+
+ self.embedding_downsample = nn.Linear(self.d_llm, d_ff)
+
+ # Output head size depends on task:
+ # semantic_segmentation => one logit per class per step
+ # segmentation => a single binary logit per step
+ # anomaly_detection /
+ # reconstruction => one value per feature per step
+ # forecasting /
+ # pretraining => same as reconstruction
+ self.n_outputs_per_step = self._compute_n_outputs_per_step(
+ task, n_classes, n_features
+ )
+ self.output_projection = FlattenHead(
+ n_features_in=d_ff * self.n_patches,
+ n_outputs=self.n_outputs_per_step * self.pred_len,
+ )
+
+ # Prompting config
+ self.dataset_description = dataset_description
+ self.task_description = (
+ task_description
+ or self._default_task_description(task, seq_len, self.pred_len)
+ )
+ self.prompt_config = {
+ "dataset": prompt_dataset,
+ "task": prompt_task,
+ "patient": prompt_patient,
+ "stats": prompt_stats,
+ }
+
+ @staticmethod
+ def _compute_n_outputs_per_step(
+ task: str, n_classes: int, n_features: int
+ ) -> int:
+ """Resolve the per-timestep output dimension from the task."""
+ if task == "semantic_segmentation":
+ return n_classes
+ if task == "segmentation":
+ return 1
+ if task in (
+ "anomaly_detection",
+ "reconstruction",
+ "forecasting",
+ "pretraining",
+ ):
+ return n_features
+ raise ValueError(f"Unsupported task: {task!r}")
+
+ @staticmethod
+ def _default_task_description(
+ task: str, seq_len: int, pred_len: int
+ ) -> str:
+ """Generate a task-appropriate default task description.
+
+ Mirrors the original paper implementation's
+ ``get_task_description``.
+ """
+ if task in ("forecasting", "pretraining"):
+ return (
+ f"Forecast the next {pred_len} steps given the "
+ f"previous {seq_len} steps of data."
+ )
+ if task in ("anomaly_detection", "reconstruction"):
+ return (
+ f"Reconstruct the past {seq_len} steps of data as "
+ "accurately as possible using the following "
+ "information."
+ )
+ if task == "semantic_segmentation":
+ return (
+ f"Classify the past {seq_len} steps of data as "
+ "accurately as possible using the following "
+ "information."
+ )
+ if task == "segmentation":
+ return (
+ f"Identify the change points in the past {seq_len} "
+ "steps of data to segment the sequence."
+ )
+ return ""
+
+ def parameters(self, recurse: bool = True):
+ """Yield only trainable parameters.
+
+ Overrides nn.Module.parameters() to exclude the frozen LLM
+ backbone. This prevents PyHealth's Trainer from allocating
+ optimizer state (momentum, variance) for frozen parameters,
+ saving ~2x the frozen param count in memory.
+
+ For GPT-2 (137M) this saves ~1GB. For Qwen2.5-1.5B it
+ saves ~12GB.
+ """
+ for p in super().parameters(recurse=recurse):
+ if p.requires_grad:
+ yield p
+
+ def named_parameters(
+ self,
+ prefix: str = "",
+ recurse: bool = True,
+ remove_duplicate: bool = True,
+ ):
+ """Yield only trainable named parameters.
+
+ See parameters() for rationale.
+ """
+ for name, p in super().named_parameters(
+ prefix=prefix,
+ recurse=recurse,
+ remove_duplicate=remove_duplicate,
+ ):
+ if p.requires_grad:
+ yield name, p
+
+ def _setup_llm(self, backbone: str, dtype: torch.dtype) -> None:
+ """Load a frozen HuggingFace LLM as backbone."""
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
+
+ llm_config = AutoConfig.from_pretrained(backbone)
+ llm_config.output_hidden_states = True
+
+ self.llm = AutoModel.from_pretrained(
+ backbone,
+ config=llm_config,
+ torch_dtype=dtype,
+ attn_implementation="sdpa",
+ )
+ self.tokenizer = AutoTokenizer.from_pretrained(backbone)
+ if self.tokenizer.pad_token is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ # Freeze LLM
+ for param in self.llm.parameters():
+ param.requires_grad = False
+
+ # Extract word embeddings
+ we = self.llm.get_input_embeddings().weight.detach().cpu()
+ if we.shape[0] > 100_000:
+ inds = torch.linspace(
+ 0, we.shape[0] - 1, 100_000, dtype=torch.long
+ )
+ we = we[inds].clone()
+ self.word_embeddings = nn.Parameter(we, requires_grad=False)
+ self.vocab_size = self.word_embeddings.shape[0]
+ self.d_llm = self.word_embeddings.shape[1]
+ self._use_llm = True
+
+ def _setup_replacement(self, word_embeddings: Tensor) -> None:
+ """Setup a small feedforward network replacing the LLM."""
+ self.word_embeddings = nn.Parameter(
+ word_embeddings.detach().clone(), requires_grad=False
+ )
+ self.vocab_size = word_embeddings.shape[0]
+ self.d_llm = word_embeddings.shape[1]
+ self.llm_replacement = nn.Sequential(
+ nn.Linear(self.d_llm, self.d_llm),
+ nn.GELU(),
+ nn.Linear(self.d_llm, self.d_llm),
+ )
+ self.tokenizer = None
+ self._use_llm = False
+
+ def forward(self, **kwargs) -> Dict[str, Tensor]:
+ """Forward pass following PyHealth's BaseModel contract.
+
+ Args:
+ **kwargs: Must include ``signal`` tensor of shape
+ ``(batch, seq_len)`` or ``(batch, seq_len, n_features)``.
+ May include ``label`` tensor for loss computation and
+ ``description`` list of per-sample strings for the
+ patient prompt.
+
+ Returns:
+ Dict with keys: ``logit``, ``y_prob``, and optionally
+ ``loss``, ``y_true``.
+ """
+ signal = kwargs[self.feature_keys[0]].to(self.device)
+ if signal.ndim == 2:
+ signal = signal.unsqueeze(-1)
+ bs = signal.shape[0]
+
+ # Encode time series
+ enc_out = self._encode_ts(signal)
+
+ # Build + prepend prompt (only when a real LLM is attached)
+ if self._use_llm and any(self.prompt_config.values()):
+ prompt_enc = self._build_prompt_embeddings(signal, bs, kwargs)
+ enc_out = torch.cat([prompt_enc, enc_out], dim=1)
+
+ # LLM or replacement forward
+ if self._use_llm:
+ dec_out = self.llm(inputs_embeds=enc_out).last_hidden_state
+ dec_out = dec_out.to(device=signal.device, dtype=signal.dtype)
+ else:
+ dec_out = self.llm_replacement(enc_out)
+
+ # Keep last n_patches outputs
+ dec_out = dec_out[:, -self.n_patches :, :]
+
+ # Downsample and project
+ dec_out = self.embedding_downsample(dec_out)
+ dec_out = dec_out.permute(0, 2, 1).contiguous()
+ dec_out = self.output_projection(dec_out)
+
+ # Reshape to (bs, pred_len, n_outputs_per_step)
+ dec_out = dec_out.view(
+ bs, self.pred_len, self.n_outputs_per_step
+ )
+
+ label_key = self.label_keys[0] if self.label_keys else None
+ y_true = (
+ kwargs[label_key].to(self.device)
+ if label_key and label_key in kwargs
+ else None
+ )
+
+ return self._task_head(dec_out, signal, y_true)
+
+ def _task_head(
+ self,
+ dec_out: Tensor,
+ signal: Tensor,
+ y_true: Optional[Tensor],
+ ) -> Dict[str, Tensor]:
+ """Apply task-specific post-processing and loss.
+
+ Branches on ``self.task``:
+
+ - ``semantic_segmentation``: softmax probs, cross-entropy loss.
+ - ``segmentation``: binary logit per step, BCE-with-logits loss.
+ - ``anomaly_detection`` / ``reconstruction``: RevIN-denormalized
+ reconstruction in the original signal space, MSE loss against
+ the input signal.
+ - ``forecasting`` / ``pretraining``: denormalized prediction,
+ MSE loss against the input signal (placeholder — true
+ forecasting needs a future-signal label).
+ """
+ output: Dict[str, Tensor] = {}
+ if y_true is not None:
+ output["y_true"] = y_true
+
+ if self.task == "semantic_segmentation":
+ # dec_out: (bs, pred_len, n_classes)
+ logit = dec_out
+ output["logit"] = logit
+ output["y_prob"] = F.softmax(logit, dim=-1)
+ if y_true is not None:
+ output["loss"] = F.cross_entropy(
+ logit.view(-1, self.n_classes),
+ y_true.view(-1).long(),
+ )
+ return output
+
+ if self.task == "segmentation":
+ # dec_out: (bs, pred_len, 1) -> (bs, pred_len)
+ logit = dec_out.squeeze(-1)
+ output["logit"] = logit
+ output["y_prob"] = torch.sigmoid(logit)
+ if y_true is not None:
+ output["loss"] = F.binary_cross_entropy_with_logits(
+ logit, y_true.float()
+ )
+ return output
+
+ if self.task in (
+ "anomaly_detection",
+ "reconstruction",
+ "forecasting",
+ "pretraining",
+ ):
+ # dec_out: (bs, pred_len, n_features) — denormalize to
+ # recover original signal space before computing loss.
+ prediction = self.normalize_layers(dec_out, "denorm")
+ output["logit"] = prediction
+ output["y_prob"] = prediction
+ # Train to reconstruct the input signal. For
+ # anomaly_detection, labels (anomaly masks) are used at
+ # eval time for scoring — not during training.
+ output["loss"] = F.mse_loss(prediction, signal)
+ return output
+
+ raise ValueError(f"Unsupported task in forward: {self.task!r}")
+
+ def _build_prompt_embeddings(
+ self, signal: Tensor, bs: int, kwargs: Dict
+ ) -> Tensor:
+ """Construct and encode the prompt for the current batch.
+
+ No caching: prompts are rebuilt every forward pass to match
+ the original paper implementation. Dataset/task/stats prompts
+ are cheap to re-tokenize; patient prompts depend on batch
+ contents and can't be cached anyway.
+ """
+ include_patient = self.prompt_config.get("patient", False)
+ bos = getattr(self.tokenizer, "bos_token", None) or ""
+
+ prompt_inputs: Dict = {"x_enc": signal}
+
+ if include_patient:
+ description = kwargs.get("description")
+ if description is None:
+ warnings.warn(
+ "prompt_patient=True but no 'description' field "
+ "provided in batch. Patient prompt will be empty. "
+ "Ensure your task emits 'description' per sample.",
+ RuntimeWarning,
+ stacklevel=3,
+ )
+ descriptions = [""] * bs
+ else:
+ descriptions = _coerce_descriptions(description, bs)
+ prompt_inputs["descriptions"] = descriptions
+
+ prompts = build_prompt(
+ prompt_inputs,
+ dataset_description=self.dataset_description,
+ task_description=self.task_description,
+ include_dataset=self.prompt_config.get("dataset", False),
+ include_task=self.prompt_config.get("task", False),
+ include_clip=include_patient,
+ include_stats=self.prompt_config.get("stats", False),
+ n_lags=self.n_lags,
+ bos_token=bos,
+ )
+
+ with torch.no_grad():
+ return encode_prompts(
+ prompts,
+ self.tokenizer,
+ self.llm.get_input_embeddings(),
+ signal.device,
+ )
+
+ def _encode_ts(self, x_enc: Tensor) -> Tensor:
+ """Encode time series: normalize -> patch -> reprogram.
+
+ Args:
+ x_enc: (batch, seq_len, n_features).
+
+ Returns:
+ Encoded representation: (batch, n_patches, d_llm).
+ """
+ bs, seq_len, n_features = x_enc.shape
+
+ x_enc = self.normalize_layers(x_enc, "norm")
+ x_enc = x_enc.permute(0, 2, 1).contiguous()
+
+ enc_out, _ = self.patch_embedding(x_enc)
+
+ # Project word embeddings to prototypes
+ we = self.word_embeddings.to(self.mapping_layer.weight.dtype)
+ source_embeddings = self.mapping_layer(
+ we.permute(1, 0)
+ ).permute(1, 0)
+
+ # Handle covariate modes
+ if self.covariate_mode == "concat":
+ enc_out = enc_out.reshape(
+ bs, n_features, self.n_patches, self.d_model
+ )
+ enc_out = enc_out.permute(0, 2, 1, 3)
+ enc_out = enc_out.reshape(
+ bs, self.n_patches, n_features * self.d_model
+ )
+
+ enc_out = self.reprogramming_layer(
+ enc_out, source_embeddings, source_embeddings
+ )
+
+ return enc_out
+
+
+def _coerce_descriptions(
+ description, bs: int
+) -> list[str]:
+ """Normalize a batch description field into a list[str] of length bs.
+
+ Handles the common shapes produced by PyHealth's default collate
+ (a list of strings) as well as scalar-string edge cases.
+ """
+ if isinstance(description, str):
+ return [description] * bs
+ if isinstance(description, list):
+ return description
+ try:
+ return list(description)
+ except TypeError:
+ return [description] * bs
diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py
index a32618f9c..fd3682a27 100644
--- a/pyhealth/tasks/__init__.py
+++ b/pyhealth/tasks/__init__.py
@@ -22,6 +22,10 @@
drug_recommendation_mimic4_fn,
drug_recommendation_omop_fn,
)
+from .ecg_anomaly_detection import ECGAnomalyDetection
+from .ecg_boundary_detection import ECGBoundaryDetection
+from .ecg_wave_segmentation import ECGWaveSegmentation
+from .respiratory_boundary_detection import RespiratoryBoundaryDetection
from .in_hospital_mortality_mimic4 import InHospitalMortalityMIMIC4
from .length_of_stay_prediction import (
LengthOfStayPredictioneICU,
diff --git a/pyhealth/tasks/ecg_anomaly_detection.py b/pyhealth/tasks/ecg_anomaly_detection.py
new file mode 100644
index 000000000..52e1e3163
--- /dev/null
+++ b/pyhealth/tasks/ecg_anomaly_detection.py
@@ -0,0 +1,152 @@
+# Author: Anton Barchukov
+# Paper: Chan et al., "MedTsLLM: Leveraging LLMs for Multimodal
+# Medical Time Series Analysis", MLHC 2024
+# Description: Beat-level anomaly detection task for the MIT-BIH
+# Arrhythmia Database. Each timestep is labeled 1 if it falls
+# within a beat interval annotated as an abnormal rhythm type,
+# and 0 otherwise (including non-beat gaps). Training is
+# reconstruction-style: MedTsLLM (task="anomaly_detection")
+# learns to reconstruct the ECG and anomalous beats are flagged
+# at eval time by elevated reconstruction error.
+
+from typing import Any, Dict
+
+import numpy as np
+
+from pyhealth.tasks import BaseTask
+from pyhealth.tasks.ecg_boundary_detection import (
+ _load_signal_and_annotations,
+)
+
+# Rhythm annotation symbols considered "normal" — everything else
+# is an abnormal beat in the MedTsLLM paper's setup.
+_NORMAL_BEATS = {"N", "L", "R", "e", "j"}
+
+
+class ECGAnomalyDetection(BaseTask):
+ """Beat-level arrhythmia anomaly detection on MIT-BIH ECG.
+
+ Produces per-timestep binary labels over a downsampled 2-channel
+ ECG. A timestep is labeled ``1`` when it falls inside the
+ interval of an abnormal-type beat annotation (all symbols other
+ than ``{"N", "L", "R", "e", "j"}`` and the rhythm-change marker
+ ``"+"``) and ``0`` otherwise.
+
+ Signal decoding, downsampling, and optional trimming are handled
+ at the dataset level via :class:`MITBIHDataset`'s ``preprocess``,
+ ``downsample_factor``, and ``trim`` kwargs.
+
+ Args:
+ window_size: Number of time points per window. Default 128.
+ step_size: Stride between consecutive windows. Default 128.
+
+ Attributes:
+ task_name (str): ``"ECGAnomalyDetection"``
+ input_schema (Dict[str, str]): ``{"signal": "tensor"}``
+ output_schema (Dict[str, str]): ``{"label": "tensor"}``
+
+ Examples:
+ >>> from pyhealth.datasets import MITBIHDataset
+ >>> from pyhealth.tasks import ECGAnomalyDetection
+ >>> dataset = MITBIHDataset(
+ ... root="/path/to/mitdb/",
+ ... preprocess=True,
+ ... paper_split="abnormal_sorted",
+ ... )
+ >>> sample_ds = dataset.set_task(ECGAnomalyDetection())
+ """
+
+ task_name: str = "ECGAnomalyDetection"
+ input_schema: Dict[str, str] = {"signal": "tensor"}
+ output_schema: Dict[str, str] = {"label": "tensor"}
+
+ def __init__(
+ self,
+ window_size: int = 128,
+ step_size: int = 128,
+ ):
+ self.window_size = window_size
+ self.step_size = step_size
+ super().__init__()
+
+ def __call__(self, patient: Any) -> list[dict[str, Any]]:
+ """Process a single patient into windowed samples."""
+ pid = patient.patient_id
+ events = patient.get_events()
+
+ samples = []
+ for event in events:
+ result = _load_signal_and_annotations(event)
+ if result is None:
+ continue
+ signal, ann_sample, ann_symbol = result
+
+ labels = _build_anomaly_mask(
+ ann_sample, ann_symbol, len(signal)
+ )
+
+ description = _build_description(event)
+ split = getattr(event, "split", "") or ""
+
+ for start in range(
+ 0, len(signal) - self.window_size + 1, self.step_size
+ ):
+ end = start + self.window_size
+ samples.append({
+ "patient_id": pid,
+ "signal": signal[start:end],
+ "label": labels[start:end].astype(np.float32),
+ "description": description,
+ "split": split,
+ })
+
+ return samples
+
+
+def _build_anomaly_mask(
+ ann_sample: np.ndarray,
+ ann_symbol: np.ndarray,
+ signal_len: int,
+) -> np.ndarray:
+ """Label each timestep 1 if inside an abnormal beat interval.
+
+ Rhythm-change markers (``"+"``) are ignored. For beat ``i`` with
+ sample ``s_i``, the interval ``[s_i, s_{i+1})`` is marked 1 when
+ ``symbol_i`` is an abnormal beat type. The last beat extends to
+ ``signal_len``.
+ """
+ labels = np.zeros(signal_len, dtype=np.int64)
+ # Filter out rhythm-change markers to match the paper's labeling.
+ beat_mask = np.array(
+ [str(s) != "+" for s in ann_symbol], dtype=bool
+ )
+ beats_sample = np.asarray(ann_sample, dtype=np.int64)[beat_mask]
+ beats_symbol = np.asarray(ann_symbol)[beat_mask]
+
+ for i, symbol in enumerate(beats_symbol):
+ if str(symbol) in _NORMAL_BEATS:
+ continue
+ start = max(0, int(beats_sample[i]))
+ if i + 1 < len(beats_sample):
+ end = int(beats_sample[i + 1])
+ else:
+ end = signal_len
+ end = min(end, signal_len)
+ if start < end:
+ labels[start:end] = 1
+
+ return labels
+
+
+def _build_description(event) -> str:
+ """Compose per-patient description from event demographics."""
+ parts: list[str] = []
+ for attr in ("age", "sex", "medications"):
+ value = getattr(event, attr, "")
+ if value is None:
+ continue
+ text = str(value).strip()
+ if not text or text.lower() == "nan":
+ continue
+ parts.append(f"{attr}: {text}")
+ return ", ".join(parts)
diff --git a/pyhealth/tasks/ecg_boundary_detection.py b/pyhealth/tasks/ecg_boundary_detection.py
new file mode 100644
index 000000000..0af98273d
--- /dev/null
+++ b/pyhealth/tasks/ecg_boundary_detection.py
@@ -0,0 +1,151 @@
+# Author: Anton Barchukov
+# Paper: Chan et al., "MedTsLLM: Leveraging LLMs for Multimodal
+# Medical Time Series Analysis", MLHC 2024
+# Description: R-peak boundary detection task for the MIT-BIH dataset.
+# Detects beat boundaries (R-peaks) in ECG signals.
+
+from typing import Any, Dict
+
+import numpy as np
+
+from pyhealth.tasks import BaseTask
+
+# Normal beat types (all others are anomalies)
+_NORMAL_BEATS = {"N", "L", "R", "e", "j"}
+
+
+class ECGBoundaryDetection(BaseTask):
+ """R-peak boundary detection on MIT-BIH ECG signals.
+
+ Binary classification of each time point as a beat boundary
+ (R-peak) or not. Signal decoding, downsampling, and optional
+ trimming are handled at the dataset level via
+ :class:`MITBIHDataset`'s ``preprocess``, ``downsample_factor``,
+ and ``trim`` kwargs.
+
+ Args:
+ window_size: Number of time points per window. Default 256.
+ step_size: Stride between consecutive windows. Default 256.
+
+ Attributes:
+ task_name (str): ``"ECGBoundaryDetection"``
+ input_schema (Dict[str, str]): ``{"signal": "tensor"}``
+ output_schema (Dict[str, str]): ``{"label": "tensor"}``
+
+ Examples:
+ >>> from pyhealth.datasets import MITBIHDataset
+ >>> dataset = MITBIHDataset(
+ ... root="/path/to/mitdb/", preprocess=True
+ ... )
+ >>> sample_dataset = dataset.set_task(ECGBoundaryDetection())
+ """
+
+ task_name: str = "ECGBoundaryDetection"
+ input_schema: Dict[str, str] = {"signal": "tensor"}
+ output_schema: Dict[str, str] = {"label": "tensor"}
+
+ def __init__(
+ self,
+ window_size: int = 256,
+ step_size: int = 256,
+ ):
+ self.window_size = window_size
+ self.step_size = step_size
+ super().__init__()
+
+ def __call__(self, patient: Any) -> list[dict[str, Any]]:
+ """Process a single patient into windowed samples.
+
+ Args:
+ patient: A Patient object from MITBIHDataset.
+
+ Returns:
+ List of sample dicts with signal and binary label arrays.
+ """
+ pid = patient.patient_id
+ events = patient.get_events()
+
+ samples = []
+ for event in events:
+ result = _load_signal_and_annotations(event)
+ if result is None:
+ continue
+ signal, ann_sample, _ = result
+
+ # Build binary R-peak mask.
+ labels = np.zeros(len(signal), dtype=np.int64)
+ for s in ann_sample:
+ idx = int(s)
+ if 0 <= idx < len(signal):
+ labels[idx] = 1
+
+ description = _build_description(event)
+ split = getattr(event, "split", "") or ""
+
+ for start in range(
+ 0, len(signal) - self.window_size + 1, self.step_size
+ ):
+ end = start + self.window_size
+ samples.append({
+ "patient_id": pid,
+ "signal": signal[start:end],
+ "label": labels[start:end].astype(np.float32),
+ "description": description,
+ "split": split,
+ })
+
+ return samples
+
+
+def _load_signal_and_annotations(
+ event,
+ default_downsample: int = 3,
+) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None:
+ """Return (signal, ann_sample, ann_symbol) from cache or wfdb.
+
+ Cache path is preferred. Wfdb fallback downsamples by
+ ``default_downsample`` but does not trim — trim only applies to
+ cached arrays built via ``MITBIHDataset(preprocess=True, trim=...)``.
+ """
+ import os
+
+ processed_file = getattr(event, "processed_file", "") or ""
+ if processed_file and os.path.exists(processed_file):
+ with np.load(processed_file, allow_pickle=False) as npz:
+ return (
+ np.asarray(npz["signal"], dtype=np.float32),
+ np.asarray(npz["ann_sample"], dtype=np.int64),
+ np.asarray(npz["ann_symbol"]),
+ )
+
+ import wfdb
+
+ try:
+ record = wfdb.rdrecord(event.signal_file)
+ ann = wfdb.rdann(event.signal_file, extension=event.annotation_file)
+ except FileNotFoundError:
+ return None
+
+ signal = record.p_signal.astype(np.float32)
+ if default_downsample > 1:
+ signal = signal[::default_downsample]
+ ds_samples = np.asarray(ann.sample) // default_downsample
+ ann_symbols = np.asarray(ann.symbol)
+ in_bounds = (ds_samples >= 0) & (ds_samples < len(signal))
+ return (
+ signal,
+ ds_samples[in_bounds].astype(np.int64),
+ ann_symbols[in_bounds],
+ )
+
+
+def _build_description(event) -> str:
+ """Compose per-patient description from event demographics."""
+ parts: list[str] = []
+ if getattr(event, "age", ""):
+ parts.append(f"age: {event.age}")
+ if getattr(event, "sex", ""):
+ parts.append(f"sex: {event.sex}")
+ if getattr(event, "medications", ""):
+ parts.append(f"medications: {event.medications}")
+ return ", ".join(parts)
diff --git a/pyhealth/tasks/ecg_wave_segmentation.py b/pyhealth/tasks/ecg_wave_segmentation.py
new file mode 100644
index 000000000..0d07dce32
--- /dev/null
+++ b/pyhealth/tasks/ecg_wave_segmentation.py
@@ -0,0 +1,180 @@
+# Author: Anton Barchukov
+# Paper: Chan et al., "MedTsLLM: Leveraging LLMs for Multimodal
+# Medical Time Series Analysis", MLHC 2024
+# Paper link: https://arxiv.org/abs/2408.07773
+# Description: Per-timestep ECG wave segmentation task for the LUDB
+# dataset. Classifies each sample as background (0), P wave (1),
+# QRS complex (2), or T wave (3).
+
+from typing import Any, Dict
+
+import numpy as np
+
+from pyhealth.tasks import BaseTask
+
+
+# wfdb annotation symbol -> class index
+_WAVE_LABELS = {"p": 1, "N": 2, "t": 3}
+
+
+class ECGWaveSegmentation(BaseTask):
+ """Per-timestep ECG wave segmentation on LUDB.
+
+ Classifies each time point in a 12-lead ECG as one of:
+ background (0), P wave (1), QRS complex (2), or T wave (3).
+
+ The raw signal is windowed into fixed-length chunks. Each chunk
+ produces a sample with a signal tensor and a per-timestep label
+ array of the same length.
+
+ Trim/decode are controlled at the dataset level via
+ :class:`LUDBDataset`'s ``preprocess`` and ``trim`` kwargs; this
+ task only handles windowing and emission.
+
+ Args:
+ window_size: Number of time points per window. Default 512.
+ step_size: Stride between consecutive windows. Default 256.
+
+ Attributes:
+ task_name (str): ``"ECGWaveSegmentation"``
+ input_schema (Dict[str, str]): ``{"signal": "tensor"}``
+ output_schema (Dict[str, str]): ``{"label": "tensor"}``
+
+ Examples:
+ >>> from pyhealth.datasets import LUDBDataset
+ >>> dataset = LUDBDataset(root="/path/to/ludb/", preprocess=True)
+ >>> sample_dataset = dataset.set_task(ECGWaveSegmentation())
+ >>> sample_dataset.samples[0].keys()
+ dict_keys(['patient_id', 'lead', 'signal', 'label', ...])
+ """
+
+ task_name: str = "ECGWaveSegmentation"
+ input_schema: Dict[str, str] = {"signal": "tensor"}
+ output_schema: Dict[str, str] = {"label": "tensor"}
+
+ def __init__(
+ self,
+ window_size: int = 512,
+ step_size: int = 256,
+ ):
+ self.window_size = window_size
+ self.step_size = step_size
+ super().__init__()
+
+ def __call__(self, patient: Any) -> list[dict[str, Any]]:
+ """Process a single patient into windowed samples.
+
+ Args:
+ patient: A Patient object from LUDBDataset. Each event
+ represents one ECG lead with ``signal_file`` (wfdb
+ record path), ``label_file`` (annotation extension),
+ and optionally ``processed_file`` (cached ``.npz``).
+
+ Returns:
+ List of sample dicts, each containing:
+ - ``patient_id``: str
+ - ``lead``: str, ECG lead name
+ - ``signal``: np.ndarray, shape ``(window_size,)``
+ - ``label``: np.ndarray of int, shape ``(window_size,)``
+ """
+ pid = patient.patient_id
+ events = patient.get_events()
+
+ samples = []
+ for event in events:
+ result = _load_signal_and_labels(event)
+ if result is None:
+ continue
+ signal, labels = result
+
+ description = _build_description(event)
+ split = getattr(event, "split", "") or ""
+ lead = event.lead
+
+ # Window into fixed-length chunks
+ for start in range(
+ 0, len(signal) - self.window_size + 1, self.step_size
+ ):
+ end = start + self.window_size
+ samples.append({
+ "patient_id": pid,
+ "lead": lead,
+ "signal": signal[start:end],
+ "label": labels[start:end],
+ "description": description,
+ "split": split,
+ })
+
+ return samples
+
+
+def _load_signal_and_labels(event) -> tuple[np.ndarray, np.ndarray] | None:
+ """Return (signal, labels) for one event, preferring the ``.npz`` cache."""
+ processed_file = getattr(event, "processed_file", "") or ""
+ if processed_file:
+ import os
+
+ if os.path.exists(processed_file):
+ with np.load(processed_file, allow_pickle=False) as npz:
+ return (
+ np.asarray(npz["signal"], dtype=np.float32),
+ np.asarray(npz["labels"], dtype=np.int64),
+ )
+
+ # Fallback: decode wfdb on demand.
+ import wfdb
+
+ try:
+ record = wfdb.rdrecord(event.signal_file)
+ except FileNotFoundError:
+ return None
+ try:
+ lead_idx = record.sig_name.index(event.lead)
+ except ValueError:
+ return None
+ signal = record.p_signal[:, lead_idx].astype(np.float32)
+
+ try:
+ ann = wfdb.rdann(event.signal_file, extension=event.label_file)
+ except FileNotFoundError:
+ return None
+
+ labels = np.zeros(len(signal), dtype=np.int64)
+ i = 0
+ while i < len(ann.symbol):
+ sym = ann.symbol[i]
+ if sym == "(" and i + 2 < len(ann.symbol):
+ wave_type = ann.symbol[i + 1]
+ onset = ann.sample[i]
+ offset = (
+ ann.sample[i + 2]
+ if ann.symbol[i + 2] == ")"
+ else ann.sample[i + 1]
+ )
+ if wave_type in _WAVE_LABELS:
+ labels[onset : offset + 1] = _WAVE_LABELS[wave_type]
+ i += 3
+ else:
+ i += 1
+
+ return signal, labels
+
+
+def _build_description(event) -> str:
+ """Compose a per-patient description string from event attributes.
+
+ Returns a comma-separated ``"age: X, sex: Y, diagnoses: Z"``
+ string built from the demographics attached to each LUDB event.
+ NaN values (from pandas reading empty-string cells) and missing
+ attributes are skipped.
+ """
+ parts: list[str] = []
+ for attr in ("age", "sex", "diagnoses"):
+ value = getattr(event, attr, "")
+ if value is None:
+ continue
+ text = str(value).strip()
+ if not text or text.lower() == "nan":
+ continue
+ parts.append(f"{attr}: {text}")
+ return ", ".join(parts)
diff --git a/pyhealth/tasks/respiratory_boundary_detection.py b/pyhealth/tasks/respiratory_boundary_detection.py
new file mode 100644
index 000000000..0557fdd83
--- /dev/null
+++ b/pyhealth/tasks/respiratory_boundary_detection.py
@@ -0,0 +1,145 @@
+# Author: Anton Barchukov
+# Paper: Chan et al., "MedTsLLM: Leveraging LLMs for Multimodal
+# Medical Time Series Analysis", MLHC 2024
+# Description: Breath boundary detection task for the BIDMC dataset.
+# Detects breath boundaries in respiratory impedance signals.
+
+from typing import Any, Dict
+
+import numpy as np
+
+from pyhealth.tasks import BaseTask
+
+# RESP, PLETH, and ECG lead II — the 3 channels used in the paper.
+_TARGET_CHANNELS = ["RESP,", "PLETH,", "II,"]
+
+
+class RespiratoryBoundaryDetection(BaseTask):
+ """Breath boundary detection on BIDMC respiratory signals.
+
+ Binary classification of each time point as a breath boundary
+ or not. The model is trained on 3 channels (RESP, PLETH, II)
+ and predicts boundary locations.
+
+ Args:
+ window_size: Number of time points per window. Default 256.
+ step_size: Stride between consecutive windows. Default 128.
+ annotator: Which annotator's labels to use (1 or 2).
+ Default 1.
+
+ Attributes:
+ task_name (str): ``"RespiratoryBoundaryDetection"``
+ input_schema (Dict[str, str]): ``{"signal": "tensor"}``
+ output_schema (Dict[str, str]): ``{"label": "tensor"}``
+
+ Examples:
+ >>> from pyhealth.datasets import BIDMCDataset
+ >>> dataset = BIDMCDataset(root="/path/to/bidmc/")
+ >>> sample_dataset = dataset.set_task(
+ ... RespiratoryBoundaryDetection()
+ ... )
+ """
+
+ task_name: str = "RespiratoryBoundaryDetection"
+ input_schema: Dict[str, str] = {"signal": "tensor"}
+ output_schema: Dict[str, str] = {"label": "tensor"}
+
+ def __init__(
+ self,
+ window_size: int = 256,
+ step_size: int = 128,
+ annotator: int = 1,
+ ):
+ self.window_size = window_size
+ self.step_size = step_size
+ self.annotator = annotator
+ super().__init__()
+
+ def __call__(self, patient: Any) -> list[dict[str, Any]]:
+ """Process a single patient into windowed samples.
+
+ Args:
+ patient: A Patient object from BIDMCDataset.
+
+ Returns:
+ List of sample dicts with signal and binary label arrays.
+ """
+ pid = patient.patient_id
+ events = patient.get_events()
+
+ samples = []
+ for event in events:
+ result = _load_signal_and_annotations(event)
+ if result is None:
+ continue
+ signal, ann_sample, ann_aux = result
+ if signal.shape[1] != 3:
+ continue
+
+ labels = np.zeros(len(signal), dtype=np.int64)
+ ann_tag = f"ann{self.annotator}"
+ for s, aux in zip(ann_sample, ann_aux):
+ if str(aux) == ann_tag and 0 <= int(s) < len(signal):
+ labels[int(s)] = 1
+
+ # Build patient description
+ desc_parts = []
+ if event.age:
+ desc_parts.append(f"age: {event.age}")
+ if event.sex:
+ desc_parts.append(f"sex: {event.sex}")
+ if event.location:
+ desc_parts.append(f"location: {event.location}")
+ description = ", ".join(desc_parts)
+ split = getattr(event, "split", "") or ""
+
+ # Window into fixed-length chunks
+ for start in range(
+ 0, len(signal) - self.window_size + 1, self.step_size
+ ):
+ end = start + self.window_size
+ samples.append({
+ "patient_id": pid,
+ "signal": signal[start:end],
+ "label": labels[start:end].astype(np.float32),
+ "description": description,
+ "split": split,
+ })
+
+ return samples
+
+
+def _load_signal_and_annotations(
+ event,
+) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None:
+ """Return (signal, ann_sample, ann_aux) from ``.npz`` cache or wfdb."""
+ import os
+
+ processed_file = getattr(event, "processed_file", "") or ""
+ if processed_file and os.path.exists(processed_file):
+ with np.load(processed_file, allow_pickle=False) as npz:
+ return (
+ np.asarray(npz["signal"], dtype=np.float32),
+ np.asarray(npz["ann_sample"], dtype=np.int64),
+ np.asarray(npz["ann_aux"]),
+ )
+
+ import wfdb
+
+ try:
+ record = wfdb.rdrecord(event.signal_file)
+ ann = wfdb.rdann(event.signal_file, extension=event.annotation_file)
+ except FileNotFoundError:
+ return None
+
+ col_idx = [
+ record.sig_name.index(ch)
+ for ch in _TARGET_CHANNELS
+ if ch in record.sig_name
+ ]
+ signal = record.p_signal[:, col_idx].astype(np.float32)
+ return (
+ signal,
+ np.asarray(ann.sample, dtype=np.int64),
+ np.asarray(ann.aux_note),
+ )
diff --git a/pyproject.toml b/pyproject.toml
index 98f88d47b..a57d42a16 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -43,6 +43,7 @@ dependencies = [
"pyarrow~=22.0.0",
"narwhals~=2.13.0",
"more-itertools~=10.8.0",
+ "wfdb>=4.1.0",
"einops>=0.8.0",
"linear-attention-transformer>=0.19.1",
]
diff --git a/test-resources/core/bidmc/bidmc01.breath b/test-resources/core/bidmc/bidmc01.breath
new file mode 100644
index 000000000..5ce0a71c6
Binary files /dev/null and b/test-resources/core/bidmc/bidmc01.breath differ
diff --git a/test-resources/core/bidmc/bidmc01.dat b/test-resources/core/bidmc/bidmc01.dat
new file mode 100644
index 000000000..6ffff1030
Binary files /dev/null and b/test-resources/core/bidmc/bidmc01.dat differ
diff --git a/test-resources/core/bidmc/bidmc01.hea b/test-resources/core/bidmc/bidmc01.hea
new file mode 100644
index 000000000..ad3c35859
--- /dev/null
+++ b/test-resources/core/bidmc/bidmc01.hea
@@ -0,0 +1,8 @@
+bidmc01 5 125 60001
+bidmc01.dat 16 57446.965(18)/pm 16 0 162 53972 0 RESP,
+bidmc01.dat 16 88262.76(133)/NU 16 0 42 42240 0 PLETH,
+bidmc01.dat 16 68600.47(108)/mV 16 0 -952 6385 0 V,
+bidmc01.dat 16 70302.98(-146)/mV 16 0 -566 36056 0 AVR,
+bidmc01.dat 16 69816.54(304)/mV 16 0 1866 49819 0 II,
+# : 55
+# : M
diff --git a/test-resources/core/bidmc/bidmc02.breath b/test-resources/core/bidmc/bidmc02.breath
new file mode 100644
index 000000000..5ce0a71c6
Binary files /dev/null and b/test-resources/core/bidmc/bidmc02.breath differ
diff --git a/test-resources/core/bidmc/bidmc02.dat b/test-resources/core/bidmc/bidmc02.dat
new file mode 100644
index 000000000..004e8b6ba
Binary files /dev/null and b/test-resources/core/bidmc/bidmc02.dat differ
diff --git a/test-resources/core/bidmc/bidmc02.hea b/test-resources/core/bidmc/bidmc02.hea
new file mode 100644
index 000000000..80fc961aa
--- /dev/null
+++ b/test-resources/core/bidmc/bidmc02.hea
@@ -0,0 +1,8 @@
+bidmc02 5 125 60001
+bidmc02.dat 16 57636.965(126)/pm 16 0 297 37315 0 RESP,
+bidmc02.dat 16 89899.484(-105)/NU 16 0 -1492 16716 0 PLETH,
+bidmc02.dat 16 68559.59(920)/mV 16 0 -187 58542 0 V,
+bidmc02.dat 16 69553.68(-494)/mV 16 0 199 253 0 AVR,
+bidmc02.dat 16 68740.055(84)/mV 16 0 99 24415 0 II,
+# : 58
+# : F
diff --git a/test-resources/core/ludb/data/1.avf b/test-resources/core/ludb/data/1.avf
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/1.avf differ
diff --git a/test-resources/core/ludb/data/1.avl b/test-resources/core/ludb/data/1.avl
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/1.avl differ
diff --git a/test-resources/core/ludb/data/1.avr b/test-resources/core/ludb/data/1.avr
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/1.avr differ
diff --git a/test-resources/core/ludb/data/1.dat b/test-resources/core/ludb/data/1.dat
new file mode 100644
index 000000000..958126295
Binary files /dev/null and b/test-resources/core/ludb/data/1.dat differ
diff --git a/test-resources/core/ludb/data/1.hea b/test-resources/core/ludb/data/1.hea
new file mode 100644
index 000000000..dff2c9da4
--- /dev/null
+++ b/test-resources/core/ludb/data/1.hea
@@ -0,0 +1,19 @@
+1 12 500 5000
+1.dat 16 62239.465(-23190)/mV 16 0 -23033 33766 0 i
+1.dat 16 55949.375(-23524)/mV 16 0 -23725 5480 0 ii
+1.dat 16 51311.14(-24306)/mV 16 0 -23804 45137 0 iii
+1.dat 16 47775.176(-25170)/mV 16 0 -25008 18423 0 avr
+1.dat 16 44748.46(-25615)/mV 16 0 -25325 28490 0 avl
+1.dat 16 41369.31(-26581)/mV 16 0 -26117 20648 0 avf
+1.dat 16 39463.5(-26298)/mV 16 0 -27171 54031 0 v1
+1.dat 16 37131.64(-27143)/mV 16 0 -26799 25262 0 v2
+1.dat 16 35097.883(-27041)/mV 16 0 -26918 37315 0 v3
+1.dat 16 33396.85(-27664)/mV 16 0 -27398 3336 0 v4
+1.dat 16 31306.604(-27419)/mV 16 0 -27598 13770 0 v5
+1.dat 16 30131.424(-27905)/mV 16 0 -27941 56298 0 v6
+# : 40
+# : M
+# :
+# Rhythm: Sinus rhythm.
+# Left ventricular hypertrophy.
+# Non-specific repolarization abnormalities.
diff --git a/test-resources/core/ludb/data/1.i b/test-resources/core/ludb/data/1.i
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/1.i differ
diff --git a/test-resources/core/ludb/data/1.ii b/test-resources/core/ludb/data/1.ii
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/1.ii differ
diff --git a/test-resources/core/ludb/data/1.iii b/test-resources/core/ludb/data/1.iii
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/1.iii differ
diff --git a/test-resources/core/ludb/data/1.v1 b/test-resources/core/ludb/data/1.v1
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/1.v1 differ
diff --git a/test-resources/core/ludb/data/1.v2 b/test-resources/core/ludb/data/1.v2
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/1.v2 differ
diff --git a/test-resources/core/ludb/data/1.v3 b/test-resources/core/ludb/data/1.v3
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/1.v3 differ
diff --git a/test-resources/core/ludb/data/1.v4 b/test-resources/core/ludb/data/1.v4
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/1.v4 differ
diff --git a/test-resources/core/ludb/data/1.v5 b/test-resources/core/ludb/data/1.v5
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/1.v5 differ
diff --git a/test-resources/core/ludb/data/1.v6 b/test-resources/core/ludb/data/1.v6
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/1.v6 differ
diff --git a/test-resources/core/ludb/data/2.avf b/test-resources/core/ludb/data/2.avf
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/2.avf differ
diff --git a/test-resources/core/ludb/data/2.avl b/test-resources/core/ludb/data/2.avl
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/2.avl differ
diff --git a/test-resources/core/ludb/data/2.avr b/test-resources/core/ludb/data/2.avr
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/2.avr differ
diff --git a/test-resources/core/ludb/data/2.dat b/test-resources/core/ludb/data/2.dat
new file mode 100644
index 000000000..a7251e472
Binary files /dev/null and b/test-resources/core/ludb/data/2.dat differ
diff --git a/test-resources/core/ludb/data/2.hea b/test-resources/core/ludb/data/2.hea
new file mode 100644
index 000000000..cfa3a04c5
--- /dev/null
+++ b/test-resources/core/ludb/data/2.hea
@@ -0,0 +1,19 @@
+2 12 500 5000
+2.dat 16 61262.926(-23165)/mV 16 0 -24100 9855 0 i
+2.dat 16 56641.953(-23737)/mV 16 0 -24570 45290 0 ii
+2.dat 16 51504.383(-24241)/mV 16 0 -25406 50166 0 iii
+2.dat 16 48346.41(-25206)/mV 16 0 -24033 57309 0 avr
+2.dat 16 44399.043(-25462)/mV 16 0 -25828 53239 0 avl
+2.dat 16 41577.14(-25641)/mV 16 0 -25734 21095 0 avf
+2.dat 16 38740.266(-26479)/mV 16 0 -27922 18472 0 v1
+2.dat 16 37059.348(-26572)/mV 16 0 -26685 21959 0 v2
+2.dat 16 35257.223(-27349)/mV 16 0 -26520 61948 0 v3
+2.dat 16 33735.87(-27532)/mV 16 0 -28077 48552 0 v4
+2.dat 16 31750.809(-27608)/mV 16 0 -28287 56113 0 v5
+2.dat 16 30051.408(-28146)/mV 16 0 -27743 12377 0 v6
+# : 45
+# : F
+# :
+# Rhythm: Sinus rhythm.
+# Left ventricular hypertrophy.
+# Non-specific repolarization abnormalities.
diff --git a/test-resources/core/ludb/data/2.i b/test-resources/core/ludb/data/2.i
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/2.i differ
diff --git a/test-resources/core/ludb/data/2.ii b/test-resources/core/ludb/data/2.ii
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/2.ii differ
diff --git a/test-resources/core/ludb/data/2.iii b/test-resources/core/ludb/data/2.iii
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/2.iii differ
diff --git a/test-resources/core/ludb/data/2.v1 b/test-resources/core/ludb/data/2.v1
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/2.v1 differ
diff --git a/test-resources/core/ludb/data/2.v2 b/test-resources/core/ludb/data/2.v2
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/2.v2 differ
diff --git a/test-resources/core/ludb/data/2.v3 b/test-resources/core/ludb/data/2.v3
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/2.v3 differ
diff --git a/test-resources/core/ludb/data/2.v4 b/test-resources/core/ludb/data/2.v4
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/2.v4 differ
diff --git a/test-resources/core/ludb/data/2.v5 b/test-resources/core/ludb/data/2.v5
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/2.v5 differ
diff --git a/test-resources/core/ludb/data/2.v6 b/test-resources/core/ludb/data/2.v6
new file mode 100644
index 000000000..9c3a63c8c
Binary files /dev/null and b/test-resources/core/ludb/data/2.v6 differ
diff --git a/test-resources/core/mitbih/100.atr b/test-resources/core/mitbih/100.atr
new file mode 100644
index 000000000..2a63ba95e
Binary files /dev/null and b/test-resources/core/mitbih/100.atr differ
diff --git a/test-resources/core/mitbih/100.dat b/test-resources/core/mitbih/100.dat
new file mode 100644
index 000000000..929240569
Binary files /dev/null and b/test-resources/core/mitbih/100.dat differ
diff --git a/test-resources/core/mitbih/100.hea b/test-resources/core/mitbih/100.hea
new file mode 100644
index 000000000..406294f87
--- /dev/null
+++ b/test-resources/core/mitbih/100.hea
@@ -0,0 +1,4 @@
+100 2 360 65000
+100.dat 16 47871.453(-386)/mV 16 0 -266 25690 0 MLII
+100.dat 16 57050.91(-246)/mV 16 0 8033 32572 0 V5
+# # 60 M
diff --git a/test-resources/core/mitbih/101.atr b/test-resources/core/mitbih/101.atr
new file mode 100644
index 000000000..2a63ba95e
Binary files /dev/null and b/test-resources/core/mitbih/101.atr differ
diff --git a/test-resources/core/mitbih/101.dat b/test-resources/core/mitbih/101.dat
new file mode 100644
index 000000000..0cc0b8934
Binary files /dev/null and b/test-resources/core/mitbih/101.dat differ
diff --git a/test-resources/core/mitbih/101.hea b/test-resources/core/mitbih/101.hea
new file mode 100644
index 000000000..dc06edc52
--- /dev/null
+++ b/test-resources/core/mitbih/101.hea
@@ -0,0 +1,4 @@
+101 2 360 65000
+101.dat 16 48609.28(109)/mV 16 0 350 28443 0 MLII
+101.dat 16 57637.297(92)/mV 16 0 9894 54587 0 V5
+# # 63 F
diff --git a/tests/core/_synthetic_wfdb.py b/tests/core/_synthetic_wfdb.py
new file mode 100644
index 000000000..d078d5d2f
--- /dev/null
+++ b/tests/core/_synthetic_wfdb.py
@@ -0,0 +1,319 @@
+"""Synthetic wfdb record generators for ECG/respiratory dataset tests.
+
+Generates minimal in-repo wfdb records for LUDB, MIT-BIH, and BIDMC so
+tests never ship real patient data. Each generator writes the same
+files the dataset loader expects (``.dat``, ``.hea``, per-lead/``.atr``
+/``.breath`` annotations), with seeded RNG for reproducibility.
+
+All signals are sine/noise mixes — not realistic ECGs, but valid wfdb
+records that satisfy the dataset parsers and task annotations. Keep
+these generators in sync with the real record schemas when the
+dataset parsers change.
+"""
+
+from __future__ import annotations
+
+import os
+from typing import Iterable
+
+import numpy as np
+
+
+# --------------------------------------------------------------------- #
+# LUDB: 12-lead ECG at 500 Hz, 10 s per record, wave annotations per lead
+# --------------------------------------------------------------------- #
+
+_LUDB_LEADS = ("i", "ii", "iii", "avr", "avl", "avf",
+ "v1", "v2", "v3", "v4", "v5", "v6")
+_LUDB_FS = 500
+_LUDB_LEN = 5000 # 10 s
+_LUDB_DIAGNOSES = (
+ "Rhythm: Sinus rhythm.",
+ "Left ventricular hypertrophy.",
+ "Non-specific repolarization abnormalities.",
+)
+
+
+def synthesize_ludb(
+ dest_root: str,
+ n_records: int = 2,
+ seed: int = 0,
+) -> None:
+ """Write ``n_records`` synthetic LUDB records into ``{dest_root}/data/``.
+
+ Each record has:
+ * 12-lead signal with a weak sinusoidal "ECG" shape + noise
+ * Header comments with ````, ````, ````
+ * Per-lead wave annotation file (``.i``, ``.ii``, ...) with evenly
+ spaced P / N (QRS) / T wave triplets via the LUDB ``( sym )``
+ encoding.
+ """
+ import wfdb
+
+ data_dir = os.path.join(dest_root, "data")
+ os.makedirs(data_dir, exist_ok=True)
+ rng = np.random.default_rng(seed)
+
+ for rec_idx in range(n_records):
+ rec_name = str(rec_idx + 1)
+ signal = _synthesize_12lead_ecg(rng)
+ comments = [
+ f": {40 + 5 * rec_idx}",
+ f": {'M' if rec_idx % 2 == 0 else 'F'}",
+ ":",
+ *_LUDB_DIAGNOSES,
+ ]
+ wfdb.wrsamp(
+ record_name=rec_name,
+ fs=_LUDB_FS,
+ units=["mV"] * len(_LUDB_LEADS),
+ sig_name=list(_LUDB_LEADS),
+ p_signal=signal,
+ fmt=["16"] * len(_LUDB_LEADS),
+ write_dir=data_dir,
+ comments=comments,
+ )
+ _write_ludb_annotations(data_dir, rec_name, rng)
+
+
+def _synthesize_12lead_ecg(rng: np.random.Generator) -> np.ndarray:
+ """Return (samples, 12) float array with ECG-ish waveforms."""
+ t = np.arange(_LUDB_LEN) / _LUDB_FS
+ # Base rhythm at ~1 Hz with QRS-like spikes
+ base = 0.1 * np.sin(2 * np.pi * 1.0 * t)
+ spikes = np.zeros_like(t)
+ for beat_t in np.arange(0.4, 10.0, 0.8):
+ idx = int(beat_t * _LUDB_FS)
+ if idx < _LUDB_LEN:
+ spikes[idx] = 1.0
+ # Small gaussian bumps around each spike for QRS shape
+ kernel = np.exp(-((np.arange(-20, 21)) ** 2) / 40.0)
+ qrs = np.convolve(spikes, kernel, mode="same")
+ signal = np.stack([
+ base + qrs * (0.8 + 0.1 * i) + rng.normal(0, 0.02, _LUDB_LEN)
+ for i in range(len(_LUDB_LEADS))
+ ], axis=1).astype(np.float32)
+ return signal
+
+
+def _write_ludb_annotations(
+ data_dir: str, rec_name: str, rng: np.random.Generator,
+) -> None:
+ """Write one annotation file per lead with evenly spaced P/N/T waves.
+
+ LUDB encodes each wave as a triplet of symbols ``(``, wave-type,
+ ``)`` at onset, peak, offset sample indices, where wave-type is
+ ``p`` (P wave), ``N`` (QRS complex), or ``t`` (T wave).
+ """
+ import wfdb
+
+ # Produce ~10 beats across the 10 s clip, each with P/N/T triplet
+ beat_centers = np.linspace(300, _LUDB_LEN - 300, 10, dtype=int)
+ samples = []
+ symbols = []
+ for center in beat_centers:
+ # P wave triplet (onset, symbol, offset) relative to beat center
+ for offset_from_center, sym in (
+ (-120, "p"), (-90, "p"), (-60, "p"), # P wave
+ (-20, "N"), (0, "N"), (20, "N"), # QRS complex
+ (80, "t"), (120, "t"), (160, "t"), # T wave
+ ):
+ s = int(center + offset_from_center)
+ if 0 <= s < _LUDB_LEN:
+ # LUDB schema uses ( sym ) at the three sample positions
+ # but wfdb wrann writes one symbol per sample. The parser
+ # looks for '(' at onset, wave-type in middle, ')' at
+ # offset. We emit them in triplets here.
+ samples.append(s)
+ # Build symbols matching the loop above
+ # Recompute symbols aligned to samples
+ samples = []
+ symbols = []
+ for center in beat_centers:
+ for delta, sym in (
+ (-120, "("), (-90, "p"), (-60, ")"),
+ (-20, "("), (0, "N"), (20, ")"),
+ (80, "("), (120, "t"), (160, ")"),
+ ):
+ s = int(center + delta)
+ if 0 <= s < _LUDB_LEN:
+ samples.append(s)
+ symbols.append(sym)
+
+ samples_arr = np.array(samples, dtype=np.int64)
+ # wrann rejects extensions with digits (e.g., "v1", "v2"), but
+ # LUDB's lead-name extensions include six of those. Write each
+ # annotation with a letters-only placeholder extension, then
+ # rename to the real lead extension afterwards.
+ placeholder_stem = "xyzlead"
+ for i, lead in enumerate(_LUDB_LEADS):
+ placeholder = f"{placeholder_stem}{chr(ord('a') + i)}"
+ wfdb.wrann(
+ record_name=rec_name,
+ extension=placeholder,
+ sample=samples_arr,
+ symbol=symbols,
+ write_dir=data_dir,
+ )
+ src = os.path.join(data_dir, f"{rec_name}.{placeholder}")
+ dst = os.path.join(data_dir, f"{rec_name}.{lead}")
+ os.replace(src, dst)
+
+
+# --------------------------------------------------------------------- #
+# BIDMC: respiratory + ECG, 125 Hz, ~8 min, breath annotations
+# --------------------------------------------------------------------- #
+
+_BIDMC_FS = 125
+_BIDMC_LEN = 60_001 # matches the real format's length
+# BIDMC's real headers carry a trailing comma in each signal name
+# (e.g., ``RESP,``). The dataset parser matches on the comma-suffixed
+# form, so preserve it here.
+_BIDMC_SIGS = ("RESP,", "PLETH,", "V,", "AVR,", "II,")
+
+
+def synthesize_bidmc(
+ dest_root: str,
+ n_records: int = 2,
+ seed: int = 0,
+) -> None:
+ """Write ``n_records`` synthetic BIDMC records into ``dest_root/``.
+
+ Each record includes a RESP signal + 2 ECG leads + PLETH + AVR,
+ plus a ``.breath`` annotation file with periodic breath markers.
+ """
+ import wfdb
+
+ os.makedirs(dest_root, exist_ok=True)
+ rng = np.random.default_rng(seed)
+
+ for rec_idx in range(n_records):
+ rec_name = f"bidmc{rec_idx + 1:02d}"
+ t = np.arange(_BIDMC_LEN) / _BIDMC_FS
+ resp = 0.5 * np.sin(2 * np.pi * 0.25 * t) # 15 breaths/min
+ pleth = 0.3 * np.sin(2 * np.pi * 1.2 * t)
+ ecg = 0.4 * np.sin(2 * np.pi * 1.2 * t)
+ noise = lambda: rng.normal(0, 0.02, _BIDMC_LEN)
+ sigs = np.stack([resp + noise(), pleth + noise(), ecg + noise(),
+ -ecg + noise(), ecg + noise()], axis=1).astype(np.float32)
+ # Age + sex metadata in the first comment line — BIDMC parser
+ # reads demographics from the header just like LUDB.
+ comments = [
+ f": {55 + 3 * rec_idx}",
+ f": {'M' if rec_idx % 2 == 0 else 'F'}",
+ ]
+ wfdb.wrsamp(
+ record_name=rec_name,
+ fs=_BIDMC_FS,
+ units=["pm", "NU", "mV", "mV", "mV"],
+ sig_name=list(_BIDMC_SIGS),
+ p_signal=sigs,
+ fmt=["16"] * len(_BIDMC_SIGS),
+ write_dir=dest_root,
+ comments=comments,
+ )
+ # Breath annotations: one mark per breath (0.25 Hz => 4 s apart)
+ breath_samples = np.arange(2 * _BIDMC_FS, _BIDMC_LEN, 4 * _BIDMC_FS, dtype=np.int64)
+ wfdb.wrann(
+ record_name=rec_name,
+ extension="breath",
+ sample=breath_samples,
+ symbol=["+"] * len(breath_samples),
+ write_dir=dest_root,
+ )
+
+
+# --------------------------------------------------------------------- #
+# MIT-BIH: 2-lead ECG at 360 Hz, beat annotations
+# --------------------------------------------------------------------- #
+
+_MITBIH_FS = 360
+_MITBIH_LEN = 65_000 # short clip ≈ 3 min
+_MITBIH_SIGS = ("MLII", "V5")
+
+
+def synthesize_mitbih(
+ dest_root: str,
+ record_names: Iterable[str] = ("100", "101"),
+ seed: int = 0,
+) -> None:
+ """Write synthetic MIT-BIH records with beat annotations.
+
+ Each record has 2 ECG leads + ``.atr`` annotations containing a
+ mix of normal (``N``) and abnormal (``V``, ``A``) beat symbols so
+ ECGAnomalyDetection / ECGBoundaryDetection tasks see both classes.
+ """
+ import wfdb
+
+ os.makedirs(dest_root, exist_ok=True)
+ rng = np.random.default_rng(seed)
+
+ for rec_idx, rec_name in enumerate(record_names):
+ t = np.arange(_MITBIH_LEN) / _MITBIH_FS
+ mlii = 0.6 * np.sin(2 * np.pi * 1.2 * t)
+ v5 = 0.5 * np.sin(2 * np.pi * 1.2 * t + 0.3)
+ noise = rng.normal(0, 0.02, (_MITBIH_LEN, 2)).astype(np.float32)
+ sigs = np.stack([mlii, v5], axis=1).astype(np.float32) + noise
+ # Age/sex encoded the way MIT-BIH headers do it: "# 69 M ..."
+ comments = [f"# {60 + 3 * rec_idx} {'M' if rec_idx % 2 == 0 else 'F'}"]
+ wfdb.wrsamp(
+ record_name=rec_name,
+ fs=_MITBIH_FS,
+ units=["mV", "mV"],
+ sig_name=list(_MITBIH_SIGS),
+ p_signal=sigs,
+ fmt=["16", "16"],
+ write_dir=dest_root,
+ comments=comments,
+ )
+ # Beat markers every ~1 s — mix N and V so tests see both classes
+ beat_samples = np.arange(_MITBIH_FS, _MITBIH_LEN, _MITBIH_FS, dtype=np.int64)
+ symbols = ["N" if i % 3 else "V" for i in range(len(beat_samples))]
+ wfdb.wrann(
+ record_name=rec_name,
+ extension="atr",
+ sample=beat_samples,
+ symbol=symbols,
+ write_dir=dest_root,
+ )
+
+
+# --------------------------------------------------------------------- #
+# Entry point — regenerate the committed test-resources fixtures
+# --------------------------------------------------------------------- #
+#
+# Run with ``python -m tests.core._synthetic_wfdb`` from the repo root.
+# Regenerates the wfdb records under ``test-resources/core/{ludb,
+# bidmc,mitbih}/`` from a fixed seed. Checked-in outputs are fully
+# synthetic — never ship real patient data in tests.
+
+_REPO_ROOT = os.path.dirname(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
+
+
+def _regenerate_all(repo_root: str = _REPO_ROOT) -> None:
+ """Rewrite all three dataset fixtures under ``test-resources/core/``."""
+ import shutil
+
+ def _wipe_and_recreate(path: str) -> None:
+ if os.path.isdir(path):
+ shutil.rmtree(path)
+ os.makedirs(path, exist_ok=True)
+
+ ludb_root = os.path.join(repo_root, "test-resources", "core", "ludb")
+ _wipe_and_recreate(os.path.join(ludb_root, "data"))
+ synthesize_ludb(ludb_root, n_records=2)
+
+ bidmc_root = os.path.join(repo_root, "test-resources", "core", "bidmc")
+ _wipe_and_recreate(bidmc_root)
+ synthesize_bidmc(bidmc_root, n_records=2)
+
+ mitbih_root = os.path.join(repo_root, "test-resources", "core", "mitbih")
+ _wipe_and_recreate(mitbih_root)
+ synthesize_mitbih(mitbih_root, record_names=["100", "101"])
+
+
+if __name__ == "__main__":
+ _regenerate_all()
+ print(f"Regenerated synthetic test fixtures under {_REPO_ROOT}/test-resources/core/")
diff --git a/tests/core/test_bidmc.py b/tests/core/test_bidmc.py
new file mode 100644
index 000000000..e834be825
--- /dev/null
+++ b/tests/core/test_bidmc.py
@@ -0,0 +1,412 @@
+"""Unit tests for BIDMC dataset and RespiratoryBoundaryDetection task.
+
+Author: Anton Barchukov
+
+Tests cover:
+ - BIDMCDataset instantiation from wfdb files
+ - Metadata preparation
+ - Patient header parsing (age, sex, location)
+ - RespiratoryBoundaryDetection task (windowing, annotator, labels)
+ - Full pipeline: dataset -> task -> samples
+ - MedTsLLM model with the BIDMC data structure
+
+Test data under ``test-resources/core/bidmc/`` is fully synthetic —
+records generated by ``tests/core/_synthetic_wfdb.synthesize_bidmc``
+with placeholder respiratory/ECG signals. No real patient records
+are committed to the repo.
+"""
+
+import os
+import unittest
+
+import numpy as np
+import torch
+
+TEST_DATA_DIR = os.path.join(
+ os.path.dirname(__file__), "..", "..", "test-resources", "core", "bidmc"
+)
+HAS_TEST_DATA = os.path.isdir(TEST_DATA_DIR) and any(
+ f.endswith(".dat") for f in os.listdir(TEST_DATA_DIR)
+)
+
+
+def setUpModule():
+ """Drop the committed CSV + cache so tests use the current schema."""
+ if HAS_TEST_DATA:
+ _force_regenerate_bidmc_metadata()
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found")
+class TestBIDMCDataset(unittest.TestCase):
+ """Tests for BIDMCDataset with real wfdb files."""
+
+ @classmethod
+ def setUpClass(cls):
+ from pyhealth.datasets.bidmc import BIDMCDataset
+
+ cls.dataset = BIDMCDataset(root=TEST_DATA_DIR, dev=True)
+
+ def test_dataset_loads(self):
+ """Dataset initializes without error."""
+ self.assertIsNotNone(self.dataset)
+
+ def test_metadata_csv_created(self):
+ """prepare_metadata creates bidmc-pyhealth.csv."""
+ csv_path = os.path.join(TEST_DATA_DIR, "bidmc-pyhealth.csv")
+ self.assertTrue(os.path.exists(csv_path))
+
+ def test_has_patients(self):
+ """Dataset has at least 1 patient."""
+ pids = self.dataset.unique_patient_ids
+ self.assertGreater(len(pids), 0)
+
+ def test_patient_has_events(self):
+ """Each patient has respiratory events."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ events = patient.get_events()
+ self.assertGreater(len(events), 0)
+
+ def test_event_has_demographics(self):
+ """Events contain age and sex attributes."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ event = patient.get_events()[0]
+ self.assertTrue(hasattr(event, "age"))
+ self.assertTrue(hasattr(event, "sex"))
+ self.assertTrue(hasattr(event, "location"))
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found")
+class TestBIDMCHeaderParsing(unittest.TestCase):
+ """Tests for BIDMC patient metadata parsing."""
+
+ def test_parse_header(self):
+ """Header parsing extracts age, sex, location."""
+ import wfdb
+
+ rec = wfdb.rdrecord(os.path.join(TEST_DATA_DIR, "bidmc01"))
+ from pyhealth.datasets.bidmc import BIDMCDataset
+
+ # Just verify the record has comments
+ self.assertIsNotNone(rec.comments)
+ self.assertGreater(len(rec.comments), 0)
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found")
+class TestRespiratoryBoundaryDetectionTask(unittest.TestCase):
+ """Tests for RespiratoryBoundaryDetection with real data."""
+
+ @classmethod
+ def setUpClass(cls):
+ from pyhealth.datasets.bidmc import BIDMCDataset
+ from pyhealth.tasks.respiratory_boundary_detection import (
+ RespiratoryBoundaryDetection,
+ )
+
+ cls.dataset = BIDMCDataset(root=TEST_DATA_DIR, dev=True)
+ cls.task = RespiratoryBoundaryDetection(
+ window_size=256, step_size=128
+ )
+
+ def test_task_attributes(self):
+ """Task has correct name and schema."""
+ self.assertEqual(
+ self.task.task_name, "RespiratoryBoundaryDetection"
+ )
+ self.assertEqual(self.task.input_schema, {"signal": "tensor"})
+ self.assertEqual(self.task.output_schema, {"label": "tensor"})
+
+ def test_task_produces_samples(self):
+ """Task generates samples from a patient."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ samples = self.task(patient)
+ self.assertGreater(len(samples), 0)
+
+ def test_signal_is_3_channel(self):
+ """Signal has 3 channels (RESP, PLETH, II)."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ sample = self.task(patient)[0]
+ self.assertEqual(sample["signal"].shape, (256, 3))
+
+ def test_label_is_binary(self):
+ """Labels are 0 or 1 (boundary or not)."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ sample = self.task(patient)[0]
+ unique = set(sample["label"].tolist())
+ self.assertTrue(unique.issubset({0.0, 1.0}))
+
+ def test_sample_has_description(self):
+ """Samples include patient description."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ sample = self.task(patient)[0]
+ self.assertIn("description", sample)
+ self.assertIn("age:", sample["description"])
+
+ def test_set_task_produces_sample_dataset(self):
+ """dataset.set_task() returns a usable SampleDataset."""
+ sample_ds = self.dataset.set_task(self.task)
+ self.assertGreater(len(sample_ds), 0)
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found")
+class TestMedTsLLMWithBIDMC(unittest.TestCase):
+ """Test MedTsLLM model with real BIDMC data structure."""
+
+ @classmethod
+ def setUpClass(cls):
+ from pyhealth.datasets.bidmc import BIDMCDataset
+ from pyhealth.tasks.respiratory_boundary_detection import (
+ RespiratoryBoundaryDetection,
+ )
+ from pyhealth.datasets import get_dataloader
+ from pyhealth.models.medtsllm import MedTsLLM
+
+ dataset = BIDMCDataset(root=TEST_DATA_DIR, dev=True)
+ task = RespiratoryBoundaryDetection(
+ window_size=128, step_size=64
+ )
+ cls.sample_ds = dataset.set_task(task)
+
+ cls.model = MedTsLLM(
+ dataset=cls.sample_ds,
+ seq_len=128,
+ n_features=3,
+ n_classes=2,
+ backbone=None,
+ word_embeddings=torch.randn(50, 32),
+ d_model=16,
+ d_ff=32,
+ n_heads=4,
+ num_tokens=50,
+ covariate_mode="concat",
+ )
+ loader = get_dataloader(cls.sample_ds, batch_size=2, shuffle=False)
+ cls.batch = next(iter(loader))
+
+ def test_forward(self):
+ """Model forward pass works with real BIDMC samples."""
+ out = self.model(**self.batch)
+ self.assertIn("logit", out)
+ self.assertIn("loss", out)
+ self.assertTrue(out["loss"].isfinite())
+
+ def test_backward(self):
+ """Backward pass works with BIDMC data."""
+ out = self.model(**self.batch)
+ out["loss"].backward()
+
+
+# ------------------------------------------------------------------ #
+# Phase 5: paper-match split column (BIDMC 85/15 seed=0)
+# ------------------------------------------------------------------ #
+
+
+def _force_regenerate_bidmc_metadata():
+ """Reset BIDMC test state: delete CSV and PyHealth cache."""
+ import shutil
+
+ import platformdirs
+
+ csv_path = os.path.join(TEST_DATA_DIR, "bidmc-pyhealth.csv")
+ if os.path.exists(csv_path):
+ os.remove(csv_path)
+
+ cache_root = platformdirs.user_cache_dir(appname="pyhealth")
+ if os.path.isdir(cache_root):
+ shutil.rmtree(cache_root)
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found")
+class TestBIDMCPaperSplit(unittest.TestCase):
+ """paper_split=True writes an 85/15 seed=0 split column."""
+
+ @classmethod
+ def setUpClass(cls):
+ from pyhealth.datasets.bidmc import BIDMCDataset
+
+ _force_regenerate_bidmc_metadata()
+ cls.dataset = BIDMCDataset(
+ root=TEST_DATA_DIR, dev=True, paper_split=True
+ )
+
+ def test_constructor_accepts_paper_split(self):
+ """BIDMCDataset accepts a paper_split kwarg."""
+ import inspect
+
+ from pyhealth.datasets.bidmc import BIDMCDataset
+
+ sig = inspect.signature(BIDMCDataset.__init__)
+ self.assertIn("paper_split", sig.parameters)
+
+ def test_csv_has_split_column(self):
+ """Regenerated CSV contains a ``split`` column."""
+ import pandas as pd
+
+ csv_path = os.path.join(TEST_DATA_DIR, "bidmc-pyhealth.csv")
+ df = pd.read_csv(csv_path)
+ self.assertIn("split", df.columns)
+
+ def test_split_values_are_train_or_test(self):
+ """Every split value is either 'train' or 'test'."""
+ import pandas as pd
+
+ csv_path = os.path.join(TEST_DATA_DIR, "bidmc-pyhealth.csv")
+ df = pd.read_csv(csv_path)
+ unique = set(df["split"].unique())
+ self.assertTrue(unique.issubset({"train", "test"}))
+
+ def test_event_has_split_attr(self):
+ """Events expose a ``split`` attribute."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ event = patient.get_events()[0]
+ self.assertTrue(hasattr(event, "split"))
+ self.assertIn(event.split, ("train", "test"))
+
+ def test_task_emits_split_field(self):
+ """RespiratoryBoundaryDetection propagates event.split to samples."""
+ from pyhealth.tasks.respiratory_boundary_detection import (
+ RespiratoryBoundaryDetection,
+ )
+
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ task = RespiratoryBoundaryDetection(
+ window_size=256, step_size=128
+ )
+ sample = task(patient)[0]
+ self.assertIn("split", sample)
+ self.assertIn(sample["split"], ("train", "test"))
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found")
+class TestBIDMCPaperSplitDeterministic(unittest.TestCase):
+ """Seed=0 BIDMC split is reproducible across runs."""
+
+ def test_split_is_deterministic(self):
+ """Regenerating twice yields the same patient -> split mapping."""
+ import pandas as pd
+ from pyhealth.datasets.bidmc import BIDMCDataset
+
+ _force_regenerate_bidmc_metadata()
+ BIDMCDataset(root=TEST_DATA_DIR, dev=True, paper_split=True)
+ df1 = pd.read_csv(os.path.join(TEST_DATA_DIR, "bidmc-pyhealth.csv"))
+ mapping1 = df1.set_index("patient_id")["split"].to_dict()
+
+ _force_regenerate_bidmc_metadata()
+ BIDMCDataset(root=TEST_DATA_DIR, dev=True, paper_split=True)
+ df2 = pd.read_csv(os.path.join(TEST_DATA_DIR, "bidmc-pyhealth.csv"))
+ mapping2 = df2.set_index("patient_id")["split"].to_dict()
+
+ self.assertEqual(mapping1, mapping2)
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found")
+class TestBIDMCPaperSplitDisabled(unittest.TestCase):
+ """paper_split=False leaves the split column blank."""
+
+ def test_split_column_blank_when_disabled(self):
+ """Without paper_split, split column is empty."""
+ import pandas as pd
+ from pyhealth.datasets.bidmc import BIDMCDataset
+
+ _force_regenerate_bidmc_metadata()
+ BIDMCDataset(root=TEST_DATA_DIR, dev=True, paper_split=False)
+ df = pd.read_csv(os.path.join(TEST_DATA_DIR, "bidmc-pyhealth.csv"))
+ self.assertIn("split", df.columns)
+ cleaned = df["split"].fillna("").astype(str).str.strip()
+ self.assertTrue((cleaned == "").all())
+
+
+# ------------------------------------------------------------------ #
+# Phase 6: preprocess=True -> .npz cache for BIDMC
+# ------------------------------------------------------------------ #
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found")
+class TestBIDMCPreprocessCache(unittest.TestCase):
+ """preprocess=True writes per-record .npz with signal + annotations."""
+
+ @classmethod
+ def setUpClass(cls):
+ import shutil
+
+ from pyhealth.datasets.bidmc import BIDMCDataset
+
+ processed_dir = os.path.join(TEST_DATA_DIR, "processed")
+ if os.path.isdir(processed_dir):
+ shutil.rmtree(processed_dir)
+ _force_regenerate_bidmc_metadata()
+
+ cls.dataset = BIDMCDataset(
+ root=TEST_DATA_DIR, dev=True, preprocess=True
+ )
+
+ def test_processed_dir_created(self):
+ self.assertTrue(
+ os.path.isdir(os.path.join(TEST_DATA_DIR, "processed"))
+ )
+
+ def test_event_has_processed_file_attr(self):
+ pid = self.dataset.unique_patient_ids[0]
+ event = self.dataset.get_patient(pid).get_events()[0]
+ self.assertTrue(hasattr(event, "processed_file"))
+ self.assertTrue(event.processed_file.endswith(".npz"))
+ self.assertTrue(os.path.exists(event.processed_file))
+
+ def test_cached_arrays_have_expected_keys(self):
+ pid = self.dataset.unique_patient_ids[0]
+ event = self.dataset.get_patient(pid).get_events()[0]
+ with np.load(event.processed_file, allow_pickle=False) as npz:
+ self.assertIn("signal", npz.files)
+ self.assertIn("ann_sample", npz.files)
+ self.assertIn("ann_aux", npz.files)
+
+ def test_cached_signal_is_3_channel(self):
+ pid = self.dataset.unique_patient_ids[0]
+ event = self.dataset.get_patient(pid).get_events()[0]
+ with np.load(event.processed_file, allow_pickle=False) as npz:
+ self.assertEqual(int(npz["signal"].shape[1]), 3)
+
+ def test_task_uses_cache(self):
+ from pyhealth.tasks.respiratory_boundary_detection import (
+ RespiratoryBoundaryDetection,
+ )
+
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ samples = RespiratoryBoundaryDetection(
+ window_size=256, step_size=128
+ )(patient)
+ self.assertGreater(len(samples), 0)
+ self.assertEqual(samples[0]["signal"].shape, (256, 3))
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "BIDMC test data not found")
+class TestBIDMCPreprocessDisabled(unittest.TestCase):
+ """preprocess=False leaves processed_file blank."""
+
+ def test_processed_file_empty_when_disabled(self):
+ import shutil
+
+ from pyhealth.datasets.bidmc import BIDMCDataset
+
+ processed_dir = os.path.join(TEST_DATA_DIR, "processed")
+ if os.path.isdir(processed_dir):
+ shutil.rmtree(processed_dir)
+ _force_regenerate_bidmc_metadata()
+
+ ds = BIDMCDataset(root=TEST_DATA_DIR, dev=True, preprocess=False)
+ pid = ds.unique_patient_ids[0]
+ event = ds.get_patient(pid).get_events()[0]
+ value = getattr(event, "processed_file", "") or ""
+ self.assertEqual(str(value).strip(), "")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/test_ludb.py b/tests/core/test_ludb.py
new file mode 100644
index 000000000..ce30ecf2a
--- /dev/null
+++ b/tests/core/test_ludb.py
@@ -0,0 +1,628 @@
+"""Unit tests for LUDB dataset and ECGWaveSegmentation task.
+
+Author: Anton Barchukov
+
+Tests cover:
+ - LUDBDataset instantiation from wfdb files
+ - Metadata preparation (prepare_metadata)
+ - Patient header parsing (age, sex, diagnoses)
+ - ECGWaveSegmentation task (windowing, trimming, labels)
+ - Full pipeline: dataset -> task -> samples
+ - MedTsLLM model with the LUDB data structure
+
+Test data under ``test-resources/core/ludb/data/`` is fully synthetic
+— records generated by ``tests/core/_synthetic_wfdb.synthesize_ludb``
+with random ECG-shaped signals and placeholder demographics. No real
+patient records are committed to the repo.
+"""
+
+import os
+import unittest
+
+import numpy as np
+import torch
+
+TEST_DATA_DIR = os.path.join(
+ os.path.dirname(__file__), "..", "..", "test-resources", "core", "ludb"
+)
+HAS_TEST_DATA = os.path.isdir(os.path.join(TEST_DATA_DIR, "data"))
+
+
+def setUpModule():
+ """Drop the committed CSV + cache so tests use the current schema."""
+ if HAS_TEST_DATA:
+ _force_regenerate_metadata()
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found")
+class TestLUDBDataset(unittest.TestCase):
+ """Tests for LUDBDataset with real wfdb files."""
+
+ @classmethod
+ def setUpClass(cls):
+ from pyhealth.datasets.ludb import LUDBDataset
+
+ cls.dataset = LUDBDataset(root=TEST_DATA_DIR, dev=True)
+
+ def test_dataset_loads(self):
+ """Dataset initializes without error."""
+ self.assertIsNotNone(self.dataset)
+
+ def test_metadata_csv_created(self):
+ """prepare_metadata creates ludb-pyhealth.csv."""
+ csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv")
+ self.assertTrue(os.path.exists(csv_path))
+
+ def test_has_patients(self):
+ """Dataset has at least 1 patient."""
+ pids = self.dataset.unique_patient_ids
+ self.assertGreater(len(pids), 0)
+
+ def test_patient_has_events(self):
+ """Each patient has ECG events."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ events = patient.get_events()
+ self.assertGreater(len(events), 0)
+
+ def test_event_has_signal_file(self):
+ """Events contain signal_file attribute."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ event = patient.get_events()[0]
+ self.assertTrue(hasattr(event, "signal_file"))
+
+ def test_event_has_lead(self):
+ """Events contain lead attribute."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ event = patient.get_events()[0]
+ self.assertTrue(hasattr(event, "lead"))
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found")
+class TestLUDBHeaderParsing(unittest.TestCase):
+ """Tests for LUDB patient metadata in wfdb headers."""
+
+ def test_header_has_comments(self):
+ """LUDB wfdb records contain header comments."""
+ import wfdb
+
+ data_dir = os.path.join(TEST_DATA_DIR, "data")
+ rec = wfdb.rdrecord(os.path.join(data_dir, "1"))
+ self.assertIsNotNone(rec.comments)
+ self.assertGreater(len(rec.comments), 0)
+
+ def test_header_has_age_sex(self):
+ """Header comments contain age and sex fields."""
+ import wfdb
+
+ data_dir = os.path.join(TEST_DATA_DIR, "data")
+ rec = wfdb.rdrecord(os.path.join(data_dir, "1"))
+ comments = "\n".join(rec.comments)
+ self.assertIn(":", comments)
+ self.assertIn(":", comments)
+
+ def test_header_has_diagnoses(self):
+ """Header comments contain diagnoses section."""
+ import wfdb
+
+ data_dir = os.path.join(TEST_DATA_DIR, "data")
+ rec = wfdb.rdrecord(os.path.join(data_dir, "1"))
+ comments = "\n".join(rec.comments)
+ self.assertIn(":", comments)
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found")
+class TestECGWaveSegmentationTask(unittest.TestCase):
+ """Tests for ECGWaveSegmentation with real LUDB data."""
+
+ @classmethod
+ def setUpClass(cls):
+ from pyhealth.datasets.ludb import LUDBDataset
+ from pyhealth.tasks.ecg_wave_segmentation import (
+ ECGWaveSegmentation,
+ )
+
+ cls.dataset = LUDBDataset(root=TEST_DATA_DIR, dev=True)
+ cls.task = ECGWaveSegmentation(window_size=512, step_size=256)
+
+ def test_task_attributes(self):
+ """Task has correct name and schema."""
+ self.assertEqual(self.task.task_name, "ECGWaveSegmentation")
+ self.assertEqual(self.task.input_schema, {"signal": "tensor"})
+ self.assertEqual(self.task.output_schema, {"label": "tensor"})
+
+ def test_task_produces_samples(self):
+ """Task generates samples from a patient."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ samples = self.task(patient)
+ self.assertGreater(len(samples), 0)
+
+ def test_sample_has_required_keys(self):
+ """Each sample has signal, label, patient_id."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ sample = self.task(patient)[0]
+ self.assertIn("signal", sample)
+ self.assertIn("label", sample)
+ self.assertIn("patient_id", sample)
+
+ def test_signal_shape(self):
+ """Signal has shape (window_size,)."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ sample = self.task(patient)[0]
+ self.assertEqual(sample["signal"].shape, (512,))
+
+ def test_label_shape(self):
+ """Label has shape (window_size,) matching signal."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ sample = self.task(patient)[0]
+ self.assertEqual(sample["label"].shape, (512,))
+
+ def test_labels_are_valid_classes(self):
+ """Labels are 0 (bg), 1 (P), 2 (QRS), or 3 (T)."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ sample = self.task(patient)[0]
+ unique = set(sample["label"].tolist())
+ self.assertTrue(unique.issubset({0, 1, 2, 3}))
+
+ def test_samples_contain_waves(self):
+ """Default (no preprocess) produces samples containing wave labels."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ samples = self.task(patient)
+ has_waves = any((s["label"] > 0).any() for s in samples)
+ self.assertTrue(has_waves)
+
+ def test_set_task_produces_sample_dataset(self):
+ """dataset.set_task() returns a usable SampleDataset."""
+ sample_ds = self.dataset.set_task(self.task)
+ self.assertGreater(len(sample_ds), 0)
+ sample = sample_ds[0]
+ self.assertIn("signal", sample)
+ self.assertIn("label", sample)
+
+
+# ------------------------------------------------------------------ #
+# Phase 2: patient-prompt plumbing (LUDB headers -> description)
+# ------------------------------------------------------------------ #
+
+
+def _force_regenerate_metadata():
+ """Reset LUDB test state: delete committed CSV and PyHealth cache.
+
+ The repo ships a pre-built ``ludb-pyhealth.csv`` and PyHealth
+ caches the parsed event dataframe under ``~/.cache/pyhealth/``.
+ Both need to be cleared so tests exercise the current
+ ``prepare_metadata`` implementation and event schema from the
+ YAML config, not stale artifacts from a prior run.
+ """
+ import shutil
+
+ import platformdirs
+
+ csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv")
+ if os.path.exists(csv_path):
+ os.remove(csv_path)
+
+ cache_root = platformdirs.user_cache_dir(appname="pyhealth")
+ if os.path.isdir(cache_root):
+ shutil.rmtree(cache_root)
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found")
+class TestLUDBMetadataDemographics(unittest.TestCase):
+ """Metadata CSV + events expose patient demographics."""
+
+ @classmethod
+ def setUpClass(cls):
+ from pyhealth.datasets.ludb import LUDBDataset
+
+ _force_regenerate_metadata()
+ cls.dataset = LUDBDataset(root=TEST_DATA_DIR, dev=True)
+
+ def test_metadata_csv_has_age_column(self):
+ """Regenerated metadata CSV includes an ``age`` column."""
+ import pandas as pd
+
+ csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv")
+ df = pd.read_csv(csv_path)
+ self.assertIn("age", df.columns)
+
+ def test_metadata_csv_has_sex_column(self):
+ """Regenerated metadata CSV includes a ``sex`` column."""
+ import pandas as pd
+
+ csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv")
+ df = pd.read_csv(csv_path)
+ self.assertIn("sex", df.columns)
+
+ def test_metadata_csv_has_diagnoses_column(self):
+ """Regenerated metadata CSV includes a ``diagnoses`` column."""
+ import pandas as pd
+
+ csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv")
+ df = pd.read_csv(csv_path)
+ self.assertIn("diagnoses", df.columns)
+
+ def test_metadata_age_non_empty(self):
+ """At least one row has a non-empty age."""
+ import pandas as pd
+
+ csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv")
+ df = pd.read_csv(csv_path)
+ ages = df["age"].dropna().astype(str).str.strip()
+ self.assertTrue((ages != "").any())
+
+ def test_event_has_age_attr(self):
+ """Events expose an ``age`` attribute."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ event = patient.get_events()[0]
+ self.assertTrue(hasattr(event, "age"))
+
+ def test_event_has_sex_attr(self):
+ """Events expose a ``sex`` attribute."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ event = patient.get_events()[0]
+ self.assertTrue(hasattr(event, "sex"))
+
+ def test_event_has_diagnoses_attr(self):
+ """Events expose a ``diagnoses`` attribute."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ event = patient.get_events()[0]
+ self.assertTrue(hasattr(event, "diagnoses"))
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found")
+class TestECGWaveSegmentationDescription(unittest.TestCase):
+ """ECGWaveSegmentation emits per-sample ``description``."""
+
+ @classmethod
+ def setUpClass(cls):
+ from pyhealth.datasets.ludb import LUDBDataset
+ from pyhealth.tasks.ecg_wave_segmentation import (
+ ECGWaveSegmentation,
+ )
+
+ _force_regenerate_metadata()
+ cls.dataset = LUDBDataset(root=TEST_DATA_DIR, dev=True)
+ cls.task = ECGWaveSegmentation(window_size=512, step_size=256)
+
+ def test_sample_has_description_key(self):
+ """Samples include a ``description`` field."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ sample = self.task(patient)[0]
+ self.assertIn("description", sample)
+
+ def test_description_is_string(self):
+ """Description is a string."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ sample = self.task(patient)[0]
+ self.assertIsInstance(sample["description"], str)
+
+ def test_description_populated(self):
+ """Description is non-empty for a patient with header data."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ sample = self.task(patient)[0]
+ self.assertGreater(len(sample["description"]), 0)
+
+ def test_description_contains_age(self):
+ """Description mentions age for patient 1."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ sample = self.task(patient)[0]
+ self.assertIn("age", sample["description"].lower())
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found")
+class TestMedTsLLMWithLUDB(unittest.TestCase):
+ """Test MedTsLLM model with real LUDB data structure."""
+
+ @classmethod
+ def setUpClass(cls):
+ from pyhealth.datasets.ludb import LUDBDataset
+ from pyhealth.tasks.ecg_wave_segmentation import (
+ ECGWaveSegmentation,
+ )
+ from pyhealth.datasets import get_dataloader
+ from pyhealth.models.medtsllm import MedTsLLM
+
+ dataset = LUDBDataset(root=TEST_DATA_DIR, dev=True)
+ task = ECGWaveSegmentation(window_size=128, step_size=64)
+ cls.sample_ds = dataset.set_task(task)
+
+ cls.model = MedTsLLM(
+ dataset=cls.sample_ds,
+ seq_len=128,
+ n_features=1,
+ n_classes=4,
+ backbone=None,
+ word_embeddings=torch.randn(50, 32),
+ d_model=16,
+ d_ff=32,
+ n_heads=4,
+ num_tokens=50,
+ )
+ loader = get_dataloader(cls.sample_ds, batch_size=2, shuffle=False)
+ cls.batch = next(iter(loader))
+
+ def test_forward(self):
+ """Model forward pass works with real LUDB samples."""
+ out = self.model(**self.batch)
+ self.assertIn("logit", out)
+ self.assertIn("loss", out)
+ self.assertTrue(out["loss"].isfinite())
+
+ def test_output_shape(self):
+ """Output has correct shape for 4-class segmentation."""
+ out = self.model(**self.batch)
+ bs = self.batch["label"].shape[0]
+ self.assertEqual(out["logit"].shape, (bs, 128, 4))
+
+
+# ------------------------------------------------------------------ #
+# Phase 5: paper-match split column (LUDB 80/20 seed=0)
+# ------------------------------------------------------------------ #
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found")
+class TestLUDBPaperSplit(unittest.TestCase):
+ """paper_split=True writes an 80/20 seed=0 split column."""
+
+ @classmethod
+ def setUpClass(cls):
+ from pyhealth.datasets.ludb import LUDBDataset
+
+ _force_regenerate_metadata()
+ cls.dataset = LUDBDataset(
+ root=TEST_DATA_DIR, dev=True, paper_split=True
+ )
+
+ def test_constructor_accepts_paper_split(self):
+ """LUDBDataset accepts a paper_split kwarg."""
+ import inspect
+
+ from pyhealth.datasets.ludb import LUDBDataset
+
+ sig = inspect.signature(LUDBDataset.__init__)
+ self.assertIn("paper_split", sig.parameters)
+
+ def test_csv_has_split_column(self):
+ """Regenerated CSV contains a ``split`` column."""
+ import pandas as pd
+
+ csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv")
+ df = pd.read_csv(csv_path)
+ self.assertIn("split", df.columns)
+
+ def test_split_values_are_train_or_test(self):
+ """Every split value is either 'train' or 'test'."""
+ import pandas as pd
+
+ csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv")
+ df = pd.read_csv(csv_path)
+ unique = set(df["split"].unique())
+ self.assertTrue(unique.issubset({"train", "test"}))
+
+ def test_split_consistent_per_patient(self):
+ """Every lead of a patient shares one split assignment."""
+ import pandas as pd
+
+ csv_path = os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv")
+ df = pd.read_csv(csv_path)
+ for _, group in df.groupby("patient_id"):
+ self.assertEqual(group["split"].nunique(), 1)
+
+ def test_event_has_split_attr(self):
+ """Events expose a ``split`` attribute."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ event = patient.get_events()[0]
+ self.assertTrue(hasattr(event, "split"))
+ self.assertIn(event.split, ("train", "test"))
+
+ def test_task_emits_split_field(self):
+ """ECGWaveSegmentation propagates event.split to each sample."""
+ from pyhealth.tasks.ecg_wave_segmentation import (
+ ECGWaveSegmentation,
+ )
+
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ task = ECGWaveSegmentation(window_size=512, step_size=256)
+ sample = task(patient)[0]
+ self.assertIn("split", sample)
+ self.assertIn(sample["split"], ("train", "test"))
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found")
+class TestLUDBPaperSplitDeterministic(unittest.TestCase):
+ """Seed=0 split is reproducible across runs."""
+
+ def test_split_is_deterministic(self):
+ """Regenerating twice yields the same patient -> split mapping."""
+ import pandas as pd
+ from pyhealth.datasets.ludb import LUDBDataset
+
+ _force_regenerate_metadata()
+ LUDBDataset(root=TEST_DATA_DIR, dev=True, paper_split=True)
+ df1 = pd.read_csv(os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv"))
+ mapping1 = df1.groupby("patient_id")["split"].first().to_dict()
+
+ _force_regenerate_metadata()
+ LUDBDataset(root=TEST_DATA_DIR, dev=True, paper_split=True)
+ df2 = pd.read_csv(os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv"))
+ mapping2 = df2.groupby("patient_id")["split"].first().to_dict()
+
+ self.assertEqual(mapping1, mapping2)
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found")
+class TestLUDBPaperSplitDisabled(unittest.TestCase):
+ """paper_split=False does not assign train/test."""
+
+ def test_split_column_blank_when_disabled(self):
+ """Without paper_split, split column is empty."""
+ import pandas as pd
+ from pyhealth.datasets.ludb import LUDBDataset
+
+ _force_regenerate_metadata()
+ LUDBDataset(root=TEST_DATA_DIR, dev=True, paper_split=False)
+ df = pd.read_csv(os.path.join(TEST_DATA_DIR, "ludb-pyhealth.csv"))
+ self.assertIn("split", df.columns)
+ cleaned = df["split"].fillna("").astype(str).str.strip()
+ self.assertTrue((cleaned == "").all())
+
+
+# ------------------------------------------------------------------ #
+# Phase 6: preprocess=True -> .npz cache for LUDB
+# ------------------------------------------------------------------ #
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found")
+class TestLUDBPreprocessCache(unittest.TestCase):
+ """preprocess=True writes .npz per (patient, lead) and wires events."""
+
+ @classmethod
+ def setUpClass(cls):
+ import shutil
+
+ from pyhealth.datasets.ludb import LUDBDataset
+
+ # Clear any prior run's cache so paths/fingerprints are fresh.
+ processed_dir = os.path.join(TEST_DATA_DIR, "processed")
+ if os.path.isdir(processed_dir):
+ shutil.rmtree(processed_dir)
+ _force_regenerate_metadata()
+
+ cls.dataset = LUDBDataset(
+ root=TEST_DATA_DIR, dev=True, preprocess=True
+ )
+
+ def test_processed_dir_created(self):
+ """preprocess=True creates {root}/processed/."""
+ self.assertTrue(
+ os.path.isdir(os.path.join(TEST_DATA_DIR, "processed"))
+ )
+
+ def test_npz_written_per_lead(self):
+ """Every (record, lead) pair has a corresponding .npz."""
+ processed_dir = os.path.join(TEST_DATA_DIR, "processed")
+ files = os.listdir(processed_dir)
+ npz_files = [f for f in files if f.endswith(".npz")]
+ # 2 records * 12 leads = 24 expected
+ self.assertGreaterEqual(len(npz_files), 20)
+
+ def test_event_has_processed_file_attr(self):
+ """Events expose the cache path via ``processed_file``."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ event = patient.get_events()[0]
+ self.assertTrue(hasattr(event, "processed_file"))
+ self.assertTrue(event.processed_file.endswith(".npz"))
+ self.assertTrue(os.path.exists(event.processed_file))
+
+ def test_cached_arrays_have_expected_keys(self):
+ """Each .npz stores signal + labels arrays."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ event = patient.get_events()[0]
+ with np.load(event.processed_file, allow_pickle=False) as npz:
+ self.assertIn("signal", npz.files)
+ self.assertIn("labels", npz.files)
+ self.assertEqual(npz["signal"].shape, npz["labels"].shape)
+
+ def test_cached_labels_contain_waves(self):
+ """Cached labels include non-background classes after trim."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ event = patient.get_events()[0]
+ with np.load(event.processed_file, allow_pickle=False) as npz:
+ self.assertTrue((npz["labels"] > 0).any())
+
+ def test_task_loads_from_cache(self):
+ """ECGWaveSegmentation returns samples using cached arrays."""
+ from pyhealth.tasks.ecg_wave_segmentation import (
+ ECGWaveSegmentation,
+ )
+
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ task = ECGWaveSegmentation(window_size=128, step_size=64)
+ samples = task(patient)
+ self.assertGreater(len(samples), 0)
+ self.assertEqual(samples[0]["signal"].shape, (128,))
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found")
+class TestLUDBPreprocessDisabled(unittest.TestCase):
+ """preprocess=False leaves processed_file blank."""
+
+ def test_processed_file_empty_when_disabled(self):
+ import shutil
+
+ from pyhealth.datasets.ludb import LUDBDataset
+
+ processed_dir = os.path.join(TEST_DATA_DIR, "processed")
+ if os.path.isdir(processed_dir):
+ shutil.rmtree(processed_dir)
+ _force_regenerate_metadata()
+
+ ds = LUDBDataset(root=TEST_DATA_DIR, dev=True, preprocess=False)
+ pid = ds.unique_patient_ids[0]
+ event = ds.get_patient(pid).get_events()[0]
+ value = getattr(event, "processed_file", "") or ""
+ self.assertEqual(str(value).strip(), "")
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "LUDB test data not found")
+class TestLUDBPreprocessTrimOff(unittest.TestCase):
+ """preprocess=True, trim=False caches the full (untrimmed) lead."""
+
+ def test_untrimmed_cache_is_longer(self):
+ import shutil
+
+ from pyhealth.datasets.ludb import LUDBDataset
+
+ processed_dir = os.path.join(TEST_DATA_DIR, "processed")
+
+ # Build trimmed cache first.
+ if os.path.isdir(processed_dir):
+ shutil.rmtree(processed_dir)
+ _force_regenerate_metadata()
+ ds_trim = LUDBDataset(
+ root=TEST_DATA_DIR, dev=True, preprocess=True, trim=True
+ )
+ pid_trim = ds_trim.unique_patient_ids[0]
+ ev_trim = ds_trim.get_patient(pid_trim).get_events()[0]
+ with np.load(ev_trim.processed_file, allow_pickle=False) as npz:
+ trimmed_len = int(npz["signal"].shape[0])
+
+ # Rebuild with trim=False.
+ if os.path.isdir(processed_dir):
+ shutil.rmtree(processed_dir)
+ _force_regenerate_metadata()
+ ds_full = LUDBDataset(
+ root=TEST_DATA_DIR, dev=True, preprocess=True, trim=False
+ )
+ pid_full = ds_full.unique_patient_ids[0]
+ ev_full = ds_full.get_patient(pid_full).get_events()[0]
+ with np.load(ev_full.processed_file, allow_pickle=False) as npz:
+ full_len = int(npz["signal"].shape[0])
+
+ self.assertGreaterEqual(full_len, trimmed_len)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/test_medtsllm.py b/tests/core/test_medtsllm.py
new file mode 100644
index 000000000..1a59d4b28
--- /dev/null
+++ b/tests/core/test_medtsllm.py
@@ -0,0 +1,1358 @@
+"""Unit tests for MedTsLLM model and its preprocessing cache helper.
+
+Author: Anton Barchukov
+
+Tests cover:
+ - Model instantiation with synthetic data (no LLM download)
+ - Forward pass output keys and shapes
+ - Backward pass (gradient flow)
+ - Loss computation when labels provided
+ - Parameter filtering (frozen params excluded from optimizer)
+ - State dict save/load with frozen params
+ - Internal layers (RevIN, PatchEmbedding, ReprogrammingLayer)
+ - Different sequence lengths, feature counts, covariate modes
+ - Edge cases: no labels, invalid inputs
+ - Compatibility with PyHealth pipeline
+ - ``_medtsllm_cache`` fingerprint + load_or_build helpers used by
+ LUDB / MIT-BIH / BIDMC preprocessing
+"""
+
+import os
+import tempfile
+import unittest
+
+import torch
+import numpy as np
+
+
+def _make_dataset(
+ n_samples: int = 4,
+ seq_len: int = 128,
+ n_classes: int = 4,
+):
+ """Create a minimal SampleDataset with signal + label."""
+ from pyhealth.datasets import create_sample_dataset
+
+ samples = []
+ for i in range(n_samples):
+ signal = np.random.randn(seq_len).astype(np.float32)
+ label = np.random.randint(0, n_classes, size=seq_len).astype(
+ np.int64
+ )
+ samples.append({
+ "patient_id": f"p{i}",
+ "visit_id": "v0",
+ "signal": signal,
+ "label": label,
+ })
+
+ return create_sample_dataset(
+ samples=samples,
+ input_schema={"signal": "tensor"},
+ output_schema={"label": "tensor"},
+ dataset_name="test_medtsllm",
+ )
+
+
+def _make_model(dataset, **kwargs):
+ """Create a MedTsLLM with synthetic word embeddings (no LLM)."""
+ from pyhealth.models.medtsllm import MedTsLLM
+
+ defaults = dict(
+ dataset=dataset,
+ seq_len=128,
+ n_features=1,
+ n_classes=4,
+ backbone=None,
+ word_embeddings=torch.randn(100, 64),
+ d_model=16,
+ d_ff=32,
+ n_heads=4,
+ num_tokens=50,
+ patch_len=16,
+ stride=8,
+ dropout=0.1,
+ )
+ defaults.update(kwargs)
+ return MedTsLLM(**defaults)
+
+
+def _make_batch(dataset, batch_size=2):
+ """Get a single batch from the dataset."""
+ from pyhealth.datasets import get_dataloader
+
+ loader = get_dataloader(dataset, batch_size=batch_size, shuffle=False)
+ return next(iter(loader))
+
+
+# ------------------------------------------------------------------ #
+# Forward pass tests
+# ------------------------------------------------------------------ #
+
+
+class TestMedTsLLMForward(unittest.TestCase):
+ """Tests for MedTsLLM forward pass."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.dataset = _make_dataset()
+ cls.model = _make_model(cls.dataset)
+ cls.batch = _make_batch(cls.dataset)
+
+ def test_forward_keys(self):
+ """Output dict has expected keys when labels provided."""
+ out = self.model(**self.batch)
+ self.assertIn("logit", out)
+ self.assertIn("y_prob", out)
+ self.assertIn("loss", out)
+ self.assertIn("y_true", out)
+
+ def test_logit_shape(self):
+ """Logit has shape (batch, seq_len, n_classes)."""
+ out = self.model(**self.batch)
+ bs = self.batch["label"].shape[0]
+ self.assertEqual(out["logit"].shape, (bs, 128, 4))
+
+ def test_y_prob_shape(self):
+ """y_prob has same shape as logit."""
+ out = self.model(**self.batch)
+ self.assertEqual(out["y_prob"].shape, out["logit"].shape)
+
+ def test_y_prob_sums_to_one(self):
+ """Softmax probabilities sum to ~1 along class dim."""
+ out = self.model(**self.batch)
+ sums = out["y_prob"].sum(dim=-1)
+ self.assertTrue(
+ torch.allclose(sums, torch.ones_like(sums), atol=1e-5)
+ )
+
+ def test_y_prob_non_negative(self):
+ """All probabilities are non-negative."""
+ out = self.model(**self.batch)
+ self.assertTrue((out["y_prob"] >= 0).all())
+
+ def test_loss_is_scalar(self):
+ """Loss is a scalar tensor."""
+ out = self.model(**self.batch)
+ self.assertEqual(out["loss"].shape, ())
+
+ def test_loss_is_finite(self):
+ """Loss is not NaN or Inf."""
+ out = self.model(**self.batch)
+ self.assertTrue(out["loss"].isfinite())
+
+ def test_2d_signal_input(self):
+ """Forward works with 2D signal (batch, seq_len) — auto-unsqueeze."""
+ signal = torch.randn(2, 128)
+ label = torch.randint(0, 4, (2, 128))
+ out = self.model(signal=signal, label=label)
+ self.assertEqual(out["logit"].shape, (2, 128, 4))
+
+ def test_3d_signal_input(self):
+ """Forward works with 3D signal (batch, seq_len, features)."""
+ signal = torch.randn(2, 128, 1)
+ label = torch.randint(0, 4, (2, 128))
+ out = self.model(signal=signal, label=label)
+ self.assertEqual(out["logit"].shape, (2, 128, 4))
+
+
+# ------------------------------------------------------------------ #
+# Backward pass tests
+# ------------------------------------------------------------------ #
+
+
+class TestMedTsLLMBackward(unittest.TestCase):
+ """Tests for MedTsLLM backward pass."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.dataset = _make_dataset()
+ cls.model = _make_model(cls.dataset)
+ cls.batch = _make_batch(cls.dataset)
+
+ def test_backward(self):
+ """Loss backward pass succeeds."""
+ out = self.model(**self.batch)
+ out["loss"].backward()
+
+ def test_gradients_flow(self):
+ """Trainable parameters receive gradients."""
+ self.model.zero_grad()
+ out = self.model(**self.batch)
+ out["loss"].backward()
+
+ has_grad = False
+ for name, param in self.model.named_parameters():
+ if param.requires_grad and param.grad is not None:
+ if param.grad.abs().sum() > 0:
+ has_grad = True
+ break
+ self.assertTrue(has_grad, "No gradients found in trainable params")
+
+ def test_word_embeddings_frozen(self):
+ """Word embeddings should not receive gradients."""
+ self.assertFalse(self.model.word_embeddings.requires_grad)
+
+ def test_parameters_only_trainable(self):
+ """parameters() yields only trainable params (no frozen)."""
+ for p in self.model.parameters():
+ self.assertTrue(
+ p.requires_grad,
+ "parameters() yielded a frozen parameter",
+ )
+
+ def test_named_parameters_only_trainable(self):
+ """named_parameters() excludes frozen params."""
+ for name, p in self.model.named_parameters():
+ self.assertTrue(
+ p.requires_grad,
+ f"named_parameters() yielded frozen param: {name}",
+ )
+
+ def test_frozen_params_excluded_from_count(self):
+ """Frozen word_embeddings not in parameters() count."""
+ param_names = {n for n, _ in self.model.named_parameters()}
+ self.assertNotIn("word_embeddings", param_names)
+
+
+# ------------------------------------------------------------------ #
+# State dict tests
+# ------------------------------------------------------------------ #
+
+
+class TestMedTsLLMStateDict(unittest.TestCase):
+ """Tests for checkpoint save/load with frozen params."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.dataset = _make_dataset()
+ cls.model = _make_model(cls.dataset)
+
+ def test_state_dict_includes_all_params(self):
+ """state_dict() includes frozen params (needed for full reload)."""
+ sd = self.model.state_dict()
+ self.assertIn("word_embeddings", sd)
+ self.assertIn("patch_embedding.value_embedding.conv.weight", sd)
+
+ def test_load_state_dict_roundtrip(self):
+ """Save and reload state dict without errors."""
+ sd = self.model.state_dict()
+ model2 = _make_model(self.dataset)
+ model2.load_state_dict(sd)
+
+ # Verify weights match
+ for (n1, p1), (n2, p2) in zip(
+ self.model.state_dict().items(),
+ model2.state_dict().items(),
+ ):
+ self.assertEqual(n1, n2)
+ self.assertTrue(torch.equal(p1, p2), f"Mismatch in {n1}")
+
+
+# ------------------------------------------------------------------ #
+# No labels tests
+# ------------------------------------------------------------------ #
+
+
+class TestMedTsLLMNoLabels(unittest.TestCase):
+ """Tests for forward pass without labels."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.dataset = _make_dataset()
+ cls.model = _make_model(cls.dataset)
+
+ def test_no_loss_without_labels(self):
+ """No loss key when labels not in kwargs."""
+ signal = torch.randn(2, 128)
+ out = self.model(signal=signal)
+ self.assertNotIn("loss", out)
+ self.assertNotIn("y_true", out)
+ self.assertIn("logit", out)
+ self.assertIn("y_prob", out)
+
+
+# ------------------------------------------------------------------ #
+# Internal layers tests
+# ------------------------------------------------------------------ #
+
+
+class TestRevIN(unittest.TestCase):
+ """Tests for RevIN normalization layer."""
+
+ def test_normalize_denormalize_roundtrip(self):
+ """RevIN(x, 'norm') -> RevIN(result, 'denorm') ≈ x."""
+ from pyhealth.models._medtsllm.layers import RevIN
+
+ revin = RevIN(num_features=3)
+ x = torch.randn(2, 64, 3) * 10 + 5 # non-zero mean, large std
+ normed = revin(x, "norm")
+ recovered = revin(normed, "denorm")
+ self.assertTrue(
+ torch.allclose(x, recovered, atol=1e-5),
+ "RevIN roundtrip failed",
+ )
+
+ def test_normalize_zero_mean_unit_var(self):
+ """After normalization, each feature has ~0 mean and ~1 std."""
+ from pyhealth.models._medtsllm.layers import RevIN
+
+ revin = RevIN(num_features=2)
+ x = torch.randn(4, 100, 2) * 5 + 3
+ normed = revin(x, "norm")
+ # Check per-sample, per-feature
+ means = normed.mean(dim=1)
+ stds = normed.std(dim=1, unbiased=False)
+ self.assertTrue(
+ torch.allclose(means, torch.zeros_like(means), atol=1e-4)
+ )
+ self.assertTrue(
+ torch.allclose(stds, torch.ones_like(stds), atol=0.1)
+ )
+
+ def test_denorm_before_norm_raises(self):
+ """Calling denorm before norm raises RuntimeError."""
+ from pyhealth.models._medtsllm.layers import RevIN
+
+ revin = RevIN(num_features=1)
+ x = torch.randn(1, 10, 1)
+ with self.assertRaises(RuntimeError):
+ revin(x, "denorm")
+
+ def test_invalid_mode_raises(self):
+ """Invalid mode raises ValueError."""
+ from pyhealth.models._medtsllm.layers import RevIN
+
+ revin = RevIN(num_features=1)
+ x = torch.randn(1, 10, 1)
+ with self.assertRaises(ValueError):
+ revin(x, "invalid")
+
+
+class TestPatchEmbedding(unittest.TestCase):
+ """Tests for PatchEmbedding layer."""
+
+ def test_output_shape(self):
+ """Output has correct shape (batch*features, n_patches, d_model)."""
+ from pyhealth.models._medtsllm.layers import PatchEmbedding
+
+ pe = PatchEmbedding(d_model=32, patch_len=16, stride=8)
+ x = torch.randn(2, 1, 128) # (batch, features, seq_len)
+ out, n_vars = pe(x)
+ self.assertEqual(n_vars, 1)
+ # n_patches = (128 - 16) // 8 + 2 = 16
+ expected_patches = (128 - 16) // 8 + 2
+ self.assertEqual(out.shape, (2, expected_patches, 32))
+
+ def test_multivariate_output_shape(self):
+ """Multivariate input merges batch and feature dims."""
+ from pyhealth.models._medtsllm.layers import PatchEmbedding
+
+ pe = PatchEmbedding(d_model=16, patch_len=8, stride=4)
+ x = torch.randn(3, 5, 64) # 3 batches, 5 features
+ out, n_vars = pe(x)
+ self.assertEqual(n_vars, 5)
+ self.assertEqual(out.shape[0], 15) # 3 * 5
+ self.assertEqual(out.shape[2], 16) # d_model
+
+ def test_gradient_flow(self):
+ """Gradients flow through patch embedding."""
+ from pyhealth.models._medtsllm.layers import PatchEmbedding
+
+ pe = PatchEmbedding(d_model=16, patch_len=8, stride=4)
+ x = torch.randn(2, 1, 32, requires_grad=True)
+ out, _ = pe(x)
+ out.sum().backward()
+ self.assertIsNotNone(x.grad)
+
+
+class TestReprogrammingLayer(unittest.TestCase):
+ """Tests for ReprogrammingLayer (cross-attention)."""
+
+ def test_output_shape(self):
+ """Output shape matches (batch, n_patches, d_llm)."""
+ from pyhealth.models._medtsllm.layers import ReprogrammingLayer
+
+ layer = ReprogrammingLayer(
+ d_model=32, n_heads=4, d_keys=16, d_llm=64
+ )
+ target = torch.randn(2, 10, 32) # patches
+ source = torch.randn(50, 64) # word prototypes
+ out = layer(target, source, source)
+ self.assertEqual(out.shape, (2, 10, 64))
+
+ def test_gradient_flow(self):
+ """Gradients flow through reprogramming layer."""
+ from pyhealth.models._medtsllm.layers import ReprogrammingLayer
+
+ layer = ReprogrammingLayer(
+ d_model=16, n_heads=2, d_keys=8, d_llm=32
+ )
+ target = torch.randn(2, 5, 16, requires_grad=True)
+ source = torch.randn(20, 32)
+ out = layer(target, source, source)
+ out.sum().backward()
+ self.assertIsNotNone(target.grad)
+
+
+class TestFlattenHead(unittest.TestCase):
+ """Tests for FlattenHead output projection."""
+
+ def test_output_shape(self):
+ """Flattens and projects to correct output size."""
+ from pyhealth.models._medtsllm.layers import FlattenHead
+
+ head = FlattenHead(n_features_in=64, n_outputs=128)
+ x = torch.randn(2, 8, 8) # (batch, d_ff, n_patches)
+ out = head(x)
+ self.assertEqual(out.shape, (2, 128))
+
+
+# ------------------------------------------------------------------ #
+# Configuration tests
+# ------------------------------------------------------------------ #
+
+
+class TestMedTsLLMConfigs(unittest.TestCase):
+ """Tests for different model configurations."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.dataset = _make_dataset(seq_len=64, n_classes=2)
+
+ def test_different_seq_len(self):
+ """Model works with non-default sequence length."""
+ model = _make_model(
+ self.dataset,
+ seq_len=64,
+ n_classes=2,
+ word_embeddings=torch.randn(50, 32),
+ )
+ signal = torch.randn(2, 64)
+ label = torch.randint(0, 2, (2, 64))
+ out = model(signal=signal, label=label)
+ self.assertEqual(out["logit"].shape, (2, 64, 2))
+ self.assertTrue(out["loss"].isfinite())
+
+ def test_multivariate_concat(self):
+ """Model works with concat covariate mode."""
+ dataset = _make_dataset(seq_len=64, n_classes=2)
+ model = _make_model(
+ dataset,
+ seq_len=64,
+ n_features=3,
+ n_classes=2,
+ covariate_mode="concat",
+ word_embeddings=torch.randn(50, 32),
+ )
+ signal = torch.randn(2, 64, 3)
+ label = torch.randint(0, 2, (2, 64))
+ out = model(signal=signal, label=label)
+ self.assertEqual(out["logit"].shape, (2, 64, 2))
+ self.assertTrue(out["loss"].isfinite())
+ out["loss"].backward()
+
+ def test_univariate_mode(self):
+ """Model works with univariate covariate mode (default)."""
+ dataset = _make_dataset(seq_len=64, n_classes=3)
+ model = _make_model(
+ dataset,
+ seq_len=64,
+ n_features=1,
+ n_classes=3,
+ covariate_mode="univariate",
+ word_embeddings=torch.randn(50, 32),
+ )
+ signal = torch.randn(2, 64)
+ label = torch.randint(0, 3, (2, 64))
+ out = model(signal=signal, label=label)
+ self.assertEqual(out["logit"].shape, (2, 64, 3))
+
+ def test_small_d_model(self):
+ """Model works with minimal dimensions."""
+ dataset = _make_dataset(seq_len=32, n_classes=2)
+ model = _make_model(
+ dataset,
+ seq_len=32,
+ n_classes=2,
+ d_model=8,
+ d_ff=16,
+ n_heads=2,
+ num_tokens=10,
+ word_embeddings=torch.randn(20, 16),
+ )
+ signal = torch.randn(1, 32)
+ label = torch.randint(0, 2, (1, 32))
+ out = model(signal=signal, label=label)
+ self.assertTrue(out["loss"].isfinite())
+
+
+# ------------------------------------------------------------------ #
+# LinearProjection layer tests
+# ------------------------------------------------------------------ #
+
+
+class TestLinearProjection(unittest.TestCase):
+ """Tests for LinearProjection ablation layer."""
+
+ def test_output_shape(self):
+ """Output shape matches (batch, n_patches, d_llm)."""
+ from pyhealth.models._medtsllm.layers import LinearProjection
+
+ layer = LinearProjection(d_model=32, d_llm=64)
+ target = torch.randn(2, 10, 32)
+ source = torch.randn(50, 64) # ignored
+ out = layer(target, source, source)
+ self.assertEqual(out.shape, (2, 10, 64))
+
+ def test_ignores_source(self):
+ """Output is the same regardless of source embeddings."""
+ from pyhealth.models._medtsllm.layers import LinearProjection
+
+ layer = LinearProjection(d_model=16, d_llm=32)
+ target = torch.randn(2, 5, 16)
+ source_a = torch.randn(20, 32)
+ source_b = torch.randn(100, 32)
+ out_a = layer(target, source_a, source_a)
+ out_b = layer(target, source_b, source_b)
+ self.assertTrue(torch.equal(out_a, out_b))
+
+ def test_gradient_flow(self):
+ """Gradients flow through linear projection."""
+ from pyhealth.models._medtsllm.layers import LinearProjection
+
+ layer = LinearProjection(d_model=16, d_llm=32)
+ target = torch.randn(2, 5, 16, requires_grad=True)
+ source = torch.randn(20, 32)
+ out = layer(target, source, source)
+ out.sum().backward()
+ self.assertIsNotNone(target.grad)
+
+
+# ------------------------------------------------------------------ #
+# Prompt builder tests
+# ------------------------------------------------------------------ #
+
+
+class TestPromptBuilder(unittest.TestCase):
+ """Tests for build_prompt function."""
+
+ def test_dataset_task_prompt(self):
+ """Builds prompt with dataset and task descriptions."""
+ from pyhealth.models._medtsllm.prompt import build_prompt
+
+ inputs = {"x_enc": torch.randn(2, 128, 1)}
+ prompts = build_prompt(
+ inputs,
+ dataset_description="Test dataset",
+ task_description="Test task",
+ include_dataset=True,
+ include_task=True,
+ )
+ self.assertEqual(len(prompts), 2)
+ flat = " ".join(prompts[0])
+ self.assertIn("Test dataset", flat)
+ self.assertIn("Test task", flat)
+
+ def test_no_prompt(self):
+ """Empty prompt when all flags are False."""
+ from pyhealth.models._medtsllm.prompt import build_prompt
+
+ inputs = {"x_enc": torch.randn(1, 64, 1)}
+ prompts = build_prompt(
+ inputs,
+ include_dataset=False,
+ include_task=False,
+ include_clip=False,
+ include_stats=False,
+ )
+ # Should still have "Time series:" at minimum
+ flat = " ".join(prompts[0])
+ self.assertIn("Time series:", flat)
+
+ def test_clip_prompt(self):
+ """Per-sample descriptions included when clip=True."""
+ from pyhealth.models._medtsllm.prompt import build_prompt
+
+ inputs = {
+ "x_enc": torch.randn(2, 64, 1),
+ "descriptions": ["age: 51, sex: F", "age: 64, sex: M"],
+ }
+ prompts = build_prompt(
+ inputs,
+ include_dataset=False,
+ include_task=False,
+ include_clip=True,
+ )
+ flat_0 = " ".join(prompts[0])
+ flat_1 = " ".join(prompts[1])
+ self.assertIn("age: 51", flat_0)
+ self.assertIn("age: 64", flat_1)
+
+ def test_batch_size_matches(self):
+ """Number of prompts matches batch size."""
+ from pyhealth.models._medtsllm.prompt import build_prompt
+
+ inputs = {"x_enc": torch.randn(5, 32, 1)}
+ prompts = build_prompt(inputs)
+ self.assertEqual(len(prompts), 5)
+
+ def test_stats_prompt_included(self):
+ """Stats prompt appears when include_stats=True."""
+ from pyhealth.models._medtsllm.prompt import build_prompt
+
+ inputs = {"x_enc": torch.randn(2, 64, 1)}
+ prompts = build_prompt(
+ inputs,
+ include_dataset=False,
+ include_task=False,
+ include_stats=True,
+ )
+ flat = " ".join(prompts[0])
+ self.assertIn("Input statistics", flat)
+ self.assertIn("min value", flat)
+ self.assertIn("max value", flat)
+ self.assertIn("median value", flat)
+ self.assertIn("trend", flat)
+ self.assertIn("lags", flat)
+
+ def test_stats_prompt_omitted_by_default(self):
+ """Stats prompt absent when include_stats=False."""
+ from pyhealth.models._medtsllm.prompt import build_prompt
+
+ inputs = {"x_enc": torch.randn(2, 64, 1)}
+ prompts = build_prompt(
+ inputs,
+ include_dataset=False,
+ include_task=False,
+ include_stats=False,
+ )
+ flat = " ".join(prompts[0])
+ self.assertNotIn("Input statistics", flat)
+
+
+# ------------------------------------------------------------------ #
+# compute_lags tests
+# ------------------------------------------------------------------ #
+
+
+class TestComputeLags(unittest.TestCase):
+ """Tests for the FFT-based autocorrelation lag helper."""
+
+ def test_shape_univariate(self):
+ """Shape is (batch, n_lags) for univariate input."""
+ from pyhealth.models._medtsllm.prompt import compute_lags
+
+ x = torch.randn(4, 128)
+ lags = compute_lags(x, n_lags=5)
+ self.assertEqual(lags.shape, (4, 5))
+
+ def test_shape_multivariate(self):
+ """Shape is (batch, n_lags) for 3D input."""
+ from pyhealth.models._medtsllm.prompt import compute_lags
+
+ x = torch.randn(4, 128, 3)
+ lags = compute_lags(x, n_lags=3)
+ self.assertEqual(lags.shape, (4, 3))
+
+ def test_lags_are_long(self):
+ """Lag indices are integer-valued."""
+ from pyhealth.models._medtsllm.prompt import compute_lags
+
+ x = torch.randn(2, 64, 1)
+ lags = compute_lags(x, n_lags=5)
+ self.assertEqual(lags.dtype, torch.int64)
+
+ def test_lags_in_range(self):
+ """Lag indices fall within sequence length."""
+ from pyhealth.models._medtsllm.prompt import compute_lags
+
+ seq_len = 128
+ x = torch.randn(2, seq_len)
+ lags = compute_lags(x, n_lags=5)
+ self.assertTrue((lags >= 0).all())
+ self.assertTrue((lags < seq_len).all())
+
+
+# ------------------------------------------------------------------ #
+# Task parameter + default task_description
+# ------------------------------------------------------------------ #
+
+
+class TestTaskParam(unittest.TestCase):
+ """Tests for the ``task`` constructor parameter."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.dataset = _make_dataset()
+
+ def test_default_task_is_semantic_segmentation(self):
+ """Default task is semantic_segmentation."""
+ model = _make_model(self.dataset)
+ self.assertEqual(model.task, "semantic_segmentation")
+
+ def test_invalid_task_raises(self):
+ """Unknown task string raises ValueError."""
+ with self.assertRaises(ValueError):
+ _make_model(self.dataset, task="not_a_real_task")
+
+ def test_semseg_description(self):
+ """semantic_segmentation task_description says 'Classify'."""
+ model = _make_model(self.dataset, task="semantic_segmentation")
+ self.assertIn("Classify", model.task_description)
+
+ def test_segmentation_description(self):
+ """segmentation task_description says 'change points'."""
+ model = _make_model(self.dataset, task="segmentation")
+ self.assertIn("change points", model.task_description)
+
+ def test_anomaly_description(self):
+ """anomaly_detection task_description says 'Reconstruct'."""
+ model = _make_model(self.dataset, task="anomaly_detection")
+ self.assertIn("Reconstruct", model.task_description)
+
+ def test_reconstruction_description(self):
+ """reconstruction task_description says 'Reconstruct'."""
+ model = _make_model(self.dataset, task="reconstruction")
+ self.assertIn("Reconstruct", model.task_description)
+
+ def test_forecasting_description(self):
+ """forecasting task_description says 'Forecast'."""
+ model = _make_model(self.dataset, task="forecasting")
+ self.assertIn("Forecast", model.task_description)
+
+ def test_explicit_task_description_wins(self):
+ """Explicit task_description is not overwritten by default."""
+ model = _make_model(
+ self.dataset,
+ task="semantic_segmentation",
+ task_description="custom description here",
+ )
+ self.assertEqual(
+ model.task_description, "custom description here"
+ )
+
+
+# ------------------------------------------------------------------ #
+# Prompt config knobs
+# ------------------------------------------------------------------ #
+
+
+class TestPromptConfigKnobs(unittest.TestCase):
+ """All four prompt knobs are exposed and default True."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.dataset = _make_dataset()
+
+ def test_four_knobs_present(self):
+ """prompt_config has dataset/task/patient/stats keys."""
+ model = _make_model(self.dataset)
+ for key in ("dataset", "task", "patient", "stats"):
+ self.assertIn(key, model.prompt_config)
+
+ def test_prompt_defaults(self):
+ """dataset/task/patient default True; stats defaults False to
+ match the cs598-pyhealth dtp reference config."""
+ model = _make_model(self.dataset)
+ self.assertTrue(model.prompt_config["dataset"])
+ self.assertTrue(model.prompt_config["task"])
+ self.assertTrue(model.prompt_config["patient"])
+ self.assertFalse(model.prompt_config["stats"])
+
+ def test_individual_toggles(self):
+ """Each knob can be toggled independently."""
+ model = _make_model(
+ self.dataset,
+ prompt_dataset=False,
+ prompt_task=True,
+ prompt_patient=False,
+ prompt_stats=False,
+ )
+ self.assertFalse(model.prompt_config["dataset"])
+ self.assertTrue(model.prompt_config["task"])
+ self.assertFalse(model.prompt_config["patient"])
+ self.assertFalse(model.prompt_config["stats"])
+
+
+# ------------------------------------------------------------------ #
+# reprogramming_layer override (LinearProjection ablation)
+# ------------------------------------------------------------------ #
+
+
+class TestReprogrammingOverride(unittest.TestCase):
+ """Override ReprogrammingLayer with LinearProjection."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.dataset = _make_dataset()
+
+ def test_linear_projection_override(self):
+ """LinearProjection slots into the model."""
+ from pyhealth.models._medtsllm.layers import LinearProjection
+
+ # d_model=16, d_llm=64 in _make_model defaults
+ linear = LinearProjection(d_model=16, d_llm=64)
+ model = _make_model(self.dataset, reprogramming_layer=linear)
+ self.assertIs(model.reprogramming_layer, linear)
+
+ def test_forward_with_linear_projection(self):
+ """Forward pass works with LinearProjection ablation."""
+ from pyhealth.models._medtsllm.layers import LinearProjection
+
+ linear = LinearProjection(d_model=16, d_llm=64)
+ model = _make_model(self.dataset, reprogramming_layer=linear)
+ batch = _make_batch(self.dataset)
+ out = model(**batch)
+ self.assertIn("logit", out)
+ self.assertTrue(out["loss"].isfinite())
+
+
+# ------------------------------------------------------------------ #
+# Description coercion
+# ------------------------------------------------------------------ #
+
+
+class TestDescriptionCoercion(unittest.TestCase):
+ """_coerce_descriptions handles list, tuple, scalar-string."""
+
+ def test_list_passthrough(self):
+ from pyhealth.models.medtsllm import _coerce_descriptions
+
+ desc = ["a", "b"]
+ self.assertEqual(_coerce_descriptions(desc, bs=2), ["a", "b"])
+
+ def test_tuple_coerced(self):
+ from pyhealth.models.medtsllm import _coerce_descriptions
+
+ desc = ("a", "b", "c")
+ self.assertEqual(
+ _coerce_descriptions(desc, bs=3), ["a", "b", "c"]
+ )
+
+ def test_scalar_string_broadcast(self):
+ """Single string broadcasts to batch size, not list-of-chars."""
+ from pyhealth.models.medtsllm import _coerce_descriptions
+
+ desc = "hello"
+ out = _coerce_descriptions(desc, bs=3)
+ self.assertEqual(out, ["hello", "hello", "hello"])
+
+
+# ------------------------------------------------------------------ #
+# Patient prompt integration test
+# ------------------------------------------------------------------ #
+
+
+class TestMedTsLLMPatientPrompt(unittest.TestCase):
+ """Tests for patient prompt path in model forward."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.dataset = _make_dataset()
+ # Model with prompt_patient=True but no LLM
+ # (synthetic replacement skips prompting, so we test the config)
+ cls.model = _make_model(cls.dataset, prompt_patient=True)
+
+ def test_prompt_config_has_patient(self):
+ """prompt_config includes patient key."""
+ self.assertIn("patient", self.model.prompt_config)
+ self.assertTrue(self.model.prompt_config["patient"])
+
+ def test_forward_with_description_kwarg(self):
+ """Forward pass works when description is in kwargs."""
+ signal = torch.randn(2, 128)
+ label = torch.randint(0, 4, (2, 128))
+ out = self.model(
+ signal=signal,
+ label=label,
+ description=["age: 51, sex: F", "age: 64, sex: M"],
+ )
+ self.assertIn("logit", out)
+ self.assertTrue(out["loss"].isfinite())
+
+ def test_forward_without_description(self):
+ """Forward pass works even without description kwarg."""
+ signal = torch.randn(2, 128)
+ label = torch.randint(0, 4, (2, 128))
+ out = self.model(signal=signal, label=label)
+ self.assertIn("logit", out)
+ self.assertTrue(out["loss"].isfinite())
+
+
+# ------------------------------------------------------------------ #
+# Task branching in forward (Phase 3)
+# ------------------------------------------------------------------ #
+
+
+def _make_binary_dataset(
+ n_samples: int = 4,
+ seq_len: int = 128,
+ n_features: int = 1,
+):
+ """Dataset with binary float labels (boundary detection style)."""
+ from pyhealth.datasets import create_sample_dataset
+
+ samples = []
+ for i in range(n_samples):
+ if n_features == 1:
+ signal = np.random.randn(seq_len).astype(np.float32)
+ else:
+ signal = np.random.randn(seq_len, n_features).astype(
+ np.float32
+ )
+ label = np.random.randint(0, 2, size=seq_len).astype(np.float32)
+ samples.append({
+ "patient_id": f"p{i}",
+ "visit_id": "v0",
+ "signal": signal,
+ "label": label,
+ })
+
+ return create_sample_dataset(
+ samples=samples,
+ input_schema={"signal": "tensor"},
+ output_schema={"label": "tensor"},
+ dataset_name="test_medtsllm_binary",
+ )
+
+
+def _make_unlabeled_dataset(
+ n_samples: int = 4,
+ seq_len: int = 128,
+ n_features: int = 2,
+):
+ """Multivariate dataset without labels (reconstruction style)."""
+ from pyhealth.datasets import create_sample_dataset
+
+ samples = []
+ for i in range(n_samples):
+ signal = np.random.randn(seq_len, n_features).astype(np.float32)
+ samples.append({
+ "patient_id": f"p{i}",
+ "visit_id": "v0",
+ "signal": signal,
+ })
+
+ return create_sample_dataset(
+ samples=samples,
+ input_schema={"signal": "tensor"},
+ output_schema={},
+ dataset_name="test_medtsllm_unlabeled",
+ )
+
+
+class TestTaskOutputShapes(unittest.TestCase):
+ """Output head shape and n_outputs_per_step adapt to task."""
+
+ def test_semseg_n_outputs_per_step(self):
+ """semantic_segmentation => n_outputs_per_step == n_classes."""
+ dataset = _make_dataset(n_classes=4)
+ model = _make_model(
+ dataset, task="semantic_segmentation", n_classes=4
+ )
+ self.assertEqual(model.n_outputs_per_step, 4)
+
+ def test_segmentation_n_outputs_per_step(self):
+ """segmentation => n_outputs_per_step == 1 (binary logit)."""
+ dataset = _make_binary_dataset()
+ model = _make_model(dataset, task="segmentation")
+ self.assertEqual(model.n_outputs_per_step, 1)
+
+ def test_anomaly_n_outputs_per_step(self):
+ """anomaly_detection => n_outputs_per_step == n_features."""
+ dataset = _make_unlabeled_dataset(n_features=2)
+ model = _make_model(
+ dataset,
+ task="anomaly_detection",
+ n_features=2,
+ covariate_mode="concat",
+ )
+ self.assertEqual(model.n_outputs_per_step, 2)
+
+ def test_reconstruction_n_outputs_per_step(self):
+ """reconstruction => n_outputs_per_step == n_features."""
+ dataset = _make_unlabeled_dataset(n_features=3)
+ model = _make_model(
+ dataset,
+ task="reconstruction",
+ n_features=3,
+ covariate_mode="concat",
+ )
+ self.assertEqual(model.n_outputs_per_step, 3)
+
+ def test_semseg_logit_shape(self):
+ """semantic_segmentation logit is (bs, pred_len, n_classes)."""
+ dataset = _make_dataset(n_classes=4)
+ model = _make_model(
+ dataset, task="semantic_segmentation", n_classes=4
+ )
+ batch = _make_batch(dataset)
+ out = model(**batch)
+ bs = batch["label"].shape[0]
+ self.assertEqual(out["logit"].shape, (bs, 128, 4))
+
+ def test_segmentation_logit_shape(self):
+ """segmentation logit is (bs, pred_len) after squeeze."""
+ dataset = _make_binary_dataset()
+ model = _make_model(dataset, task="segmentation")
+ batch = _make_batch(dataset)
+ out = model(**batch)
+ bs = batch["label"].shape[0]
+ self.assertEqual(out["logit"].shape, (bs, 128))
+
+ def test_anomaly_reconstruction_shape(self):
+ """anomaly_detection prediction is (bs, pred_len, n_features)."""
+ dataset = _make_unlabeled_dataset(n_features=2)
+ model = _make_model(
+ dataset,
+ task="anomaly_detection",
+ n_features=2,
+ covariate_mode="concat",
+ )
+ batch = _make_batch(dataset)
+ out = model(**batch)
+ bs = batch["signal"].shape[0]
+ self.assertEqual(out["logit"].shape, (bs, 128, 2))
+
+
+class TestTaskLosses(unittest.TestCase):
+ """Task-specific loss computation."""
+
+ def test_semseg_uses_cross_entropy(self):
+ """semseg loss is finite and bounded for random inputs."""
+ dataset = _make_dataset(n_classes=4)
+ model = _make_model(
+ dataset, task="semantic_segmentation", n_classes=4
+ )
+ batch = _make_batch(dataset)
+ out = model(**batch)
+ self.assertIn("loss", out)
+ self.assertTrue(out["loss"].isfinite())
+
+ def test_segmentation_uses_bce(self):
+ """segmentation loss is finite under BCE-with-logits."""
+ dataset = _make_binary_dataset()
+ model = _make_model(dataset, task="segmentation")
+ batch = _make_batch(dataset)
+ out = model(**batch)
+ self.assertIn("loss", out)
+ self.assertTrue(out["loss"].isfinite())
+
+ def test_anomaly_uses_mse_no_label(self):
+ """anomaly_detection computes MSE against signal, no label needed."""
+ dataset = _make_unlabeled_dataset(n_features=2)
+ model = _make_model(
+ dataset,
+ task="anomaly_detection",
+ n_features=2,
+ covariate_mode="concat",
+ )
+ batch = _make_batch(dataset)
+ out = model(**batch)
+ self.assertIn("loss", out)
+ self.assertTrue(out["loss"].isfinite())
+
+ def test_reconstruction_uses_mse(self):
+ """reconstruction computes MSE against signal."""
+ dataset = _make_unlabeled_dataset(n_features=2)
+ model = _make_model(
+ dataset,
+ task="reconstruction",
+ n_features=2,
+ covariate_mode="concat",
+ )
+ batch = _make_batch(dataset)
+ out = model(**batch)
+ self.assertIn("loss", out)
+ self.assertTrue(out["loss"].isfinite())
+
+
+class TestTaskGradients(unittest.TestCase):
+ """Backward pass works across all task types."""
+
+ def _backward_flows(self, model, batch):
+ out = model(**batch)
+ loss = out["loss"]
+ loss.backward()
+ has_grad = any(
+ p.grad is not None and p.grad.abs().sum() > 0
+ for p in model.parameters()
+ if p.requires_grad
+ )
+ return has_grad
+
+ def test_semseg_backward(self):
+ dataset = _make_dataset(n_classes=4)
+ model = _make_model(
+ dataset, task="semantic_segmentation", n_classes=4
+ )
+ self.assertTrue(
+ self._backward_flows(model, _make_batch(dataset))
+ )
+
+ def test_segmentation_backward(self):
+ dataset = _make_binary_dataset()
+ model = _make_model(dataset, task="segmentation")
+ self.assertTrue(
+ self._backward_flows(model, _make_batch(dataset))
+ )
+
+ def test_anomaly_backward(self):
+ dataset = _make_unlabeled_dataset(n_features=2)
+ model = _make_model(
+ dataset,
+ task="anomaly_detection",
+ n_features=2,
+ covariate_mode="concat",
+ )
+ self.assertTrue(
+ self._backward_flows(model, _make_batch(dataset))
+ )
+
+
+# ------------------------------------------------------------------ #
+# Synthetic example end-to-end test
+# ------------------------------------------------------------------ #
+
+
+class TestSyntheticEndToEnd(unittest.TestCase):
+ """End-to-end test with synthetic data through PyHealth pipeline."""
+
+ def test_full_pipeline(self):
+ """Create dataset -> split -> model -> train step -> eval."""
+ from pyhealth.datasets import (
+ create_sample_dataset,
+ get_dataloader,
+ split_by_patient,
+ )
+ from pyhealth.models.medtsllm import MedTsLLM
+
+ # Create synthetic samples
+ samples = []
+ for i in range(6):
+ samples.append({
+ "patient_id": f"p{i}",
+ "visit_id": "v0",
+ "signal": np.random.randn(128).astype(np.float32),
+ "label": np.random.randint(0, 4, 128).astype(np.int64),
+ })
+
+ dataset = create_sample_dataset(
+ samples=samples,
+ input_schema={"signal": "tensor"},
+ output_schema={"label": "tensor"},
+ dataset_name="e2e_test",
+ )
+
+ # Split
+ train_ds, _, test_ds = split_by_patient(
+ dataset, ratios=[0.7, 0.0, 0.3]
+ )
+
+ # Model
+ model = MedTsLLM(
+ dataset=dataset,
+ seq_len=128,
+ n_classes=4,
+ backbone=None,
+ word_embeddings=torch.randn(50, 32),
+ d_model=8,
+ d_ff=16,
+ n_heads=2,
+ num_tokens=10,
+ )
+
+ # Train one step
+ loader = get_dataloader(train_ds, batch_size=2, shuffle=True)
+ batch = next(iter(loader))
+ out = model(**batch)
+ self.assertIn("loss", out)
+ self.assertTrue(out["loss"].isfinite())
+ out["loss"].backward()
+
+ # Eval
+ model.eval()
+ test_loader = get_dataloader(test_ds, batch_size=2, shuffle=False)
+ with torch.no_grad():
+ for batch in test_loader:
+ out = model(**batch)
+ self.assertIn("y_prob", out)
+ break
+
+
+# ------------------------------------------------------------------ #
+# Error handling tests
+# ------------------------------------------------------------------ #
+
+
+class TestMedTsLLMInvalidInputs(unittest.TestCase):
+ """Tests for error handling."""
+
+ def test_no_backbone_or_embeddings(self):
+ """Raises ValueError when neither backbone nor embeddings given."""
+ from pyhealth.models.medtsllm import MedTsLLM
+
+ dataset = _make_dataset()
+ with self.assertRaises(ValueError):
+ MedTsLLM(
+ dataset=dataset,
+ backbone=None,
+ word_embeddings=None,
+ )
+
+
+# ------------------------------------------------------------------ #
+# Preprocessing cache (_medtsllm_cache) — shared by LUDB / MIT-BIH /
+# BIDMC loaders. Lives here rather than a separate file because it
+# only exists to serve the MedTsLLM port.
+# ------------------------------------------------------------------ #
+
+
+class TestComputeFingerprint(unittest.TestCase):
+ """Fingerprint is a stable hash of (raw file stats, params)."""
+
+ def test_same_inputs_same_fingerprint(self):
+ from pyhealth.datasets._medtsllm_cache import compute_fingerprint
+
+ with tempfile.TemporaryDirectory() as tmp:
+ raw = os.path.join(tmp, "a.dat")
+ open(raw, "w").write("hi")
+ fp1 = compute_fingerprint([raw], {"trim": True})
+ fp2 = compute_fingerprint([raw], {"trim": True})
+ self.assertEqual(fp1, fp2)
+
+ def test_different_params_different_fingerprint(self):
+ from pyhealth.datasets._medtsllm_cache import compute_fingerprint
+
+ with tempfile.TemporaryDirectory() as tmp:
+ raw = os.path.join(tmp, "a.dat")
+ open(raw, "w").write("hi")
+ fp1 = compute_fingerprint([raw], {"trim": True})
+ fp2 = compute_fingerprint([raw], {"trim": False})
+ self.assertNotEqual(fp1, fp2)
+
+ def test_changed_file_changes_fingerprint(self):
+ from pyhealth.datasets._medtsllm_cache import compute_fingerprint
+
+ with tempfile.TemporaryDirectory() as tmp:
+ raw = os.path.join(tmp, "a.dat")
+ open(raw, "w").write("hi")
+ fp1 = compute_fingerprint([raw], {})
+ open(raw, "w").write("hello world now longer")
+ fp2 = compute_fingerprint([raw], {})
+ self.assertNotEqual(fp1, fp2)
+
+ def test_returns_string(self):
+ from pyhealth.datasets._medtsllm_cache import compute_fingerprint
+
+ with tempfile.TemporaryDirectory() as tmp:
+ raw = os.path.join(tmp, "a.dat")
+ open(raw, "w").write("hi")
+ fp = compute_fingerprint([raw], {})
+ self.assertIsInstance(fp, str)
+ self.assertGreater(len(fp), 16)
+
+
+class TestLoadOrBuild(unittest.TestCase):
+ """load_or_build skips the builder when the cache is warm."""
+
+ def test_first_call_invokes_builder(self):
+ from pyhealth.datasets._medtsllm_cache import load_or_build
+
+ with tempfile.TemporaryDirectory() as tmp:
+ cache = os.path.join(tmp, "c.npz")
+ calls = {"n": 0}
+
+ def builder():
+ calls["n"] += 1
+ return {"x": np.arange(5, dtype=np.int64)}
+
+ result = load_or_build(cache, "fp1", builder)
+ self.assertEqual(calls["n"], 1)
+ np.testing.assert_array_equal(result["x"], np.arange(5))
+ self.assertTrue(os.path.exists(cache))
+
+ def test_second_call_skips_builder(self):
+ from pyhealth.datasets._medtsllm_cache import load_or_build
+
+ with tempfile.TemporaryDirectory() as tmp:
+ cache = os.path.join(tmp, "c.npz")
+ calls = {"n": 0}
+
+ def builder():
+ calls["n"] += 1
+ return {"x": np.arange(5, dtype=np.int64)}
+
+ load_or_build(cache, "fp1", builder)
+ load_or_build(cache, "fp1", builder)
+ self.assertEqual(calls["n"], 1)
+
+ def test_fingerprint_mismatch_rebuilds(self):
+ from pyhealth.datasets._medtsllm_cache import load_or_build
+
+ with tempfile.TemporaryDirectory() as tmp:
+ cache = os.path.join(tmp, "c.npz")
+ calls = {"n": 0}
+
+ def builder():
+ calls["n"] += 1
+ return {"x": np.full(3, calls["n"], dtype=np.int64)}
+
+ load_or_build(cache, "fp-old", builder)
+ second = load_or_build(cache, "fp-new", builder)
+ self.assertEqual(calls["n"], 2)
+ np.testing.assert_array_equal(second["x"], np.full(3, 2))
+
+ def test_creates_parent_dirs(self):
+ from pyhealth.datasets._medtsllm_cache import load_or_build
+
+ with tempfile.TemporaryDirectory() as tmp:
+ cache = os.path.join(tmp, "nested", "dir", "c.npz")
+
+ def builder():
+ return {"x": np.zeros(1, dtype=np.int64)}
+
+ load_or_build(cache, "fp", builder)
+ self.assertTrue(os.path.exists(cache))
+
+ def test_preserves_string_arrays(self):
+ """Cache round-trips unicode arrays (wfdb annotation symbols)."""
+ from pyhealth.datasets._medtsllm_cache import load_or_build
+
+ with tempfile.TemporaryDirectory() as tmp:
+ cache = os.path.join(tmp, "c.npz")
+
+ def builder():
+ return {
+ "signal": np.zeros((4, 2), dtype=np.float32),
+ "symbols": np.array(["N", "V", "L", "A"]),
+ }
+
+ load_or_build(cache, "fp", builder)
+
+ # Second call hits cache
+ result = load_or_build(cache, "fp", lambda: None) # type: ignore[arg-type]
+ self.assertEqual(list(result["symbols"]), ["N", "V", "L", "A"])
+
+ def test_corrupt_cache_rebuilds(self):
+ from pyhealth.datasets._medtsllm_cache import load_or_build
+
+ with tempfile.TemporaryDirectory() as tmp:
+ cache = os.path.join(tmp, "c.npz")
+ with open(cache, "w") as f:
+ f.write("not a real npz")
+
+ calls = {"n": 0}
+
+ def builder():
+ calls["n"] += 1
+ return {"x": np.arange(2, dtype=np.int64)}
+
+ result = load_or_build(cache, "fp", builder)
+ self.assertEqual(calls["n"], 1)
+ np.testing.assert_array_equal(result["x"], np.arange(2))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/test_mitbih.py b/tests/core/test_mitbih.py
new file mode 100644
index 000000000..4413ec893
--- /dev/null
+++ b/tests/core/test_mitbih.py
@@ -0,0 +1,668 @@
+"""Unit tests for MIT-BIH dataset and ECGBoundaryDetection task.
+
+Author: Anton Barchukov
+
+Tests cover:
+ - MITBIHDataset instantiation from real wfdb files
+ - Metadata preparation (paced record exclusion)
+ - Patient header parsing (age, sex, medications)
+ - ECGBoundaryDetection task (windowing, downsampling, labels)
+ - Full pipeline: dataset -> task -> samples
+ - MedTsLLM model with real MIT-BIH data structure
+"""
+
+import os
+import unittest
+
+import numpy as np
+import torch
+
+TEST_DATA_DIR = os.path.join(
+ os.path.dirname(__file__), "..", "..", "test-resources", "core", "mitbih"
+)
+HAS_TEST_DATA = os.path.isdir(TEST_DATA_DIR) and any(
+ f.endswith(".dat") for f in os.listdir(TEST_DATA_DIR)
+)
+
+
+def setUpModule():
+ """Drop the committed CSV + cache so tests use the current schema."""
+ if HAS_TEST_DATA:
+ _force_regenerate_mitbih_metadata()
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found")
+class TestMITBIHDataset(unittest.TestCase):
+ """Tests for MITBIHDataset with real wfdb files."""
+
+ @classmethod
+ def setUpClass(cls):
+ from pyhealth.datasets.mitbih import MITBIHDataset
+
+ cls.dataset = MITBIHDataset(root=TEST_DATA_DIR, dev=True)
+
+ def test_dataset_loads(self):
+ """Dataset initializes without error."""
+ self.assertIsNotNone(self.dataset)
+
+ def test_metadata_csv_created(self):
+ """prepare_metadata creates mitbih-pyhealth.csv."""
+ csv_path = os.path.join(TEST_DATA_DIR, "mitbih-pyhealth.csv")
+ self.assertTrue(os.path.exists(csv_path))
+
+ def test_has_patients(self):
+ """Dataset has at least 1 patient."""
+ pids = self.dataset.unique_patient_ids
+ self.assertGreater(len(pids), 0)
+
+ def test_paced_records_excluded(self):
+ """Paced records (102, 104, 107, 217) are not in dataset."""
+ from pyhealth.datasets.mitbih import _PACED_RECORDS
+
+ pids = set(self.dataset.unique_patient_ids)
+ for paced in _PACED_RECORDS:
+ self.assertNotIn(paced, pids)
+
+ def test_patient_has_events(self):
+ """Each patient has ECG events."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ events = patient.get_events()
+ self.assertGreater(len(events), 0)
+
+ def test_event_has_demographics(self):
+ """Events contain age, sex, medications."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ event = patient.get_events()[0]
+ self.assertTrue(hasattr(event, "age"))
+ self.assertTrue(hasattr(event, "sex"))
+ self.assertTrue(hasattr(event, "medications"))
+
+ def test_event_has_abnormal_count(self):
+ """Events track number of abnormal beats."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ event = patient.get_events()[0]
+ self.assertTrue(hasattr(event, "n_abnormal"))
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found")
+class TestMITBIHHeaderParsing(unittest.TestCase):
+ """Tests for MIT-BIH patient metadata parsing."""
+
+ def test_parse_header(self):
+ """Header contains age and sex info."""
+ import wfdb
+
+ rec = wfdb.rdrecord(os.path.join(TEST_DATA_DIR, "100"))
+ self.assertIsNotNone(rec.comments)
+ self.assertGreater(len(rec.comments), 0)
+ # First comment line should have age and sex
+ first = rec.comments[0].strip()
+ tokens = first.split()
+ self.assertGreaterEqual(len(tokens), 2)
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found")
+class TestECGBoundaryDetectionTask(unittest.TestCase):
+ """Tests for ECGBoundaryDetection with real MIT-BIH data."""
+
+ @classmethod
+ def setUpClass(cls):
+ from pyhealth.datasets.mitbih import MITBIHDataset
+ from pyhealth.tasks.ecg_boundary_detection import (
+ ECGBoundaryDetection,
+ )
+
+ cls.dataset = MITBIHDataset(root=TEST_DATA_DIR, dev=True)
+ cls.task = ECGBoundaryDetection(window_size=256, step_size=256)
+
+ def test_task_attributes(self):
+ """Task has correct name and schema."""
+ self.assertEqual(self.task.task_name, "ECGBoundaryDetection")
+ self.assertEqual(self.task.input_schema, {"signal": "tensor"})
+ self.assertEqual(self.task.output_schema, {"label": "tensor"})
+
+ def test_task_produces_samples(self):
+ """Task generates samples from a patient."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ samples = self.task(patient)
+ self.assertGreater(len(samples), 0)
+
+ def test_signal_is_2_channel(self):
+ """Signal has 2 channels (MLII, V1 or similar)."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ sample = self.task(patient)[0]
+ self.assertEqual(sample["signal"].shape, (256, 2))
+
+ def test_label_is_binary(self):
+ """Labels are 0 or 1 (R-peak or not)."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ sample = self.task(patient)[0]
+ unique = set(sample["label"].tolist())
+ self.assertTrue(unique.issubset({0.0, 1.0}))
+
+ def test_sample_has_description(self):
+ """Samples include patient description with medications."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ sample = self.task(patient)[0]
+ self.assertIn("description", sample)
+ self.assertIn("age:", sample["description"])
+
+ def test_set_task_produces_sample_dataset(self):
+ """dataset.set_task() returns a usable SampleDataset."""
+ sample_ds = self.dataset.set_task(self.task)
+ self.assertGreater(len(sample_ds), 0)
+
+
+# ------------------------------------------------------------------ #
+# Phase 4: anomaly-detection task
+# ------------------------------------------------------------------ #
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found")
+class TestECGAnomalyDetectionTask(unittest.TestCase):
+ """Tests for ECGAnomalyDetection with real MIT-BIH data."""
+
+ @classmethod
+ def setUpClass(cls):
+ from pyhealth.datasets.mitbih import MITBIHDataset
+ from pyhealth.tasks.ecg_anomaly_detection import (
+ ECGAnomalyDetection,
+ )
+
+ cls.dataset = MITBIHDataset(root=TEST_DATA_DIR, dev=True)
+ cls.task = ECGAnomalyDetection(window_size=128, step_size=128)
+
+ def test_task_attributes(self):
+ """Task has correct name and schema."""
+ self.assertEqual(self.task.task_name, "ECGAnomalyDetection")
+ self.assertEqual(self.task.input_schema, {"signal": "tensor"})
+ self.assertEqual(self.task.output_schema, {"label": "tensor"})
+
+ def test_task_produces_samples(self):
+ """Task generates samples from a patient."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ samples = self.task(patient)
+ self.assertGreater(len(samples), 0)
+
+ def test_signal_is_2_channel(self):
+ """Signal has 2 channels (MLII, V1 or similar)."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ sample = self.task(patient)[0]
+ self.assertEqual(sample["signal"].shape, (128, 2))
+
+ def test_label_is_binary_mask(self):
+ """Labels are a per-timestep 0/1 anomaly mask."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ sample = self.task(patient)[0]
+ self.assertEqual(sample["label"].shape, (128,))
+ unique = set(sample["label"].tolist())
+ self.assertTrue(unique.issubset({0.0, 1.0}))
+
+ def test_sample_has_description(self):
+ """Samples include per-patient description."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ sample = self.task(patient)[0]
+ self.assertIn("description", sample)
+
+ def test_set_task_produces_sample_dataset(self):
+ """dataset.set_task() works with anomaly detection task."""
+ sample_ds = self.dataset.set_task(self.task)
+ self.assertGreater(len(sample_ds), 0)
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found")
+class TestAnomalyTaskExportedFromPackage(unittest.TestCase):
+ """ECGAnomalyDetection is importable from pyhealth.tasks."""
+
+ def test_import_from_tasks(self):
+ from pyhealth.tasks import ECGAnomalyDetection # noqa: F401
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found")
+class TestMedTsLLMAnomalyOnMITBIH(unittest.TestCase):
+ """End-to-end anomaly-detection training step on MIT-BIH."""
+
+ @classmethod
+ def setUpClass(cls):
+ from pyhealth.datasets import get_dataloader
+ from pyhealth.datasets.mitbih import MITBIHDataset
+ from pyhealth.models.medtsllm import MedTsLLM
+ from pyhealth.tasks.ecg_anomaly_detection import (
+ ECGAnomalyDetection,
+ )
+
+ dataset = MITBIHDataset(root=TEST_DATA_DIR, dev=True)
+ task = ECGAnomalyDetection(window_size=128, step_size=128)
+ cls.sample_ds = dataset.set_task(task)
+
+ cls.model = MedTsLLM(
+ dataset=cls.sample_ds,
+ task="anomaly_detection",
+ seq_len=128,
+ n_features=2,
+ backbone=None,
+ word_embeddings=torch.randn(50, 32),
+ d_model=16,
+ d_ff=32,
+ n_heads=4,
+ num_tokens=50,
+ covariate_mode="concat",
+ )
+ loader = get_dataloader(cls.sample_ds, batch_size=2, shuffle=False)
+ cls.batch = next(iter(loader))
+
+ def test_forward(self):
+ """Anomaly-detection forward returns finite MSE loss."""
+ out = self.model(**self.batch)
+ self.assertIn("logit", out)
+ self.assertIn("loss", out)
+ self.assertTrue(out["loss"].isfinite())
+
+ def test_prediction_shape(self):
+ """Reconstruction shape matches (bs, seq_len, n_features)."""
+ out = self.model(**self.batch)
+ bs = self.batch["signal"].shape[0]
+ self.assertEqual(out["logit"].shape, (bs, 128, 2))
+
+ def test_backward(self):
+ """Backward pass works for anomaly detection on MIT-BIH."""
+ out = self.model(**self.batch)
+ out["loss"].backward()
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found")
+class TestMedTsLLMWithMITBIH(unittest.TestCase):
+ """Test MedTsLLM model with real MIT-BIH data structure."""
+
+ @classmethod
+ def setUpClass(cls):
+ from pyhealth.datasets.mitbih import MITBIHDataset
+ from pyhealth.tasks.ecg_boundary_detection import (
+ ECGBoundaryDetection,
+ )
+ from pyhealth.datasets import get_dataloader
+ from pyhealth.models.medtsllm import MedTsLLM
+
+ dataset = MITBIHDataset(root=TEST_DATA_DIR, dev=True)
+ task = ECGBoundaryDetection(window_size=128, step_size=64)
+ cls.sample_ds = dataset.set_task(task)
+
+ cls.model = MedTsLLM(
+ dataset=cls.sample_ds,
+ seq_len=128,
+ n_features=2,
+ n_classes=2,
+ backbone=None,
+ word_embeddings=torch.randn(50, 32),
+ d_model=16,
+ d_ff=32,
+ n_heads=4,
+ num_tokens=50,
+ covariate_mode="concat",
+ )
+ loader = get_dataloader(cls.sample_ds, batch_size=2, shuffle=False)
+ cls.batch = next(iter(loader))
+
+ def test_forward(self):
+ """Model forward pass works with real MIT-BIH samples."""
+ out = self.model(**self.batch)
+ self.assertIn("logit", out)
+ self.assertIn("loss", out)
+ self.assertTrue(out["loss"].isfinite())
+
+ def test_backward(self):
+ """Backward pass works with MIT-BIH data."""
+ out = self.model(**self.batch)
+ out["loss"].backward()
+
+
+# ------------------------------------------------------------------ #
+# Phase 6: preprocess=True -> .npz cache for MIT-BIH
+# ------------------------------------------------------------------ #
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found")
+class TestMITBIHPreprocessCache(unittest.TestCase):
+ """preprocess=True writes per-record .npz with signal + annotations."""
+
+ @classmethod
+ def setUpClass(cls):
+ import shutil
+
+ from pyhealth.datasets.mitbih import MITBIHDataset
+
+ processed_dir = os.path.join(TEST_DATA_DIR, "processed")
+ if os.path.isdir(processed_dir):
+ shutil.rmtree(processed_dir)
+ _force_regenerate_mitbih_metadata()
+
+ cls.dataset = MITBIHDataset(
+ root=TEST_DATA_DIR,
+ dev=True,
+ preprocess=True,
+ downsample_factor=3,
+ trim=True,
+ )
+
+ def test_processed_dir_created(self):
+ self.assertTrue(
+ os.path.isdir(os.path.join(TEST_DATA_DIR, "processed"))
+ )
+
+ def test_event_has_processed_file_attr(self):
+ pid = self.dataset.unique_patient_ids[0]
+ event = self.dataset.get_patient(pid).get_events()[0]
+ self.assertTrue(hasattr(event, "processed_file"))
+ self.assertTrue(event.processed_file.endswith(".npz"))
+ self.assertTrue(os.path.exists(event.processed_file))
+
+ def test_cached_arrays_have_expected_keys(self):
+ pid = self.dataset.unique_patient_ids[0]
+ event = self.dataset.get_patient(pid).get_events()[0]
+ with np.load(event.processed_file, allow_pickle=False) as npz:
+ self.assertIn("signal", npz.files)
+ self.assertIn("ann_sample", npz.files)
+ self.assertIn("ann_symbol", npz.files)
+
+ def test_trim_applied_to_cache(self):
+ """After trim, first and last ann_sample are within [0, len-1]."""
+ pid = self.dataset.unique_patient_ids[0]
+ event = self.dataset.get_patient(pid).get_events()[0]
+ with np.load(event.processed_file, allow_pickle=False) as npz:
+ signal_len = int(npz["signal"].shape[0])
+ ann_sample = np.asarray(npz["ann_sample"])
+ self.assertGreaterEqual(int(ann_sample[0]), 0)
+ self.assertLess(int(ann_sample[-1]), signal_len)
+ # Trim: first annotation should be at or near sample 0.
+ self.assertLessEqual(int(ann_sample[0]), 5)
+
+ def test_boundary_task_uses_cache(self):
+ from pyhealth.tasks.ecg_boundary_detection import (
+ ECGBoundaryDetection,
+ )
+
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ samples = ECGBoundaryDetection(
+ window_size=128, step_size=128
+ )(patient)
+ self.assertGreater(len(samples), 0)
+ self.assertEqual(samples[0]["signal"].shape, (128, 2))
+
+ def test_anomaly_task_uses_cache(self):
+ from pyhealth.tasks.ecg_anomaly_detection import (
+ ECGAnomalyDetection,
+ )
+
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ samples = ECGAnomalyDetection(
+ window_size=128, step_size=128
+ )(patient)
+ self.assertGreater(len(samples), 0)
+ self.assertEqual(samples[0]["signal"].shape, (128, 2))
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found")
+class TestMITBIHPreprocessTrimOff(unittest.TestCase):
+ """trim=False caches the full downsampled signal."""
+
+ def test_untrimmed_cache_is_longer(self):
+ import shutil
+
+ from pyhealth.datasets.mitbih import MITBIHDataset
+
+ processed_dir = os.path.join(TEST_DATA_DIR, "processed")
+
+ if os.path.isdir(processed_dir):
+ shutil.rmtree(processed_dir)
+ _force_regenerate_mitbih_metadata()
+ ds_trim = MITBIHDataset(
+ root=TEST_DATA_DIR,
+ dev=True,
+ preprocess=True,
+ trim=True,
+ )
+ ev_trim = ds_trim.get_patient(
+ ds_trim.unique_patient_ids[0]
+ ).get_events()[0]
+ with np.load(ev_trim.processed_file, allow_pickle=False) as npz:
+ trimmed_len = int(npz["signal"].shape[0])
+
+ if os.path.isdir(processed_dir):
+ shutil.rmtree(processed_dir)
+ _force_regenerate_mitbih_metadata()
+ ds_full = MITBIHDataset(
+ root=TEST_DATA_DIR,
+ dev=True,
+ preprocess=True,
+ trim=False,
+ )
+ ev_full = ds_full.get_patient(
+ ds_full.unique_patient_ids[0]
+ ).get_events()[0]
+ with np.load(ev_full.processed_file, allow_pickle=False) as npz:
+ full_len = int(npz["signal"].shape[0])
+
+ self.assertGreaterEqual(full_len, trimmed_len)
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found")
+class TestMITBIHPreprocessDisabled(unittest.TestCase):
+ """preprocess=False leaves processed_file blank."""
+
+ def test_processed_file_empty_when_disabled(self):
+ import shutil
+
+ from pyhealth.datasets.mitbih import MITBIHDataset
+
+ processed_dir = os.path.join(TEST_DATA_DIR, "processed")
+ if os.path.isdir(processed_dir):
+ shutil.rmtree(processed_dir)
+ _force_regenerate_mitbih_metadata()
+
+ ds = MITBIHDataset(root=TEST_DATA_DIR, dev=True, preprocess=False)
+ pid = ds.unique_patient_ids[0]
+ event = ds.get_patient(pid).get_events()[0]
+ value = getattr(event, "processed_file", "") or ""
+ self.assertEqual(str(value).strip(), "")
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found")
+class TestMITBIHInvalidDownsample(unittest.TestCase):
+ """downsample_factor < 1 raises ValueError."""
+
+ def test_invalid_downsample_raises(self):
+ from pyhealth.datasets.mitbih import MITBIHDataset
+
+ _force_regenerate_mitbih_metadata()
+ with self.assertRaises(ValueError):
+ MITBIHDataset(root=TEST_DATA_DIR, dev=True, downsample_factor=0)
+
+
+# ------------------------------------------------------------------ #
+# Phase 5: paper-match split column (MIT-BIH 80/20 seed=0)
+# ------------------------------------------------------------------ #
+
+
+def _reset_mitbih_csv():
+ """Delete the committed metadata CSV so the next ``MITBIHDataset``
+ rebuilds it with the current ``prepare_metadata`` implementation.
+
+ Leaves the PyHealth cache warm — safe when the next dataset uses
+ the same fingerprint as a prior class. For config changes (e.g.
+ paper_split), use ``_force_regenerate_mitbih_metadata`` instead.
+ """
+ csv_path = os.path.join(TEST_DATA_DIR, "mitbih-pyhealth.csv")
+ if os.path.exists(csv_path):
+ os.remove(csv_path)
+
+
+def _force_regenerate_mitbih_metadata():
+ """Nuclear reset: CSV + entire PyHealth cache. Use in setUpModule
+ once per file, or when a test wipes ``processed_dir`` / changes
+ config and needs the event-dataframe cache to be rebuilt.
+ """
+ import shutil
+
+ import platformdirs
+
+ _reset_mitbih_csv()
+ cache_root = platformdirs.user_cache_dir(appname="pyhealth")
+ if os.path.isdir(cache_root):
+ shutil.rmtree(cache_root)
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found")
+class TestMITBIHPaperSplitRandom(unittest.TestCase):
+ """paper_split='random' writes an 80/20 seed=0 split column."""
+
+ @classmethod
+ def setUpClass(cls):
+ from pyhealth.datasets.mitbih import MITBIHDataset
+
+ _force_regenerate_mitbih_metadata()
+ cls.dataset = MITBIHDataset(
+ root=TEST_DATA_DIR, dev=True, paper_split="random"
+ )
+
+ def test_constructor_accepts_paper_split(self):
+ """MITBIHDataset accepts a paper_split kwarg."""
+ import inspect
+
+ from pyhealth.datasets.mitbih import MITBIHDataset
+
+ sig = inspect.signature(MITBIHDataset.__init__)
+ self.assertIn("paper_split", sig.parameters)
+
+ def test_csv_has_split_column(self):
+ """Regenerated CSV contains a ``split`` column."""
+ import pandas as pd
+
+ csv_path = os.path.join(TEST_DATA_DIR, "mitbih-pyhealth.csv")
+ df = pd.read_csv(csv_path)
+ self.assertIn("split", df.columns)
+
+ def test_split_values_are_train_or_test(self):
+ """Every split value is either 'train' or 'test'."""
+ import pandas as pd
+
+ csv_path = os.path.join(TEST_DATA_DIR, "mitbih-pyhealth.csv")
+ df = pd.read_csv(csv_path)
+ unique = set(df["split"].unique())
+ self.assertTrue(unique.issubset({"train", "test"}))
+
+ def test_event_has_split_attr(self):
+ """Events expose a ``split`` attribute."""
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ event = patient.get_events()[0]
+ self.assertTrue(hasattr(event, "split"))
+ self.assertIn(event.split, ("train", "test"))
+
+ def test_boundary_task_emits_split_field(self):
+ """ECGBoundaryDetection propagates event.split to each sample."""
+ from pyhealth.tasks.ecg_boundary_detection import (
+ ECGBoundaryDetection,
+ )
+
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ task = ECGBoundaryDetection(window_size=256, step_size=256)
+ sample = task(patient)[0]
+ self.assertIn("split", sample)
+ self.assertIn(sample["split"], ("train", "test"))
+
+ def test_anomaly_task_emits_split_field(self):
+ """ECGAnomalyDetection propagates event.split to each sample."""
+ from pyhealth.tasks.ecg_anomaly_detection import (
+ ECGAnomalyDetection,
+ )
+
+ pid = self.dataset.unique_patient_ids[0]
+ patient = self.dataset.get_patient(pid)
+ task = ECGAnomalyDetection(window_size=128, step_size=128)
+ sample = task(patient)[0]
+ self.assertIn("split", sample)
+ self.assertIn(sample["split"], ("train", "test"))
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found")
+class TestMITBIHPaperSplitAbnormalSorted(unittest.TestCase):
+ """paper_split='abnormal_sorted' orders patients by n_abnormal asc."""
+
+ @classmethod
+ def setUpClass(cls):
+ from pyhealth.datasets.mitbih import MITBIHDataset
+
+ _force_regenerate_mitbih_metadata()
+ cls.dataset = MITBIHDataset(
+ root=TEST_DATA_DIR, dev=True, paper_split="abnormal_sorted"
+ )
+
+ def test_csv_has_split_column(self):
+ """Regenerated CSV contains a populated ``split`` column."""
+ import pandas as pd
+
+ csv_path = os.path.join(TEST_DATA_DIR, "mitbih-pyhealth.csv")
+ df = pd.read_csv(csv_path)
+ self.assertIn("split", df.columns)
+ unique = set(df["split"].unique())
+ self.assertTrue(unique.issubset({"train", "test"}))
+
+ def test_train_has_lower_or_equal_n_abnormal(self):
+ """All train patients have n_abnormal <= all test patients."""
+ import pandas as pd
+
+ csv_path = os.path.join(TEST_DATA_DIR, "mitbih-pyhealth.csv")
+ df = pd.read_csv(csv_path)
+ train = df[df["split"] == "train"]["n_abnormal"]
+ test = df[df["split"] == "test"]["n_abnormal"]
+ if len(train) > 0 and len(test) > 0:
+ self.assertLessEqual(train.max(), test.min())
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found")
+class TestMITBIHPaperSplitDisabled(unittest.TestCase):
+ """paper_split=None leaves the split column blank."""
+
+ def test_split_column_blank_when_disabled(self):
+ """Without paper_split, split column is empty."""
+ import pandas as pd
+ from pyhealth.datasets.mitbih import MITBIHDataset
+
+ _force_regenerate_mitbih_metadata()
+ MITBIHDataset(root=TEST_DATA_DIR, dev=True, paper_split=None)
+ df = pd.read_csv(os.path.join(TEST_DATA_DIR, "mitbih-pyhealth.csv"))
+ self.assertIn("split", df.columns)
+ cleaned = df["split"].fillna("").astype(str).str.strip()
+ self.assertTrue((cleaned == "").all())
+
+
+@unittest.skipUnless(HAS_TEST_DATA, "MIT-BIH test data not found")
+class TestMITBIHPaperSplitInvalid(unittest.TestCase):
+ """Unknown paper_split mode raises ValueError."""
+
+ def test_invalid_mode_raises(self):
+ from pyhealth.datasets.mitbih import MITBIHDataset
+
+ _force_regenerate_mitbih_metadata()
+ with self.assertRaises(ValueError):
+ MITBIHDataset(
+ root=TEST_DATA_DIR, dev=True, paper_split="nonsense"
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()