Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,4 @@ Available Datasets
datasets/pyhealth.datasets.TCGAPRADDataset
datasets/pyhealth.datasets.splitter
datasets/pyhealth.datasets.utils
pyhealth.datasets.ptbxl
7 changes: 7 additions & 0 deletions docs/api/datasets/pyhealth.datasets.ptbxl.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.datasets.ptbxl
=======================

.. autoclass:: pyhealth.datasets.PTBXLDataset
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,4 @@ Available Tasks
Mutation Pathogenicity (COSMIC) <tasks/pyhealth.tasks.MutationPathogenicityPrediction>
Cancer Survival Prediction (TCGA) <tasks/pyhealth.tasks.CancerSurvivalPrediction>
Cancer Mutation Burden (TCGA) <tasks/pyhealth.tasks.CancerMutationBurden>
PTB-XL MI Classification <tasks/pyhealth.tasks.ptbxl_mi_classification>
7 changes: 7 additions & 0 deletions docs/api/tasks/pyhealth.tasks.ptbxl_mi_classification.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.tasks.ptbxl_mi_classification
======================================

.. autoclass:: pyhealth.tasks.PTBXLMIClassificationTask
:members:
:undoc-members:
:show-inheritance:
113 changes: 113 additions & 0 deletions examples/ptbxl_mi_classification_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from pyhealth.datasets import PTBXLDataset
from pyhealth.tasks import PTBXLMIClassificationTask
import os
import argparse


def run_once(root: str, normalize: bool):
metadata_file = os.path.join(root, "ptbxl_database.csv")

if not os.path.exists(metadata_file):
print("=" * 60)
print("PTB-XL dataset not found. Running synthetic demo mode.")
print(f"Expected file: {metadata_file}")

import torch
from torch.utils.data import DataLoader, TensorDataset

X = torch.randn(10, 12, 1000)
y = torch.randint(0, 2, (10,)).float()

demo_dataset = TensorDataset(X, y)
demo_loader = DataLoader(demo_dataset, batch_size=2, shuffle=False)

first_batch = next(iter(demo_loader))
demo_signal, demo_label = first_batch

print(f"normalize={normalize}")
print(f"demo batch signal shape: {demo_signal.shape}")
print(f"demo batch labels: {demo_label}")
print(f"demo signal mean/std: {demo_signal.mean():.4f} / {demo_signal.std():.4f}")
print(f"Number of demo samples: {len(demo_dataset)}")
return

dataset = PTBXLDataset(
root=root,
dev=True,
use_high_resolution=False, # False -> records100, True -> records500
)

task = PTBXLMIClassificationTask(
root=root,
signal_length=1000, # 10 seconds at 100 Hz
normalize=normalize,
)

task_dataset = dataset.set_task(task)

sample = task_dataset[0]
signal = sample["signal"]

try:
signal_np = signal.detach().cpu().numpy()
except Exception:
signal_np = signal

print("=" * 60)
print(f"normalize={normalize}")
print(f"sample label: {sample['label']}")
print(f"signal shape: {signal_np.shape}")
print(f"signal mean/std: {signal_np.mean():.4f} / {signal_np.std():.4f}")
print(f"Number of samples: {len(task_dataset)}")


def main():
parser = argparse.ArgumentParser()

parser.add_argument(
"--root",
type=str,
default=os.getenv(
"PTBXL_ROOT",
os.path.expanduser(
"~/Downloads/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3"
),
),
help=(
"Path to PTB-XL root folder (contains ptbxl_database.csv, "
"scp_statements.csv, records100/records500/). "
"You can also set PTBXL_ROOT environment variable instead of passing --root."
),
)

parser.add_argument(
"--normalize",
dest="normalize",
action="store_true",
default=True,
help="Enable per-channel z-score normalization (default: True).",
)
parser.add_argument(
"--no-normalize",
dest="normalize",
action="store_false",
help="Disable normalization.",
)
parser.add_argument(
"--ablation-normalize",
action="store_true",
help="Run a tiny ablation: compare normalize=True vs normalize=False.",
)

args = parser.parse_args()
root = args.root

if args.ablation_normalize:
run_once(root, normalize=True)
run_once(root, normalize=False)
else:
run_once(root, normalize=args.normalize)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,4 @@ def __init__(self, *args, **kwargs):
save_processors,
)
from .collate import collate_temporal
from .ptbxl import PTBXLDataset
4 changes: 4 additions & 0 deletions pyhealth/datasets/configs/ptbxl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
dataset_name: PTBXL
tables:
- ptbxl
root: data/ptb-xl
49 changes: 49 additions & 0 deletions pyhealth/datasets/ptbxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
from typing import Optional

import dask.dataframe as dd
import pandas as pd

from pyhealth.datasets import BaseDataset


class PTBXLDataset(BaseDataset):
"""PTB-XL ECG dataset represented as an event table."""

def __init__(
self,
root: str,
dataset_name: Optional[str] = "PTBXL",
dev: bool = False,
cache_dir: Optional[str] = None,
num_workers: int = 1,
use_high_resolution: bool = False,
):
self.use_high_resolution = use_high_resolution
super().__init__(
root=root,
tables=["ptbxl"],
dataset_name=dataset_name,
cache_dir=cache_dir,
num_workers=num_workers,
dev=dev,
)

def load_data(self) -> dd.DataFrame:
metadata_path = os.path.join(self.root, "ptbxl_database.csv")
df = pd.read_csv(metadata_path)

record_path_col = "filename_hr" if self.use_high_resolution else "filename_lr"

