Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,4 @@ Available Tasks
Mutation Pathogenicity (COSMIC) <tasks/pyhealth.tasks.MutationPathogenicityPrediction>
Cancer Survival Prediction (TCGA) <tasks/pyhealth.tasks.CancerSurvivalPrediction>
Cancer Mutation Burden (TCGA) <tasks/pyhealth.tasks.CancerMutationBurden>
Temporal Mortality Prediction (eICU) <tasks/pyhealth.tasks.temporal_mortality>
147 changes: 147 additions & 0 deletions examples/eicu_temporal_mortality_rnn_vs_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import random
from typing import Dict

import numpy as np
import torch
from pyhealth.datasets import eICUDataset, get_dataloader, split_by_sample
from pyhealth.models import MLP, RNN
from pyhealth.trainer import Trainer
from pyhealth.tasks.temporal_mortality import TemporalMortalityPredictionEICU

DATA_ROOT = r"../dataset/eicu-collaborative-research-database-demo-2.0.1"


def set_seed(seed: int = 42) -> None:
"""Sets random seeds for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)


def load_sample_dataset(root: str):
"""Loads the eICU demo dataset and applies the custom temporal task."""
dataset = eICUDataset(
root=root,
tables=["diagnosis", "medication", "physicalexam"],
dev=True,
)
task = TemporalMortalityPredictionEICU()
sample_dataset = dataset.set_task(task)
return sample_dataset


def build_dataloaders(sample_dataset, batch_size: int = 8):
"""Splits the dataset and creates train/val/test dataloaders."""
train_ds, val_ds, test_ds = split_by_sample(sample_dataset, [0.6, 0.2, 0.2])

print(f"train: {len(train_ds)}")
print(f"val: {len(val_ds)}")
print(f"test: {len(test_ds)}")

train_loader = get_dataloader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=batch_size, shuffle=False)
test_loader = get_dataloader(test_ds, batch_size=batch_size, shuffle=False)
return train_loader, val_loader, test_loader


def train_and_evaluate(model, train_loader, val_loader, test_loader, epochs: int = 10) -> Dict[str, float]:
"""Trains a model and returns evaluation metrics on the test set."""
trainer = Trainer(
model=model,
metrics=["pr_auc", "roc_auc", "f1"],
device="cuda" if torch.cuda.is_available() else "cpu",
)

trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=epochs,
optimizer_class=torch.optim.Adam,
optimizer_params={"lr": 0.001},
monitor="roc_auc",
monitor_criterion="max",
load_best_model_at_last=True,
)

results = trainer.evaluate(test_loader)
return results


def describe_temporal_fields(sample_dataset) -> None:
"""Prints temporal metadata present in the processed samples."""
print("sample keys:", sample_dataset[0].keys())
years = []
groups = {"early": 0, "late": 0}
for i in range(len(sample_dataset)):
sample = sample_dataset[i]
year = sample.get("discharge_year", None)
if hasattr(year, "item"):
year = int(year.item())
if isinstance(year, int) and year != -1:
years.append(year)

group = sample.get("split_group", None)
if group is not None:
groups[str(group)] = groups.get(str(group), 0) + 1

if years:
print("unique discharge years:", sorted(set(years)))
else:
print("no valid discharge_year values found in processed samples")

print("split_group counts:", groups)


def main() -> None:
"""Runs the full baseline experiment."""
set_seed(42)

sample_dataset = load_sample_dataset(DATA_ROOT)
print(sample_dataset)
print(sample_dataset[0])
print("num samples:", len(sample_dataset))
describe_temporal_fields(sample_dataset)

train_loader, val_loader, test_loader = build_dataloaders(
sample_dataset=sample_dataset,
batch_size=8,
)

print("\n=== Training RNN baseline ===")
rnn_model = RNN(
dataset=sample_dataset,
embedding_dim=64,
hidden_dim=64,
)
rnn_results = train_and_evaluate(
model=rnn_model,
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
epochs=10,
)
print("RNN results:", rnn_results)

print("\n=== Training MLP baseline ===")
mlp_model = MLP(
dataset=sample_dataset,
embedding_dim=64,
)
mlp_results = train_and_evaluate(
model=mlp_model,
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
epochs=10,
)
print("MLP results:", mlp_results)

print("\n=== Summary ===")
print({"task": "TemporalMortalityPredictionEICU", "model": "RNN", **rnn_results})
print({"task": "TemporalMortalityPredictionEICU", "model": "MLP", **mlp_results})


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyhealth/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
MultimodalMortalityPredictionMIMIC3,
MultimodalMortalityPredictionMIMIC4,
)
from .temporal_mortality import TemporalMortalityPredictionEICU
from .survival_preprocess_support2 import SurvivalPreprocessSupport2
from .mortality_prediction_stagenet_mimic4 import (
MortalityPredictionStageNetMIMIC4,
Expand Down
123 changes: 123 additions & 0 deletions pyhealth/tasks/temporal_mortality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import Any, Dict, List, Optional

from .base_task import BaseTask


class TemporalMortalityPredictionEICU(BaseTask):
"""Task for temporal mortality prediction using the eICU dataset.

