diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..eb76599aa 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -230,3 +230,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + Temporal Mortality Prediction (eICU) diff --git a/examples/eicu_temporal_mortality_rnn_vs_mlp.py b/examples/eicu_temporal_mortality_rnn_vs_mlp.py new file mode 100644 index 000000000..d109a14fd --- /dev/null +++ b/examples/eicu_temporal_mortality_rnn_vs_mlp.py @@ -0,0 +1,147 @@ +import random +from typing import Dict + +import numpy as np +import torch +from pyhealth.datasets import eICUDataset, get_dataloader, split_by_sample +from pyhealth.models import MLP, RNN +from pyhealth.trainer import Trainer +from pyhealth.tasks.temporal_mortality import TemporalMortalityPredictionEICU + +DATA_ROOT = r"../dataset/eicu-collaborative-research-database-demo-2.0.1" + + +def set_seed(seed: int = 42) -> None: + """Sets random seeds for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def load_sample_dataset(root: str): + """Loads the eICU demo dataset and applies the custom temporal task.""" + dataset = eICUDataset( + root=root, + tables=["diagnosis", "medication", "physicalexam"], + dev=True, + ) + task = TemporalMortalityPredictionEICU() + sample_dataset = dataset.set_task(task) + return sample_dataset + + +def build_dataloaders(sample_dataset, batch_size: int = 8): + """Splits the dataset and creates train/val/test dataloaders.""" + train_ds, val_ds, test_ds = split_by_sample(sample_dataset, [0.6, 0.2, 0.2]) + + print(f"train: {len(train_ds)}") + print(f"val: {len(val_ds)}") + print(f"test: {len(test_ds)}") + + train_loader = get_dataloader(train_ds, batch_size=batch_size, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=batch_size, shuffle=False) + test_loader = get_dataloader(test_ds, batch_size=batch_size, shuffle=False) + return train_loader, val_loader, test_loader + + +def train_and_evaluate(model, train_loader, val_loader, test_loader, epochs: int = 10) -> Dict[str, float]: + """Trains a model and returns evaluation metrics on the test set.""" + trainer = Trainer( + model=model, + metrics=["pr_auc", "roc_auc", "f1"], + device="cuda" if torch.cuda.is_available() else "cpu", + ) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=epochs, + optimizer_class=torch.optim.Adam, + optimizer_params={"lr": 0.001}, + monitor="roc_auc", + monitor_criterion="max", + load_best_model_at_last=True, + ) + + results = trainer.evaluate(test_loader) + return results + + +def describe_temporal_fields(sample_dataset) -> None: + """Prints temporal metadata present in the processed samples.""" + print("sample keys:", sample_dataset[0].keys()) + years = [] + groups = {"early": 0, "late": 0} + for i in range(len(sample_dataset)): + sample = sample_dataset[i] + year = sample.get("discharge_year", None) + if hasattr(year, "item"): + year = int(year.item()) + if isinstance(year, int) and year != -1: + years.append(year) + + group = sample.get("split_group", None) + if group is not None: + groups[str(group)] = groups.get(str(group), 0) + 1 + + if years: + print("unique discharge years:", sorted(set(years))) + else: + print("no valid discharge_year values found in processed samples") + + print("split_group counts:", groups) + + +def main() -> None: + """Runs the full baseline experiment.""" + set_seed(42) + + sample_dataset = load_sample_dataset(DATA_ROOT) + print(sample_dataset) + print(sample_dataset[0]) + print("num samples:", len(sample_dataset)) + describe_temporal_fields(sample_dataset) + + train_loader, val_loader, test_loader = build_dataloaders( + sample_dataset=sample_dataset, + batch_size=8, + ) + + print("\n=== Training RNN baseline ===") + rnn_model = RNN( + dataset=sample_dataset, + embedding_dim=64, + hidden_dim=64, + ) + rnn_results = train_and_evaluate( + model=rnn_model, + train_loader=train_loader, + val_loader=val_loader, + test_loader=test_loader, + epochs=10, + ) + print("RNN results:", rnn_results) + + print("\n=== Training MLP baseline ===") + mlp_model = MLP( + dataset=sample_dataset, + embedding_dim=64, + ) + mlp_results = train_and_evaluate( + model=mlp_model, + train_loader=train_loader, + val_loader=val_loader, + test_loader=test_loader, + epochs=10, + ) + print("MLP results:", mlp_results) + + print("\n=== Summary ===") + print({"task": "TemporalMortalityPredictionEICU", "model": "RNN", **rnn_results}) + print({"task": "TemporalMortalityPredictionEICU", "model": "MLP", **mlp_results}) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..dad25e001 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -41,6 +41,7 @@ MultimodalMortalityPredictionMIMIC3, MultimodalMortalityPredictionMIMIC4, ) +from .temporal_mortality import TemporalMortalityPredictionEICU from .survival_preprocess_support2 import SurvivalPreprocessSupport2 from .mortality_prediction_stagenet_mimic4 import ( MortalityPredictionStageNetMIMIC4, diff --git a/pyhealth/tasks/temporal_mortality.py b/pyhealth/tasks/temporal_mortality.py new file mode 100644 index 000000000..70f1f509f --- /dev/null +++ b/pyhealth/tasks/temporal_mortality.py @@ -0,0 +1,123 @@ +from typing import Any, Dict, List, Optional + +from .base_task import BaseTask + + +class TemporalMortalityPredictionEICU(BaseTask): + """Task for temporal mortality prediction using the eICU dataset. + + This task predicts whether the patient will die in the *next* hospital stay + based on diagnoses, physical exams, and medications from the current ICU stay. + + It extends the standard eICU mortality task with temporal metadata: + - ``discharge_year`` for coarse calendar-time grouping + - ``stay_order`` for within-patient chronology + - ``split_group`` (early/late) for simple temporal cohorting + + Examples: + >>> from pyhealth.datasets import eICUDataset + >>> from pyhealth.tasks.temporal_mortality import TemporalMortalityPredictionEICU + >>> dataset = eICUDataset( + ... root="/path/to/eicu-crd/2.0", + ... tables=["diagnosis", "medication", "physicalexam"], + ... ) + >>> task = TemporalMortalityPredictionEICU() + >>> samples = dataset.set_task(task) + """ + + task_name: str = "TemporalMortalityPredictionEICU" + input_schema: Dict[str, str] = { + "conditions": "sequence", + "procedures": "sequence", + "drugs": "sequence", + } + output_schema: Dict[str, str] = {"mortality": "binary"} + + @staticmethod + def _normalize_year(value: Optional[Any]) -> int: + """Converts the provided year value into an integer if possible.""" + if value is None: + return -1 + try: + year = int(value) + # Basic sanity range for eICU years. + if 1900 <= year <= 2100: + return year + except (TypeError, ValueError): + pass + return -1 + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Processes a single patient into temporal mortality samples.""" + samples: List[Dict[str, Any]] = [] + + patient_stays = patient.get_events(event_type="patient") + if len(patient_stays) <= 1: + return [] + + num_candidate_stays = len(patient_stays) - 1 + + for i in range(num_candidate_stays): + stay = patient_stays[i] + next_stay = patient_stays[i + 1] + + discharge_status = getattr(next_stay, "hospitaldischargestatus", None) + if discharge_status not in ["Alive", "Expired"]: + mortality_label = 0 + else: + mortality_label = 0 if discharge_status == "Alive" else 1 + + stay_id = str(getattr(stay, "patientunitstayid", "")) + + diagnoses = patient.get_events( + event_type="diagnosis", + filters=[("patientunitstayid", "==", stay_id)], + ) + physical_exams = patient.get_events( + event_type="physicalexam", + filters=[("patientunitstayid", "==", stay_id)], + ) + medications = patient.get_events( + event_type="medication", + filters=[("patientunitstayid", "==", stay_id)], + ) + + conditions = [ + getattr(event, "icd9code", "") + for event in diagnoses + if getattr(event, "icd9code", None) + ] + procedures_list = [ + getattr(event, "physicalexampath", "") + for event in physical_exams + if getattr(event, "physicalexampath", None) + ] + drugs = [ + getattr(event, "drugname", "") + for event in medications + if getattr(event, "drugname", None) + ] + + if len(conditions) * len(procedures_list) * len(drugs) == 0: + continue + + discharge_year = self._normalize_year( + getattr(stay, "hospitaldischargeyear", None) + ) + split_group = "early" if i < max(1, num_candidate_stays // 2) else "late" + + samples.append( + { + "visit_id": stay_id, + "patient_id": patient.patient_id, + "conditions": conditions, + "procedures": procedures_list, + "drugs": drugs, + "mortality": mortality_label, + "discharge_year": discharge_year, + "stay_order": i, + "split_group": split_group, + } + ) + + return samples \ No newline at end of file diff --git a/tests/test_temporal_mortality.py b/tests/test_temporal_mortality.py new file mode 100644 index 000000000..398f94c4d --- /dev/null +++ b/tests/test_temporal_mortality.py @@ -0,0 +1,195 @@ +""" +Tests for the TemporalMortalityPredictionEICU task. + +These tests use tiny synthetic patient objects and do not require real eICU data. +""" + +from dataclasses import dataclass +from typing import Dict, List, Optional + +from pyhealth.tasks.temporal_mortality import TemporalMortalityPredictionEICU + + +@dataclass +class FakeEvent: + """Simple event container for synthetic tests.""" + attrs: Dict[str, object] + + def __getattr__(self, item): + if item in self.attrs: + return self.attrs[item] + raise AttributeError(item) + + +class FakePatient: + """Minimal patient stub implementing the methods used by the task.""" + + def __init__( + self, + patient_id: str, + patient_events: Optional[List[FakeEvent]] = None, + diagnosis_events: Optional[List[FakeEvent]] = None, + physicalexam_events: Optional[List[FakeEvent]] = None, + medication_events: Optional[List[FakeEvent]] = None, + ) -> None: + self.patient_id = patient_id + self._events = { + "patient": patient_events or [], + "diagnosis": diagnosis_events or [], + "physicalexam": physicalexam_events or [], + "medication": medication_events or [], + } + + def get_events(self, event_type: str, filters=None, return_df: bool = False): + assert return_df is False + events = self._events.get(event_type, []) + if not filters: + return events + + filtered = [] + for event in events: + keep = True + for field, op, value in filters: + event_value = getattr(event, field) + if op == "==" and str(event_value) != str(value): + keep = False + break + if keep: + filtered.append(event) + return filtered + + +def make_valid_patient() -> FakePatient: + """Creates a patient with two stays so one sample can be generated.""" + patient_events = [ + FakeEvent( + { + "patientunitstayid": "stay_1", + "hospitaldischargestatus": "Alive", + "hospitaldischargeyear": 2014, + } + ), + FakeEvent( + { + "patientunitstayid": "stay_2", + "hospitaldischargestatus": "Expired", + "hospitaldischargeyear": 2015, + } + ), + ] + + diagnosis_events = [ + FakeEvent({"patientunitstayid": "stay_1", "icd9code": "038.9"}), + FakeEvent({"patientunitstayid": "stay_2", "icd9code": "518.81"}), + ] + physicalexam_events = [ + FakeEvent({"patientunitstayid": "stay_1", "physicalexampath": "cardiovascular"}), + FakeEvent({"patientunitstayid": "stay_2", "physicalexampath": "pulmonary"}), + ] + medication_events = [ + FakeEvent({"patientunitstayid": "stay_1", "drugname": "vancomycin"}), + FakeEvent({"patientunitstayid": "stay_2", "drugname": "norepinephrine"}), + ] + + return FakePatient( + patient_id="patient_1", + patient_events=patient_events, + diagnosis_events=diagnosis_events, + physicalexam_events=physicalexam_events, + medication_events=medication_events, + ) + + +def test_temporal_mortality_generates_samples(): + task = TemporalMortalityPredictionEICU() + patient = make_valid_patient() + + samples = list(task(patient)) + + assert len(samples) == 1 + sample = samples[0] + + assert sample["patient_id"] == "patient_1" + assert sample["visit_id"] == "stay_1" + assert "conditions" in sample + assert "procedures" in sample + assert "drugs" in sample + assert "mortality" in sample + assert "discharge_year" in sample + assert "stay_order" in sample + assert "split_group" in sample + + +def test_temporal_mortality_label_is_binary(): + task = TemporalMortalityPredictionEICU() + patient = make_valid_patient() + + sample = list(task(patient))[0] + assert sample["mortality"] in [0, 1] + + +def test_temporal_mortality_split_group_is_valid(): + task = TemporalMortalityPredictionEICU() + patient = make_valid_patient() + + sample = list(task(patient))[0] + assert sample["split_group"] in ["early", "late"] + + +def test_temporal_mortality_skips_visits_without_required_features(): + task = TemporalMortalityPredictionEICU() + + patient = FakePatient( + patient_id="patient_empty", + patient_events=[ + FakeEvent( + { + "patientunitstayid": "stay_x", + "hospitaldischargestatus": "Alive", + "hospitaldischargeyear": 2014, + } + ), + FakeEvent( + { + "patientunitstayid": "stay_y", + "hospitaldischargestatus": "Alive", + "hospitaldischargeyear": 2015, + } + ), + ], + diagnosis_events=[], + physicalexam_events=[], + medication_events=[], + ) + + samples = list(task(patient)) + assert samples == [] + + +def test_temporal_mortality_requires_multiple_visits(): + task = TemporalMortalityPredictionEICU() + + patient = FakePatient( + patient_id="single_visit_patient", + patient_events=[ + FakeEvent( + { + "patientunitstayid": "stay_only", + "hospitaldischargestatus": "Alive", + "hospitaldischargeyear": 2014, + } + ) + ], + diagnosis_events=[ + FakeEvent({"patientunitstayid": "stay_only", "icd9code": "486"}) + ], + physicalexam_events=[ + FakeEvent({"patientunitstayid": "stay_only", "physicalexampath": "respiratory"}) + ], + medication_events=[ + FakeEvent({"patientunitstayid": "stay_only", "drugname": "ceftriaxone"}) + ], + ) + + samples = list(task(patient)) + assert samples == [] \ No newline at end of file