event_df = pd.DataFrame(
{
"patient_id": df["patient_id"].astype(str),
"event_type": "ptbxl",
"timestamp": pd.NaT,
"ptbxl/ecg_id": df["ecg_id"],
"ptbxl/record_path": df[record_path_col],
"ptbxl/scp_codes": df["scp_codes"],
}
)

return dd.from_pandas(event_df, npartitions=1)
1 change: 1 addition & 0 deletions pyhealth/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,4 @@
VariantClassificationClinVar,
)
from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task
from .ptbxl_mi_classification import PTBXLMIClassificationTask
135 changes: 135 additions & 0 deletions pyhealth/tasks/ptbxl_mi_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""PTBXL MI classification task for PyHealth.

This module defines a task that loads PTB-XL ECG records, maps SCP
diagnostic codes to myocardial infarction (MI) labels, and returns one
binary-labeled sample per record.
"""

import ast
import os
from typing import Dict, List

import numpy as np
import pandas as pd
import wfdb

from pyhealth.tasks import BaseTask


class PTBXLMIClassificationTask(BaseTask):
"""Task for classifying myocardial infarction (MI) in PTB-XL ECG records.

This task converts the PTB-XL SCP diagnostic codes into a binary MI label
and loads the corresponding ECG signal for each record.

Attributes:
task_name (str): The name of the task.
input_schema (Dict[str, str]): Input schema mapping signal to tensor.
output_schema (Dict[str, str]): Output schema mapping label to binary.
"""

task_name = "ptbxl_mi_classification"
input_schema = {
"signal": "tensor",
}
output_schema = {
"label": "binary",
}

def __init__(
self,
root: str,
signal_length: int = 1000,
normalize: bool = True,
):
"""Initialize the PTBXL MI classification task.

Args:
root: PTB-XL dataset root directory containing `scp_statements.csv`.
signal_length: Number of samples to use for each ECG signal.
normalize: Whether to z-score normalize each ECG channel.
"""

self.root = root
self.signal_length = signal_length
self.normalize = normalize

scp_path = os.path.join(self.root, "scp_statements.csv")
scp_df = pd.read_csv(scp_path, index_col=0)
self.mi_codes = set(
scp_df[scp_df["diagnostic_class"] == "MI"].index.astype(str).tolist()
)

def _load_ecg_signal(self, record_rel_path: str) -> np.ndarray:
"""Loads a PTB-XL WFDB record and returns shape (12, signal_length)."""
record_path = os.path.join(self.root, record_rel_path)

# WFDB expects the record path without file extension.
signal, _ = wfdb.rdsamp(record_path)

# rdsamp returns shape (num_samples, num_channels)
signal = signal.T.astype(np.float32) # -> (channels, time)

if self.normalize:
mean = signal.mean(axis=1, keepdims=True)
std = signal.std(axis=1, keepdims=True)
std = np.where(std < 1e-6, 1.0, std)
signal = (signal - mean) / std

current_len = signal.shape[1]
if current_len >= self.signal_length:
signal = signal[:, : self.signal_length]
else:
pad_width = self.signal_length - current_len
signal = np.pad(signal, ((0, 0), (0, pad_width)), mode="constant")

return signal

def __call__(self, patient) -> List[Dict]:
"""Generate PTB-XL MI samples from a patient record.

Args:
patient: Patient object containing PTB-XL event data.

Returns:
A list of sample dictionaries with keys:
- patient_id
- visit_id
- record_id
- signal
- label
"""

samples = []

rows = patient.data_source.to_dicts()

for idx, row in enumerate(rows):
raw_label = row["ptbxl/scp_codes"]
record_rel_path = row["ptbxl/record_path"]

try:
scp_codes = (
ast.literal_eval(raw_label)
if isinstance(raw_label, str)
else raw_label
)
except (ValueError, SyntaxError):
scp_codes = {}

label = 1 if any(code in self.mi_codes for code in scp_codes.keys()) else 0
signal = self._load_ecg_signal(record_rel_path)

visit_id = str(row["ptbxl/ecg_id"])

samples.append(
{
"patient_id": patient.patient_id,
"visit_id": visit_id,
"record_id": idx + 1,
"signal": signal.tolist(),
"label": label,
}
)

return samples
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"more-itertools~=10.8.0",
"einops>=0.8.0",
"linear-attention-transformer>=0.19.1",
"wfdb>=4.0.0"
]
license = "BSD-3-Clause"
license-files = ["LICENSE.md"]
Expand Down
36 changes: 36 additions & 0 deletions tests/core/test_ptbxl_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
import tempfile
import unittest

from pyhealth.datasets import PTBXLDataset


class TestPTBXLDataset(unittest.TestCase):
def test_load_data_dev_mode(self):
with tempfile.TemporaryDirectory() as tmpdir:
csv_path = os.path.join(tmpdir, "ptbxl_database.csv")

with open(csv_path, "w") as f:
f.write("ecg_id,patient_id,filename_lr,filename_hr,scp_codes\n")
f.write('1,100,records100/00000/00001_lr,records500/00000/00001_hr,"{\'MI\': 1}"\n')
f.write('2,101,records100/00000/00002_lr,records500/00000/00002_hr,"{\'NORM\': 1}"\n')

dataset = PTBXLDataset(
root=tmpdir,
dev=True,
)

df = dataset.load_data().compute()

self.assertEqual(len(df), 2)
self.assertIn("patient_id", df.columns)
self.assertIn("event_type", df.columns)
self.assertIn("ptbxl/ecg_id", df.columns)
self.assertIn("ptbxl/record_path", df.columns)
self.assertIn("ptbxl/scp_codes", df.columns)
self.assertEqual(str(df.iloc[0]["patient_id"]), "100")
self.assertEqual(df.iloc[0]["event_type"], "ptbxl")


if __name__ == "__main__":
unittest.main()
Loading