This task predicts whether the patient will die in the *next* hospital stay
based on diagnoses, physical exams, and medications from the current ICU stay.

It extends the standard eICU mortality task with temporal metadata:
- ``discharge_year`` for coarse calendar-time grouping
- ``stay_order`` for within-patient chronology
- ``split_group`` (early/late) for simple temporal cohorting

Examples:
>>> from pyhealth.datasets import eICUDataset
>>> from pyhealth.tasks.temporal_mortality import TemporalMortalityPredictionEICU
>>> dataset = eICUDataset(
... root="/path/to/eicu-crd/2.0",
... tables=["diagnosis", "medication", "physicalexam"],
... )
>>> task = TemporalMortalityPredictionEICU()
>>> samples = dataset.set_task(task)
"""

task_name: str = "TemporalMortalityPredictionEICU"
input_schema: Dict[str, str] = {
"conditions": "sequence",
"procedures": "sequence",
"drugs": "sequence",
}
output_schema: Dict[str, str] = {"mortality": "binary"}

@staticmethod
def _normalize_year(value: Optional[Any]) -> int:
"""Converts the provided year value into an integer if possible."""
if value is None:
return -1
try:
year = int(value)
# Basic sanity range for eICU years.
if 1900 <= year <= 2100:
return year
except (TypeError, ValueError):
pass
return -1

def __call__(self, patient: Any) -> List[Dict[str, Any]]:
"""Processes a single patient into temporal mortality samples."""
samples: List[Dict[str, Any]] = []

patient_stays = patient.get_events(event_type="patient")
if len(patient_stays) <= 1:
return []

num_candidate_stays = len(patient_stays) - 1

for i in range(num_candidate_stays):
stay = patient_stays[i]
next_stay = patient_stays[i + 1]

discharge_status = getattr(next_stay, "hospitaldischargestatus", None)
if discharge_status not in ["Alive", "Expired"]:
mortality_label = 0
else:
mortality_label = 0 if discharge_status == "Alive" else 1

stay_id = str(getattr(stay, "patientunitstayid", ""))

diagnoses = patient.get_events(
event_type="diagnosis",
filters=[("patientunitstayid", "==", stay_id)],
)
physical_exams = patient.get_events(
event_type="physicalexam",
filters=[("patientunitstayid", "==", stay_id)],
)
medications = patient.get_events(
event_type="medication",
filters=[("patientunitstayid", "==", stay_id)],
)

conditions = [
getattr(event, "icd9code", "")
for event in diagnoses
if getattr(event, "icd9code", None)
]
procedures_list = [
getattr(event, "physicalexampath", "")
for event in physical_exams
if getattr(event, "physicalexampath", None)
]
drugs = [
getattr(event, "drugname", "")
for event in medications
if getattr(event, "drugname", None)
]

if len(conditions) * len(procedures_list) * len(drugs) == 0:
continue

discharge_year = self._normalize_year(
getattr(stay, "hospitaldischargeyear", None)
)
split_group = "early" if i < max(1, num_candidate_stays // 2) else "late"

samples.append(
{
"visit_id": stay_id,
"patient_id": patient.patient_id,
"conditions": conditions,
"procedures": procedures_list,
"drugs": drugs,
"mortality": mortality_label,
"discharge_year": discharge_year,
"stay_order": i,
"split_group": split_group,
}
)

return samples
Loading