From 206a17ccdaaa35f87fb91ba64184c95b274b9b34 Mon Sep 17 00:00:00 2001 From: Vismayak Mohanarajan Date: Mon, 6 Apr 2026 18:04:46 -0500 Subject: [PATCH 1/3] Add chartevents config and test notebook Add a chartevents table mapping to pyhealth/datasets/configs/mimic4_ehr.yaml. Add test_notebook.ipynb to aid manual testing (dev/demo usage) for TPC-LoS changes. --- pyhealth/datasets/configs/mimic4_ehr.yaml | 16 ++ test_notebook.ipynb | 177 ++++++++++++++++++++++ 2 files changed, 193 insertions(+) create mode 100644 test_notebook.ipynb diff --git a/pyhealth/datasets/configs/mimic4_ehr.yaml b/pyhealth/datasets/configs/mimic4_ehr.yaml index 84c570bb9..67c917ec6 100644 --- a/pyhealth/datasets/configs/mimic4_ehr.yaml +++ b/pyhealth/datasets/configs/mimic4_ehr.yaml @@ -117,3 +117,19 @@ tables: - "hcpcs_cd" - "seq_num" - "short_description" + + chartevents: + file_path: "icu/chartevents.csv.gz" + patient_id: "subject_id" + timestamp: "charttime" + attributes: + - "hadm_id" + - "stay_id" + - "caregiver_id" + - "storetime" + - "itemid" + - "value" + - "valuenum" + - "valueuom" + - "warning" + diff --git a/test_notebook.ipynb b/test_notebook.ipynb new file mode 100644 index 000000000..39fb9f071 --- /dev/null +++ b/test_notebook.ipynb @@ -0,0 +1,177 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "138ff2cb", + "metadata": {}, + "source": [ + "## Notebook to test changes to PyHealth Code for TPC-LoS" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a25c84a", + "metadata": {}, + "outputs": [], + "source": [ + "from pyhealth.datasets import MIMIC4Dataset\n", + "import json\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "markdown", + "id": "ea441ac1", + "metadata": {}, + "source": [ + "### Load the data with modified MIMIC4Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3261cf9d", + "metadata": {}, + "outputs": [], + "source": [ + "DATASETS_ROOT = Path(\"~/Desktop/CS 598 - Deep Learning for Healthcare/datasets\")\n", + "\n", + "# MIMIC-IV Demo Dataset\n", + "MIMIC4_ROOT = DATASETS_ROOT / \"mimic-iv-demo\" / \"2.2\"\n", + "\n", + "MIMIC4_TABLES = [\n", + " \"diagnoses_icd\",\n", + " \"prescriptions\",\n", + " \"labevents\",\n", + " \"chartevents\",\n", + "]\n", + "\n", + "mimic4_base = MIMIC4Dataset(\n", + " ehr_root=str(MIMIC4_ROOT),\n", + " ehr_tables=MIMIC4_TABLES,\n", + " dev=True, # use 1000 patients while exploring\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d836659", + "metadata": {}, + "outputs": [], + "source": [ + "#checking stats\n", + "print(\"Number of MIMIC patients:\", len(mimic4_base.unique_patient_ids))\n", + "# inspecting one MIMIC patient and its loaded event partitions\n", + "\n", + "first_patient = next(mimic4_base.iter_patients())\n", + "\n", + "print(\"Patient type:\", type(first_patient)) #info about patients\n", + "print(\"Patient fields:\", list(vars(first_patient).keys()))\n", + "print(\"Patient dict:\", vars(first_patient))\n", + "\n", + "\n", + "#see event partitions pyhealth created, which tables actually landed inside patient object, how each partition is stored\n", + "print(\"\\nType of event_type_partitions:\", type(first_patient.event_type_partitions))\n", + "print(\"Partition keys:\", list(first_patient.event_type_partitions.keys()))\n", + "\n", + "for key, value in first_patient.event_type_partitions.items():\n", + " print(f\"\\nPartition: {key}\")\n", + " print(\"Type:\", type(value))\n", + " if isinstance(value, list):\n", + " print(\"Length:\", len(value))\n", + " if len(value) > 0:\n", + " print(\"First item type:\", type(value[0]))\n", + " print(\"First item:\", value[0])\n", + " elif isinstance(value, dict):\n", + " print(\"Keys sample:\", list(value.keys())[:10])\n", + " else:\n", + " print(value)\n", + "\n", + "print(\"PARTITION SUMMARY\")\n", + "for key, value in first_patient.event_type_partitions.items():\n", + " if isinstance(value, list):\n", + " print(f\"{key}: list with {len(value)} events\")\n", + " elif isinstance(value, dict):\n", + " print(f\"{key}: dict with {len(value)} keys\")\n", + " else:\n", + " print(f\"{key}: {type(value)}\")\n", + "\n", + "#inspecting more patients\n", + "print(\"CHECKING MULTIPLE PATIENTS\")\n", + "for i, patient in enumerate(mimic4_base.iter_patients()):\n", + " print(f\"\\nPatient {i}\")\n", + " print(\"Patient fields:\", list(vars(patient).keys()))\n", + " print(\"Partition keys:\", list(patient.event_type_partitions.keys()))\n", + " if i == 2:\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "315538f1", + "metadata": {}, + "outputs": [], + "source": [ + "# Testing chartevents \n", + "dir(mimic4_base.get_patient(\"10000032\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12f25f5b", + "metadata": {}, + "outputs": [], + "source": [ + "mimic4_base.get_patient(\"10000032\").event_type_partitions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7d4ba1c1", + "metadata": {}, + "outputs": [], + "source": [ + "patient = mimic4_base.get_patient(\"10000032\")\n", + "\n", + "# get all chartevents for this patient\n", + "chartevent_df = patient.get_events(event_type=\"chartevents\")\n", + "print(chartevent_df)\n", + "len(chartevent_df)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a9826e39", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pyhealth-tpc", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From c560a23d83bfb09e57fa6654f965a10299d17344 Mon Sep 17 00:00:00 2001 From: Vismayak Mohanarajan Date: Mon, 6 Apr 2026 20:04:30 -0500 Subject: [PATCH 2/3] Update test_notebook.ipynb --- test_notebook.ipynb | 33 +++++++++------------------------ 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/test_notebook.ipynb b/test_notebook.ipynb index 39fb9f071..5630c93a1 100644 --- a/test_notebook.ipynb +++ b/test_notebook.ipynb @@ -109,45 +109,30 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "315538f1", - "metadata": {}, - "outputs": [], - "source": [ - "# Testing chartevents \n", - "dir(mimic4_base.get_patient(\"10000032\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12f25f5b", + "cell_type": "markdown", + "id": "9e22478a", "metadata": {}, - "outputs": [], "source": [ - "mimic4_base.get_patient(\"10000032\").event_type_partitions" + "#3# Testing chartevents works for specific patient" ] }, { "cell_type": "code", "execution_count": null, - "id": "7d4ba1c1", + "id": "d4dd0916", "metadata": {}, "outputs": [], "source": [ - "patient = mimic4_base.get_patient(\"10000032\")\n", - "\n", - "# get all chartevents for this patient\n", - "chartevent_df = patient.get_events(event_type=\"chartevents\")\n", - "print(chartevent_df)\n", - "len(chartevent_df)" + "patient = mimic4_base.get_patient(mimic4_base.unique_patient_ids[0])\n", + "events = patient.get_events(event_type=\"chartevents\")\n", + "assert len(events) > 0\n", + "print(events[0].itemid, events[0].valuenum, events[0].timestamp)" ] }, { "cell_type": "code", "execution_count": null, - "id": "a9826e39", + "id": "52edc32d", "metadata": {}, "outputs": [], "source": [] From fcb1302817f9cc9690fd908160497b526a7f6a87 Mon Sep 17 00:00:00 2001 From: Kevin Chen Date: Wed, 22 Apr 2026 18:14:30 -0700 Subject: [PATCH 3/3] Implement TPC --- docs/api/models.rst | 1 + docs/api/models/pyhealth.models.TPC.rst | 14 + docs/api/processors.rst | 5 +- ...processors.RegressionSequenceProcessor.rst | 9 + ...pyhealth.processors.TPCStaticProcessor.rst | 9 + ...alth.processors.TPCTimeseriesProcessor.rst | 9 + docs/api/tasks.rst | 1 + ....tasks.RemainingLengthOfStayTPC_MIMIC4.rst | 10 + examples/mimic4_remaining_los_tpc.py | 168 +++++++ examples/mimic4_remaining_los_tpc_ablation.py | 362 +++++++++++++++ pyhealth/__init__.py | 6 +- pyhealth/models/__init__.py | 1 + pyhealth/models/tpc.py | 395 ++++++++++++++++ pyhealth/processors/__init__.py | 6 + .../regression_sequence_processor.py | 46 ++ pyhealth/processors/tpc_static_processor.py | 206 +++++++++ .../processors/tpc_timeseries_processor.py | 216 +++++++++ pyhealth/tasks/__init__.py | 1 + .../remaining_length_of_stay_tpc_mimic4.py | 432 ++++++++++++++++++ test_notebook.ipynb | 162 ------- tests/core/test_tpc.py | 336 ++++++++++++++ 21 files changed, 2230 insertions(+), 165 deletions(-) create mode 100644 docs/api/models/pyhealth.models.TPC.rst create mode 100644 docs/api/processors/pyhealth.processors.RegressionSequenceProcessor.rst create mode 100644 docs/api/processors/pyhealth.processors.TPCStaticProcessor.rst create mode 100644 docs/api/processors/pyhealth.processors.TPCTimeseriesProcessor.rst create mode 100644 docs/api/tasks/pyhealth.tasks.RemainingLengthOfStayTPC_MIMIC4.rst create mode 100644 examples/mimic4_remaining_los_tpc.py create mode 100644 examples/mimic4_remaining_los_tpc_ablation.py create mode 100644 pyhealth/models/tpc.py create mode 100644 pyhealth/processors/regression_sequence_processor.py create mode 100644 pyhealth/processors/tpc_static_processor.py create mode 100644 pyhealth/processors/tpc_timeseries_processor.py create mode 100644 pyhealth/tasks/remaining_length_of_stay_tpc_mimic4.py delete mode 100644 test_notebook.ipynb create mode 100644 tests/core/test_tpc.py diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..dcbec9c86 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -204,3 +204,4 @@ API Reference models/pyhealth.models.TextEmbedding models/pyhealth.models.BIOT models/pyhealth.models.unified_multimodal_embedding_docs + models/pyhealth.models.TPC diff --git a/docs/api/models/pyhealth.models.TPC.rst b/docs/api/models/pyhealth.models.TPC.rst new file mode 100644 index 000000000..ac29d0e88 --- /dev/null +++ b/docs/api/models/pyhealth.models.TPC.rst @@ -0,0 +1,14 @@ +pyhealth.models.TPC +=================================== + +The separate callable TPCBlock layer and the complete TPC model. + +.. autoclass:: pyhealth.models.tpc.TPCBlock + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.tpc.TPC + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/processors.rst b/docs/api/processors.rst index a06e3c955..e5b5563be 100644 --- a/docs/api/processors.rst +++ b/docs/api/processors.rst @@ -496,4 +496,7 @@ API Reference processors/pyhealth.processors.MultiHotProcessor processors/pyhealth.processors.StageNetProcessor processors/pyhealth.processors.StageNetTensorProcessor - processors/pyhealth.processors.GraphProcessor \ No newline at end of file + processors/pyhealth.processors.GraphProcessor + processors/pyhealth.processors.RegressionSequenceProcessor + processors/pyhealth.processors.TPCStaticProcessor + processors/pyhealth.processors.TPCTimeseriesProcessor \ No newline at end of file diff --git a/docs/api/processors/pyhealth.processors.RegressionSequenceProcessor.rst b/docs/api/processors/pyhealth.processors.RegressionSequenceProcessor.rst new file mode 100644 index 000000000..0c6c96b39 --- /dev/null +++ b/docs/api/processors/pyhealth.processors.RegressionSequenceProcessor.rst @@ -0,0 +1,9 @@ +pyhealth.processors.RegressionSequenceProcessor +=================================== + +Label processor for variable-length remaining LoS regression sequences. + +.. autoclass:: pyhealth.processors.RegressionSequenceProcessor + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/processors/pyhealth.processors.TPCStaticProcessor.rst b/docs/api/processors/pyhealth.processors.TPCStaticProcessor.rst new file mode 100644 index 000000000..c5072881c --- /dev/null +++ b/docs/api/processors/pyhealth.processors.TPCStaticProcessor.rst @@ -0,0 +1,9 @@ +pyhealth.processors.TPCStaticProcessor +=================================== + +Feature processor for TPC static inputs (demographics and admission-time measurements). + +.. autoclass:: pyhealth.processors.TPCStaticProcessor + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/processors/pyhealth.processors.TPCTimeseriesProcessor.rst b/docs/api/processors/pyhealth.processors.TPCTimeseriesProcessor.rst new file mode 100644 index 000000000..c00b8ad8f --- /dev/null +++ b/docs/api/processors/pyhealth.processors.TPCTimeseriesProcessor.rst @@ -0,0 +1,9 @@ +pyhealth.processors.TPCTimeseriesProcessor +=================================== + +Feature processor for TPC time-series inputs with hourly resampling and decay indicators. + +.. autoclass:: pyhealth.processors.TPCTimeseriesProcessor + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..427818ff8 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -229,3 +229,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + Remaining Length of Stay (TPC, MIMIC-IV) diff --git a/docs/api/tasks/pyhealth.tasks.RemainingLengthOfStayTPC_MIMIC4.rst b/docs/api/tasks/pyhealth.tasks.RemainingLengthOfStayTPC_MIMIC4.rst new file mode 100644 index 000000000..c02f8d327 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.RemainingLengthOfStayTPC_MIMIC4.rst @@ -0,0 +1,10 @@ +pyhealth.tasks.RemainingLengthOfStayTPC_MIMIC4 +=============================================== + +Task Classes +------------ + +.. autoclass:: pyhealth.tasks.remaining_length_of_stay_tpc_mimic4.RemainingLengthOfStayTPC_MIMIC4 + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/mimic4_remaining_los_tpc.py b/examples/mimic4_remaining_los_tpc.py new file mode 100644 index 000000000..5e024e609 --- /dev/null +++ b/examples/mimic4_remaining_los_tpc.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +""" +Minimal example: TPC remaining ICU LoS (MIMIC-IV). + +This is the paper-style setting: + - remaining ICU length-of-stay regression + - hourly predictions starting at hour 5 + - MSLE loss +""" + +import os +import sys + +# Put cache inside repo by default (avoids sandbox permission errors). +os.environ.setdefault("PYHEALTH_CACHE_PATH", os.path.join(os.path.dirname(__file__), "..", ".pyhealth_cache")) + +# Ensure we import the *local* repo `pyhealth/` rather than any site-packages install. +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +from pyhealth.datasets import MIMIC4EHRDataset, split_by_patient, get_dataloader +from pyhealth.tasks import RemainingLengthOfStayTPC_MIMIC4 +from pyhealth.models import TPC +from pyhealth.trainer import Trainer + +# Using Labebents and chartevents itemids reported +# in table 17 of the paper by Rocheteau et al. (2021) +LABEVENTS_ITEMIDS = [ + 50861, # Alanine Aminotransferase (ALT) + 50863, # Alkaline Phosphatase + 50868, # Anion Gap + 50878, # Asparate Aminotransferase (AST) + 50882, # Bicarbonate + 50885, # Bilirubin, Total + 50893, # Calcium, Total + 50804, # Calculated Total CO2 + 50902, # Chloride + 50912, # Creatinine + 50808, # Free Calcium + 50931, # Glucose + 51221, # Hematocrit + 50810, # Hematocrit, Calculated + 51222, # Hemoglobin + 51237, # INR(PT) + 50813, # Lactate + 51248, # MCH + 51249, # MCHC + 51250, # MCV + 50960, # Magnesium + 50817, # Oxygen Saturation + 51274, # PT + 51275, # PTT + 50970, # Phosphate + 51265, # Platelet Count + 50971, # Potassium + 51277, # RDW + 52172, # RDW-SD + 51279, # Red Blood Cells + 50983, # Sodium + 51006, # Urea Nitrogen + 51301, # White Blood Cells + 50818, # pCO2 + 50820, # pH + 50821, # pO2 +] + +CHARTEVENTS_ITEMIDS = [ + 229319, # Activity / Mobility (JH-HLM) + 223876, # Apnea Interval + 220058, # Arterial Blood Pressure Alarm - High + 220056, # Arterial Blood Pressure Alarm - Low + 220051, # Arterial Blood Pressure diastolic + 220052, # Arterial Blood Pressure mean + 220050, # Arterial Blood Pressure systolic + 229323, # Current Dyspnea Assessment + 224639, # Daily Weight + 226871, # Expiratory Ratio + 223875, # Fspn High + 220739, # GCS - Eye Opening + 223901, # GCS - Motor Response + 223900, # GCS - Verbal Response + 225664, # Glucose finger stick (range 70-100) + 220045, # Heart Rate + 220047, # Heart Rate Alarm - Low + 220046, # Heart rate Alarm - High + 223835, # Inspired O2 Fraction + 224697, # Mean Airway Pressure + 224687, # Minute Volume + 220293, # Minute Volume Alarm - High + 220292, # Minute Volume Alarm - Low + 220180, # Non Invasive Blood Pressure diastolic + 220181, # Non Invasive Blood Pressure mean + 220179, # Non Invasive Blood Pressure systolic + 223751, # Non-Invasive Blood Pressure Alarm - High + 223752, # Non-Invasive Blood Pressure Alarm - Low + 223834, # O2 Flow + 223770, # O2 Saturation Pulseoxymetry Alarm - Low + 220277, # O2 saturation pulseoxymetry + 220339, # PEEP set + 224701, # PSV Level + 223791, # Pain Level + 224409, # Pain Level Response + 223873, # Paw High + 224695, # Peak Insp. Pressure + 225677, # Phosphorous + 224696, # Plateau Pressure + 224161, # Resp Alarm - High + 224162, # Resp Alarm - Low + 220210, # Respiratory Rate + 224688, # Respiratory Rate (Set) + 224690, # Respiratory Rate (Total) + 224689, # Respiratory Rate (spontaneous) + 228096, # Richmond-RAS Scale + 228409, # Strength L Arm + 228410, # Strength L Leg + 228412, # Strength R Arm + 228411, # Strength R Leg + 223761, # Temperature Fahrenheit + 224685, # Tidal Volume (observed) + 224684, # Tidal Volume (set) + 224686, # Tidal Volume (spontaneous) + 224700, # Total PEEP Level + 223849, # Ventilator Mode + 223874, # Vti High +] + + +def main(): + # Adjust these paths for your environment. + ehr_root = "./datasets/mimic-iv-demo/2.2" + cache_dir = os.path.join(_REPO_ROOT, ".pyhealth_dataset_cache") + + dataset = MIMIC4EHRDataset( + root=ehr_root, + tables=["patients", "admissions", "icustays", "labevents", "chartevents"], + dev=True, + num_workers=1, + cache_dir=cache_dir, + ) + + task = RemainingLengthOfStayTPC_MIMIC4( + labevent_itemids=LABEVENTS_ITEMIDS, + chartevent_itemids=CHARTEVENTS_ITEMIDS, + ) + sample_dataset = dataset.set_task(task) + + train_ds, val_ds, test_ds = split_by_patient(sample_dataset, ratios=[0.8, 0.1, 0.1]) + train_loader = get_dataloader(train_ds, batch_size=8, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=8, shuffle=False) + test_loader = get_dataloader(test_ds, batch_size=8, shuffle=False) + + model = TPC(dataset=sample_dataset) + trainer = Trainer(model, metrics=["mae", "mse"]) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + test_dataloader=test_loader, + epochs=5, + monitor="mae", + monitor_criterion="min", + ) + + +if __name__ == "__main__": + main() + diff --git a/examples/mimic4_remaining_los_tpc_ablation.py b/examples/mimic4_remaining_los_tpc_ablation.py new file mode 100644 index 000000000..ec25520cb --- /dev/null +++ b/examples/mimic4_remaining_los_tpc_ablation.py @@ -0,0 +1,362 @@ +from __future__ import annotations + +""" +Ablation study: lab-only vs chart-only vs lab+chart for TPC remaining LoS. + +Converted from examples/ablation.ipynb for local / CLI runs. +Install deps first: pip install -e . scikit-learn pandas +""" + +import argparse +import os +import sys +from typing import List + +import numpy as np +import pandas as pd +import torch + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +os.environ.setdefault( + "PYHEALTH_CACHE_PATH", + os.path.join(_REPO_ROOT, ".pyhealth_cache"), +) + +from pyhealth.datasets import MIMIC4EHRDataset, get_dataloader, split_by_patient +from pyhealth.models import TPC +from pyhealth.tasks import RemainingLengthOfStayTPC_MIMIC4 +from pyhealth.tasks.length_of_stay_prediction import categorize_los +from pyhealth.trainer import Trainer + + +def bin_remaining_los_days(y_days: np.ndarray) -> np.ndarray: + flat = y_days.reshape(-1) + out = np.zeros_like(flat, dtype=int) + for i, value in enumerate(flat): + if value == 0: + out[i] = -1 + else: + out[i] = categorize_los(int(np.floor(value))) + return out.reshape(y_days.shape) + + +def masked_kappa_and_accuracy(y_true_days: np.ndarray, y_pred_days: np.ndarray): + from sklearn.metrics import cohen_kappa_score + + y_true_bin = bin_remaining_los_days(y_true_days) + y_pred_bin = bin_remaining_los_days(y_pred_days) + + mask = y_true_bin != -1 + yt = y_true_bin[mask] + yp = y_pred_bin[mask] + + acc = float((yt == yp).mean()) if yt.size else float("nan") + kappa = float(cohen_kappa_score(yt, yp)) if yt.size else float("nan") + return kappa, acc + + +def safe_variable_length_inference(model, dataloader, device=None): + """Run inference batch-by-batch and stack predictions for metric computation. + + TPC returns ``y_true`` / ``y_prob`` as **1-D** masked tensors (all valid + timesteps in the batch flattened). Some models return **2-D** ``(B, T)`` + padded sequences; both layouts are supported. + """ + if device is None: + device = next(model.parameters()).device + + model.eval() + + y_true_all: list[np.ndarray] = [] + y_pred_all: list[np.ndarray] = [] + loss_all: list[float] = [] + max_len = 0 + use_2d = False + + with torch.no_grad(): + for data in dataloader: + batch = {} + for k, v in data.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.to(device) + else: + batch[k] = v + + output = model(**batch) + + y_true = output["y_true"].detach().cpu().numpy() + y_pred = output["y_prob"].detach().cpu().numpy() + loss_all.append(float(output["loss"].item())) + + if y_true.ndim == 1: + y_true_all.append(y_true) + y_pred_all.append(y_pred) + elif y_true.ndim == 2: + use_2d = True + max_len = max(max_len, y_true.shape[1]) + y_true_all.append(y_true) + y_pred_all.append(y_pred) + else: + raise ValueError(f"Unexpected y_true shape: {y_true.shape}") + + if not use_2d: + y_true_concat = np.concatenate(y_true_all, axis=0) if y_true_all else np.array([]) + y_pred_concat = np.concatenate(y_pred_all, axis=0) if y_pred_all else np.array([]) + else: + padded_true = [] + padded_pred = [] + for yt, yp in zip(y_true_all, y_pred_all): + pad_width = max_len - yt.shape[1] + if pad_width > 0: + yt = np.pad(yt, ((0, 0), (0, pad_width)), mode="constant", constant_values=0) + yp = np.pad(yp, ((0, 0), (0, pad_width)), mode="constant", constant_values=0) + padded_true.append(yt) + padded_pred.append(yp) + y_true_concat = np.concatenate(padded_true, axis=0) + y_pred_concat = np.concatenate(padded_pred, axis=0) + + loss_mean = float(np.mean(loss_all)) if loss_all else float("nan") + return y_true_concat, y_pred_concat, loss_mean + + +def run_experiment( + name: str, + root: str, + cache_dir: str, + labevent_itemids: List[str], + chartevent_itemids: List[str], + epochs: int = 1, +): + print(f"\n========== Running experiment: {name} ==========") + print(f"Lab features: {len(labevent_itemids)}") + print(f"Chart features: {len(chartevent_itemids)}") + + dataset = MIMIC4EHRDataset( + root=root, + tables=["patients", "admissions", "icustays", "labevents", "chartevents"], + dev=False, + num_workers=2, + cache_dir=cache_dir, + ) + + task = RemainingLengthOfStayTPC_MIMIC4( + labevent_itemids=labevent_itemids, + chartevent_itemids=chartevent_itemids, + ) + + sample_dataset = dataset.set_task(task) + print("Task dataset built. Number of samples:", len(sample_dataset)) + + train_ds, val_ds, test_ds = split_by_patient(sample_dataset, ratios=[0.8, 0.1, 0.1]) + + print("Train:", len(train_ds)) + print("Val:", len(val_ds)) + print("Test:", len(test_ds)) + + train_loader = get_dataloader(train_ds, batch_size=8, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=8, shuffle=False) + test_loader = get_dataloader(test_ds, batch_size=8, shuffle=False) + + model = TPC( + dataset=sample_dataset, + temporal_channels=11, + pointwise_channels=5, + num_layers=8, + kernel_size=5, + main_dropout=0.0, + temporal_dropout=0.05, + use_batchnorm=True, + final_hidden=36, + ) + + trainer = Trainer( + model=model, + metrics=["mae", "mse"], + enable_logging=False, + ) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=None, + test_dataloader=None, + epochs=epochs, + monitor="mae", + monitor_criterion="min", + optimizer_params={"lr": 0.00221}, + ) + + y_true, y_pred, _ = safe_variable_length_inference(model, test_loader) + + kappa, acc = masked_kappa_and_accuracy(y_true, y_pred) + mae = float(np.mean(np.abs(y_true - y_pred))) + mse = float(np.mean((y_true - y_pred) ** 2)) + + result = { + "experiment": name, + "mae": mae, + "mse": mse, + "kappa": kappa, + "accuracy": acc, + } + + print("Result:", result) + return result + + +def _check_required_files(data_root: str) -> bool: + required_files = [ + ("hosp", "patients.csv.gz"), + ("hosp", "admissions.csv.gz"), + ("icu", "icustays.csv.gz"), + ("hosp", "labevents.csv.gz"), + ("hosp", "d_labitems.csv.gz"), + ("icu", "chartevents.csv.gz"), + ] + print("Checking required files...\n") + all_ok = True + for sub, fn in required_files: + path = os.path.join(data_root, sub, fn) + exists = os.path.exists(path) + print(f"{path}: {exists}") + if not exists: + all_ok = False + print("\nAll files found:", all_ok) + return all_ok + + +def main() -> None: + default_data = os.path.join(_REPO_ROOT, "datasets", "mimic-iv-demo", "2.2") + default_csv = os.path.join(_REPO_ROOT, "ablation_results.csv") + + parser = argparse.ArgumentParser(description="Run TPC ablation experiments (from ablation.ipynb).") + parser.add_argument( + "--data-root", + default=os.environ.get("MIMIC4_DATA_ROOT", default_data), + help="MIMIC-IV (or demo) root containing hosp/ and icu/ (default: repo demo path or MIMIC4_DATA_ROOT).", + ) + parser.add_argument( + "--output-csv", + default=default_csv, + help="Where to write the results table.", + ) + parser.add_argument("--epochs", type=int, default=1, help="Training epochs per experiment.") + parser.add_argument( + "--skip-sanity", + action="store_true", + help="Skip the small dev=True dataset load used in the notebook as a sanity check.", + ) + parser.add_argument( + "--skip-baseline", + action="store_true", + help="Skip the lab_only_debug run before the three named ablations.", + ) + args = parser.parse_args() + + data_root = os.path.abspath(args.data_root) + print("DATA_ROOT:", data_root) + + if not _check_required_files(data_root): + raise SystemExit(1) + + labevent_itemids = [ + "50824", "52455", "50983", "52623", "50822", "52452", "50971", + "52610", "50806", "52434", "50902", "52535", "50803", "50804", + "50809", "52027", "50931", "52569", "50808", "51624", "50960", + "50868", "52500", "52031", "50964", "51701", "50970", + ] + + chart_vitals_itemids = [ + "220045", + "220179", + "220180", + "220181", + "220210", + "224690", + "223761", + "223762", + "220277", + "225664", + "220621", + "226537", + ] + + if not args.skip_sanity: + dataset = MIMIC4EHRDataset( + root=data_root, + tables=["patients", "admissions", "icustays", "labevents", "chartevents"], + dev=True, + num_workers=2, + cache_dir=os.path.join(_REPO_ROOT, ".sanity_cache"), + ) + print(dataset) + task = RemainingLengthOfStayTPC_MIMIC4( + labevent_itemids=labevent_itemids, + chartevent_itemids=[], + ) + sample_dataset = dataset.set_task(task) + print(sample_dataset) + print("Number of samples:", len(sample_dataset)) + + if not args.skip_baseline: + run_experiment( + name="lab_only_debug", + root=data_root, + cache_dir=os.path.join(_REPO_ROOT, ".ablation_cache", "lab_only_debug"), + labevent_itemids=labevent_itemids, + chartevent_itemids=[], + epochs=args.epochs, + ) + + experiments = [ + {"name": "lab_only", "labevent_itemids": labevent_itemids, "chartevent_itemids": []}, + { + "name": "chart_only", + "labevent_itemids": [], + "chartevent_itemids": chart_vitals_itemids, + }, + { + "name": "lab_plus_chart", + "labevent_itemids": labevent_itemids, + "chartevent_itemids": chart_vitals_itemids, + }, + ] + + all_results = [] + for exp in experiments: + result = run_experiment( + name=exp["name"], + root=data_root, + cache_dir=os.path.join(_REPO_ROOT, ".ablation_cache", exp["name"]), + labevent_itemids=exp["labevent_itemids"], + chartevent_itemids=exp["chartevent_itemids"], + epochs=args.epochs, + ) + all_results.append(result) + + results_df = pd.DataFrame(all_results) + print("\nResults:") + print(results_df) + print("\nSorted by MAE:") + print(results_df.sort_values("mae")) + + out_path = os.path.abspath(args.output_csv) + os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) + results_df.to_csv(out_path, index=False) + print("Saved to:", out_path) + + best_row = results_df.sort_values("mae").iloc[0] + print("\nBest configuration:") + print(best_row) + print("\nInterpretation template:") + print( + f"The best-performing ablation setting was {best_row['experiment']} " + f"with MAE={best_row['mae']:.4f}, MSE={best_row['mse']:.4f}, " + f"Kappa={best_row['kappa']:.4f}, Accuracy={best_row['accuracy']:.4f}." + ) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/__init__.py b/pyhealth/__init__.py index 086e68849..edfaaa9c9 100755 --- a/pyhealth/__init__.py +++ b/pyhealth/__init__.py @@ -6,10 +6,12 @@ __version__ = "2.0.0" # package-level cache path -BASE_CACHE_PATH = os.path.join(str(Path.home()), ".cache/pyhealth/") +_DEFAULT_CACHE_PATH = os.path.join(str(Path.home()), ".cache/pyhealth/") +# Allow overriding cache location (useful for sandboxed environments). +BASE_CACHE_PATH = os.environ.get("PYHEALTH_CACHE_PATH", _DEFAULT_CACHE_PATH) # BASE_CACHE_PATH = "/srv/local/data/pyhealth-cache" if not os.path.exists(BASE_CACHE_PATH): - os.makedirs(BASE_CACHE_PATH) + os.makedirs(BASE_CACHE_PATH, exist_ok=True) # logging logger = logging.getLogger(__name__) diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..c2fc45b73 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -44,3 +44,4 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding +from .tpc import TPC, TPCBlock diff --git a/pyhealth/models/tpc.py b/pyhealth/models/tpc.py new file mode 100644 index 000000000..34db32bd3 --- /dev/null +++ b/pyhealth/models/tpc.py @@ -0,0 +1,395 @@ +from __future__ import annotations + +from typing import Dict, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pyhealth.datasets import SampleDataset +from pyhealth.models.base_model import BaseModel + + +class TPCBlock(nn.Module): + """One TPC layer: temporal conv, pointwise conv, and dense skip connections. + + The temporal branch applies per-feature causal convolution (independent weights per + feature, no cross-feature sharing). The pointwise branch mixes across all features + and static inputs at each timestep. Their outputs are concatenated with skip + connections, growing the feature dimension by Z (pointwise_channels) each layer. + + Args: + in_features: Number of input features R. + in_channels: Channels per feature C. 2 on layer 0 (value, decay), + temporal_channels +1 thereafter. + temporal_channels: Temporal conv output channels Y per feature. + pointwise_channels: Pointwise conv output channels Z. + kernel_size: Temporal conv kernel size k. + dilation: Dilation factor d. Set to layer_idx + 1 per layer. + main_dropout: Dropout after pointwise branch. + temporal_dropout: Dropout after temporal branch. + use_batchnorm: Batch normalisation after each branch. Default: True. + static_dim: Static feature dimension S injected into pointwise branch. + + Examples: + >>> block = TPCBlock( + ... in_features=101, in_channels=2, + ... temporal_channels=11, pointwise_channels=5, + ... kernel_size=5, dilation=1, + ... main_dropout=0.0, temporal_dropout=0.05, + ... static_dim=32, + ... ) + >>> x = torch.randn(8, 100, 101, 2) + >>> out = block(x, static=torch.randn(8, 32)) + >>> out.shape # (8, 100, 106, 12) + """ + + def __init__( + self, + *, + in_features: int, + in_channels: int, + temporal_channels: int, + pointwise_channels: int, + kernel_size: int, + dilation: int, + main_dropout: float, + temporal_dropout: float, + use_batchnorm: bool = True, + static_dim: int = 0, + ) -> None: + super().__init__() + self.in_features = int(in_features) + self.in_channels = int(in_channels) + self.temporal_channels = int(temporal_channels) + self.pointwise_channels = int(pointwise_channels) + self.kernel_size = int(kernel_size) + self.dilation = int(dilation) + self.use_batchnorm = bool(use_batchnorm) + self.static_dim = int(static_dim) + + # Temporal branch: grouped Conv1d => separate weights per feature. + self.temporal_conv = nn.Conv1d( + in_channels=self.in_features * self.in_channels, + out_channels=self.in_features * self.temporal_channels, + kernel_size=self.kernel_size, + dilation=self.dilation, + groups=self.in_features, + bias=True, + ) + self.bn_temporal = nn.BatchNorm1d(self.in_features * self.temporal_channels) + self.dropout_temporal = nn.Dropout(temporal_dropout) + + # Pointwise branch: Linear applied to each time step. + # Input to pointwise uses r = [x_value_skip, temporal_out] => channels (Y + 1). + point_in_dim = self.in_features * (self.temporal_channels + 1) + self.static_dim + self.pointwise = nn.Linear(point_in_dim, self.pointwise_channels) + self.bn_pointwise = nn.BatchNorm1d(self.pointwise_channels) + self.dropout_main = nn.Dropout(main_dropout) + + self.relu = nn.ReLU() + + def forward( + self, + x: torch.Tensor, + static: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass of one TPC block. + + Args: + x: (B, T, R, C_in) time-series input tensor. + static: (B, S) static feature tensor, or None. + + Returns: + torch.Tensor: output tensor of shape (B, T, R + Z, Y + 1). + """ + B, T, R, C = x.shape + if R != self.in_features or C != self.in_channels: + raise ValueError( + "TPCBlock got x shape " + f"{x.shape}, expected (B,T,{self.in_features},{self.in_channels})" + ) + + # === Temporal branch === + # reshape to (B, R*C, T) for grouped conv; causal left padding. + x_tc = x.permute(0, 2, 3, 1).reshape(B, R * C, T) + pad = (self.kernel_size - 1) * self.dilation + x_tc = F.pad(x_tc, (pad, 0), mode="constant", value=0.0) + t_out = self.temporal_conv(x_tc) # (B, R*Y, T) + if self.use_batchnorm: + t_out = self.bn_temporal(t_out) + t_out = self.dropout_temporal(t_out) + t_out = t_out.reshape(B, R, self.temporal_channels, T).permute( + 0, 3, 1, 2 + ) # (B,T,R,Y) + + # Skip: use the (current) value channel as 1 extra channel. + x_value = x[..., 0:1] # (B,T,R,1) + r = torch.cat([x_value, t_out], dim=-1) # (B,T,R,Y+1) + + # === Pointwise branch === + r_flat = r.reshape(B, T, R * (self.temporal_channels + 1)) + if static is not None: + static_rep = static.unsqueeze(1).expand(B, T, static.shape[-1]) + p_in = torch.cat([r_flat, static_rep], dim=-1) + else: + p_in = r_flat + + pw = self.pointwise(p_in) # (B,T,Z) + if self.use_batchnorm: + pw_bn = self.bn_pointwise(pw.reshape(B * T, -1)).reshape(B, T, -1) + else: + pw_bn = pw + pw_bn = self.dropout_main(pw_bn) + + # Broadcast pointwise outputs to (B,T,Z,Y+1) as new "features". + pw_feat = pw_bn.unsqueeze(-1).expand( + B, T, self.pointwise_channels, self.temporal_channels + 1 + ) + + h = torch.cat([r, pw_feat], dim=2) # (B,T,R+Z,Y+1) + return self.relu(h) + + +class TPC(BaseModel): + """Temporal Pointwise Convolution (TPC) for remaining ICU length-of-stay. + + Paper: Rocheteau et al., *Temporal Pointwise Convolutional Networks for Length of + Stay Prediction in the Intensive Care Unit* (ACM CHIL 2021). + + Note: + Predicts remaining LoS in days each ICU hour from the configured start hour. + Temporal convolutions are per-feature (no cross-feature weight sharing); + pointwise layers mix features. Inputs must include ``ts`` (values + decay) and + ``static``, as produced by ``RemainingLengthOfStayTPC_MIMIC4`` with + ``TPCTimeseriesProcessor`` and ``TPCStaticProcessor``. + + ``mode`` is set to ``'regression'`` because ``'regression_sequence'`` is not + a recognised PyHealth mode string. With labels, ``y_prob`` and ``y_true`` are + flattened over valid timesteps so padded zeros do not affect metrics. + + Args: + dataset: Dataset used to infer ``F`` (time-series features) and ``S`` (static + width) from fitted processors. + temporal_channels: Temporal conv output channels ``Y`` per feature. Default 11. + pointwise_channels: Pointwise output channels ``Z``. Default 5. + num_layers: Number of stacked TPC blocks ``N``. Default 8. + kernel_size: Temporal conv kernel size ``k``. Default 5. + main_dropout: Dropout after the pointwise branch. Default 0.0. + temporal_dropout: Dropout after the temporal branch. Default 0.05. + use_batchnorm: Whether to apply batch norm after each branch. Default True. + final_hidden: Hidden size of the two-layer prediction head. Default 36. + decay_clip_min_days: ``HardTanh`` minimum in days. Default ``1/48``. + decay_clip_max_days: ``HardTanh`` maximum in days. Default ``100``. + + Examples: + >>> from datetime import datetime + >>> from pyhealth.datasets import create_sample_dataset, get_dataloader + >>> import torch + >>> samples = [ + ... { + ... "patient_id": "p0", + ... "stay_id": "s0", + ... "ts": { + ... "prefill_start": datetime(2020, 1, 1, 0), + ... "icu_start": datetime(2020, 1, 1, 0), + ... "pred_start": datetime(2020, 1, 1, 5), + ... "pred_end": datetime(2020, 1, 1, 10), + ... "feature_itemids": ["A", "B"], + ... "long_df": { + ... "timestamp": [], "itemid": [], "value": [], "source": [] + ... }, + ... }, + ... "static": { + ... "gender": "M", "race": "WHITE", + ... "admission_location": "ER", "insurance": "Medicare", + ... "first_careunit": "MICU", "hour_of_admission": 0, + ... "admission_height": 170.0, "admission_weight": 80.0, + ... "gcs_eye": 4.0, "gcs_motor": 6.0, "gcs_verbal": 5.0, + ... "anchor_age": 65, + ... }, + ... "y": [2.0, 1.5, 1.0, 0.75, 0.5], + ... }, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={ + ... "ts": ("tpc_timeseries", {}), + ... "static": ("tpc_static", {}), + ... }, + ... output_schema={"y": ("regression_sequence", {})}, + ... dataset_name="test_tpc", + ... ) + >>> model = TPC(dataset=dataset) + >>> train_loader = get_dataloader(dataset, batch_size=1, shuffle=False) + >>> data_batch = next(iter(train_loader)) + >>> ret = model(**data_batch) + >>> print(ret) + { + 'loss': tensor(0.4231, grad_fn=), + 'y_prob': tensor([1.9842, 1.4921, 1.0013, 0.7506, 0.5001]), + 'y_true': tensor([2.0000, 1.5000, 1.0000, 0.7500, 0.5000]), + 'logit': tensor([[0.6851, 0.3972, 0.0013, -0.2876, -0.6912]]) + } + """ + + def __init__( + self, + dataset: SampleDataset, + *, + temporal_channels: int = 11, + pointwise_channels: int = 5, + num_layers: int = 8, + kernel_size: int = 5, + main_dropout: float = 0.0, + temporal_dropout: float = 0.05, + use_batchnorm: bool = True, + final_hidden: int = 36, + decay_clip_min_days: float = 1.0 / 48.0, + decay_clip_max_days: float = 100.0, + ) -> None: + super().__init__(dataset=dataset) + assert "ts" in self.feature_keys and "static" in self.feature_keys, ( + "TPC expects dataset.input_schema to contain 'ts' and 'static'." + ) + assert len(self.label_keys) == 1, "TPC currently supports a single label key." + self.label_key = self.label_keys[0] + + # Hardcoded: label processor is "regression_sequence" but BaseModel mode must be + # "regression" so metrics and Trainer wiring match the regression path. + self.mode = "regression" + + self.temporal_channels = int(temporal_channels) + self.pointwise_channels = int(pointwise_channels) + self.num_layers = int(num_layers) + self.kernel_size = int(kernel_size) + self.use_batchnorm = bool(use_batchnorm) + self.final_hidden = int(final_hidden) + + self.min_days = float(decay_clip_min_days) + self.max_days = float(decay_clip_max_days) + + # We infer feature/static dimensions from the dataset processors. + ts_proc = dataset.input_processors["ts"] + static_proc = dataset.input_processors["static"] + self.F = ts_proc.size() + self.S = static_proc.size() + + # Stack TPC blocks; feature dimension grows by Z each layer. + blocks = [] + in_features = self.F + in_channels = 2 # (value, decay) + for layer_idx in range(self.num_layers): + blocks.append( + TPCBlock( + in_features=in_features, + in_channels=in_channels, + temporal_channels=self.temporal_channels, + pointwise_channels=self.pointwise_channels, + kernel_size=self.kernel_size, + dilation=layer_idx + 1, + main_dropout=main_dropout, + temporal_dropout=temporal_dropout, + use_batchnorm=self.use_batchnorm, + static_dim=self.S, + ) + ) + # after first block, channels become (Y+1) + in_channels = self.temporal_channels + 1 + in_features = in_features + self.pointwise_channels + self.blocks = nn.ModuleList(blocks) + + # Final per-time-step head (2-layer pointwise MLP). + final_in = in_features * (self.temporal_channels + 1) + self.S + self.head_fc1 = nn.Linear(final_in, self.final_hidden) + self.head_relu = nn.ReLU() + self.head_fc2 = nn.Linear(self.final_hidden, 1) + + self.hardtanh = nn.Hardtanh(min_val=self.min_days, max_val=self.max_days) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward propagation. + + Args: + **kwargs: keyword arguments for the model. Must contain all feature + keys and the label key. Specifically: + - ts: (B, T, F, 2) padded time-series tensor produced by + TPCTimeseriesProcessor (value + decay channels). + - static: (B, S) static feature tensor produced by + TPCStaticProcessor. + - (optional): (B, T) padded target tensor; when + present, loss and y_true are added to the output. + + Returns: + Dict[str, torch.Tensor]: A dictionary with the following keys: + + - logit: a (B, T) tensor of raw log-space predictions. + - y_prob: predicted remaining LoS in days. When labels are + provided this is a masked 1-D tensor of real timesteps only; + otherwise (B, T). + - loss (when labels provided): scalar MSLE loss over + unpadded timesteps. + - y_true (when labels provided): masked 1-D tensor of + ground-truth LoS values aligned with y_prob. + """ + ts: torch.Tensor = kwargs["ts"].to(self.device) # (B,T,F,2) padded + static: torch.Tensor = kwargs["static"].to(self.device) # (B,S) + y_true: Optional[torch.Tensor] = None + if self.label_key in kwargs: + y_true = kwargs[self.label_key].to(self.device) # (B,T) padded + + B, T, F, C = ts.shape + if C != 2: + raise ValueError(f"TPC expects ts channels=2, got {C}.") + if F != self.F: + raise ValueError(f"TPC expects F={self.F} features, got {F}.") + + h = ts + for block in self.blocks: + h = block(h, static=static) # grows feature dimension, channels -> (Y+1) + + # Final predictions per hour. + h_flat = h.reshape(B, T, -1) # (B,T, features*channels) + static_rep = static.unsqueeze(1).expand(B, T, static.shape[-1]) + head_in = torch.cat([h_flat, static_rep], dim=-1) + + hidden = self.head_relu(self.head_fc1(head_in)) + logit = self.head_fc2(hidden).squeeze(-1) # (B,T) + + # Predict log(LoS) then exponentiate + clip (paper Appendix A). + y_pred = self.hardtanh(torch.exp(logit)) + + results: Dict[str, torch.Tensor] = { + "logit": logit, + "y_prob": y_pred, + } + + if y_true is not None: + # Padding uses 0; real labels are >= 1/48 day after task clipping. + mask = (y_true != 0).float() + # MSLE = mean((log(y_pred) - log(y_true))^2) over valid timesteps. + eps = 1e-8 + log_pred = torch.log(torch.clamp(y_pred, min=eps)) + log_true = torch.log(torch.clamp(y_true, min=eps)) + se = (log_pred - log_true) ** 2 + loss = (se * mask).sum() / torch.clamp(mask.sum(), min=1.0) + results["loss"] = loss + # Flatten valid positions so batch collation ignores padding. + results["y_prob"] = y_pred[mask.bool()] + results["y_true"] = y_true[mask.bool()] + + return results + + def forward_from_embedding(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward pass starting from feature embeddings. + + TPC takes dense numeric tensors directly (no token-embedding step), so + this method just calls the forward method. + + Args: + **kwargs: same keyword arguments as forward(). + + Returns: + Dict[str, torch.Tensor]: same output dictionary as forward(). + """ + return self.forward(**kwargs) diff --git a/pyhealth/processors/__init__.py b/pyhealth/processors/__init__.py index b48072270..99354f937 100644 --- a/pyhealth/processors/__init__.py +++ b/pyhealth/processors/__init__.py @@ -50,6 +50,9 @@ def get_processor(name: str): from .ignore_processor import IgnoreProcessor from .temporal_timeseries_processor import TemporalTimeseriesProcessor from .tuple_time_text_processor import TupleTimeTextProcessor +from .tpc_timeseries_processor import TPCTimeseriesProcessor +from .tpc_static_processor import TPCStaticProcessor +from .regression_sequence_processor import RegressionSequenceProcessor # Expose public API from .base_processor import ( @@ -79,4 +82,7 @@ def get_processor(name: str): "GraphProcessor", "AudioProcessor", "TupleTimeTextProcessor", + "TPCTimeseriesProcessor", + "TPCStaticProcessor", + "RegressionSequenceProcessor", ] diff --git a/pyhealth/processors/regression_sequence_processor.py b/pyhealth/processors/regression_sequence_processor.py new file mode 100644 index 000000000..5d560acd9 --- /dev/null +++ b/pyhealth/processors/regression_sequence_processor.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Any, Dict, Iterable + +import torch + +from . import register_processor +from .tensor_processor import TensorProcessor + + +@register_processor("regression_sequence") +class RegressionSequenceProcessor(TensorProcessor): + """Label processor for variable-length remaining LoS sequences. + + Wraps :class:`TensorProcessor` with ``dtype=torch.float32`` and a temporal + spatial dimension. Converts per-hour remaining LoS values from the + ``RemainingLengthOfStayTPC_MIMIC4`` task into a 1-D float tensor. + + Note: + The constructor takes no arguments; dtype and spatial layout are fixed for the + TPC label pipeline. Each ``process`` call maps a ``list[float]`` of remaining + LoS in days to a ``torch.float32`` tensor of shape ``(T,)`` (up to 332 steps for + a 336-hour stay with predictions from hour 5). + + Examples: + >>> processor = RegressionSequenceProcessor() + >>> processor.fit([], "y") + >>> out = processor.process([2.0, 1.5, 1.0, 0.5]) + >>> out.shape + torch.Size([4]) + >>> out.dtype + torch.float32 + """ + + def __init__(self) -> None: + """Initialise float32 regression labels with one spatial (time) axis.""" + super().__init__(dtype=torch.float32, spatial_dims=(True,)) + self._n_dim = 1 + + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: + """No fitting is required for regression-sequence labels.""" + return + + def size(self) -> int: + """Each label timestep is a scalar; processor width is 1.""" + return 1 diff --git a/pyhealth/processors/tpc_static_processor.py b/pyhealth/processors/tpc_static_processor.py new file mode 100644 index 000000000..7caa707f4 --- /dev/null +++ b/pyhealth/processors/tpc_static_processor.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import torch + +from . import register_processor +from .base_processor import FeatureProcessor + + +@register_processor("tpc_static") +class TPCStaticProcessor(FeatureProcessor): + """ + Feature processor for TPC static inputs + + Encodes the 12 static features for one ICU stay into a fixed-size 1D float tensor. + Categorical features are one-hot encoded over a vocabulary built during fit(). + Numeric features are scaled to [-1, 1] using per-feature 5th/95th percentiles + computed during fit(), then clipped to [clip_min, clip_max]. + + Categorical features (one-hot encoded): + gender, race (paper: ethnicity), admission_location, insurance, first_careunit + + Numeric features (robust scaled): + hour_of_admission, admission_height, admission_weight, + gcs_eye, gcs_motor, gcs_verbal, anchor_age + + Input format (dict) produced by RemainingLengthOfStayTPC_MIMIC4 task: + { + "gender": str, # "M" or "F" + "race": str, # categorical — paper calls this ethnicity + "admission_location": str, # categorical + "insurance": str, # categorical + "first_careunit": str, # categorical + "hour_of_admission": int, # 0-23 + "admission_height": float, # cm, or None if not recorded + "admission_weight": float, # kg, or None if not recorded + "gcs_eye": float, # 1-4, or None if not recorded + "gcs_motor": float, # 1-6, or None if not recorded + "gcs_verbal": float, # 1-5, or None if not recorded + "anchor_age": int, # age at ICU admission + } + + Args: + clip_min: Lower clip bound after scaling. Default: -4.0. + clip_max: Upper clip bound after scaling. Default: 4.0. + + Returns: + ``torch.FloatTensor`` of shape ``(S,)`` where ``S`` is the sum of one-hot + vocabulary sizes plus seven numeric features. + Missing categorical values map to the one-hot position. + Missing numeric values map to 0.0 (the scaled midpoint). + + Examples: + >>> processor = TPCStaticProcessor() + >>> samples = [{"static": {"gender": "M", "race": "WHITE", ...}}] + >>> processor.fit(samples, "static") + >>> out = processor.process({"gender": "M", "race": "WHITE", ...}) + >>> out.shape # (S,) where S depends on vocab sizes seen during fit + >>> out.dtype # torch.float32 + """ + + CATEGORICAL_KEYS: Tuple[str, ...] = ( + "gender", + "race", + "admission_location", + "insurance", + "first_careunit", + ) + NUMERIC_KEYS: Tuple[str, ...] = ( + "hour_of_admission", + "admission_height", + "admission_weight", + "gcs_eye", + "gcs_motor", + "gcs_verbal", + "anchor_age", + ) + + def __init__(self, clip_min: float = -4.0, clip_max: float = 4.0) -> None: + self.clip_min = float(clip_min) + self.clip_max = float(clip_max) + + self._cat_vocab: Dict[str, List[str]] = {k: [] for k in self.CATEGORICAL_KEYS} + self._cat_index: Dict[str, Dict[str, int]] = { + k: {} for k in self.CATEGORICAL_KEYS + } + + self._p5: Dict[str, float] = {} + self._p95: Dict[str, float] = {} + + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: + """Build vocabularies and numeric scaling bounds from ``samples``. + + Collects categorical tokens and numeric values for each static field, then + stores sorted vocabularies (with ````) and 5th/95th percentiles per + numeric feature. + """ + cat_values: Dict[str, set[str]] = {k: set() for k in self.CATEGORICAL_KEYS} + num_values: Dict[str, List[float]] = {k: [] for k in self.NUMERIC_KEYS} + + for sample in samples: + if field not in sample or sample[field] is None: + continue + s = sample[field] + if not isinstance(s, dict): + continue + + for k in self.CATEGORICAL_KEYS: + v = s.get(k, None) + if v is None: + continue + cat_values[k].add(str(v)) + + for k in self.NUMERIC_KEYS: + v = s.get(k, None) + if v is None: + continue + try: + num_values[k].append(float(v)) + except Exception: + continue + + for k in self.CATEGORICAL_KEYS: + vocab = sorted(cat_values[k]) + # reserve index 0 for unknown/missing + self._cat_vocab[k] = [""] + vocab + self._cat_index[k] = {tok: i for i, tok in enumerate(self._cat_vocab[k])} + + for k in self.NUMERIC_KEYS: + arr = np.asarray(num_values[k], dtype=float) + if arr.size == 0: + self._p5[k] = 0.0 + self._p95[k] = 1.0 + else: + self._p5[k] = float(np.nanpercentile(arr, 5)) + self._p95[k] = float(np.nanpercentile(arr, 95)) + + def _scale(self, key: str, x: float) -> float: + """Linearly scale ``x`` into ``[-1, 1]`` using stored percentiles. + + Values are clipped to ``[clip_min, clip_max]`` after scaling. If the + percentile range is degenerate, returns ``0.0``. + """ + + p5 = self._p5.get(key, 0.0) + p95 = self._p95.get(key, 1.0) + if p95 == p5: + return 0.0 + scaled = 2.0 * (x - p5) / (p95 - p5) - 1.0 + return float(np.clip(scaled, self.clip_min, self.clip_max)) + + def process(self, value: Dict[str, Any]) -> torch.Tensor: + """Encode ``value`` into a 1D float tensor. + + Categorical columns are one-hot encoded; numeric columns are robust-scaled. + """ + parts: List[float] = [] + + # Categorical one-hots. + for k in self.CATEGORICAL_KEYS: + vocab = self._cat_vocab.get(k, [""]) + idx_map = self._cat_index.get(k, {"": 0}) + one_hot = np.zeros(len(vocab), dtype=float) + raw = value.get(k, None) + tok = "" if raw is None else str(raw) + one_hot[idx_map.get(tok, 0)] = 1.0 + parts.extend(one_hot.tolist()) + + # Numeric robust scaling. + for k in self.NUMERIC_KEYS: + raw = value.get(k, None) + if raw is None: + parts.append(0.0) + continue + try: + parts.append(self._scale(k, float(raw))) + except Exception: + parts.append(0.0) + + return torch.tensor(parts, dtype=torch.float32) + + def size(self) -> int: + """Return total static dimension (one-hot widths plus numeric count).""" + cat_size = sum( + len(self._cat_vocab.get(k, [""])) for k in self.CATEGORICAL_KEYS + ) + return cat_size + len(self.NUMERIC_KEYS) + + def is_token(self) -> bool: + """Static features are continuous, not discrete tokens.""" + return False + + def schema(self) -> tuple[str, ...]: + """Output is a tuple of (value) tensor.""" + return ("value",) + + def dim(self) -> tuple[int, ...]: + """Output is a 1D tensor.""" + return (1,) + + def spatial(self) -> tuple[bool, ...]: + """Static features are not spatial.""" + return (False,) + diff --git a/pyhealth/processors/tpc_timeseries_processor.py b/pyhealth/processors/tpc_timeseries_processor.py new file mode 100644 index 000000000..d9e3e1042 --- /dev/null +++ b/pyhealth/processors/tpc_timeseries_processor.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import Any, Dict, Iterable, List, Sequence + +import numpy as np +import torch + +from . import register_processor +from .base_processor import FeatureProcessor + + +@register_processor("tpc_timeseries") +class TPCTimeseriesProcessor(FeatureProcessor): + """Hourly TPC time-series tensor from irregular MIMIC observations. + + Consumes the ``ts`` payload from ``RemainingLengthOfStayTPC_MIMIC4`` and emits + shape ``(T, F, 2)``: scaled values plus decay weights. Robust scaling + uses per-feature 5th/95th percentiles learned in :meth:`fit`. + + The payload dict contains ``prefill_start``, ``icu_start``, ``pred_start``, + ``pred_end``, ordered ``feature_itemids``, and ``long_df`` with keys + ``timestamp``, ``itemid``, ``value``, and ``source`` (``chartevents`` or + ``labevents``). + + Args: + sampling_rate: Resampling interval; must be one hour (paper default). + decay_base: Base ``b`` for ``decay = b ** hours_since_last_observation``. + clip_min: Lower clip after scaling. + clip_max: Upper clip after scaling. + + Returns: + From :meth:`process`, a ``torch.float32`` tensor ``(T, F, 2)``. Channel 0 is + forward-filled scaled values (0 before first observation). Channel 1 is the + decay trace (1 at a fresh sample, ``decay_base**j`` after ``j`` hours of + silence, 0 if never observed). + + Examples: + >>> from datetime import datetime, timedelta + >>> processor = TPCTimeseriesProcessor() + >>> prefill = datetime(2020, 1, 1, 0) + >>> payload = { + ... "prefill_start": prefill, + ... "icu_start": prefill, + ... "pred_start": prefill + timedelta(hours=5), + ... "pred_end": prefill + timedelta(hours=10), + ... "feature_itemids": ["A", "B"], + ... "long_df": { + ... "timestamp": [prefill], + ... "itemid": ["A"], + ... "value": [80.0], + ... "source": ["chartevents"], + ... } + ... } + >>> processor.fit([{"ts": payload}], "ts") + >>> out = processor.process(payload) + >>> out.shape + torch.Size([5, 2, 2]) + """ + + def __init__( + self, + sampling_rate: timedelta = timedelta(hours=1), + decay_base: float = 0.75, + clip_min: float = -4.0, + clip_max: float = 4.0, + ) -> None: + self.sampling_rate = sampling_rate + self.decay_base = float(decay_base) + self.clip_min = float(clip_min) + self.clip_max = float(clip_max) + + # Feature-dependent robust scaling parameters, keyed by itemid. + self._p5: Dict[str, float] = {} + self._p95: Dict[str, float] = {} + self._feature_itemids: List[str] = [] + + def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: + """Collect values per itemid and store 5th/95th percentile bounds.""" + values_by_item: Dict[str, List[float]] = {} + feature_itemids: List[str] | None = None + + for sample in samples: + if field not in sample or sample[field] is None: + continue + payload = sample[field] + if not isinstance(payload, dict): + continue + if feature_itemids is None: + feature_itemids = [str(x) for x in payload.get("feature_itemids", [])] + long_df = payload.get("long_df") or {} + itemids = long_df.get("itemid", []) + vals = long_df.get("value", []) + for itemid, v in zip(itemids, vals): + if v is None: + continue + try: + fv = float(v) + except Exception: + continue + key = str(itemid) + values_by_item.setdefault(key, []).append(fv) + + self._feature_itemids = feature_itemids or sorted(values_by_item.keys()) + for itemid in self._feature_itemids: + arr = np.asarray(values_by_item.get(itemid, []), dtype=float) + if arr.size == 0: + self._p5[itemid] = 0.0 + self._p95[itemid] = 1.0 + continue + self._p5[itemid] = float(np.nanpercentile(arr, 5)) + self._p95[itemid] = float(np.nanpercentile(arr, 95)) + + def _scale(self, itemid: str, x: float) -> float: + """Scale ``x`` with stored percentiles and clip to ``[clip_min, clip_max]``.""" + p5 = self._p5.get(itemid, 0.0) + p95 = self._p95.get(itemid, 1.0) + if p95 == p5: + return 0.0 + scaled = 2.0 * (x - p5) / (p95 - p5) - 1.0 + return float(np.clip(scaled, self.clip_min, self.clip_max)) + + def process(self, value: Dict[str, Any]) -> torch.Tensor: + """Build the hourly forward-filled tensor for one ICU stay.""" + prefill_start: datetime = value["prefill_start"] + pred_start: datetime = value["pred_start"] + pred_end: datetime = value["pred_end"] + feature_itemids: Sequence[str] = value["feature_itemids"] + long_df = value["long_df"] + + step_hours = int(self.sampling_rate.total_seconds() // 3600) + if step_hours != 1: + raise ValueError( + "TPCTimeseriesProcessor currently supports 1-hour sampling only." + ) + + total_steps = int((pred_end - prefill_start).total_seconds() // 3600) + if total_steps <= 0: + raise ValueError("Invalid time window for TPC time series.") + + start_idx = int((pred_start - prefill_start).total_seconds() // 3600) + pred_steps = int((pred_end - pred_start).total_seconds() // 3600) + if pred_steps <= 0: + raise ValueError("Invalid prediction window for TPC time series.") + + n_feat = len(feature_itemids) + sampled = np.full((total_steps, n_feat), np.nan, dtype=float) + observed = np.zeros((total_steps, n_feat), dtype=bool) + + col_index = {str(itemid): j for j, itemid in enumerate(feature_itemids)} + + ts_list = long_df.get("timestamp", []) + item_list = long_df.get("itemid", []) + val_list = long_df.get("value", []) + for ts, itemid, v in zip(ts_list, item_list, val_list): + if ts is None or itemid is None or v is None: + continue + itemid = str(itemid) + if itemid not in col_index: + continue + try: + t: datetime = ts + idx = int((t - prefill_start).total_seconds() // 3600) + if idx < 0 or idx >= total_steps: + continue + fv = self._scale(itemid, float(v)) + except Exception: + continue + j = col_index[itemid] + sampled[idx, j] = fv + observed[idx, j] = True + + values_ff = np.zeros((total_steps, n_feat), dtype=float) + decay = np.zeros((total_steps, n_feat), dtype=float) + for j in range(n_feat): + last_value = 0.0 + last_seen: int | None = None + for t in range(total_steps): + if observed[t, j] and not np.isnan(sampled[t, j]): + last_value = float(sampled[t, j]) + last_seen = t + values_ff[t, j] = last_value + decay[t, j] = 1.0 + else: + values_ff[t, j] = last_value + if last_seen is None: + decay[t, j] = 0.0 + else: + dt = t - last_seen + decay[t, j] = float(self.decay_base**dt) + + values_ff = values_ff[start_idx : start_idx + pred_steps] + decay = decay[start_idx : start_idx + pred_steps] + + out = np.stack([values_ff, decay], axis=-1) + return torch.tensor(out, dtype=torch.float32) + + def size(self) -> int: + """Number of time-series features (length of ``feature_itemids``).""" + return len(self._feature_itemids) + + def is_token(self) -> bool: + """Continuous values, not discrete tokens.""" + return False + + def schema(self) -> tuple[str, ...]: + """Schema tag for the value channel (decay rides alongside).""" + return ("value",) + + def dim(self) -> tuple[int, ...]: + """Three-dimensional output: time, feature, channel.""" + return (3,) + + def spatial(self) -> tuple[bool, ...]: + """Only the leading time axis is spatial.""" + return (True, False, False) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..d787e7df5 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -29,6 +29,7 @@ LengthOfStayPredictionOMOP, ) from .length_of_stay_stagenet_mimic4 import LengthOfStayStageNetMIMIC4 +from .remaining_length_of_stay_tpc_mimic4 import RemainingLengthOfStayTPC_MIMIC4 from .medical_coding import MIMIC3ICD9Coding from .medical_transcriptions_classification import MedicalTranscriptionsClassification from .mortality_prediction import ( diff --git a/pyhealth/tasks/remaining_length_of_stay_tpc_mimic4.py b/pyhealth/tasks/remaining_length_of_stay_tpc_mimic4.py new file mode 100644 index 000000000..ba98feea8 --- /dev/null +++ b/pyhealth/tasks/remaining_length_of_stay_tpc_mimic4.py @@ -0,0 +1,432 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple + +import polars as pl + +from .base_task import BaseTask + + +@dataclass(frozen=True) +class _TPCWindow: + """ICU stay time boundaries used for prefill, resampling, and labels.""" + + prefill_start: datetime + icu_start: datetime + pred_start: datetime + pred_end: datetime + + +class RemainingLengthOfStayTPC_MIMIC4(BaseTask): + """Paper-setting remaining ICU LoS (MIMIC-IV) for TPC. + + This task matches the *paper setting* (Rocheteau et al., CHIL 2021; + `arXiv:2007.09483 `__, PDF + `https://arxiv.org/pdf/2007.09483` ) for MIMIC-IV: + predict *remaining* ICU LoS at hourly intervals, starting at `start_hour` hours + after ICU admission, truncated to the first `max_hours` hours of the stay. + + Output samples are designed for a dedicated TPC processor which performs: + hourly resampling, forward-fill, and decay-indicator construction, while allowing + pre-ICU values to be used for forward-fill and then removed. + + Notes: + - We intentionally do **not** include diagnoses for MIMIC-IV (paper §4.2). + - Time series use ``chartevents`` / ``labevents`` with user-configured itemids. + + Required dataset tables: + - patients, admissions, icustays + - chartevents and/or labevents (depending on provided itemids) + + Attributes: + input_schema (Dict[str, Any]): Built per instance from the itemid lists passed + to the constructor. Maps feature keys to ``(processor_name, kwargs)`` pairs. + + .. code-block:: python + + { + "ts": ("tpc_timeseries", {}), + "static": ("tpc_static", {}), + } + + output_schema (Dict[str, Any]): Declared next to ``input_schema`` with the same + pattern for labels. + + .. code-block:: python + + { + "y": ("regression_sequence", {}), + } + """ + + task_name: str = "RemainingLengthOfStayTPC_MIMIC4" + + # Common MIMIC-IV itemids (Chartevents) for static proxies (optional). + # These are intentionally kept minimal and can be overridden by the user. + DEFAULT_HEIGHT_ITEMIDS: ClassVar[List[str]] = ["226730"] # height (cm) + DEFAULT_WEIGHT_ITEMIDS: ClassVar[List[str]] = ["226512"] # weight (kg) + DEFAULT_GCS_EYE_ITEMIDS: ClassVar[List[str]] = ["220739"] + DEFAULT_GCS_VERBAL_ITEMIDS: ClassVar[List[str]] = ["223900"] + DEFAULT_GCS_MOTOR_ITEMIDS: ClassVar[List[str]] = ["223901"] + + def __init__( + self, + *, + start_hour: int = 5, + max_hours: int = 14 * 24, + pre_icu_hours: int = 24, + min_icu_hours: int = 5, + chartevent_itemids: Optional[Sequence[str]] = None, + labevent_itemids: Optional[Sequence[str]] = None, + static_height_itemids: Optional[Sequence[str]] = None, + static_weight_itemids: Optional[Sequence[str]] = None, + static_gcs_eye_itemids: Optional[Sequence[str]] = None, + static_gcs_verbal_itemids: Optional[Sequence[str]] = None, + static_gcs_motor_itemids: Optional[Sequence[str]] = None, + ) -> None: + self.start_hour = int(start_hour) + self.max_hours = int(max_hours) + self.pre_icu_hours = int(pre_icu_hours) + self.min_icu_hours = int(min_icu_hours) + + self.chartevent_itemids = [str(x) for x in (chartevent_itemids or [])] + self.labevent_itemids = [str(x) for x in (labevent_itemids or [])] + self.feature_itemids = self.chartevent_itemids + self.labevent_itemids + + self.static_height_itemids = [ + str(x) for x in (static_height_itemids or self.DEFAULT_HEIGHT_ITEMIDS) + ] + self.static_weight_itemids = [ + str(x) for x in (static_weight_itemids or self.DEFAULT_WEIGHT_ITEMIDS) + ] + self.static_gcs_eye_itemids = [ + str(x) for x in (static_gcs_eye_itemids or self.DEFAULT_GCS_EYE_ITEMIDS) + ] + self.static_gcs_verbal_itemids = [ + str(x) + for x in (static_gcs_verbal_itemids or self.DEFAULT_GCS_VERBAL_ITEMIDS) + ] + self.static_gcs_motor_itemids = [ + str(x) for x in (static_gcs_motor_itemids or self.DEFAULT_GCS_MOTOR_ITEMIDS) + ] + + # Input/Output schemas use explicit processor registrations. + # - ts: custom TPC time-series processor will produce (T, F, 2) + # - static: custom TPC static processor will encode/scale features + # - y: custom regression-sequence label processor will produce (T,) + self.input_schema: Dict[str, Any] = { + "ts": ("tpc_timeseries", {}), + "static": ("tpc_static", {}), + } + self.output_schema: Dict[str, Any] = { + "y": ("regression_sequence", {}), + } + + def _get_admission_for_stay(self, patient: Any, hadm_id: str) -> Optional[Any]: + """Return the admissions row for ``hadm_id``, or ``None`` if missing.""" + admissions = patient.get_events( + event_type="admissions", + filters=[("hadm_id", "==", hadm_id)], + ) + if not admissions: + return None + # Choose the first match (should be unique). + return admissions[0] + + def _build_window( + self, icu_start: datetime, icu_end: datetime + ) -> Optional[_TPCWindow]: + """Compute the prediction window for a single ICU stay.""" + if icu_end <= icu_start: + return None # malformed data + duration_hours = (icu_end - icu_start).total_seconds() / 3600.0 + if duration_hours < self.min_icu_hours: + return None # shorter than minimum ICU length + + prefill_start = icu_start - timedelta(hours=self.pre_icu_hours) + pred_start = icu_start + timedelta(hours=self.start_hour) + pred_end_cap = icu_start + timedelta(hours=self.max_hours) + pred_end = min(icu_end, pred_end_cap) + if pred_end <= pred_start: + return None + + return _TPCWindow( + prefill_start=prefill_start, + icu_start=icu_start, + pred_start=pred_start, + pred_end=pred_end, + ) + + def _extract_static_from_events( + self, + patient: Any, + *, + stay: Any, + admission: Optional[Any], + icu_start: datetime, + prefill_start: datetime, + ) -> Dict[str, Any]: + """Extract the 12 static features from MIMIC-IV tables for one ICU stay. + + Produces a raw dict for ``TPCStaticProcessor``. + + Args: + patient: Patient record with PyHealth event accessors. + stay: ``icustays`` row for this ICU episode. + admission: Matching ``admissions`` row, or ``None``. + icu_start: ICU ``intime`` (used for age and chart window). + prefill_start: Start of the pre-ICU lookback for early vitals. + + Returns: + Mapping of static field names to raw values before encoding. + + Note: + Height, weight, and GCS proxies use the first chart value in + ``[prefill_start, icu_start + 1 hour]``. + """ + static: Dict[str, Any] = {} + + # Table 6 (paper) core fields. + demographics = patient.get_events(event_type="patients") + if demographics: + demo = demographics[0] + static["gender"] = getattr(demo, "gender", None) + # static["anchor_age"] = getattr(demo, "anchor_age", None) + # Compute age at ICU admission + # Age from ICU intime calendar year vs patient anchor_year. + try: + anchor_age = int(demo.anchor_age) + anchor_year = int(demo.anchor_year) + static["anchor_age"] = anchor_age + (icu_start.year - anchor_year) + except Exception: + static["anchor_age"] = None + + + if admission is not None: + static["race"] = getattr(admission, "race", None) + static["admission_location"] = getattr( + admission, "admission_location", None + ) + static["insurance"] = getattr(admission, "insurance", None) + + static["first_careunit"] = getattr(stay, "first_careunit", None) + static["hour_of_admission"] = int(icu_start.hour) + + # Approximate admission height/weight/GCS from early chartevents: first value in + # [prefill_start, icu_start + 1h]. + early_end = icu_start + timedelta(hours=1) + ce_df = patient.get_events( + event_type="chartevents", + start=prefill_start, + end=early_end, + return_df=True, + ) + if ce_df is not None and ce_df.height > 0: + ce_df = ce_df.select( + pl.col("timestamp"), + pl.col("chartevents/itemid").cast(pl.Utf8), + pl.col("chartevents/valuenum").cast(pl.Float64), + ).drop_nulls(["timestamp", "chartevents/itemid"]) + if ce_df.height > 0: + # Take first value per itemid by time. + ce_df = ce_df.sort("timestamp") + + def first_item_value(itemids: Sequence[str]) -> Optional[float]: + """First matching ``valuenum`` in ``ce_df``, else ``None``.""" + sub = ce_df.filter( + pl.col("chartevents/itemid").is_in([str(x) for x in itemids]) + ) + if sub.height == 0: + return None + # first non-null value + sub = sub.drop_nulls(["chartevents/valuenum"]) + if sub.height == 0: + return None + return float(sub["chartevents/valuenum"][0]) + + static["admission_height"] = first_item_value( + self.static_height_itemids + ) + static["admission_weight"] = first_item_value( + self.static_weight_itemids + ) + static["gcs_eye"] = first_item_value(self.static_gcs_eye_itemids) + static["gcs_verbal"] = first_item_value(self.static_gcs_verbal_itemids) + static["gcs_motor"] = first_item_value(self.static_gcs_motor_itemids) + + return static + + def _extract_timeseries( + self, + patient: Any, + *, + prefill_start: datetime, + pred_end: datetime, + stay_id: str, + ) -> Tuple[List[datetime], Dict[str, List[Any]]]: + """Return irregular observations for requested itemids. + + The returned dataframe is in long format: (timestamp, itemid, value, source). + """ + frames: List[pl.DataFrame] = [] + + if self.chartevent_itemids: + ce = patient.get_events( + event_type="chartevents", + start=prefill_start, + end=pred_end, + filters=[("stay_id", "==", stay_id)], + return_df=True, + ) + if ce is not None and ce.height > 0: + ce = ( + ce.select( + pl.col("timestamp"), + pl.col("chartevents/itemid").cast(pl.Utf8).alias("itemid"), + pl.col("chartevents/valuenum").cast(pl.Float64).alias("value"), + ) + .filter(pl.col("itemid").is_in(self.chartevent_itemids)) + .drop_nulls(["timestamp", "itemid"]) + .with_columns(pl.lit("chartevents").alias("source")) + ) + if ce.height > 0: + frames.append(ce) + + if self.labevent_itemids: + le = patient.get_events( + event_type="labevents", + start=prefill_start, + end=pred_end, + return_df=True, + ) + if le is not None and le.height > 0: + le = ( + le.select( + pl.col("timestamp"), + pl.col("labevents/itemid").cast(pl.Utf8).alias("itemid"), + pl.col("labevents/valuenum").cast(pl.Float64).alias("value"), + pl.col("labevents/hadm_id").cast(pl.Utf8).alias("hadm_id"), + ) + .filter(pl.col("itemid").is_in(self.labevent_itemids)) + .drop_nulls(["timestamp", "itemid"]) + .with_columns(pl.lit("labevents").alias("source")) + ) + if le.height > 0: + frames.append(le.drop("hadm_id")) + + if not frames: + return [], {"timestamp": [], "itemid": [], "value": [], "source": []} + + df = pl.concat(frames, how="vertical").sort("timestamp") + timestamps = df["timestamp"].to_list() + # Convert to a pure-Python payload for robust pickling in task caching. + payload = df.select( + "timestamp", "itemid", "value", "source" + ).to_dict(as_series=False) + return timestamps, payload + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Build TPC training samples for every qualifying ICU stay on ``patient``. + + Args: + patient: PyHealth patient with MIMIC-IV tables attached. + + Returns: + List of sample dicts with ``patient_id``, ``stay_id``, ``hadm_id``, + ``static``, ``ts``, and ``y``. Empty when no stay passes cohort and + data-availability checks. + """ + # Adult cohort: anchor age/year fetched once per patient. + demographics = patient.get_events(event_type="patients") + if not demographics: + return [] + try: + anchor_age = int(demographics[0].anchor_age) + anchor_year = int(demographics[0].anchor_year) + except Exception: + return [] + + + stays = patient.get_events(event_type="icustays") + if not stays: + return [] + + samples: List[Dict[str, Any]] = [] + for stay in stays: + try: + icu_start: datetime = stay.timestamp + icu_end: datetime = datetime.strptime(stay.outtime, "%Y-%m-%d %H:%M:%S") + except Exception: + continue + # cohort filter: age at this specific ICU admission + if anchor_age + (icu_start.year - anchor_year) < 18: + continue + + + window = self._build_window(icu_start, icu_end) + if window is None: + continue + + stay_id = str(getattr(stay, "stay_id", "")) + hadm_id = str(getattr(stay, "hadm_id", "")) + if not stay_id or not hadm_id: + continue + + admission = self._get_admission_for_stay(patient, hadm_id) + + # Require at least one requested feature id (otherwise model has no inputs). + if len(self.feature_itemids) == 0: + continue + + ts_timestamps, ts_long_payload = self._extract_timeseries( + patient, + prefill_start=window.prefill_start, + pred_end=window.pred_end, + stay_id=stay_id, + ) + if not ts_timestamps: + continue + + static = self._extract_static_from_events( + patient, + stay=stay, + admission=admission, + icu_start=window.icu_start, + prefill_start=window.prefill_start, + ) + + # Labels: remaining ICU LoS (days) per hour in [pred_start, pred_end). + total_hours = int( + (window.pred_end - window.pred_start).total_seconds() // 3600 + ) + if total_hours <= 0: + continue + y = [] + for h in range(total_hours): + t = window.pred_start + timedelta(hours=h) + rem_days = (icu_end - t).total_seconds() / 86400.0 + # Remaining LoS is positive by construction; enforce the same lower clip + # used by the paper's output clipping (30 minutes = 1/48 days). + y.append(max(rem_days, 1.0 / 48.0)) + + sample: Dict[str, Any] = { + "patient_id": patient.patient_id, + "stay_id": stay_id, + "hadm_id": hadm_id, + "static": static, + "ts": { + "prefill_start": window.prefill_start, + "icu_start": window.icu_start, + "pred_start": window.pred_start, + "pred_end": window.pred_end, + "feature_itemids": self.feature_itemids, + "long_df": ts_long_payload, # dict[str, list] + }, + "y": y, + } + samples.append(sample) + + return samples + diff --git a/test_notebook.ipynb b/test_notebook.ipynb deleted file mode 100644 index 5630c93a1..000000000 --- a/test_notebook.ipynb +++ /dev/null @@ -1,162 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "138ff2cb", - "metadata": {}, - "source": [ - "## Notebook to test changes to PyHealth Code for TPC-LoS" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6a25c84a", - "metadata": {}, - "outputs": [], - "source": [ - "from pyhealth.datasets import MIMIC4Dataset\n", - "import json\n", - "from pathlib import Path" - ] - }, - { - "cell_type": "markdown", - "id": "ea441ac1", - "metadata": {}, - "source": [ - "### Load the data with modified MIMIC4Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3261cf9d", - "metadata": {}, - "outputs": [], - "source": [ - "DATASETS_ROOT = Path(\"~/Desktop/CS 598 - Deep Learning for Healthcare/datasets\")\n", - "\n", - "# MIMIC-IV Demo Dataset\n", - "MIMIC4_ROOT = DATASETS_ROOT / \"mimic-iv-demo\" / \"2.2\"\n", - "\n", - "MIMIC4_TABLES = [\n", - " \"diagnoses_icd\",\n", - " \"prescriptions\",\n", - " \"labevents\",\n", - " \"chartevents\",\n", - "]\n", - "\n", - "mimic4_base = MIMIC4Dataset(\n", - " ehr_root=str(MIMIC4_ROOT),\n", - " ehr_tables=MIMIC4_TABLES,\n", - " dev=True, # use 1000 patients while exploring\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8d836659", - "metadata": {}, - "outputs": [], - "source": [ - "#checking stats\n", - "print(\"Number of MIMIC patients:\", len(mimic4_base.unique_patient_ids))\n", - "# inspecting one MIMIC patient and its loaded event partitions\n", - "\n", - "first_patient = next(mimic4_base.iter_patients())\n", - "\n", - "print(\"Patient type:\", type(first_patient)) #info about patients\n", - "print(\"Patient fields:\", list(vars(first_patient).keys()))\n", - "print(\"Patient dict:\", vars(first_patient))\n", - "\n", - "\n", - "#see event partitions pyhealth created, which tables actually landed inside patient object, how each partition is stored\n", - "print(\"\\nType of event_type_partitions:\", type(first_patient.event_type_partitions))\n", - "print(\"Partition keys:\", list(first_patient.event_type_partitions.keys()))\n", - "\n", - "for key, value in first_patient.event_type_partitions.items():\n", - " print(f\"\\nPartition: {key}\")\n", - " print(\"Type:\", type(value))\n", - " if isinstance(value, list):\n", - " print(\"Length:\", len(value))\n", - " if len(value) > 0:\n", - " print(\"First item type:\", type(value[0]))\n", - " print(\"First item:\", value[0])\n", - " elif isinstance(value, dict):\n", - " print(\"Keys sample:\", list(value.keys())[:10])\n", - " else:\n", - " print(value)\n", - "\n", - "print(\"PARTITION SUMMARY\")\n", - "for key, value in first_patient.event_type_partitions.items():\n", - " if isinstance(value, list):\n", - " print(f\"{key}: list with {len(value)} events\")\n", - " elif isinstance(value, dict):\n", - " print(f\"{key}: dict with {len(value)} keys\")\n", - " else:\n", - " print(f\"{key}: {type(value)}\")\n", - "\n", - "#inspecting more patients\n", - "print(\"CHECKING MULTIPLE PATIENTS\")\n", - "for i, patient in enumerate(mimic4_base.iter_patients()):\n", - " print(f\"\\nPatient {i}\")\n", - " print(\"Patient fields:\", list(vars(patient).keys()))\n", - " print(\"Partition keys:\", list(patient.event_type_partitions.keys()))\n", - " if i == 2:\n", - " break" - ] - }, - { - "cell_type": "markdown", - "id": "9e22478a", - "metadata": {}, - "source": [ - "#3# Testing chartevents works for specific patient" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d4dd0916", - "metadata": {}, - "outputs": [], - "source": [ - "patient = mimic4_base.get_patient(mimic4_base.unique_patient_ids[0])\n", - "events = patient.get_events(event_type=\"chartevents\")\n", - "assert len(events) > 0\n", - "print(events[0].itemid, events[0].valuenum, events[0].timestamp)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "52edc32d", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "pyhealth-tpc", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.13" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tests/core/test_tpc.py b/tests/core/test_tpc.py new file mode 100644 index 000000000..7ca782025 --- /dev/null +++ b/tests/core/test_tpc.py @@ -0,0 +1,336 @@ +import unittest +from datetime import datetime, timedelta + +import numpy as np +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import TPC + + +def _make_ts_payload( + n_features: int = 2, + pred_hours: int = 5, + seed: int = 0, +) -> dict: + """Build a minimal raw ts payload as produced by RemainingLengthOfStayTPC_MIMIC4.""" + rng = np.random.RandomState(seed) + prefill_start = datetime(2020, 1, 1, 0, 0, 0) + pred_start = prefill_start + timedelta(hours=5) + pred_end = pred_start + timedelta(hours=pred_hours) + + feature_itemids = [str(i) for i in range(n_features)] + + # One observation per feature at hour 0 (prefill window) + timestamps = [prefill_start + timedelta(hours=0)] * n_features + itemids = feature_itemids[:] + values = rng.uniform(1.0, 10.0, size=n_features).tolist() + + return { + "prefill_start": prefill_start, + "icu_start": prefill_start, + "pred_start": pred_start, + "pred_end": pred_end, + "feature_itemids": feature_itemids, + "long_df": { + "timestamp": timestamps, + "itemid": itemids, + "value": values, + "source": ["chartevents"] * n_features, + }, + } + + +def _make_static_dict() -> dict: + """Build a minimal static feature dict as produced by RemainingLengthOfStayTPC_MIMIC4.""" + return { + "gender": "M", + "race": "WHITE", + "admission_location": "EMERGENCY ROOM", + "insurance": "Medicare", + "first_careunit": "Medical Intensive Care Unit (MICU)", + "hour_of_admission": 8, + "admission_height": 170.0, + "admission_weight": 80.0, + "gcs_eye": 4.0, + "gcs_motor": 6.0, + "gcs_verbal": 5.0, + "anchor_age": 65, + } + + +def _make_dataset( + n_samples: int = 4, + n_features: int = 2, + pred_hours: int = 5, +) -> "SampleDataset": + """Create a minimal in-memory SampleDataset for TPC testing.""" + samples = [] + for i in range(n_samples): + ts = _make_ts_payload(n_features=n_features, pred_hours=pred_hours, seed=i) + static = _make_static_dict() + # Alternate gender/careunit for categorical vocab diversity + static["gender"] = "M" if i % 2 == 0 else "F" + static["first_careunit"] = ( + "Medical Intensive Care Unit (MICU)" + if i % 2 == 0 + else "Surgical Intensive Care Unit (SICU)" + ) + y = [max(float(pred_hours - h) / 24.0, 1.0 / 48.0) for h in range(pred_hours)] + samples.append( + { + "patient_id": f"p{i}", + "stay_id": f"s{i}", + "hadm_id": f"h{i}", + "ts": ts, + "static": static, + "y": y, + } + ) + + return create_sample_dataset( + samples=samples, + input_schema={"ts": ("tpc_timeseries", {}), "static": ("tpc_static", {})}, + output_schema={"y": ("regression_sequence", {})}, + dataset_name="test_tpc", + in_memory=True, + ) + + +class TestTPC(unittest.TestCase): + """Unit tests for the TPC model.""" + + def setUp(self): + """Set up shared dataset and model for all tests.""" + torch.manual_seed(0) + np.random.seed(0) + + self.n_features = 2 + self.pred_hours = 5 + self.batch_size = 2 + + self.dataset = _make_dataset( + n_samples=4, + n_features=self.n_features, + pred_hours=self.pred_hours, + ) + + self.model = TPC( + dataset=self.dataset, + temporal_channels=4, + pointwise_channels=3, + num_layers=2, + kernel_size=2, + main_dropout=0.0, + temporal_dropout=0.0, + use_batchnorm=False, # off for small batch sizes in tests + final_hidden=8, + ) + + self.loader = get_dataloader( + self.dataset, batch_size=self.batch_size, shuffle=False + ) + + def test_model_initialisation(self): + """Model initialises correctly and infers F and S from processors.""" + self.assertIsInstance(self.model, TPC) + self.assertEqual(self.model.F, self.n_features) + self.assertGreater(self.model.S, 0) + self.assertEqual(self.model.num_layers, 2) + self.assertEqual(self.model.temporal_channels, 4) + self.assertEqual(self.model.pointwise_channels, 3) + self.assertIn("ts", self.model.feature_keys) + self.assertIn("static", self.model.feature_keys) + self.assertEqual(len(self.model.label_keys), 1) + self.assertEqual(self.model.label_keys[0], "y") + self.assertEqual(self.model.mode, "regression") + + def test_blocks_count(self): + """Number of TPCBlocks equals num_layers.""" + self.assertEqual(len(self.model.blocks), self.model.num_layers) + + def test_feature_dimension_growth(self): + """Feature dimension grows by pointwise_channels each layer.""" + expected_in_features = self.n_features + for i, block in enumerate(self.model.blocks): + self.assertEqual(block.in_features, expected_in_features) + expected_in_features += self.model.pointwise_channels + + + def test_forward_output_keys(self): + """Forward pass returns required output keys.""" + batch = next(iter(self.loader)) + with torch.no_grad(): + ret = self.model(**batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + + def test_forward_loss_is_scalar(self): + """Loss is a scalar tensor.""" + batch = next(iter(self.loader)) + with torch.no_grad(): + ret = self.model(**batch) + self.assertEqual(ret["loss"].dim(), 0) + self.assertTrue(torch.isfinite(ret["loss"])) + + def test_forward_y_prob_y_true_are_1d_masked(self): + """y_prob and y_true are 1D and contain only real (unpadded) timesteps.""" + batch = next(iter(self.loader)) + with torch.no_grad(): + ret = self.model(**batch) + + self.assertEqual(ret["y_prob"].dim(), 1) + self.assertEqual(ret["y_true"].dim(), 1) + # Both should have the same length (all real timesteps in the batch) + self.assertEqual(ret["y_prob"].shape[0], ret["y_true"].shape[0]) + # No padded zeros in y_true + self.assertTrue((ret["y_true"] > 0).all()) + + def test_forward_logit_shape(self): + """Logit has shape (B, T_max) — full padded output.""" + batch = next(iter(self.loader)) + with torch.no_grad(): + ret = self.model(**batch) + + self.assertEqual(ret["logit"].dim(), 2) + self.assertEqual(ret["logit"].shape[0], self.batch_size) + + def test_forward_y_prob_bounds(self): + """Predictions are within HardTanh bounds [1/48, 100] days.""" + batch = next(iter(self.loader)) + with torch.no_grad(): + ret = self.model(**batch) + + self.assertTrue((ret["y_prob"] >= 1.0 / 48.0).all()) + self.assertTrue((ret["y_prob"] <= 100.0).all()) + + def test_forward_without_labels(self): + """Forward pass without labels returns y_prob and logit but no loss or y_true.""" + batch = next(iter(self.loader)) + batch_no_labels = {k: v for k, v in batch.items() if k != "y"} + + with torch.no_grad(): + ret = self.model(**batch_no_labels) + + self.assertIn("y_prob", ret) + self.assertIn("logit", ret) + self.assertNotIn("loss", ret) + self.assertNotIn("y_true", ret) + + # ------------------------------------------------------------------ + # Backward pass + # ------------------------------------------------------------------ + + def test_backward_pass(self): + """Loss backward populates gradients on model parameters.""" + batch = next(iter(self.loader)) + ret = self.model(**batch) + ret["loss"].backward() + + has_grad = any( + p.requires_grad and p.grad is not None + for p in self.model.parameters() + ) + self.assertTrue(has_grad, "No parameters have gradients after backward pass.") + + def test_loss_decreases_after_one_step(self): + """A single gradient step reduces the training loss.""" + optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-2) + batch = next(iter(self.loader)) + + ret_before = self.model(**batch) + loss_before = ret_before["loss"].item() + + optimizer.zero_grad() + ret_before["loss"].backward() + optimizer.step() + + with torch.no_grad(): + ret_after = self.model(**batch) + loss_after = ret_after["loss"].item() + + self.assertLess(loss_after, loss_before) + + def test_custom_hyperparameters(self): + """Model initialises and runs forward with non-default hyperparameters.""" + model = TPC( + dataset=self.dataset, + temporal_channels=8, + pointwise_channels=4, + num_layers=3, + kernel_size=3, + main_dropout=0.1, + temporal_dropout=0.1, + use_batchnorm=False, + final_hidden=16, + ) + + batch = next(iter(self.loader)) + with torch.no_grad(): + ret = model(**batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertTrue(torch.isfinite(ret["loss"])) + + def test_tpc_block_output_shape(self): + """TPCBlock output has correct shape (B, T, R+Z, Y+1).""" + from pyhealth.models.tpc import TPCBlock + + B, T, R, C = 2, 10, self.n_features, 2 + Y, Z, S = 4, 3, self.model.S + + block = TPCBlock( + in_features=R, + in_channels=C, + temporal_channels=Y, + pointwise_channels=Z, + kernel_size=2, + dilation=1, + main_dropout=0.0, + temporal_dropout=0.0, + use_batchnorm=False, + static_dim=S, + ) + + x = torch.randn(B, T, R, C) + static = torch.randn(B, S) + + with torch.no_grad(): + out = block(x, static=static) + + self.assertEqual(out.shape, (B, T, R + Z, Y + 1)) + + def test_tpc_block_wrong_input_raises(self): + """TPCBlock raises ValueError when input shape does not match.""" + from pyhealth.models.tpc import TPCBlock + + block = TPCBlock( + in_features=4, + in_channels=2, + temporal_channels=4, + pointwise_channels=3, + kernel_size=2, + dilation=1, + main_dropout=0.0, + temporal_dropout=0.0, + use_batchnorm=False, + static_dim=0, + ) + + x_wrong = torch.randn(2, 10, 3, 2) # R=3 but block expects R=4 + with self.assertRaises(ValueError): + block(x_wrong) + + def test_msle_loss_is_non_negative(self): + """MSLE loss is always >= 0.""" + batch = next(iter(self.loader)) + with torch.no_grad(): + ret = self.model(**batch) + self.assertGreaterEqual(ret["loss"].item(), 0.0) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file