diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index 8d9a59d21..3765f8e57 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -234,6 +234,7 @@ Available Datasets
datasets/pyhealth.datasets.SHHSDataset
datasets/pyhealth.datasets.SleepEDFDataset
datasets/pyhealth.datasets.EHRShotDataset
+ datasets/pyhealth.datasets.HiRIDDataset
datasets/pyhealth.datasets.Support2Dataset
datasets/pyhealth.datasets.BMDHSDataset
datasets/pyhealth.datasets.COVID19CXRDataset
diff --git a/docs/api/datasets/pyhealth.datasets.HiRIDDataset.rst b/docs/api/datasets/pyhealth.datasets.HiRIDDataset.rst
new file mode 100644
index 000000000..cf658e3ad
--- /dev/null
+++ b/docs/api/datasets/pyhealth.datasets.HiRIDDataset.rst
@@ -0,0 +1,9 @@
+pyhealth.datasets.HiRIDDataset
+===================================
+
+The HiRID (High time Resolution ICU Dataset) contains ~34,000 ICU admissions from Bern University Hospital with high-resolution time-series data. Refer to `PhysioNet `_ for more information.
+
+.. autoclass:: pyhealth.datasets.HiRIDDataset
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst
index 23a4e06e5..c210aeb6a 100644
--- a/docs/api/tasks.rst
+++ b/docs/api/tasks.rst
@@ -223,6 +223,7 @@ Available Tasks
Temple University EEG Tasks
Sleep Staging v2
Benchmark EHRShot
+ FAMEWS Fairness Audit
ChestX-ray14 Binary Classification
De-Identification NER
ChestX-ray14 Multilabel Classification
diff --git a/docs/api/tasks/pyhealth.tasks.FAMEWS_fairness_audit.rst b/docs/api/tasks/pyhealth.tasks.FAMEWS_fairness_audit.rst
new file mode 100644
index 000000000..84602b8e3
--- /dev/null
+++ b/docs/api/tasks/pyhealth.tasks.FAMEWS_fairness_audit.rst
@@ -0,0 +1,7 @@
+pyhealth.tasks.FAMEWS_fairness_audit
+===================================
+
+.. autoclass:: pyhealth.tasks.FAMEWS_fairness_audit.FAMEWSFairnessAudit
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/HiRID_fairness_audit.ipynb b/examples/HiRID_fairness_audit.ipynb
new file mode 100644
index 000000000..9d2cfb12d
--- /dev/null
+++ b/examples/HiRID_fairness_audit.ipynb
@@ -0,0 +1,178 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "edaf0cae",
+ "metadata": {},
+ "source": [
+ "# FAMEWS Fairness Audit (Quick Start)\n",
+ "\n",
+ "This notebook shows a minimal end-to-end example for:\n",
+ "\n",
+ "1. Loading `HiRIDDataset`\n",
+ "2. Running `FAMEWSFairnessAudit` on a patient\n",
+ "3. Building a `SampleDataset` with `set_task(...)`\n",
+ "4. Printing a quick subgroup summary (sex and age group)\n",
+ "\n",
+ "Use `dev=True` for a fast smoke test, then switch to `dev=False` for full runs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e33e0fa7",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mRunning cells with 'Python 3.12.7' requires the ipykernel package.\n",
+ "\u001b[1;31mCreate a Python Environment with the required packages.\n",
+ "\u001b[1;31mOr install 'ipykernel' using the command: '/opt/homebrew/bin/python3.12 -m pip install ipykernel -U --user --force-reinstall'"
+ ]
+ }
+ ],
+ "source": [
+ "from collections import Counter\n",
+ "from pathlib import Path\n",
+ "import sys\n",
+ "\n",
+ "import pandas as pd\n",
+ "\n",
+ "def find_repo_root(start: Path) -> Path:\n",
+ " for path in [start, *start.parents]:\n",
+ " if (path / \"pyproject.toml\").exists() and (path / \"pyhealth\").exists():\n",
+ " return path\n",
+ " raise FileNotFoundError(\"Could not locate the PyHealth repository root from the current working directory.\")\n",
+ "\n",
+ "REPO_ROOT = find_repo_root(Path.cwd().resolve())\n",
+ "for name in list(sys.modules):\n",
+ " if name == \"pyhealth\" or name.startswith(\"pyhealth.\"):\n",
+ " del sys.modules[name]\n",
+ "if str(REPO_ROOT) in sys.path:\n",
+ " sys.path.remove(str(REPO_ROOT))\n",
+ "sys.path.insert(0, str(REPO_ROOT))\n",
+ "\n",
+ "from pyhealth.datasets.hirid import HiRIDDataset\n",
+ "from pyhealth.tasks.HiRID_fairness_audit import FAMEWSFairnessAudit\n",
+ "\n",
+ "HIRID_ROOT = REPO_ROOT / \"test-resources\" / \"core\" / \"hiriddemo\"\n",
+ "\n",
+ "assert HIRID_ROOT.exists(), (\n",
+ " f\"Expected HiRID root at {HIRID_ROOT}, but it was not found.\"\n",
+ ")\n",
+ "\n",
+ "print(f\"Using HiRID root: {HIRID_ROOT}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9ad8a1ef",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Quick setup: use imputed stage + dev mode for fast iteration.\n",
+ "dataset = HiRIDDataset(\n",
+ " root=str(HIRID_ROOT),\n",
+ " stage=\"imputed\",\n",
+ " dev=True,\n",
+ ")\n",
+ "\n",
+ "task = FAMEWSFairnessAudit(stage_table=\"imputed_stage\")\n",
+ "\n",
+ "# Inspect one raw task sample from a single patient.\n",
+ "first_pid = dataset.unique_patient_ids[0]\n",
+ "first_patient = dataset.get_patient(first_pid)\n",
+ "raw_samples = task(first_patient)\n",
+ "\n",
+ "print(f\"First patient id: {first_pid}\")\n",
+ "print(f\"Raw samples generated for first patient: {len(raw_samples)}\")\n",
+ "\n",
+ "if raw_samples:\n",
+ " raw0 = raw_samples[0]\n",
+ " print(\"Raw sample keys:\", sorted(raw0.keys()))\n",
+ " print(\"Demographics:\", {\n",
+ " \"sex\": raw0.get(\"sex\"),\n",
+ " \"age\": raw0.get(\"age\"),\n",
+ " \"age_group\": raw0.get(\"age_group\"),\n",
+ " \"discharge_status\": raw0.get(\"discharge_status\"),\n",
+ " })\n",
+ "\n",
+ " ts, values = raw0[\"signals\"]\n",
+ " print(\"Timeseries length:\", len(ts))\n",
+ " print(\"Values shape:\", values.shape)\n",
+ "\n",
+ " display(pd.DataFrame(values[:5], columns=raw0[\"feature_columns\"]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "78981b54",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Build a PyHealth SampleDataset (this applies processors and caches outputs).\n",
+ "sample_dataset = dataset.set_task(task, num_workers=1)\n",
+ "\n",
+ "print(\"Dataset stats:\")\n",
+ "dataset.stats()\n",
+ "print(f\"\\nNumber of processed ML samples: {len(sample_dataset)}\")\n",
+ "\n",
+ "sample0 = sample_dataset[0]\n",
+ "print(\"SampleDataset keys:\", sorted(sample0.keys()))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "98afa328",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Quick fairness-oriented summary over a small cohort slice.\n",
+ "# We intentionally use raw task outputs here to keep metadata fields explicit.\n",
+ "max_patients = 200\n",
+ "sex_counter = Counter()\n",
+ "age_group_counter = Counter()\n",
+ "valid_patients = 0\n",
+ "\n",
+ "for pid in dataset.unique_patient_ids[:max_patients]:\n",
+ " p = dataset.get_patient(pid)\n",
+ " samples = task(p)\n",
+ " if not samples:\n",
+ " continue\n",
+ " valid_patients += 1\n",
+ " s = samples[0]\n",
+ " sex_counter[str(s.get(\"sex\"))] += 1\n",
+ " age_group_counter[str(s.get(\"age_group\"))] += 1\n",
+ "\n",
+ "print(f\"Patients scanned: {max_patients}\")\n",
+ "print(f\"Patients with at least one task sample: {valid_patients}\")\n",
+ "\n",
+ "summary_df = pd.DataFrame({\n",
+ " \"sex\": dict(sex_counter),\n",
+ " \"age_group\": dict(age_group_counter),\n",
+ "})\n",
+ "\n",
+ "display(summary_df.fillna(0).astype(int))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py
index 50b1b3887..5587cb5e9 100644
--- a/pyhealth/datasets/__init__.py
+++ b/pyhealth/datasets/__init__.py
@@ -55,6 +55,7 @@ def __init__(self, *args, **kwargs):
from .dreamt import DREAMTDataset
from .ehrshot import EHRShotDataset
from .eicu import eICUDataset
+from .hirid import HiRIDDataset
from .isruc import ISRUCDataset
from .medical_transcriptions import MedicalTranscriptionsDataset
from .mimic3 import MIMIC3Dataset
diff --git a/pyhealth/datasets/configs/hirid.yaml b/pyhealth/datasets/configs/hirid.yaml
new file mode 100644
index 000000000..0b67d77c6
--- /dev/null
+++ b/pyhealth/datasets/configs/hirid.yaml
@@ -0,0 +1,89 @@
+version: "1.1.1"
+tables:
+ general_table:
+ file_path: "general_table.csv"
+ patient_id: "patientid"
+ timestamp: "admissiontime"
+ attributes:
+ - "sex"
+ - "age"
+ - "discharge_status"
+
+ merged_stage:
+ file_path: "hirid-merged-pyhealth.csv"
+ patient_id: "patientid"
+ timestamp: "datetime"
+ attributes:
+ - "heart_rate"
+ - "systolic_bp_invasive"
+ - "diastolic_bp_invasive"
+ - "mean_arterial_pressure"
+ - "cardiac_output"
+ - "spo2"
+ - "rass"
+ - "peak_inspiratory_pressure"
+ - "lactate_arterial"
+ - "lactate_venous"
+ - "inr"
+ - "serum_glucose"
+ - "c_reactive_protein"
+ - "dobutamine"
+ - "milrinone"
+ - "levosimendan"
+ - "theophyllin"
+ - "non_opioid_analgesics"
+
+ observation_tables:
+ file_path: "hirid-observations-pyhealth.csv"
+ patient_id: "patientid"
+ timestamp: "datetime"
+ attributes:
+ - "entertime"
+ - "variableid"
+ - "value"
+ - "status"
+ - "stringvalue"
+ - "type"
+
+ pharma_records:
+ file_path: "hirid-pharma-pyhealth.csv"
+ patient_id: "patientid"
+ timestamp: "givenat"
+ attributes:
+ - "pharmaid"
+ - "enteredentryat"
+ - "givendose"
+ - "cumulativedose"
+ - "fluidamount_calc"
+ - "cumulfluidamount_calc"
+ - "doseunit"
+ - "route"
+ - "infusionid"
+ - "typeid"
+ - "subtypeid"
+ - "recordstatus"
+
+ imputed_stage:
+ file_path: "hirid-imputed-pyhealth.csv"
+ patient_id: "patientid"
+ timestamp: null
+ attributes:
+ - "reldatetime"
+ - "heart_rate"
+ - "systolic_bp_invasive"
+ - "diastolic_bp_invasive"
+ - "mean_arterial_pressure"
+ - "cardiac_output"
+ - "spo2"
+ - "rass"
+ - "peak_inspiratory_pressure"
+ - "lactate_arterial"
+ - "lactate_venous"
+ - "inr"
+ - "serum_glucose"
+ - "c_reactive_protein"
+ - "dobutamine"
+ - "milrinone"
+ - "levosimendan"
+ - "theophyllin"
+ - "non_opioid_analgesics"
diff --git a/pyhealth/datasets/hirid.py b/pyhealth/datasets/hirid.py
new file mode 100644
index 000000000..1cff56198
--- /dev/null
+++ b/pyhealth/datasets/hirid.py
@@ -0,0 +1,287 @@
+import io
+import logging
+import os
+import tarfile
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import pandas as pd
+
+from .base_dataset import BaseDataset
+
+logger = logging.getLogger(__name__)
+
+
+class HiRIDDataset(BaseDataset):
+ """HiRID v1.1.1 ICU dataset.
+
+ The HiRID (High time Resolution ICU Dataset) is a freely accessible
+ critical care dataset containing data from ~34,000 ICU admissions
+ to the Department of Intensive Care Medicine of the Bern University
+ Hospital, Switzerland.
+
+ The dataset provides high-resolution time-series data (2-minute
+ intervals) including vital signs, lab values, ventilator parameters,
+ and medication records.
+
+ Dataset link:
+ https://physionet.org/content/hirid/1.1.1/
+
+ Paper:
+ Hyland et al. "Early prediction of circulatory failure in the
+ intensive care unit using machine learning." Nature Medicine, 2020.
+
+ Args:
+ root: Root directory of the HiRID dataset.
+ stage: Data processing stage to use. One of ``"merged"``,
+ ``"raw"``, or ``"imputed"``.
+ dataset_name: Name of the dataset. Defaults to ``"hirid"``.
+ config_path: Path to YAML config. Defaults to built-in config.
+ **kwargs: Additional arguments passed to
+ :class:`~pyhealth.datasets.BaseDataset`.
+
+ Attributes:
+ stage: The selected data processing stage.
+
+ Examples:
+ >>> from pyhealth.datasets import HiRIDDataset
+ >>> dataset = HiRIDDataset(
+ ... root="/path/to/hirid/1.1.1",
+ ... stage="merged",
+ ... )
+ >>> dataset.stats()
+ >>> patient = dataset.get_patient("1")
+ """
+
+ MERGED_COLUMN_MAP: Dict[str, str] = {
+ "vm1": "heart_rate",
+ "vm3": "systolic_bp_invasive",
+ "vm4": "diastolic_bp_invasive",
+ "vm5": "mean_arterial_pressure",
+ "vm13": "cardiac_output",
+ "vm20": "spo2",
+ "vm28": "rass",
+ "vm62": "peak_inspiratory_pressure",
+ "vm136": "lactate_arterial",
+ "vm146": "lactate_venous",
+ "vm172": "inr",
+ "vm174": "serum_glucose",
+ "vm176": "c_reactive_protein",
+ "pm41": "dobutamine",
+ "pm42": "milrinone",
+ "pm43": "levosimendan",
+ "pm44": "theophyllin",
+ "pm87": "non_opioid_analgesics",
+ }
+
+ def __init__(
+ self,
+ root: str,
+ stage: str = "merged",
+ dataset_name: str = "hirid",
+ config_path: Optional[str] = None,
+ **kwargs,
+ ) -> None:
+ if stage not in {"merged", "raw", "imputed"}:
+ raise ValueError(
+ "stage must be one of 'merged', 'raw', or 'imputed', "
+ f"got '{stage}'"
+ )
+
+ self.stage = stage
+
+ if config_path is None:
+ config_path = str(Path(__file__).parent / "configs" / "hirid.yaml")
+
+ self._verify_data(root, stage)
+ self._prepare_data(root, stage)
+
+ if stage == "merged":
+ tables = ["general_table", "merged_stage"]
+ elif stage == "raw":
+ tables = ["general_table", "observation_tables", "pharma_records"]
+ else:
+ tables = ["general_table", "imputed_stage"]
+
+ super().__init__(
+ root=root,
+ tables=tables,
+ dataset_name=dataset_name,
+ config_path=config_path,
+ **kwargs,
+ )
+
+ def _verify_data(self, root: str, stage: str) -> None:
+ if not os.path.exists(root):
+ raise FileNotFoundError(f"Dataset path does not exist: {root}")
+
+ general_table_path = os.path.join(root, "general_table.csv")
+ if not os.path.exists(general_table_path):
+ raise FileNotFoundError(
+ f"Required file not found: {general_table_path}"
+ )
+
+ requirements = self._stage_requirements(root, stage)
+ for processed_path, tar_path in requirements:
+ if os.path.exists(processed_path):
+ continue
+ if not os.path.exists(tar_path):
+ raise FileNotFoundError(f"Required file not found: {tar_path}")
+
+ def _prepare_data(self, root: str, stage: str) -> None:
+ if stage == "merged":
+ self._prepare_merged(root)
+ elif stage == "raw":
+ self._prepare_raw(root)
+ else:
+ self._prepare_imputed(root)
+
+ def _prepare_merged(self, root: str) -> None:
+ output_path = os.path.join(root, "hirid-merged-pyhealth.csv")
+ if os.path.exists(output_path):
+ logger.info("Processed file exists, skipping: %s", output_path)
+ return
+
+ tar_path = os.path.join(root, "merged_stage", "merged_stage_csv.tar.gz")
+ self._extract_and_concat_tar(
+ tar_path=tar_path,
+ output_path=output_path,
+ column_map=self.MERGED_COLUMN_MAP,
+ )
+
+ def _prepare_raw(self, root: str) -> None:
+ observation_output_path = os.path.join(
+ root, "hirid-observations-pyhealth.csv"
+ )
+ if os.path.exists(observation_output_path):
+ logger.info(
+ "Processed file exists, skipping: %s",
+ observation_output_path,
+ )
+ else:
+ observation_tar_path = os.path.join(
+ root,
+ "raw_stage",
+ "observation_tables_csv.tar.gz",
+ )
+ self._extract_and_concat_tar(
+ tar_path=observation_tar_path,
+ output_path=observation_output_path,
+ )
+
+ pharma_output_path = os.path.join(root, "hirid-pharma-pyhealth.csv")
+ if os.path.exists(pharma_output_path):
+ logger.info("Processed file exists, skipping: %s", pharma_output_path)
+ return
+
+ pharma_tar_path = os.path.join(
+ root,
+ "raw_stage",
+ "pharma_records_csv.tar.gz",
+ )
+ self._extract_and_concat_tar(
+ tar_path=pharma_tar_path,
+ output_path=pharma_output_path,
+ )
+
+ def _prepare_imputed(self, root: str) -> None:
+ output_path = os.path.join(root, "hirid-imputed-pyhealth.csv")
+ if os.path.exists(output_path):
+ logger.info("Processed file exists, skipping: %s", output_path)
+ return
+
+ tar_path = os.path.join(root, "imputed_stage", "imputed_stage_csv.tar.gz")
+ self._extract_and_concat_tar(
+ tar_path=tar_path,
+ output_path=output_path,
+ column_map=self.MERGED_COLUMN_MAP,
+ )
+
+ def _extract_and_concat_tar(
+ self,
+ tar_path: str,
+ output_path: str,
+ column_map: Optional[Dict[str, str]] = None,
+ ) -> None:
+ logger.info("Extracting %s -> %s", tar_path, output_path)
+ first = True
+
+ with tarfile.open(tar_path, "r:gz") as tar:
+ members = sorted(
+ (
+ member
+ for member in tar.getmembers()
+ if member.isfile() and member.name.endswith(".csv")
+ ),
+ key=lambda member: member.name,
+ )
+ if not members:
+ raise ValueError(f"No CSV files found in archive: {tar_path}")
+
+ for index, member in enumerate(members, start=1):
+ extracted = tar.extractfile(member)
+ if extracted is None:
+ continue
+
+ with io.TextIOWrapper(extracted) as buffer:
+ df = pd.read_csv(buffer)
+
+ if column_map:
+ df.rename(columns=column_map, inplace=True)
+
+ df.to_csv(
+ output_path,
+ mode="w" if first else "a",
+ header=first,
+ index=False,
+ )
+ first = False
+
+ if index % 1000 == 0:
+ logger.info("Processed %s/%s patient files", index, len(members))
+
+ logger.info("Created %s", output_path)
+
+ def _stage_requirements(self, root: str, stage: str) -> List[Tuple[str, str]]:
+ if stage == "merged":
+ return [
+ (
+ os.path.join(root, "hirid-merged-pyhealth.csv"),
+ os.path.join(
+ root,
+ "merged_stage",
+ "merged_stage_csv.tar.gz",
+ ),
+ )
+ ]
+
+ if stage == "raw":
+ return [
+ (
+ os.path.join(root, "hirid-observations-pyhealth.csv"),
+ os.path.join(
+ root,
+ "raw_stage",
+ "observation_tables_csv.tar.gz",
+ ),
+ ),
+ (
+ os.path.join(root, "hirid-pharma-pyhealth.csv"),
+ os.path.join(
+ root,
+ "raw_stage",
+ "pharma_records_csv.tar.gz",
+ ),
+ ),
+ ]
+
+ return [
+ (
+ os.path.join(root, "hirid-imputed-pyhealth.csv"),
+ os.path.join(
+ root,
+ "imputed_stage",
+ "imputed_stage_csv.tar.gz",
+ ),
+ )
+ ]
diff --git a/pyhealth/tasks/FAMEWS_fairness_audit.py b/pyhealth/tasks/FAMEWS_fairness_audit.py
new file mode 100644
index 000000000..ee325f7dd
--- /dev/null
+++ b/pyhealth/tasks/FAMEWS_fairness_audit.py
@@ -0,0 +1,228 @@
+"""
+PyHealth task for FAMEWS-style fairness analysis on HiRID.
+
+Dataset link:
+ https://physionet.org/content/hirid/1.1.1/
+
+Dataset paper: (please cite if you use this dataset)
+ Hoche, M.; Mineeva, O.; Burger, M.; Blasimme, A.; and
+ Ratsch, G. 2024. FAMEWS: a Fairness Auditing tool for
+ Medical Early-Warning Systems. In Proceedings of Ma-
+ chine Learning Research, volume 248, 297–311. Confer-
+ ence on Health, Inference, and Learning (CHIL) 2024.
+
+Dataset paper link:
+ https://proceedings.mlr.press/v248/hoche24a.html
+
+Author:
+ John Doll (doll3@illinois.edu)
+"""
+
+from datetime import datetime, timedelta
+from pathlib import Path
+from typing import Any, Dict, Iterable, List, Optional
+
+import pandas as pd
+import polars as pl
+
+from pyhealth.tasks import BaseTask
+
+
+class FAMEWSFairnessAudit(BaseTask):
+ """Build per-patient samples for FAMEWS-style fairness analysis on HiRID.
+
+ This task is designed for :class:`~pyhealth.datasets.HiRIDDataset` and
+ returns one sample per patient containing:
+
+ - A multivariate timeseries (for model features)
+ - Demographic metadata used for fairness subgrouping
+
+ By default it reads from ``imputed_stage``. For merged data, set
+ ``stage_table="merged_stage"``.
+
+ The HiRID subgroup configuration used by the original FAMEWS fairness
+ pipeline is bundled at ``FAMEWSFairnessAudit.group_config_path``.
+
+ Examples:
+ >>> from pyhealth.datasets import HiRIDDataset
+ >>> from pyhealth.tasks import FAMEWSFairnessAudit
+ >>> dataset = HiRIDDataset(
+ ... root="/path/to/hirid/1.1.1",
+ ... stage="imputed",
+ ... )
+ >>> task = FAMEWSFairnessAudit(stage_table="imputed_stage")
+ >>> samples = dataset.set_task(task)
+ """
+
+ task_name: str = "FAMEWS_fairness_audit"
+ input_schema: Dict[str, str] = {"signals": "timeseries"}
+ output_schema: Dict[str, str] = {}
+ group_config_path: str = str(
+ Path(__file__).parent / "configs" / "group_hirid_complete.yaml"
+ )
+
+ DEFAULT_FEATURE_COLUMNS: List[str] = [
+ "heart_rate",
+ "systolic_bp_invasive",
+ "diastolic_bp_invasive",
+ "mean_arterial_pressure",
+ "cardiac_output",
+ "spo2",
+ "rass",
+ "peak_inspiratory_pressure",
+ "lactate_arterial",
+ "lactate_venous",
+ "inr",
+ "serum_glucose",
+ "c_reactive_protein",
+ "dobutamine",
+ "milrinone",
+ "levosimendan",
+ "theophyllin",
+ "non_opioid_analgesics",
+ ]
+
+ def __init__(
+ self,
+ stage_table: str = "imputed_stage",
+ feature_columns: Optional[Iterable[str]] = None,
+ ) -> None:
+ if stage_table not in {"merged_stage", "imputed_stage"}:
+ raise ValueError(
+ "stage_table must be one of {'merged_stage', 'imputed_stage'}, "
+ f"got '{stage_table}'"
+ )
+ self.stage_table = stage_table
+ self.feature_columns = list(feature_columns) if feature_columns else list(
+ self.DEFAULT_FEATURE_COLUMNS
+ )
+ # Included in vars(task) so cache key changes after time-axis bug fix.
+ self.time_axis_version = "v2"
+
+ @staticmethod
+ def _build_age_group(age_value: Any) -> Optional[str]:
+ try:
+ age = float(age_value)
+ except (TypeError, ValueError):
+ return None
+
+ if age < 50:
+ return "<50"
+ if age < 65:
+ return "50-65"
+ if age < 75:
+ return "65-75"
+ if age < 85:
+ return "75-85"
+ return ">85"
+
+ def _extract_time_axis(self, events_df: pl.DataFrame) -> List[Any]:
+ default_axis = [
+ datetime(1970, 1, 1) + timedelta(hours=i)
+ for i in range(events_df.height)
+ ]
+
+ if "timestamp" in events_df.columns:
+ timestamp_series = events_df.get_column("timestamp")
+ if timestamp_series.null_count() < events_df.height:
+ parsed = pd.to_datetime(timestamp_series.to_list(), errors="coerce")
+ if parsed.notna().any():
+ time_axis: List[datetime] = []
+ previous = datetime(1970, 1, 1)
+ for ts in parsed:
+ if pd.isna(ts):
+ previous = previous + timedelta(hours=1)
+ else:
+ previous = ts.to_pydatetime()
+ time_axis.append(previous)
+ return time_axis
+
+ relative_time_col = f"{self.stage_table}/reldatetime"
+ if relative_time_col in events_df.columns:
+ raw_relative = events_df.get_column(relative_time_col).to_list()
+
+ if len(raw_relative) == 0:
+ return default_axis
+
+ numeric_values: List[float] = []
+ numeric_only = True
+ for value in raw_relative:
+ try:
+ numeric_values.append(float(value))
+ except (TypeError, ValueError):
+ numeric_only = False
+ break
+
+ if numeric_only:
+ non_negative_diffs = [
+ numeric_values[i] - numeric_values[i - 1]
+ for i in range(1, len(numeric_values))
+ if numeric_values[i] >= numeric_values[i - 1]
+ ]
+ # HiRID relative times are commonly encoded in minutes; some exports
+ # may use seconds. Infer units from the median step size.
+ step_median = (
+ pd.Series(non_negative_diffs).median()
+ if non_negative_diffs
+ else 0
+ )
+ as_seconds = bool(step_median and step_median > 20)
+ origin = datetime(1970, 1, 1)
+ return [
+ origin + timedelta(seconds=v if as_seconds else v * 60.0)
+ for v in numeric_values
+ ]
+
+ parsed_relative = pd.to_timedelta(raw_relative, errors="coerce")
+ if parsed_relative.notna().any():
+ origin = datetime(1970, 1, 1)
+ time_axis = []
+ previous = origin
+ for delta in parsed_relative:
+ if pd.isna(delta):
+ previous = previous + timedelta(hours=1)
+ else:
+ previous = origin + delta.to_pytimedelta()
+ time_axis.append(previous)
+ return time_axis
+
+ return default_axis
+
+ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
+ patient_general = patient.get_events(event_type="general_table")
+ if len(patient_general) == 0:
+ return []
+
+ stage_df = patient.get_events(event_type=self.stage_table, return_df=True)
+ if stage_df.height == 0:
+ return []
+
+ available_features = [
+ col
+ for col in self.feature_columns
+ if f"{self.stage_table}/{col}" in stage_df.columns
+ ]
+ if len(available_features) == 0:
+ return []
+
+ value_df = stage_df.select(
+ [
+ pl.col(f"{self.stage_table}/{col}").cast(pl.Float64)
+ for col in available_features
+ ]
+ )
+
+ general_event = patient_general[0]
+ age_group = self._build_age_group(getattr(general_event, "age", None))
+
+ sample = {
+ "patient_id": patient.patient_id,
+ "signals": (self._extract_time_axis(stage_df), value_df.to_numpy()),
+ "feature_columns": available_features,
+ "sex": getattr(general_event, "sex", None),
+ "age": getattr(general_event, "age", None),
+ "age_group": age_group,
+ "discharge_status": getattr(general_event, "discharge_status", None),
+ }
+
+ return [sample]
diff --git a/pyhealth/tasks/HiRID_fairness_audit.py b/pyhealth/tasks/HiRID_fairness_audit.py
new file mode 100644
index 000000000..9299b802d
--- /dev/null
+++ b/pyhealth/tasks/HiRID_fairness_audit.py
@@ -0,0 +1,224 @@
+"""
+PyHealth task for FAMEWS-style fairness analysis on HiRID.
+
+Dataset link:
+ https://physionet.org/content/hirid/1.1.1/
+
+Dataset paper: (please cite if you use this dataset)
+ Hoche, M.; Mineeva, O.; Burger, M.; Blasimme, A.; and
+ Ratsch, G. 2024. FAMEWS: a Fairness Auditing tool for
+ Medical Early-Warning Systems. In Proceedings of Ma-
+ chine Learning Research, volume 248, 297–311. Confer-
+ ence on Health, Inference, and Learning (CHIL) 2024.
+
+Dataset paper link:
+ https://proceedings.mlr.press/v248/hoche24a.html
+
+Author:
+ John Doll (doll3@illinois.edu)
+"""
+
+from datetime import datetime, timedelta
+from pathlib import Path
+from typing import Any, Dict, Iterable, List, Optional
+
+import pandas as pd
+import polars as pl
+
+from pyhealth.tasks import BaseTask
+
+
+class FAMEWSFairnessAudit(BaseTask):
+ """Build per-patient samples for FAMEWS-style fairness analysis on HiRID.
+
+ This task is designed for :class:`~pyhealth.datasets.HiRIDDataset` and
+ returns one sample per patient containing:
+
+ - A multivariate timeseries (for model features)
+ - Demographic metadata used for fairness subgrouping
+
+ By default it reads from ``imputed_stage``. For merged data, set
+ ``stage_table="merged_stage"``.
+
+ The HiRID subgroup configuration used by the original FAMEWS fairness
+ pipeline is bundled at ``FAMEWSFairnessAudit.group_config_path``.
+
+ Examples:
+ >>> from pyhealth.datasets import HiRIDDataset
+ >>> from pyhealth.tasks import FAMEWSFairnessAudit
+ >>> dataset = HiRIDDataset(
+ ... root="/path/to/hirid/1.1.1",
+ ... stage="imputed",
+ ... )
+ >>> task = FAMEWSFairnessAudit(stage_table="imputed_stage")
+ >>> samples = dataset.set_task(task)
+ """
+
+ task_name: str = "FAMEWS_fairness_audit"
+ input_schema: Dict[str, str] = {"signals": "timeseries"}
+ output_schema: Dict[str, str] = {}
+ group_config_path: str = str(
+ Path(__file__).parent / "configs" / "group_hirid_complete.yaml"
+ )
+
+ DEFAULT_FEATURE_COLUMNS: List[str] = [
+ "heart_rate",
+ "systolic_bp_invasive",
+ "diastolic_bp_invasive",
+ "mean_arterial_pressure",
+ "cardiac_output",
+ "spo2",
+ "rass",
+ "peak_inspiratory_pressure",
+ "lactate_arterial",
+ "lactate_venous",
+ "inr",
+ "serum_glucose",
+ "c_reactive_protein",
+ "dobutamine",
+ "milrinone",
+ "levosimendan",
+ "theophyllin",
+ "non_opioid_analgesics",
+ ]
+
+ def __init__(
+ self,
+ stage_table: str = "imputed_stage",
+ feature_columns: Optional[Iterable[str]] = None,
+ ) -> None:
+ if stage_table not in {"merged_stage", "imputed_stage"}:
+ raise ValueError(
+ "stage_table must be one of {'merged_stage', 'imputed_stage'}, "
+ f"got '{stage_table}'"
+ )
+ self.stage_table = stage_table
+ self.feature_columns = list(feature_columns) if feature_columns else list(
+ self.DEFAULT_FEATURE_COLUMNS
+ )
+ # Included in vars(task) so cache key changes after time-axis bug fix.
+ self.time_axis_version = "v2"
+
+ @staticmethod
+ def _build_age_group(age_value: Any) -> Optional[str]:
+ try:
+ age = float(age_value)
+ except (TypeError, ValueError):
+ return None
+
+ if age < 50:
+ return "<50"
+ if age < 65:
+ return "50-65"
+ if age < 75:
+ return "65-75"
+ if age < 85:
+ return "75-85"
+ return ">85"
+
+ def _extract_time_axis(self, events_df: pl.DataFrame) -> List[Any]:
+ default_axis = [
+ datetime(1970, 1, 1) + timedelta(hours=i)
+ for i in range(events_df.height)
+ ]
+
+ if "timestamp" in events_df.columns:
+ timestamp_series = events_df.get_column("timestamp")
+ if timestamp_series.null_count() < events_df.height:
+ parsed = pd.to_datetime(timestamp_series.to_list(), errors="coerce")
+ if parsed.notna().any():
+ time_axis: List[datetime] = []
+ previous = datetime(1970, 1, 1)
+ for ts in parsed:
+ if pd.isna(ts):
+ previous = previous + timedelta(hours=1)
+ else:
+ previous = ts.to_pydatetime()
+ time_axis.append(previous)
+ return time_axis
+
+ relative_time_col = f"{self.stage_table}/reldatetime"
+ if relative_time_col in events_df.columns:
+ raw_relative = events_df.get_column(relative_time_col).to_list()
+
+ if len(raw_relative) == 0:
+ return default_axis
+
+ numeric_values: List[float] = []
+ numeric_only = True
+ for value in raw_relative:
+ try:
+ numeric_values.append(float(value))
+ except (TypeError, ValueError):
+ numeric_only = False
+ break
+
+ if numeric_only:
+ non_negative_diffs = [
+ numeric_values[i] - numeric_values[i - 1]
+ for i in range(1, len(numeric_values))
+ if numeric_values[i] >= numeric_values[i - 1]
+ ]
+ # HiRID relative times are commonly encoded in minutes; some exports
+ # may use seconds. Infer units from the median step size.
+ step_median = pd.Series(non_negative_diffs).median() if non_negative_diffs else 0
+ as_seconds = bool(step_median and step_median > 20)
+ origin = datetime(1970, 1, 1)
+ return [
+ origin + timedelta(seconds=v if as_seconds else v * 60.0)
+ for v in numeric_values
+ ]
+
+ parsed_relative = pd.to_timedelta(raw_relative, errors="coerce")
+ if parsed_relative.notna().any():
+ origin = datetime(1970, 1, 1)
+ time_axis = []
+ previous = origin
+ for delta in parsed_relative:
+ if pd.isna(delta):
+ previous = previous + timedelta(hours=1)
+ else:
+ previous = origin + delta.to_pytimedelta()
+ time_axis.append(previous)
+ return time_axis
+
+ return default_axis
+
+ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
+ patient_general = patient.get_events(event_type="general_table")
+ if len(patient_general) == 0:
+ return []
+
+ stage_df = patient.get_events(event_type=self.stage_table, return_df=True)
+ if stage_df.height == 0:
+ return []
+
+ available_features = [
+ col
+ for col in self.feature_columns
+ if f"{self.stage_table}/{col}" in stage_df.columns
+ ]
+ if len(available_features) == 0:
+ return []
+
+ value_df = stage_df.select(
+ [
+ pl.col(f"{self.stage_table}/{col}").cast(pl.Float64)
+ for col in available_features
+ ]
+ )
+
+ general_event = patient_general[0]
+ age_group = self._build_age_group(getattr(general_event, "age", None))
+
+ sample = {
+ "patient_id": patient.patient_id,
+ "signals": (self._extract_time_axis(stage_df), value_df.to_numpy()),
+ "feature_columns": available_features,
+ "sex": getattr(general_event, "sex", None),
+ "age": getattr(general_event, "age", None),
+ "age_group": age_group,
+ "discharge_status": getattr(general_event, "discharge_status", None),
+ }
+
+ return [sample]
diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py
index a32618f9c..2882a4256 100644
--- a/pyhealth/tasks/__init__.py
+++ b/pyhealth/tasks/__init__.py
@@ -22,6 +22,9 @@
drug_recommendation_mimic4_fn,
drug_recommendation_omop_fn,
)
+from .EEG_abnormal import EEG_isAbnormal_fn
+from .EEG_events import EEG_events_fn
+from .HiRID_fairness_audit import FAMEWSFairnessAudit
from .in_hospital_mortality_mimic4 import InHospitalMortalityMIMIC4
from .length_of_stay_prediction import (
LengthOfStayPredictioneICU,
diff --git a/pyhealth/tasks/configs/group_hirid_complete.yaml b/pyhealth/tasks/configs/group_hirid_complete.yaml
new file mode 100644
index 000000000..5ff749e5d
--- /dev/null
+++ b/pyhealth/tasks/configs/group_hirid_complete.yaml
@@ -0,0 +1,9 @@
+group:
+ sex: [F, M]
+ age_group: [<50, 50-65, 65-75, 75-85, '>85']
+ APACHE_group: [Cardiovascular, Neurological, Gastrointestinal, Respiratory, Other, Trauma, Metabolic]
+ surgical_status: [Surgical, Non-surgical]
+
+type_table_group:
+ binary_group: [sex, surgical_status]
+ multicat_group: [age_group, APACHE_group]
diff --git a/test-resources/core/hiriddemo/general_table.csv b/test-resources/core/hiriddemo/general_table.csv
new file mode 100644
index 000000000..7050f53b6
--- /dev/null
+++ b/test-resources/core/hiriddemo/general_table.csv
@@ -0,0 +1,4 @@
+patientid,admissiontime,sex,age,discharge_status
+148,2183-05-12 11:45:00,M,72,alive
+229,2151-03-02 07:10:00,F,61,dead
+307,2160-09-18 18:30:00,F,47,alive
diff --git a/test-resources/core/hiriddemo/hirid-imputed-pyhealth.csv b/test-resources/core/hiriddemo/hirid-imputed-pyhealth.csv
new file mode 100644
index 000000000..e4c76a578
--- /dev/null
+++ b/test-resources/core/hiriddemo/hirid-imputed-pyhealth.csv
@@ -0,0 +1,10 @@
+patientid,reldatetime,heart_rate,systolic_bp_invasive,diastolic_bp_invasive,mean_arterial_pressure,cardiac_output,spo2,rass,peak_inspiratory_pressure,lactate_arterial,lactate_venous,inr,serum_glucose,c_reactive_protein,dobutamine,milrinone,levosimendan,theophyllin,non_opioid_analgesics
+148,0.0,119.0,125.0,75.0,90.0,5.2,84.0,0.0,18.0,1.0,1.1,0.95,105.0,4.0,0.0,0.0,0.0,0.0,1.0
+148,30.0,114.0,122.0,73.0,88.0,5.1,89.0,0.0,19.0,1.0,1.1,0.96,109.0,4.2,0.0,0.0,0.0,0.0,1.0
+148,60.0,108.0,118.0,70.0,84.0,4.9,92.0,-1.0,19.0,0.9,1.0,0.98,111.0,4.4,0.0,0.0,0.0,0.0,0.0
+229,0.0,96.0,110.0,64.0,79.0,4.3,95.0,-1.0,17.0,1.3,1.4,1.10,132.0,12.0,2.0,0.0,0.0,0.0,1.0
+229,30.0,99.0,108.0,62.0,77.0,4.1,94.0,-1.0,17.0,1.5,1.5,1.15,138.0,14.0,2.0,0.0,0.0,0.0,1.0
+229,60.0,104.0,104.0,59.0,73.0,3.8,92.0,-2.0,18.0,1.8,1.7,1.20,145.0,17.0,3.0,0.0,0.0,0.0,1.0
+307,0.0,88.0,116.0,68.0,83.0,4.8,97.0,0.0,16.0,0.7,0.8,1.00,98.0,3.0,0.0,0.0,0.0,0.0,0.0
+307,30.0,86.0,118.0,70.0,85.0,4.9,98.0,0.0,16.0,0.7,0.8,1.00,96.0,2.9,0.0,0.0,0.0,0.0,0.0
+307,60.0,84.0,117.0,69.0,84.0,4.8,98.0,0.0,16.0,0.6,0.8,0.99,95.0,2.8,0.0,0.0,0.0,0.0,0.0
diff --git a/test-resources/core/hiriddemo/hirid-merged-pyhealth.csv b/test-resources/core/hiriddemo/hirid-merged-pyhealth.csv
new file mode 100644
index 000000000..c03d82c84
--- /dev/null
+++ b/test-resources/core/hiriddemo/hirid-merged-pyhealth.csv
@@ -0,0 +1,10 @@
+patientid,datetime,heart_rate,systolic_bp_invasive,diastolic_bp_invasive,mean_arterial_pressure,cardiac_output,spo2,rass,peak_inspiratory_pressure,lactate_arterial,lactate_venous,inr,serum_glucose,c_reactive_protein,dobutamine,milrinone,levosimendan,theophyllin,non_opioid_analgesics
+148,2183-05-12 12:00:00,119.0,125.0,75.0,90.0,5.2,84.0,0.0,18.0,1.0,1.1,0.95,105.0,4.0,0.0,0.0,0.0,0.0,1.0
+148,2183-05-12 12:30:00,114.0,122.0,73.0,88.0,5.1,89.0,0.0,19.0,1.0,1.1,0.96,109.0,4.2,0.0,0.0,0.0,0.0,1.0
+148,2183-05-12 13:00:00,108.0,118.0,70.0,84.0,4.9,92.0,-1.0,19.0,0.9,1.0,0.98,111.0,4.4,0.0,0.0,0.0,0.0,0.0
+229,2151-03-02 07:30:00,96.0,110.0,64.0,79.0,4.3,95.0,-1.0,17.0,1.3,1.4,1.10,132.0,12.0,2.0,0.0,0.0,0.0,1.0
+229,2151-03-02 08:00:00,99.0,108.0,62.0,77.0,4.1,94.0,-1.0,17.0,1.5,1.5,1.15,138.0,14.0,2.0,0.0,0.0,0.0,1.0
+229,2151-03-02 08:30:00,104.0,104.0,59.0,73.0,3.8,92.0,-2.0,18.0,1.8,1.7,1.20,145.0,17.0,3.0,0.0,0.0,0.0,1.0
+307,2160-09-18 19:00:00,88.0,116.0,68.0,83.0,4.8,97.0,0.0,16.0,0.7,0.8,1.00,98.0,3.0,0.0,0.0,0.0,0.0,0.0
+307,2160-09-18 19:30:00,86.0,118.0,70.0,85.0,4.9,98.0,0.0,16.0,0.7,0.8,1.00,96.0,2.9,0.0,0.0,0.0,0.0,0.0
+307,2160-09-18 20:00:00,84.0,117.0,69.0,84.0,4.8,98.0,0.0,16.0,0.6,0.8,0.99,95.0,2.8,0.0,0.0,0.0,0.0,0.0
diff --git a/test-resources/core/hiriddemo/hirid-observations-pyhealth.csv b/test-resources/core/hiriddemo/hirid-observations-pyhealth.csv
new file mode 100644
index 000000000..13727a020
--- /dev/null
+++ b/test-resources/core/hiriddemo/hirid-observations-pyhealth.csv
@@ -0,0 +1,7 @@
+patientid,datetime,entertime,variableid,value,status,stringvalue,type
+148,2183-05-12 12:00:00,2183-05-12 12:05:00,vm1,119.0,0,,numeric
+148,2183-05-12 12:00:00,2183-05-12 12:05:00,vm20,84.0,0,,numeric
+229,2151-03-02 07:30:00,2151-03-02 07:36:00,vm5,79.0,0,,numeric
+229,2151-03-02 08:00:00,2151-03-02 08:06:00,vm174,138.0,0,,numeric
+307,2160-09-18 19:00:00,2160-09-18 19:04:00,vm1,88.0,0,,numeric
+307,2160-09-18 19:30:00,2160-09-18 19:35:00,vm176,2.9,0,,numeric
diff --git a/test-resources/core/hiriddemo/hirid-pharma-pyhealth.csv b/test-resources/core/hiriddemo/hirid-pharma-pyhealth.csv
new file mode 100644
index 000000000..44661e350
--- /dev/null
+++ b/test-resources/core/hiriddemo/hirid-pharma-pyhealth.csv
@@ -0,0 +1,5 @@
+patientid,pharmaid,givenat,enteredentryat,givendose,cumulativedose,fluidamount_calc,cumulfluidamount_calc,doseunit,route,infusionid,typeid,subtypeid,recordstatus
+148,1000251,2183-05-12 12:54:00,2183-05-12 14:03:39,25.0,25.0,0.5,0.5,ug,iv-inj,1616089,1,8,780
+148,1000251,2183-05-12 12:58:00,2183-05-12 14:03:56,25.0,50.0,0.5,1.0,ug,iv-inj,257507,1,8,780
+229,1000302,2151-03-02 07:45:00,2151-03-02 08:00:00,2.0,2.0,10.0,10.0,mg,iv,200001,1,4,780
+307,1000403,2160-09-18 19:10:00,2160-09-18 19:22:00,1.0,1.0,5.0,5.0,mg,iv,300002,2,6,780
diff --git a/tests/core/test_famews.py b/tests/core/test_famews.py
new file mode 100644
index 000000000..8bb0caade
--- /dev/null
+++ b/tests/core/test_famews.py
@@ -0,0 +1,101 @@
+"""
+Unit tests for the HiRIDDataset and FAMEWSFairnessAudit task.
+
+Author:
+ John Doll
+"""
+
+import unittest
+from datetime import datetime
+from types import SimpleNamespace
+
+import polars as pl
+
+from pyhealth.tasks import FAMEWSFairnessAudit
+
+
+class _MockPatient:
+ def __init__(
+ self,
+ patient_id: str,
+ general_events,
+ stage_df: pl.DataFrame,
+ ):
+ self.patient_id = patient_id
+ self._general_events = general_events
+ self._stage_df = stage_df
+
+ def get_events(self, event_type: str, return_df: bool = False):
+ if event_type == "general_table":
+ return self._general_events
+ if return_df:
+ return self._stage_df
+ return []
+
+
+class TestFAMEWSFairnessAudit(unittest.TestCase):
+ def test_invalid_stage_table_raises(self):
+ with self.assertRaises(ValueError):
+ FAMEWSFairnessAudit(stage_table="raw_stage")
+
+ def test_task_generates_sample_with_datetime_axis(self):
+ task = FAMEWSFairnessAudit(stage_table="imputed_stage")
+ stage_df = pl.DataFrame(
+ {
+ "imputed_stage/reldatetime": ["0", "5", "10"],
+ "imputed_stage/heart_rate": [80.0, 82.0, 84.0],
+ "imputed_stage/spo2": [96.0, 97.0, 98.0],
+ }
+ )
+ general_event = SimpleNamespace(
+ sex="F",
+ age=73,
+ discharge_status="alive",
+ )
+ patient = _MockPatient("p1", [general_event], stage_df)
+
+ samples = task(patient)
+
+ self.assertEqual(len(samples), 1)
+ sample = samples[0]
+ self.assertEqual(sample["patient_id"], "p1")
+ self.assertEqual(sample["age_group"], "65-75")
+ self.assertListEqual(sample["feature_columns"], ["heart_rate", "spo2"])
+
+ timestamps, values = sample["signals"]
+ self.assertEqual(len(timestamps), 3)
+ self.assertTrue(all(isinstance(ts, datetime) for ts in timestamps))
+ self.assertEqual(values.shape, (3, 2))
+
+ def test_returns_empty_when_general_table_missing(self):
+ task = FAMEWSFairnessAudit(stage_table="imputed_stage")
+ stage_df = pl.DataFrame(
+ {
+ "imputed_stage/reldatetime": [0, 1],
+ "imputed_stage/heart_rate": [70.0, 75.0],
+ }
+ )
+ patient = _MockPatient("p2", [], stage_df)
+
+ self.assertEqual(task(patient), [])
+
+ def test_returns_empty_when_no_feature_columns_present(self):
+ task = FAMEWSFairnessAudit(stage_table="imputed_stage")
+ stage_df = pl.DataFrame(
+ {
+ "imputed_stage/reldatetime": [0, 1, 2],
+ "imputed_stage/non_matching_feature": [1.0, 2.0, 3.0],
+ }
+ )
+ general_event = SimpleNamespace(
+ sex="M",
+ age=45,
+ discharge_status="dead",
+ )
+ patient = _MockPatient("p3", [general_event], stage_df)
+
+ self.assertEqual(task(patient), [])
+
+
+if __name__ == "__main__":
+ unittest.main()