diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8d9a59d21..3765f8e57 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -234,6 +234,7 @@ Available Datasets datasets/pyhealth.datasets.SHHSDataset datasets/pyhealth.datasets.SleepEDFDataset datasets/pyhealth.datasets.EHRShotDataset + datasets/pyhealth.datasets.HiRIDDataset datasets/pyhealth.datasets.Support2Dataset datasets/pyhealth.datasets.BMDHSDataset datasets/pyhealth.datasets.COVID19CXRDataset diff --git a/docs/api/datasets/pyhealth.datasets.HiRIDDataset.rst b/docs/api/datasets/pyhealth.datasets.HiRIDDataset.rst new file mode 100644 index 000000000..cf658e3ad --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.HiRIDDataset.rst @@ -0,0 +1,9 @@ +pyhealth.datasets.HiRIDDataset +=================================== + +The HiRID (High time Resolution ICU Dataset) contains ~34,000 ICU admissions from Bern University Hospital with high-resolution time-series data. Refer to `PhysioNet `_ for more information. + +.. autoclass:: pyhealth.datasets.HiRIDDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..c210aeb6a 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -223,6 +223,7 @@ Available Tasks Temple University EEG Tasks Sleep Staging v2 Benchmark EHRShot + FAMEWS Fairness Audit ChestX-ray14 Binary Classification De-Identification NER ChestX-ray14 Multilabel Classification diff --git a/docs/api/tasks/pyhealth.tasks.FAMEWS_fairness_audit.rst b/docs/api/tasks/pyhealth.tasks.FAMEWS_fairness_audit.rst new file mode 100644 index 000000000..84602b8e3 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.FAMEWS_fairness_audit.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.FAMEWS_fairness_audit +=================================== + +.. autoclass:: pyhealth.tasks.FAMEWS_fairness_audit.FAMEWSFairnessAudit + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/HiRID_fairness_audit.ipynb b/examples/HiRID_fairness_audit.ipynb new file mode 100644 index 000000000..9d2cfb12d --- /dev/null +++ b/examples/HiRID_fairness_audit.ipynb @@ -0,0 +1,178 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "edaf0cae", + "metadata": {}, + "source": [ + "# FAMEWS Fairness Audit (Quick Start)\n", + "\n", + "This notebook shows a minimal end-to-end example for:\n", + "\n", + "1. Loading `HiRIDDataset`\n", + "2. Running `FAMEWSFairnessAudit` on a patient\n", + "3. Building a `SampleDataset` with `set_task(...)`\n", + "4. Printing a quick subgroup summary (sex and age group)\n", + "\n", + "Use `dev=True` for a fast smoke test, then switch to `dev=False` for full runs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e33e0fa7", + "metadata": {}, + "outputs": [ + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mRunning cells with 'Python 3.12.7' requires the ipykernel package.\n", + "\u001b[1;31mCreate a Python Environment with the required packages.\n", + "\u001b[1;31mOr install 'ipykernel' using the command: '/opt/homebrew/bin/python3.12 -m pip install ipykernel -U --user --force-reinstall'" + ] + } + ], + "source": [ + "from collections import Counter\n", + "from pathlib import Path\n", + "import sys\n", + "\n", + "import pandas as pd\n", + "\n", + "def find_repo_root(start: Path) -> Path:\n", + " for path in [start, *start.parents]:\n", + " if (path / \"pyproject.toml\").exists() and (path / \"pyhealth\").exists():\n", + " return path\n", + " raise FileNotFoundError(\"Could not locate the PyHealth repository root from the current working directory.\")\n", + "\n", + "REPO_ROOT = find_repo_root(Path.cwd().resolve())\n", + "for name in list(sys.modules):\n", + " if name == \"pyhealth\" or name.startswith(\"pyhealth.\"):\n", + " del sys.modules[name]\n", + "if str(REPO_ROOT) in sys.path:\n", + " sys.path.remove(str(REPO_ROOT))\n", + "sys.path.insert(0, str(REPO_ROOT))\n", + "\n", + "from pyhealth.datasets.hirid import HiRIDDataset\n", + "from pyhealth.tasks.HiRID_fairness_audit import FAMEWSFairnessAudit\n", + "\n", + "HIRID_ROOT = REPO_ROOT / \"test-resources\" / \"core\" / \"hiriddemo\"\n", + "\n", + "assert HIRID_ROOT.exists(), (\n", + " f\"Expected HiRID root at {HIRID_ROOT}, but it was not found.\"\n", + ")\n", + "\n", + "print(f\"Using HiRID root: {HIRID_ROOT}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ad8a1ef", + "metadata": {}, + "outputs": [], + "source": [ + "# Quick setup: use imputed stage + dev mode for fast iteration.\n", + "dataset = HiRIDDataset(\n", + " root=str(HIRID_ROOT),\n", + " stage=\"imputed\",\n", + " dev=True,\n", + ")\n", + "\n", + "task = FAMEWSFairnessAudit(stage_table=\"imputed_stage\")\n", + "\n", + "# Inspect one raw task sample from a single patient.\n", + "first_pid = dataset.unique_patient_ids[0]\n", + "first_patient = dataset.get_patient(first_pid)\n", + "raw_samples = task(first_patient)\n", + "\n", + "print(f\"First patient id: {first_pid}\")\n", + "print(f\"Raw samples generated for first patient: {len(raw_samples)}\")\n", + "\n", + "if raw_samples:\n", + " raw0 = raw_samples[0]\n", + " print(\"Raw sample keys:\", sorted(raw0.keys()))\n", + " print(\"Demographics:\", {\n", + " \"sex\": raw0.get(\"sex\"),\n", + " \"age\": raw0.get(\"age\"),\n", + " \"age_group\": raw0.get(\"age_group\"),\n", + " \"discharge_status\": raw0.get(\"discharge_status\"),\n", + " })\n", + "\n", + " ts, values = raw0[\"signals\"]\n", + " print(\"Timeseries length:\", len(ts))\n", + " print(\"Values shape:\", values.shape)\n", + "\n", + " display(pd.DataFrame(values[:5], columns=raw0[\"feature_columns\"]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "78981b54", + "metadata": {}, + "outputs": [], + "source": [ + "# Build a PyHealth SampleDataset (this applies processors and caches outputs).\n", + "sample_dataset = dataset.set_task(task, num_workers=1)\n", + "\n", + "print(\"Dataset stats:\")\n", + "dataset.stats()\n", + "print(f\"\\nNumber of processed ML samples: {len(sample_dataset)}\")\n", + "\n", + "sample0 = sample_dataset[0]\n", + "print(\"SampleDataset keys:\", sorted(sample0.keys()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "98afa328", + "metadata": {}, + "outputs": [], + "source": [ + "# Quick fairness-oriented summary over a small cohort slice.\n", + "# We intentionally use raw task outputs here to keep metadata fields explicit.\n", + "max_patients = 200\n", + "sex_counter = Counter()\n", + "age_group_counter = Counter()\n", + "valid_patients = 0\n", + "\n", + "for pid in dataset.unique_patient_ids[:max_patients]:\n", + " p = dataset.get_patient(pid)\n", + " samples = task(p)\n", + " if not samples:\n", + " continue\n", + " valid_patients += 1\n", + " s = samples[0]\n", + " sex_counter[str(s.get(\"sex\"))] += 1\n", + " age_group_counter[str(s.get(\"age_group\"))] += 1\n", + "\n", + "print(f\"Patients scanned: {max_patients}\")\n", + "print(f\"Patients with at least one task sample: {valid_patients}\")\n", + "\n", + "summary_df = pd.DataFrame({\n", + " \"sex\": dict(sex_counter),\n", + " \"age_group\": dict(age_group_counter),\n", + "})\n", + "\n", + "display(summary_df.fillna(0).astype(int))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 50b1b3887..5587cb5e9 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -55,6 +55,7 @@ def __init__(self, *args, **kwargs): from .dreamt import DREAMTDataset from .ehrshot import EHRShotDataset from .eicu import eICUDataset +from .hirid import HiRIDDataset from .isruc import ISRUCDataset from .medical_transcriptions import MedicalTranscriptionsDataset from .mimic3 import MIMIC3Dataset diff --git a/pyhealth/datasets/configs/hirid.yaml b/pyhealth/datasets/configs/hirid.yaml new file mode 100644 index 000000000..0b67d77c6 --- /dev/null +++ b/pyhealth/datasets/configs/hirid.yaml @@ -0,0 +1,89 @@ +version: "1.1.1" +tables: + general_table: + file_path: "general_table.csv" + patient_id: "patientid" + timestamp: "admissiontime" + attributes: + - "sex" + - "age" + - "discharge_status" + + merged_stage: + file_path: "hirid-merged-pyhealth.csv" + patient_id: "patientid" + timestamp: "datetime" + attributes: + - "heart_rate" + - "systolic_bp_invasive" + - "diastolic_bp_invasive" + - "mean_arterial_pressure" + - "cardiac_output" + - "spo2" + - "rass" + - "peak_inspiratory_pressure" + - "lactate_arterial" + - "lactate_venous" + - "inr" + - "serum_glucose" + - "c_reactive_protein" + - "dobutamine" + - "milrinone" + - "levosimendan" + - "theophyllin" + - "non_opioid_analgesics" + + observation_tables: + file_path: "hirid-observations-pyhealth.csv" + patient_id: "patientid" + timestamp: "datetime" + attributes: + - "entertime" + - "variableid" + - "value" + - "status" + - "stringvalue" + - "type" + + pharma_records: + file_path: "hirid-pharma-pyhealth.csv" + patient_id: "patientid" + timestamp: "givenat" + attributes: + - "pharmaid" + - "enteredentryat" + - "givendose" + - "cumulativedose" + - "fluidamount_calc" + - "cumulfluidamount_calc" + - "doseunit" + - "route" + - "infusionid" + - "typeid" + - "subtypeid" + - "recordstatus" + + imputed_stage: + file_path: "hirid-imputed-pyhealth.csv" + patient_id: "patientid" + timestamp: null + attributes: + - "reldatetime" + - "heart_rate" + - "systolic_bp_invasive" + - "diastolic_bp_invasive" + - "mean_arterial_pressure" + - "cardiac_output" + - "spo2" + - "rass" + - "peak_inspiratory_pressure" + - "lactate_arterial" + - "lactate_venous" + - "inr" + - "serum_glucose" + - "c_reactive_protein" + - "dobutamine" + - "milrinone" + - "levosimendan" + - "theophyllin" + - "non_opioid_analgesics" diff --git a/pyhealth/datasets/hirid.py b/pyhealth/datasets/hirid.py new file mode 100644 index 000000000..1cff56198 --- /dev/null +++ b/pyhealth/datasets/hirid.py @@ -0,0 +1,287 @@ +import io +import logging +import os +import tarfile +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import pandas as pd + +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + + +class HiRIDDataset(BaseDataset): + """HiRID v1.1.1 ICU dataset. + + The HiRID (High time Resolution ICU Dataset) is a freely accessible + critical care dataset containing data from ~34,000 ICU admissions + to the Department of Intensive Care Medicine of the Bern University + Hospital, Switzerland. + + The dataset provides high-resolution time-series data (2-minute + intervals) including vital signs, lab values, ventilator parameters, + and medication records. + + Dataset link: + https://physionet.org/content/hirid/1.1.1/ + + Paper: + Hyland et al. "Early prediction of circulatory failure in the + intensive care unit using machine learning." Nature Medicine, 2020. + + Args: + root: Root directory of the HiRID dataset. + stage: Data processing stage to use. One of ``"merged"``, + ``"raw"``, or ``"imputed"``. + dataset_name: Name of the dataset. Defaults to ``"hirid"``. + config_path: Path to YAML config. Defaults to built-in config. + **kwargs: Additional arguments passed to + :class:`~pyhealth.datasets.BaseDataset`. + + Attributes: + stage: The selected data processing stage. + + Examples: + >>> from pyhealth.datasets import HiRIDDataset + >>> dataset = HiRIDDataset( + ... root="/path/to/hirid/1.1.1", + ... stage="merged", + ... ) + >>> dataset.stats() + >>> patient = dataset.get_patient("1") + """ + + MERGED_COLUMN_MAP: Dict[str, str] = { + "vm1": "heart_rate", + "vm3": "systolic_bp_invasive", + "vm4": "diastolic_bp_invasive", + "vm5": "mean_arterial_pressure", + "vm13": "cardiac_output", + "vm20": "spo2", + "vm28": "rass", + "vm62": "peak_inspiratory_pressure", + "vm136": "lactate_arterial", + "vm146": "lactate_venous", + "vm172": "inr", + "vm174": "serum_glucose", + "vm176": "c_reactive_protein", + "pm41": "dobutamine", + "pm42": "milrinone", + "pm43": "levosimendan", + "pm44": "theophyllin", + "pm87": "non_opioid_analgesics", + } + + def __init__( + self, + root: str, + stage: str = "merged", + dataset_name: str = "hirid", + config_path: Optional[str] = None, + **kwargs, + ) -> None: + if stage not in {"merged", "raw", "imputed"}: + raise ValueError( + "stage must be one of 'merged', 'raw', or 'imputed', " + f"got '{stage}'" + ) + + self.stage = stage + + if config_path is None: + config_path = str(Path(__file__).parent / "configs" / "hirid.yaml") + + self._verify_data(root, stage) + self._prepare_data(root, stage) + + if stage == "merged": + tables = ["general_table", "merged_stage"] + elif stage == "raw": + tables = ["general_table", "observation_tables", "pharma_records"] + else: + tables = ["general_table", "imputed_stage"] + + super().__init__( + root=root, + tables=tables, + dataset_name=dataset_name, + config_path=config_path, + **kwargs, + ) + + def _verify_data(self, root: str, stage: str) -> None: + if not os.path.exists(root): + raise FileNotFoundError(f"Dataset path does not exist: {root}") + + general_table_path = os.path.join(root, "general_table.csv") + if not os.path.exists(general_table_path): + raise FileNotFoundError( + f"Required file not found: {general_table_path}" + ) + + requirements = self._stage_requirements(root, stage) + for processed_path, tar_path in requirements: + if os.path.exists(processed_path): + continue + if not os.path.exists(tar_path): + raise FileNotFoundError(f"Required file not found: {tar_path}") + + def _prepare_data(self, root: str, stage: str) -> None: + if stage == "merged": + self._prepare_merged(root) + elif stage == "raw": + self._prepare_raw(root) + else: + self._prepare_imputed(root) + + def _prepare_merged(self, root: str) -> None: + output_path = os.path.join(root, "hirid-merged-pyhealth.csv") + if os.path.exists(output_path): + logger.info("Processed file exists, skipping: %s", output_path) + return + + tar_path = os.path.join(root, "merged_stage", "merged_stage_csv.tar.gz") + self._extract_and_concat_tar( + tar_path=tar_path, + output_path=output_path, + column_map=self.MERGED_COLUMN_MAP, + ) + + def _prepare_raw(self, root: str) -> None: + observation_output_path = os.path.join( + root, "hirid-observations-pyhealth.csv" + ) + if os.path.exists(observation_output_path): + logger.info( + "Processed file exists, skipping: %s", + observation_output_path, + ) + else: + observation_tar_path = os.path.join( + root, + "raw_stage", + "observation_tables_csv.tar.gz", + ) + self._extract_and_concat_tar( + tar_path=observation_tar_path, + output_path=observation_output_path, + ) + + pharma_output_path = os.path.join(root, "hirid-pharma-pyhealth.csv") + if os.path.exists(pharma_output_path): + logger.info("Processed file exists, skipping: %s", pharma_output_path) + return + + pharma_tar_path = os.path.join( + root, + "raw_stage", + "pharma_records_csv.tar.gz", + ) + self._extract_and_concat_tar( + tar_path=pharma_tar_path, + output_path=pharma_output_path, + ) + + def _prepare_imputed(self, root: str) -> None: + output_path = os.path.join(root, "hirid-imputed-pyhealth.csv") + if os.path.exists(output_path): + logger.info("Processed file exists, skipping: %s", output_path) + return + + tar_path = os.path.join(root, "imputed_stage", "imputed_stage_csv.tar.gz") + self._extract_and_concat_tar( + tar_path=tar_path, + output_path=output_path, + column_map=self.MERGED_COLUMN_MAP, + ) + + def _extract_and_concat_tar( + self, + tar_path: str, + output_path: str, + column_map: Optional[Dict[str, str]] = None, + ) -> None: + logger.info("Extracting %s -> %s", tar_path, output_path) + first = True + + with tarfile.open(tar_path, "r:gz") as tar: + members = sorted( + ( + member + for member in tar.getmembers() + if member.isfile() and member.name.endswith(".csv") + ), + key=lambda member: member.name, + ) + if not members: + raise ValueError(f"No CSV files found in archive: {tar_path}") + + for index, member in enumerate(members, start=1): + extracted = tar.extractfile(member) + if extracted is None: + continue + + with io.TextIOWrapper(extracted) as buffer: + df = pd.read_csv(buffer) + + if column_map: + df.rename(columns=column_map, inplace=True) + + df.to_csv( + output_path, + mode="w" if first else "a", + header=first, + index=False, + ) + first = False + + if index % 1000 == 0: + logger.info("Processed %s/%s patient files", index, len(members)) + + logger.info("Created %s", output_path) + + def _stage_requirements(self, root: str, stage: str) -> List[Tuple[str, str]]: + if stage == "merged": + return [ + ( + os.path.join(root, "hirid-merged-pyhealth.csv"), + os.path.join( + root, + "merged_stage", + "merged_stage_csv.tar.gz", + ), + ) + ] + + if stage == "raw": + return [ + ( + os.path.join(root, "hirid-observations-pyhealth.csv"), + os.path.join( + root, + "raw_stage", + "observation_tables_csv.tar.gz", + ), + ), + ( + os.path.join(root, "hirid-pharma-pyhealth.csv"), + os.path.join( + root, + "raw_stage", + "pharma_records_csv.tar.gz", + ), + ), + ] + + return [ + ( + os.path.join(root, "hirid-imputed-pyhealth.csv"), + os.path.join( + root, + "imputed_stage", + "imputed_stage_csv.tar.gz", + ), + ) + ] diff --git a/pyhealth/tasks/FAMEWS_fairness_audit.py b/pyhealth/tasks/FAMEWS_fairness_audit.py new file mode 100644 index 000000000..ee325f7dd --- /dev/null +++ b/pyhealth/tasks/FAMEWS_fairness_audit.py @@ -0,0 +1,228 @@ +""" +PyHealth task for FAMEWS-style fairness analysis on HiRID. + +Dataset link: + https://physionet.org/content/hirid/1.1.1/ + +Dataset paper: (please cite if you use this dataset) + Hoche, M.; Mineeva, O.; Burger, M.; Blasimme, A.; and + Ratsch, G. 2024. FAMEWS: a Fairness Auditing tool for + Medical Early-Warning Systems. In Proceedings of Ma- + chine Learning Research, volume 248, 297–311. Confer- + ence on Health, Inference, and Learning (CHIL) 2024. + +Dataset paper link: + https://proceedings.mlr.press/v248/hoche24a.html + +Author: + John Doll (doll3@illinois.edu) +""" + +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional + +import pandas as pd +import polars as pl + +from pyhealth.tasks import BaseTask + + +class FAMEWSFairnessAudit(BaseTask): + """Build per-patient samples for FAMEWS-style fairness analysis on HiRID. + + This task is designed for :class:`~pyhealth.datasets.HiRIDDataset` and + returns one sample per patient containing: + + - A multivariate timeseries (for model features) + - Demographic metadata used for fairness subgrouping + + By default it reads from ``imputed_stage``. For merged data, set + ``stage_table="merged_stage"``. + + The HiRID subgroup configuration used by the original FAMEWS fairness + pipeline is bundled at ``FAMEWSFairnessAudit.group_config_path``. + + Examples: + >>> from pyhealth.datasets import HiRIDDataset + >>> from pyhealth.tasks import FAMEWSFairnessAudit + >>> dataset = HiRIDDataset( + ... root="/path/to/hirid/1.1.1", + ... stage="imputed", + ... ) + >>> task = FAMEWSFairnessAudit(stage_table="imputed_stage") + >>> samples = dataset.set_task(task) + """ + + task_name: str = "FAMEWS_fairness_audit" + input_schema: Dict[str, str] = {"signals": "timeseries"} + output_schema: Dict[str, str] = {} + group_config_path: str = str( + Path(__file__).parent / "configs" / "group_hirid_complete.yaml" + ) + + DEFAULT_FEATURE_COLUMNS: List[str] = [ + "heart_rate", + "systolic_bp_invasive", + "diastolic_bp_invasive", + "mean_arterial_pressure", + "cardiac_output", + "spo2", + "rass", + "peak_inspiratory_pressure", + "lactate_arterial", + "lactate_venous", + "inr", + "serum_glucose", + "c_reactive_protein", + "dobutamine", + "milrinone", + "levosimendan", + "theophyllin", + "non_opioid_analgesics", + ] + + def __init__( + self, + stage_table: str = "imputed_stage", + feature_columns: Optional[Iterable[str]] = None, + ) -> None: + if stage_table not in {"merged_stage", "imputed_stage"}: + raise ValueError( + "stage_table must be one of {'merged_stage', 'imputed_stage'}, " + f"got '{stage_table}'" + ) + self.stage_table = stage_table + self.feature_columns = list(feature_columns) if feature_columns else list( + self.DEFAULT_FEATURE_COLUMNS + ) + # Included in vars(task) so cache key changes after time-axis bug fix. + self.time_axis_version = "v2" + + @staticmethod + def _build_age_group(age_value: Any) -> Optional[str]: + try: + age = float(age_value) + except (TypeError, ValueError): + return None + + if age < 50: + return "<50" + if age < 65: + return "50-65" + if age < 75: + return "65-75" + if age < 85: + return "75-85" + return ">85" + + def _extract_time_axis(self, events_df: pl.DataFrame) -> List[Any]: + default_axis = [ + datetime(1970, 1, 1) + timedelta(hours=i) + for i in range(events_df.height) + ] + + if "timestamp" in events_df.columns: + timestamp_series = events_df.get_column("timestamp") + if timestamp_series.null_count() < events_df.height: + parsed = pd.to_datetime(timestamp_series.to_list(), errors="coerce") + if parsed.notna().any(): + time_axis: List[datetime] = [] + previous = datetime(1970, 1, 1) + for ts in parsed: + if pd.isna(ts): + previous = previous + timedelta(hours=1) + else: + previous = ts.to_pydatetime() + time_axis.append(previous) + return time_axis + + relative_time_col = f"{self.stage_table}/reldatetime" + if relative_time_col in events_df.columns: + raw_relative = events_df.get_column(relative_time_col).to_list() + + if len(raw_relative) == 0: + return default_axis + + numeric_values: List[float] = [] + numeric_only = True + for value in raw_relative: + try: + numeric_values.append(float(value)) + except (TypeError, ValueError): + numeric_only = False + break + + if numeric_only: + non_negative_diffs = [ + numeric_values[i] - numeric_values[i - 1] + for i in range(1, len(numeric_values)) + if numeric_values[i] >= numeric_values[i - 1] + ] + # HiRID relative times are commonly encoded in minutes; some exports + # may use seconds. Infer units from the median step size. + step_median = ( + pd.Series(non_negative_diffs).median() + if non_negative_diffs + else 0 + ) + as_seconds = bool(step_median and step_median > 20) + origin = datetime(1970, 1, 1) + return [ + origin + timedelta(seconds=v if as_seconds else v * 60.0) + for v in numeric_values + ] + + parsed_relative = pd.to_timedelta(raw_relative, errors="coerce") + if parsed_relative.notna().any(): + origin = datetime(1970, 1, 1) + time_axis = [] + previous = origin + for delta in parsed_relative: + if pd.isna(delta): + previous = previous + timedelta(hours=1) + else: + previous = origin + delta.to_pytimedelta() + time_axis.append(previous) + return time_axis + + return default_axis + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + patient_general = patient.get_events(event_type="general_table") + if len(patient_general) == 0: + return [] + + stage_df = patient.get_events(event_type=self.stage_table, return_df=True) + if stage_df.height == 0: + return [] + + available_features = [ + col + for col in self.feature_columns + if f"{self.stage_table}/{col}" in stage_df.columns + ] + if len(available_features) == 0: + return [] + + value_df = stage_df.select( + [ + pl.col(f"{self.stage_table}/{col}").cast(pl.Float64) + for col in available_features + ] + ) + + general_event = patient_general[0] + age_group = self._build_age_group(getattr(general_event, "age", None)) + + sample = { + "patient_id": patient.patient_id, + "signals": (self._extract_time_axis(stage_df), value_df.to_numpy()), + "feature_columns": available_features, + "sex": getattr(general_event, "sex", None), + "age": getattr(general_event, "age", None), + "age_group": age_group, + "discharge_status": getattr(general_event, "discharge_status", None), + } + + return [sample] diff --git a/pyhealth/tasks/HiRID_fairness_audit.py b/pyhealth/tasks/HiRID_fairness_audit.py new file mode 100644 index 000000000..9299b802d --- /dev/null +++ b/pyhealth/tasks/HiRID_fairness_audit.py @@ -0,0 +1,224 @@ +""" +PyHealth task for FAMEWS-style fairness analysis on HiRID. + +Dataset link: + https://physionet.org/content/hirid/1.1.1/ + +Dataset paper: (please cite if you use this dataset) + Hoche, M.; Mineeva, O.; Burger, M.; Blasimme, A.; and + Ratsch, G. 2024. FAMEWS: a Fairness Auditing tool for + Medical Early-Warning Systems. In Proceedings of Ma- + chine Learning Research, volume 248, 297–311. Confer- + ence on Health, Inference, and Learning (CHIL) 2024. + +Dataset paper link: + https://proceedings.mlr.press/v248/hoche24a.html + +Author: + John Doll (doll3@illinois.edu) +""" + +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional + +import pandas as pd +import polars as pl + +from pyhealth.tasks import BaseTask + + +class FAMEWSFairnessAudit(BaseTask): + """Build per-patient samples for FAMEWS-style fairness analysis on HiRID. + + This task is designed for :class:`~pyhealth.datasets.HiRIDDataset` and + returns one sample per patient containing: + + - A multivariate timeseries (for model features) + - Demographic metadata used for fairness subgrouping + + By default it reads from ``imputed_stage``. For merged data, set + ``stage_table="merged_stage"``. + + The HiRID subgroup configuration used by the original FAMEWS fairness + pipeline is bundled at ``FAMEWSFairnessAudit.group_config_path``. + + Examples: + >>> from pyhealth.datasets import HiRIDDataset + >>> from pyhealth.tasks import FAMEWSFairnessAudit + >>> dataset = HiRIDDataset( + ... root="/path/to/hirid/1.1.1", + ... stage="imputed", + ... ) + >>> task = FAMEWSFairnessAudit(stage_table="imputed_stage") + >>> samples = dataset.set_task(task) + """ + + task_name: str = "FAMEWS_fairness_audit" + input_schema: Dict[str, str] = {"signals": "timeseries"} + output_schema: Dict[str, str] = {} + group_config_path: str = str( + Path(__file__).parent / "configs" / "group_hirid_complete.yaml" + ) + + DEFAULT_FEATURE_COLUMNS: List[str] = [ + "heart_rate", + "systolic_bp_invasive", + "diastolic_bp_invasive", + "mean_arterial_pressure", + "cardiac_output", + "spo2", + "rass", + "peak_inspiratory_pressure", + "lactate_arterial", + "lactate_venous", + "inr", + "serum_glucose", + "c_reactive_protein", + "dobutamine", + "milrinone", + "levosimendan", + "theophyllin", + "non_opioid_analgesics", + ] + + def __init__( + self, + stage_table: str = "imputed_stage", + feature_columns: Optional[Iterable[str]] = None, + ) -> None: + if stage_table not in {"merged_stage", "imputed_stage"}: + raise ValueError( + "stage_table must be one of {'merged_stage', 'imputed_stage'}, " + f"got '{stage_table}'" + ) + self.stage_table = stage_table + self.feature_columns = list(feature_columns) if feature_columns else list( + self.DEFAULT_FEATURE_COLUMNS + ) + # Included in vars(task) so cache key changes after time-axis bug fix. + self.time_axis_version = "v2" + + @staticmethod + def _build_age_group(age_value: Any) -> Optional[str]: + try: + age = float(age_value) + except (TypeError, ValueError): + return None + + if age < 50: + return "<50" + if age < 65: + return "50-65" + if age < 75: + return "65-75" + if age < 85: + return "75-85" + return ">85" + + def _extract_time_axis(self, events_df: pl.DataFrame) -> List[Any]: + default_axis = [ + datetime(1970, 1, 1) + timedelta(hours=i) + for i in range(events_df.height) + ] + + if "timestamp" in events_df.columns: + timestamp_series = events_df.get_column("timestamp") + if timestamp_series.null_count() < events_df.height: + parsed = pd.to_datetime(timestamp_series.to_list(), errors="coerce") + if parsed.notna().any(): + time_axis: List[datetime] = [] + previous = datetime(1970, 1, 1) + for ts in parsed: + if pd.isna(ts): + previous = previous + timedelta(hours=1) + else: + previous = ts.to_pydatetime() + time_axis.append(previous) + return time_axis + + relative_time_col = f"{self.stage_table}/reldatetime" + if relative_time_col in events_df.columns: + raw_relative = events_df.get_column(relative_time_col).to_list() + + if len(raw_relative) == 0: + return default_axis + + numeric_values: List[float] = [] + numeric_only = True + for value in raw_relative: + try: + numeric_values.append(float(value)) + except (TypeError, ValueError): + numeric_only = False + break + + if numeric_only: + non_negative_diffs = [ + numeric_values[i] - numeric_values[i - 1] + for i in range(1, len(numeric_values)) + if numeric_values[i] >= numeric_values[i - 1] + ] + # HiRID relative times are commonly encoded in minutes; some exports + # may use seconds. Infer units from the median step size. + step_median = pd.Series(non_negative_diffs).median() if non_negative_diffs else 0 + as_seconds = bool(step_median and step_median > 20) + origin = datetime(1970, 1, 1) + return [ + origin + timedelta(seconds=v if as_seconds else v * 60.0) + for v in numeric_values + ] + + parsed_relative = pd.to_timedelta(raw_relative, errors="coerce") + if parsed_relative.notna().any(): + origin = datetime(1970, 1, 1) + time_axis = [] + previous = origin + for delta in parsed_relative: + if pd.isna(delta): + previous = previous + timedelta(hours=1) + else: + previous = origin + delta.to_pytimedelta() + time_axis.append(previous) + return time_axis + + return default_axis + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + patient_general = patient.get_events(event_type="general_table") + if len(patient_general) == 0: + return [] + + stage_df = patient.get_events(event_type=self.stage_table, return_df=True) + if stage_df.height == 0: + return [] + + available_features = [ + col + for col in self.feature_columns + if f"{self.stage_table}/{col}" in stage_df.columns + ] + if len(available_features) == 0: + return [] + + value_df = stage_df.select( + [ + pl.col(f"{self.stage_table}/{col}").cast(pl.Float64) + for col in available_features + ] + ) + + general_event = patient_general[0] + age_group = self._build_age_group(getattr(general_event, "age", None)) + + sample = { + "patient_id": patient.patient_id, + "signals": (self._extract_time_axis(stage_df), value_df.to_numpy()), + "feature_columns": available_features, + "sex": getattr(general_event, "sex", None), + "age": getattr(general_event, "age", None), + "age_group": age_group, + "discharge_status": getattr(general_event, "discharge_status", None), + } + + return [sample] diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..2882a4256 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -22,6 +22,9 @@ drug_recommendation_mimic4_fn, drug_recommendation_omop_fn, ) +from .EEG_abnormal import EEG_isAbnormal_fn +from .EEG_events import EEG_events_fn +from .HiRID_fairness_audit import FAMEWSFairnessAudit from .in_hospital_mortality_mimic4 import InHospitalMortalityMIMIC4 from .length_of_stay_prediction import ( LengthOfStayPredictioneICU, diff --git a/pyhealth/tasks/configs/group_hirid_complete.yaml b/pyhealth/tasks/configs/group_hirid_complete.yaml new file mode 100644 index 000000000..5ff749e5d --- /dev/null +++ b/pyhealth/tasks/configs/group_hirid_complete.yaml @@ -0,0 +1,9 @@ +group: + sex: [F, M] + age_group: [<50, 50-65, 65-75, 75-85, '>85'] + APACHE_group: [Cardiovascular, Neurological, Gastrointestinal, Respiratory, Other, Trauma, Metabolic] + surgical_status: [Surgical, Non-surgical] + +type_table_group: + binary_group: [sex, surgical_status] + multicat_group: [age_group, APACHE_group] diff --git a/test-resources/core/hiriddemo/general_table.csv b/test-resources/core/hiriddemo/general_table.csv new file mode 100644 index 000000000..7050f53b6 --- /dev/null +++ b/test-resources/core/hiriddemo/general_table.csv @@ -0,0 +1,4 @@ +patientid,admissiontime,sex,age,discharge_status +148,2183-05-12 11:45:00,M,72,alive +229,2151-03-02 07:10:00,F,61,dead +307,2160-09-18 18:30:00,F,47,alive diff --git a/test-resources/core/hiriddemo/hirid-imputed-pyhealth.csv b/test-resources/core/hiriddemo/hirid-imputed-pyhealth.csv new file mode 100644 index 000000000..e4c76a578 --- /dev/null +++ b/test-resources/core/hiriddemo/hirid-imputed-pyhealth.csv @@ -0,0 +1,10 @@ +patientid,reldatetime,heart_rate,systolic_bp_invasive,diastolic_bp_invasive,mean_arterial_pressure,cardiac_output,spo2,rass,peak_inspiratory_pressure,lactate_arterial,lactate_venous,inr,serum_glucose,c_reactive_protein,dobutamine,milrinone,levosimendan,theophyllin,non_opioid_analgesics +148,0.0,119.0,125.0,75.0,90.0,5.2,84.0,0.0,18.0,1.0,1.1,0.95,105.0,4.0,0.0,0.0,0.0,0.0,1.0 +148,30.0,114.0,122.0,73.0,88.0,5.1,89.0,0.0,19.0,1.0,1.1,0.96,109.0,4.2,0.0,0.0,0.0,0.0,1.0 +148,60.0,108.0,118.0,70.0,84.0,4.9,92.0,-1.0,19.0,0.9,1.0,0.98,111.0,4.4,0.0,0.0,0.0,0.0,0.0 +229,0.0,96.0,110.0,64.0,79.0,4.3,95.0,-1.0,17.0,1.3,1.4,1.10,132.0,12.0,2.0,0.0,0.0,0.0,1.0 +229,30.0,99.0,108.0,62.0,77.0,4.1,94.0,-1.0,17.0,1.5,1.5,1.15,138.0,14.0,2.0,0.0,0.0,0.0,1.0 +229,60.0,104.0,104.0,59.0,73.0,3.8,92.0,-2.0,18.0,1.8,1.7,1.20,145.0,17.0,3.0,0.0,0.0,0.0,1.0 +307,0.0,88.0,116.0,68.0,83.0,4.8,97.0,0.0,16.0,0.7,0.8,1.00,98.0,3.0,0.0,0.0,0.0,0.0,0.0 +307,30.0,86.0,118.0,70.0,85.0,4.9,98.0,0.0,16.0,0.7,0.8,1.00,96.0,2.9,0.0,0.0,0.0,0.0,0.0 +307,60.0,84.0,117.0,69.0,84.0,4.8,98.0,0.0,16.0,0.6,0.8,0.99,95.0,2.8,0.0,0.0,0.0,0.0,0.0 diff --git a/test-resources/core/hiriddemo/hirid-merged-pyhealth.csv b/test-resources/core/hiriddemo/hirid-merged-pyhealth.csv new file mode 100644 index 000000000..c03d82c84 --- /dev/null +++ b/test-resources/core/hiriddemo/hirid-merged-pyhealth.csv @@ -0,0 +1,10 @@ +patientid,datetime,heart_rate,systolic_bp_invasive,diastolic_bp_invasive,mean_arterial_pressure,cardiac_output,spo2,rass,peak_inspiratory_pressure,lactate_arterial,lactate_venous,inr,serum_glucose,c_reactive_protein,dobutamine,milrinone,levosimendan,theophyllin,non_opioid_analgesics +148,2183-05-12 12:00:00,119.0,125.0,75.0,90.0,5.2,84.0,0.0,18.0,1.0,1.1,0.95,105.0,4.0,0.0,0.0,0.0,0.0,1.0 +148,2183-05-12 12:30:00,114.0,122.0,73.0,88.0,5.1,89.0,0.0,19.0,1.0,1.1,0.96,109.0,4.2,0.0,0.0,0.0,0.0,1.0 +148,2183-05-12 13:00:00,108.0,118.0,70.0,84.0,4.9,92.0,-1.0,19.0,0.9,1.0,0.98,111.0,4.4,0.0,0.0,0.0,0.0,0.0 +229,2151-03-02 07:30:00,96.0,110.0,64.0,79.0,4.3,95.0,-1.0,17.0,1.3,1.4,1.10,132.0,12.0,2.0,0.0,0.0,0.0,1.0 +229,2151-03-02 08:00:00,99.0,108.0,62.0,77.0,4.1,94.0,-1.0,17.0,1.5,1.5,1.15,138.0,14.0,2.0,0.0,0.0,0.0,1.0 +229,2151-03-02 08:30:00,104.0,104.0,59.0,73.0,3.8,92.0,-2.0,18.0,1.8,1.7,1.20,145.0,17.0,3.0,0.0,0.0,0.0,1.0 +307,2160-09-18 19:00:00,88.0,116.0,68.0,83.0,4.8,97.0,0.0,16.0,0.7,0.8,1.00,98.0,3.0,0.0,0.0,0.0,0.0,0.0 +307,2160-09-18 19:30:00,86.0,118.0,70.0,85.0,4.9,98.0,0.0,16.0,0.7,0.8,1.00,96.0,2.9,0.0,0.0,0.0,0.0,0.0 +307,2160-09-18 20:00:00,84.0,117.0,69.0,84.0,4.8,98.0,0.0,16.0,0.6,0.8,0.99,95.0,2.8,0.0,0.0,0.0,0.0,0.0 diff --git a/test-resources/core/hiriddemo/hirid-observations-pyhealth.csv b/test-resources/core/hiriddemo/hirid-observations-pyhealth.csv new file mode 100644 index 000000000..13727a020 --- /dev/null +++ b/test-resources/core/hiriddemo/hirid-observations-pyhealth.csv @@ -0,0 +1,7 @@ +patientid,datetime,entertime,variableid,value,status,stringvalue,type +148,2183-05-12 12:00:00,2183-05-12 12:05:00,vm1,119.0,0,,numeric +148,2183-05-12 12:00:00,2183-05-12 12:05:00,vm20,84.0,0,,numeric +229,2151-03-02 07:30:00,2151-03-02 07:36:00,vm5,79.0,0,,numeric +229,2151-03-02 08:00:00,2151-03-02 08:06:00,vm174,138.0,0,,numeric +307,2160-09-18 19:00:00,2160-09-18 19:04:00,vm1,88.0,0,,numeric +307,2160-09-18 19:30:00,2160-09-18 19:35:00,vm176,2.9,0,,numeric diff --git a/test-resources/core/hiriddemo/hirid-pharma-pyhealth.csv b/test-resources/core/hiriddemo/hirid-pharma-pyhealth.csv new file mode 100644 index 000000000..44661e350 --- /dev/null +++ b/test-resources/core/hiriddemo/hirid-pharma-pyhealth.csv @@ -0,0 +1,5 @@ +patientid,pharmaid,givenat,enteredentryat,givendose,cumulativedose,fluidamount_calc,cumulfluidamount_calc,doseunit,route,infusionid,typeid,subtypeid,recordstatus +148,1000251,2183-05-12 12:54:00,2183-05-12 14:03:39,25.0,25.0,0.5,0.5,ug,iv-inj,1616089,1,8,780 +148,1000251,2183-05-12 12:58:00,2183-05-12 14:03:56,25.0,50.0,0.5,1.0,ug,iv-inj,257507,1,8,780 +229,1000302,2151-03-02 07:45:00,2151-03-02 08:00:00,2.0,2.0,10.0,10.0,mg,iv,200001,1,4,780 +307,1000403,2160-09-18 19:10:00,2160-09-18 19:22:00,1.0,1.0,5.0,5.0,mg,iv,300002,2,6,780 diff --git a/tests/core/test_famews.py b/tests/core/test_famews.py new file mode 100644 index 000000000..8bb0caade --- /dev/null +++ b/tests/core/test_famews.py @@ -0,0 +1,101 @@ +""" +Unit tests for the HiRIDDataset and FAMEWSFairnessAudit task. + +Author: + John Doll +""" + +import unittest +from datetime import datetime +from types import SimpleNamespace + +import polars as pl + +from pyhealth.tasks import FAMEWSFairnessAudit + + +class _MockPatient: + def __init__( + self, + patient_id: str, + general_events, + stage_df: pl.DataFrame, + ): + self.patient_id = patient_id + self._general_events = general_events + self._stage_df = stage_df + + def get_events(self, event_type: str, return_df: bool = False): + if event_type == "general_table": + return self._general_events + if return_df: + return self._stage_df + return [] + + +class TestFAMEWSFairnessAudit(unittest.TestCase): + def test_invalid_stage_table_raises(self): + with self.assertRaises(ValueError): + FAMEWSFairnessAudit(stage_table="raw_stage") + + def test_task_generates_sample_with_datetime_axis(self): + task = FAMEWSFairnessAudit(stage_table="imputed_stage") + stage_df = pl.DataFrame( + { + "imputed_stage/reldatetime": ["0", "5", "10"], + "imputed_stage/heart_rate": [80.0, 82.0, 84.0], + "imputed_stage/spo2": [96.0, 97.0, 98.0], + } + ) + general_event = SimpleNamespace( + sex="F", + age=73, + discharge_status="alive", + ) + patient = _MockPatient("p1", [general_event], stage_df) + + samples = task(patient) + + self.assertEqual(len(samples), 1) + sample = samples[0] + self.assertEqual(sample["patient_id"], "p1") + self.assertEqual(sample["age_group"], "65-75") + self.assertListEqual(sample["feature_columns"], ["heart_rate", "spo2"]) + + timestamps, values = sample["signals"] + self.assertEqual(len(timestamps), 3) + self.assertTrue(all(isinstance(ts, datetime) for ts in timestamps)) + self.assertEqual(values.shape, (3, 2)) + + def test_returns_empty_when_general_table_missing(self): + task = FAMEWSFairnessAudit(stage_table="imputed_stage") + stage_df = pl.DataFrame( + { + "imputed_stage/reldatetime": [0, 1], + "imputed_stage/heart_rate": [70.0, 75.0], + } + ) + patient = _MockPatient("p2", [], stage_df) + + self.assertEqual(task(patient), []) + + def test_returns_empty_when_no_feature_columns_present(self): + task = FAMEWSFairnessAudit(stage_table="imputed_stage") + stage_df = pl.DataFrame( + { + "imputed_stage/reldatetime": [0, 1, 2], + "imputed_stage/non_matching_feature": [1.0, 2.0, 3.0], + } + ) + general_event = SimpleNamespace( + sex="M", + age=45, + discharge_status="dead", + ) + patient = _MockPatient("p3", [general_event], stage_df) + + self.assertEqual(task(patient), []) + + +if __name__ == "__main__": + unittest.main()