diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..d5048a1c2 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -230,3 +230,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + Mortality Predict diff --git a/docs/api/tasks/pyhealth.tasks.mortality_text_task.rst b/docs/api/tasks/pyhealth.tasks.mortality_text_task.rst new file mode 100644 index 000000000..d70d58fa8 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.mortality_text_task.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.mortality\_text\_task +===================================== + +.. automodule:: pyhealth.tasks.mortality_text_task + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/mimic3_mortality_text_clinicalbert.py b/examples/mimic3_mortality_text_clinicalbert.py new file mode 100644 index 000000000..6f062862c --- /dev/null +++ b/examples/mimic3_mortality_text_clinicalbert.py @@ -0,0 +1,174 @@ +""" +Example and Ablation Study: MIMIC-III Mortality Text Task +========================================================== + +Reproduces part of the evaluation pipeline from: + + Zhang et al. "Hurtful Words: Quantifying Biases in Clinical Contextual + Word Embeddings." ACM CHIL 2020. https://arxiv.org/abs/2003.11515 + +This script demonstrates: + 1. How to use MortalityTextTaskMIMIC3 with MIMIC3Dataset. + 2. An ablation study showing how max_notes affects sample generation. + 3. Fairness gap evaluation (recall gap, parity gap) across demographic + subgroups (gender, ethnicity, insurance) as described in the paper. + +Requirements: + - pyhealth + - MIMIC-III data with PATIENTS and ADMISSIONS tables + (or use the synthetic subset below for a quick demo) + +Usage: + python examples/mimic3_mortality_text_clinicalbert.py +""" + +from collections import Counter, defaultdict + +from pyhealth.datasets import MIMIC3Dataset +from pyhealth.tasks.mortality_text_task import MortalityTextTaskMIMIC3 + +# --------------------------------------------------------------------------- +# 1. Load dataset +# Replace root with your local MIMIC-III path or use the synthetic subset. +# --------------------------------------------------------------------------- + +MIMIC3_ROOT = "https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III_subset/" + +print("Loading MIMIC-III dataset...") +dataset = MIMIC3Dataset( + root=MIMIC3_ROOT, + tables=["PATIENTS", "ADMISSIONS"], + dev=False, +) +print(f"Loaded {len(dataset.unique_patient_ids)} patients.\n") + + +# --------------------------------------------------------------------------- +# 2. Ablation: effect of max_notes on sample generation +# +# We test max_notes in [1, 3, 5, 8] and report: +# - Total samples generated +# - Average notes per sample +# - Label distribution (mortality rate) +# +# Finding: max_notes does not affect label distribution, only the amount +# of text available to the model. More notes give the model more context +# but also increase tokenization cost for BERT-based models. +# --------------------------------------------------------------------------- + +print("=" * 60) +print("ABLATION: Effect of max_notes on sample generation") +print("=" * 60) + +for max_notes in [1, 3, 5, 8]: + task = MortalityTextTaskMIMIC3(max_notes=max_notes) + all_samples = [] + for pid in dataset.unique_patient_ids: + patient = dataset.get_patient(pid) + all_samples.extend(task(patient)) + + label_counts = Counter(s["label"] for s in all_samples) + avg_notes = sum(len(s["notes"]) for s in all_samples) / max(len(all_samples), 1) + mortality_rate = label_counts[1] / max(len(all_samples), 1) * 100 + + print(f"\nmax_notes={max_notes}") + print(f" Total samples : {len(all_samples)}") + print(f" Avg notes/sample: {avg_notes:.1f}") + print(f" Label dist : {dict(label_counts)}") + print(f" Mortality rate: {mortality_rate:.1f}%") + + +# --------------------------------------------------------------------------- +# 3. Fairness evaluation (recall gap & parity gap) +# +# We compute naive baseline fairness gaps using the label distribution +# itself (no model needed) to show demographic imbalance in the dataset. +# This mirrors Table 4 of Zhang et al. (2020). +# +# In a full pipeline, replace `predicted` with real model predictions. +# --------------------------------------------------------------------------- + +print("\n" + "=" * 60) +print("FAIRNESS GAPS: Demographic subgroup analysis (label distribution)") +print("Paper reference: Zhang et al. 2020, Table 4") +print("=" * 60) + +task = MortalityTextTaskMIMIC3(max_notes=5) +samples = [] +for pid in dataset.unique_patient_ids: + patient = dataset.get_patient(pid) + samples.extend(task(patient)) + +# group samples by demographic +def parity_gap(group_samples): + """Compute parity (positive prediction rate) for a group.""" + if not group_samples: + return 0.0 + return sum(s["label"] for s in group_samples) / len(group_samples) + + +for attr in ["gender", "ethnicity", "insurance"]: + print(f"\n--- {attr.upper()} ---") + groups = defaultdict(list) + for s in samples: + groups[s[attr]].append(s) + + rates = {g: parity_gap(v) for g, v in groups.items() if len(v) >= 5} + if not rates: + print(" (insufficient data)") + continue + + majority = max(rates, key=rates.get) + for group, rate in sorted(rates.items(), key=lambda x: -x[1]): + gap = rates[majority] - rate + marker = " <- majority" if group == majority else f" gap={gap:.3f}" + print(f" {group:<45} rate={rate:.3f} n={len(groups[group])}{marker}") + + +# --------------------------------------------------------------------------- +# 4. Template ablation: different template sets +# +# We test two subsets of templates: +# A) Chronic conditions only (heart disease, diabetes, hypertension) +# B) Social/behavioural conditions (hiv, heroin, dnr) +# +# Finding: Template choice does not affect the label, but affects what +# linguistic context BERT encodes — directly relevant to the paper's +# finding that BERT encodes gender bias differently per medical topic. +# --------------------------------------------------------------------------- + +print("\n" + "=" * 60) +print("ABLATION: Template subset comparison") +print("=" * 60) + +CHRONIC_TEMPLATES = [ + "this is a {age} yo {gender} with a hx of heart disease", + "this is a {age} yo {gender} with a pmh of diabetes", + "this is a {age} yo {gender} with a discharge diagnosis of htn", +] + +SOCIAL_TEMPLATES = [ + "{gender} has a pmh of hiv", + "{gender} pt is dnr", + "this is a {age} yo {gender} with a hx of heroin addiction", +] + +for name, templates in [("Chronic conditions", CHRONIC_TEMPLATES), + ("Social/behavioural", SOCIAL_TEMPLATES)]: + # patch templates temporarily + import pyhealth.tasks.mortality_text_task as _mod + original = _mod.CLINICAL_NOTE_TEMPLATES + _mod.CLINICAL_NOTE_TEMPLATES = templates + + task = MortalityTextTaskMIMIC3(max_notes=3) + s0 = task(dataset.get_patient(list(dataset.unique_patient_ids)[0])) + + _mod.CLINICAL_NOTE_TEMPLATES = original # restore + + print(f"\nTemplate set: {name}") + if s0: + print(f" Example notes: {s0[0]['notes']}") + else: + print(" (no samples generated)") + +print("\nDone. See paper for full model-based fairness evaluation results.") diff --git a/pyhealth/tasks/mortality_text_task.py b/pyhealth/tasks/mortality_text_task.py new file mode 100644 index 000000000..1050742c7 --- /dev/null +++ b/pyhealth/tasks/mortality_text_task.py @@ -0,0 +1,190 @@ +"""In-hospital mortality prediction task using synthetic demographic note templates. + +This module implements a clinical text-based mortality prediction task for +MIMIC-III, reproducing the downstream evaluation setup from: + + Zhang et al. "Hurtful Words: Quantifying Biases in Clinical Contextual + Word Embeddings." ACM CHIL 2020. https://arxiv.org/abs/2003.11515 + +Since clinical notes may be unavailable or restricted, this task generates +synthetic note templates populated with real patient demographics (gender, age) +drawn from the PATIENTS and ADMISSIONS tables. This mirrors the fill-in-the-blank +template methodology described in Section 3.4 of the paper, enabling fairness +evaluation of clinical language models across gender, ethnicity, insurance status, +and language subgroups. +""" + +from typing import Any, Dict, List, Optional + +from pyhealth.tasks import BaseTask + + +# --------------------------------------------------------------------------- +# Note templates adapted from Section 3.4 of Zhang et al. (2020). +# Each template is populated with real patient demographics at runtime. +# --------------------------------------------------------------------------- + +CLINICAL_NOTE_TEMPLATES: List[str] = [ + "this is a {age} yo {gender} with a hx of heart disease", + "this is a {age} yo {gender} with a pmh of diabetes", + "{gender} pt is dnr", + "{gender} has a pmh of hiv", + "this is a {age} yo {gender} with a discharge diagnosis of htn", + "this is a {age} yo {gender} with a hx of schizophrenia", + "this is a {age} yo {gender} with a hx of heroin addiction", + "this is a {age} yo {gender} with a hx of hypertension", +] + + +class MortalityTextTaskMIMIC3(BaseTask): + """In-hospital mortality prediction from synthetic clinical note templates. + + Reproduces the in-hospital mortality clinical prediction task from: + Zhang et al. "Hurtful Words: Quantifying Biases in Clinical + Contextual Word Embeddings." ACM CHIL 2020. + https://arxiv.org/abs/2003.11515 + + For each patient admission, this task generates synthetic clinical note + templates (see CLINICAL_NOTE_TEMPLATES) populated with real patient + demographics extracted from the MIMIC-III PATIENTS and ADMISSIONS tables. + The binary mortality label is derived from the hospital_expire_flag field. + + Demographic fields (gender, ethnicity, insurance, language) are preserved + in each sample to support downstream fairness evaluation — specifically the + recall gap, parity gap, and specificity gap metrics described in the paper. + + This task is designed for use with MIMIC3Dataset loaded with at minimum + the PATIENTS and ADMISSIONS tables. + + Args: + max_notes (int): Maximum number of synthetic note templates to include + per sample. Must be >= 1. Defaults to 5. + + Attributes: + task_name (str): Unique identifier for this task. + input_schema (Dict[str, str]): Maps feature name to processor type. + output_schema (Dict[str, str]): Maps label name to processor type. + + Example: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import MortalityTextTaskMIMIC3 + >>> dataset = MIMIC3Dataset( + ... root="/path/to/mimic3", + ... tables=["PATIENTS", "ADMISSIONS"], + ... ) + >>> task = MortalityTextTaskMIMIC3(max_notes=5) + >>> task_dataset = dataset.set_task(task) + >>> print(task_dataset[0]) + { + 'visit_id': '142345', + 'patient_id': '10006', + 'notes': ['this is a 65 yo female with a hx of heart disease', ...], + 'label': 0, + 'gender': 'female', + 'ethnicity': 'WHITE', + 'insurance': 'Medicare', + 'language': 'ENGL', + } + """ + + task_name: str = "mortality_text" + input_schema: Dict[str, str] = {"notes": "sequence"} + output_schema: Dict[str, str] = {"label": "binary"} + + def __init__(self, max_notes: int = 5) -> None: + """Initialise the task. + + Args: + max_notes (int): Maximum number of synthetic note templates per + sample. Defaults to 5. + """ + self.max_notes = max_notes + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process a single patient into mortality prediction samples. + + Extracts gender from the PATIENTS partition and iterates over each + admission row to generate one sample per hospital stay. Synthetic + clinical notes are generated from CLINICAL_NOTE_TEMPLATES using the + patient's gender and computed age. + + Args: + patient: A PyHealth Patient object whose data_source attribute is + a Polars DataFrame with event rows for the 'patients' and + 'admissions' event types, and the following columns: + - event_type (str) + - timestamp (datetime | None) + - patients/gender (str | None): 'F' or 'M' + - patients/dob (datetime | None) + - admissions/hadm_id (str | None) + - admissions/hospital_expire_flag (int | None) + - admissions/ethnicity (str | None) + - admissions/insurance (str | None) + - admissions/language (str | None) + + Returns: + List[Dict[str, Any]]: One sample dict per admission. Each dict + contains the following keys: + - visit_id (str): Hospital admission ID (hadm_id). + - patient_id (str): Patient identifier. + - notes (List[str]): Synthetic clinical note strings. + - label (int): 1 if patient died, 0 if survived. + - gender (str): 'male' or 'female'. + - ethnicity (str): Ethnicity string, or 'unknown'. + - insurance (str): Insurance type string, or 'unknown'. + - language (str): Language string, or 'unknown'. + + Returns an empty list if the patient has no 'patients' or + 'admissions' event rows. + """ + samples: List[Dict[str, Any]] = [] + df = patient.data_source + + # -- gender from patients partition ----------------------------------- + patients_df = df.filter(df["event_type"] == "patients") + if patients_df.is_empty(): + return samples + + gender_raw: Optional[str] = patients_df["patients/gender"][0] + gender: str = "female" if gender_raw == "F" else "male" + + # -- one sample per admission row ------------------------------------- + admissions_df = df.filter(df["event_type"] == "admissions") + if admissions_df.is_empty(): + return samples + + for row in admissions_df.iter_rows(named=True): + ethnicity: str = row.get("admissions/ethnicity") or "unknown" + insurance: str = row.get("admissions/insurance") or "unknown" + language: str = row.get("admissions/language") or "unknown" + + expire_flag = row.get("admissions/hospital_expire_flag", 0) + label: int = int(expire_flag == 1) if expire_flag is not None else 0 + + # compute age from date of birth; fall back to 65 if unavailable + dob = patients_df["patients/dob"][0] + admit_time = row.get("timestamp") + try: + age: int = int((admit_time - dob).days / 365) + except Exception: + age = 65 + + fake_notes: List[str] = [ + t.format(gender=gender, age=age) + for t in CLINICAL_NOTE_TEMPLATES + ][: self.max_notes] + + samples.append( + { + "visit_id": str(row.get("admissions/hadm_id", "")), + "patient_id": patient.patient_id, + "notes": fake_notes, + "label": label, + "gender": gender, + "ethnicity": ethnicity, + "insurance": insurance, + "language": language, + } + ) + + return samples diff --git a/tests/test_mortality_text_task.py b/tests/test_mortality_text_task.py new file mode 100644 index 000000000..f5edace8b --- /dev/null +++ b/tests/test_mortality_text_task.py @@ -0,0 +1,473 @@ +""" +Tests for MortalityTextTask. + +Tests use fully synthetic data — no real MIMIC data required. +All tests complete in milliseconds. + +Reference paper: + Zhang et al. "Hurtful Words: Quantifying Biases in Clinical + Contextual Word Embeddings." ACM CHIL 2020. + https://arxiv.org/abs/2003.11515 +""" + +import unittest +from datetime import datetime +from unittest.mock import MagicMock + +import polars as pl +from pyhealth.tasks import BaseTask + + +# --------------------------------------------------------------------------- +# Task definition (copy of pyhealth/tasks/mortality_text_task.py) +# --------------------------------------------------------------------------- + +TEMPLATES = [ + "this is a {age} yo {gender} with a hx of heart disease", + "this is a {age} yo {gender} with a pmh of diabetes", + "{gender} pt is dnr", + "{gender} has a pmh of hiv", + "this is a {age} yo {gender} with a discharge diagnosis of htn", + "this is a {age} yo {gender} with a hx of schizophrenia", + "this is a {age} yo {gender} with a hx of heroin addiction", + "this is a {age} yo {gender} with a hx of hypertension", +] + + +class MortalityTextTask(BaseTask): + """In-hospital mortality prediction using synthetic demographic note templates. + + Reproduces the clinical prediction task from: + Zhang et al. "Hurtful Words: Quantifying Biases in Clinical + Contextual Word Embeddings." ACM CHIL 2020. + + For each patient admission, generates synthetic clinical note templates + populated with real patient demographics (gender, age) and assigns a + binary mortality label from the ADMISSIONS table. + + Args: + max_notes (int): Maximum number of note templates per sample. + Defaults to 5. + + Input schema: + notes (sequence): List of synthetic clinical note strings. + + Output schema: + label (binary): 1 if patient died during admission, 0 otherwise. + + Example: + >>> task = MortalityTextTask(max_notes=5) + >>> dataset = MIMIC3Dataset(root="...", tables=["PATIENTS", "ADMISSIONS"]) + >>> task_dataset = dataset.set_task(task) + """ + + task_name: str = "mortality_text" + input_schema: dict = {"notes": "sequence"} + output_schema: dict = {"label": "binary"} + + def __init__(self, max_notes: int = 5) -> None: + self.max_notes = max_notes + + def __call__(self, patient: object) -> list: + """Process a single patient into mortality prediction samples. + + Args: + patient: PyHealth Patient object with a data_source Polars + DataFrame containing event rows for 'patients' and + 'admissions' event types. + + Returns: + List of sample dicts, one per admission, each containing: + - visit_id (str): Hospital admission ID. + - patient_id (str): Patient identifier. + - notes (list[str]): Synthetic clinical note templates. + - label (int): 1 = died, 0 = survived. + - gender (str): 'male' or 'female'. + - ethnicity (str): Patient ethnicity string. + - insurance (str): Insurance type string. + - language (str): Language string. + """ + samples = [] + df = patient.data_source + + # -- gender from patients partition -- + patients_df = df.filter(df["event_type"] == "patients") + if patients_df.is_empty(): + return samples + + gender_raw = patients_df["patients/gender"][0] + gender = "female" if gender_raw == "F" else "male" + + # -- one sample per admission -- + admissions_df = df.filter(df["event_type"] == "admissions") + if admissions_df.is_empty(): + return samples + + for row in admissions_df.iter_rows(named=True): + ethnicity: str = row.get("admissions/ethnicity") or "unknown" + insurance: str = row.get("admissions/insurance") or "unknown" + language: str = row.get("admissions/language") or "unknown" + + expire_flag = row.get("admissions/hospital_expire_flag", 0) + label: int = int(expire_flag == 1) if expire_flag is not None else 0 + + # compute age; fall back to 65 if timestamps are unavailable + dob = patients_df["patients/dob"][0] + admit_time = row.get("timestamp") + try: + age: int = int((admit_time - dob).days / 365) + except Exception: + age = 65 + + fake_notes = [ + t.format(gender=gender, age=age) for t in TEMPLATES + ][: self.max_notes] + + samples.append( + { + "visit_id": str(row.get("admissions/hadm_id", "")), + "patient_id": patient.patient_id, + "notes": fake_notes, + "label": label, + "gender": gender, + "ethnicity": ethnicity, + "insurance": insurance, + "language": language, + } + ) + + return samples + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def make_synthetic_patient( + patient_id: str = "test_001", + gender: str = "F", + dob: datetime = datetime(1960, 1, 1), + admit_time: datetime = datetime(2020, 1, 1), + expire_flag: int = 0, + ethnicity: str = "WHITE", + insurance: str = "Medicare", + language: str = "ENGL", + n_admissions: int = 1, +) -> MagicMock: + """Build a minimal synthetic patient object backed by a Polars DataFrame. + + Args: + patient_id: Unique patient identifier string. + gender: Raw gender code, 'F' or 'M'. + dob: Date of birth as datetime. + admit_time: Admission timestamp as datetime. + expire_flag: 1 if patient died, 0 otherwise. + ethnicity: Ethnicity label string. + insurance: Insurance type string. + language: Language code string. + n_admissions: Number of admission rows to generate. + + Returns: + MagicMock patient object with patient_id and data_source attributes. + """ + # Use a consistent Datetime dtype across both rows to avoid Polars + # SchemaError when concatenating Datetime('us') with Null columns. + dt_type = pl.Datetime("us") + + patient_rows = pl.DataFrame({ + "patient_id": [patient_id], + "event_type": ["patients"], + "timestamp": pl.Series([None], dtype=dt_type), + "patients/gender": [gender], + "patients/dob": pl.Series([dob], dtype=dt_type), + "admissions/hadm_id": pl.Series([None], dtype=pl.Utf8), + "admissions/hospital_expire_flag": pl.Series([None], dtype=pl.Int64), + "admissions/ethnicity": pl.Series([None], dtype=pl.Utf8), + "admissions/insurance": pl.Series([None], dtype=pl.Utf8), + "admissions/language": pl.Series([None], dtype=pl.Utf8), + }) + + admission_rows = pl.DataFrame({ + "patient_id": [patient_id] * n_admissions, + "event_type": ["admissions"] * n_admissions, + "timestamp": pl.Series([admit_time] * n_admissions, dtype=dt_type), + "patients/gender": pl.Series([None] * n_admissions, dtype=pl.Utf8), + "patients/dob": pl.Series([None] * n_admissions, dtype=dt_type), + "admissions/hadm_id": [f"10000{i}" for i in range(n_admissions)], + "admissions/hospital_expire_flag": pl.Series( + [expire_flag] * n_admissions, dtype=pl.Int64 + ), + "admissions/ethnicity": pl.Series([ethnicity] * n_admissions, dtype=pl.Utf8), + "admissions/insurance": pl.Series([insurance] * n_admissions, dtype=pl.Utf8), + "admissions/language": pl.Series([language] * n_admissions, dtype=pl.Utf8), + }) + + df = pl.concat([patient_rows, admission_rows]) + patient = MagicMock() + patient.patient_id = patient_id + patient.data_source = df + return patient + + +# --------------------------------------------------------------------------- +# Test cases +# --------------------------------------------------------------------------- + +class TestMortalityTextTaskSampleProcessing(unittest.TestCase): + """Tests for basic sample processing.""" + + def setUp(self) -> None: + self.task = MortalityTextTask(max_notes=5) + + def test_returns_list(self) -> None: + """__call__ should always return a list.""" + patient = make_synthetic_patient() + result = self.task(patient) + self.assertIsInstance(result, list) + + def test_single_admission_one_sample(self) -> None: + """One admission row should produce exactly one sample.""" + patient = make_synthetic_patient(n_admissions=1) + result = self.task(patient) + self.assertEqual(len(result), 1) + + def test_multiple_admissions_multiple_samples(self) -> None: + """Two admission rows should produce two samples.""" + patient = make_synthetic_patient(n_admissions=2) + result = self.task(patient) + self.assertEqual(len(result), 2) + + def test_sample_has_required_keys(self) -> None: + """Every sample must contain all required keys.""" + patient = make_synthetic_patient() + result = self.task(patient) + required = { + "visit_id", "patient_id", "notes", + "label", "gender", "ethnicity", "insurance", "language", + } + self.assertTrue(required.issubset(result[0].keys())) + + def test_patient_id_preserved(self) -> None: + """Sample patient_id must match the input patient.""" + patient = make_synthetic_patient(patient_id="p_42") + result = self.task(patient) + self.assertEqual(result[0]["patient_id"], "p_42") + + +class TestMortalityTextTaskLabelGeneration(unittest.TestCase): + """Tests for mortality label generation.""" + + def setUp(self) -> None: + self.task = MortalityTextTask(max_notes=5) + + def test_label_zero_for_survivor(self) -> None: + """hospital_expire_flag=0 should produce label 0.""" + patient = make_synthetic_patient(expire_flag=0) + result = self.task(patient) + self.assertEqual(result[0]["label"], 0) + + def test_label_one_for_death(self) -> None: + """hospital_expire_flag=1 should produce label 1.""" + patient = make_synthetic_patient(expire_flag=1) + result = self.task(patient) + self.assertEqual(result[0]["label"], 1) + + def test_label_is_integer(self) -> None: + """Label must be a plain Python int.""" + patient = make_synthetic_patient(expire_flag=1) + result = self.task(patient) + self.assertIsInstance(result[0]["label"], int) + + def test_none_expire_flag_defaults_to_zero(self) -> None: + """A None expire flag should safely default to label 0.""" + patient = make_synthetic_patient(expire_flag=0) + # manually null out the flag in the dataframe + df = patient.data_source + df = df.with_columns( + pl.when(pl.col("event_type") == "admissions") + .then(None) + .otherwise(pl.col("admissions/hospital_expire_flag")) + .alias("admissions/hospital_expire_flag") + ) + patient.data_source = df + result = self.task(patient) + self.assertEqual(result[0]["label"], 0) + + +class TestMortalityTextTaskFeatureExtraction(unittest.TestCase): + """Tests for synthetic note feature extraction.""" + + def setUp(self) -> None: + self.task = MortalityTextTask(max_notes=5) + + def test_notes_is_list(self) -> None: + """Notes field should be a list.""" + patient = make_synthetic_patient() + result = self.task(patient) + self.assertIsInstance(result[0]["notes"], list) + + def test_notes_are_strings(self) -> None: + """Every note should be a plain string.""" + patient = make_synthetic_patient() + result = self.task(patient) + for note in result[0]["notes"]: + self.assertIsInstance(note, str) + + def test_max_notes_respected(self) -> None: + """Notes list length must not exceed max_notes.""" + task = MortalityTextTask(max_notes=3) + patient = make_synthetic_patient() + result = task(patient) + self.assertLessEqual(len(result[0]["notes"]), 3) + + def test_max_notes_one(self) -> None: + """max_notes=1 should return exactly one note.""" + task = MortalityTextTask(max_notes=1) + patient = make_synthetic_patient() + result = task(patient) + self.assertEqual(len(result[0]["notes"]), 1) + + def test_notes_contain_gender_female(self) -> None: + """Notes should reference 'female' for F-coded patients.""" + patient = make_synthetic_patient(gender="F") + result = self.task(patient) + combined = " ".join(result[0]["notes"]) + self.assertIn("female", combined) + + def test_notes_contain_gender_male(self) -> None: + """Notes should reference 'male' for M-coded patients.""" + patient = make_synthetic_patient(gender="M") + result = self.task(patient) + combined = " ".join(result[0]["notes"]) + self.assertIn("male", combined) + + def test_notes_contain_age(self) -> None: + """Notes should embed the computed patient age.""" + dob = datetime(1960, 1, 1) + admit = datetime(2020, 1, 1) + expected_age = str(int((admit - dob).days / 365)) + patient = make_synthetic_patient(dob=dob, admit_time=admit) + result = self.task(patient) + combined = " ".join(result[0]["notes"]) + self.assertIn(expected_age, combined) + + +class TestMortalityTextTaskDemographics(unittest.TestCase): + """Tests that demographic fields are preserved in samples.""" + + def setUp(self) -> None: + self.task = MortalityTextTask(max_notes=5) + + def test_gender_female_mapping(self) -> None: + """'F' should map to 'female' in the sample.""" + patient = make_synthetic_patient(gender="F") + result = self.task(patient) + self.assertEqual(result[0]["gender"], "female") + + def test_gender_male_mapping(self) -> None: + """'M' should map to 'male' in the sample.""" + patient = make_synthetic_patient(gender="M") + result = self.task(patient) + self.assertEqual(result[0]["gender"], "male") + + def test_ethnicity_preserved(self) -> None: + """Ethnicity string should pass through unchanged.""" + patient = make_synthetic_patient(ethnicity="BLACK/AFRICAN AMERICAN") + result = self.task(patient) + self.assertEqual(result[0]["ethnicity"], "BLACK/AFRICAN AMERICAN") + + def test_insurance_preserved(self) -> None: + """Insurance type should pass through unchanged.""" + patient = make_synthetic_patient(insurance="Medicaid") + result = self.task(patient) + self.assertEqual(result[0]["insurance"], "Medicaid") + + def test_language_preserved(self) -> None: + """Language code should pass through unchanged.""" + patient = make_synthetic_patient(language="SPAN") + result = self.task(patient) + self.assertEqual(result[0]["language"], "SPAN") + + +class TestMortalityTextTaskEdgeCases(unittest.TestCase): + """Tests for edge cases and error handling.""" + + def setUp(self) -> None: + self.task = MortalityTextTask(max_notes=5) + + def test_no_patients_row_returns_empty(self) -> None: + """If no 'patients' event rows exist, return empty list.""" + dt = pl.Datetime("us") + df = pl.DataFrame({ + "patient_id": pl.Series(["p_edge"], dtype=pl.Utf8), + "event_type": pl.Series(["admissions"], dtype=pl.Utf8), + "timestamp": pl.Series([datetime(2020, 1, 1)], dtype=dt), + "patients/gender": pl.Series([None], dtype=pl.Utf8), + "patients/dob": pl.Series([None], dtype=dt), + "admissions/hadm_id": pl.Series(["999"], dtype=pl.Utf8), + "admissions/hospital_expire_flag": pl.Series([0], dtype=pl.Int64), + "admissions/ethnicity": pl.Series(["WHITE"], dtype=pl.Utf8), + "admissions/insurance": pl.Series(["Medicare"], dtype=pl.Utf8), + "admissions/language": pl.Series(["ENGL"], dtype=pl.Utf8), + }) + patient = MagicMock() + patient.patient_id = "p_edge" + patient.data_source = df + result = self.task(patient) + self.assertEqual(result, []) + + def test_no_admissions_row_returns_empty(self) -> None: + """If no 'admissions' event rows exist, return empty list.""" + dt = pl.Datetime("us") + df = pl.DataFrame({ + "patient_id": pl.Series(["p_edge2"], dtype=pl.Utf8), + "event_type": pl.Series(["patients"], dtype=pl.Utf8), + "timestamp": pl.Series([None], dtype=dt), + "patients/gender": pl.Series(["F"], dtype=pl.Utf8), + "patients/dob": pl.Series([datetime(1960, 1, 1)], dtype=dt), + "admissions/hadm_id": pl.Series([None], dtype=pl.Utf8), + "admissions/hospital_expire_flag": pl.Series([None], dtype=pl.Int64), + "admissions/ethnicity": pl.Series([None], dtype=pl.Utf8), + "admissions/insurance": pl.Series([None], dtype=pl.Utf8), + "admissions/language": pl.Series([None], dtype=pl.Utf8), + }) + patient = MagicMock() + patient.patient_id = "p_edge2" + patient.data_source = df + result = self.task(patient) + self.assertEqual(result, []) + + def test_missing_ethnicity_defaults_to_unknown(self) -> None: + """None ethnicity should default to 'unknown'.""" + patient = make_synthetic_patient(ethnicity=None) + result = self.task(patient) + self.assertEqual(result[0]["ethnicity"], "unknown") + + def test_missing_insurance_defaults_to_unknown(self) -> None: + """None insurance should default to 'unknown'.""" + patient = make_synthetic_patient(insurance=None) + result = self.task(patient) + self.assertEqual(result[0]["insurance"], "unknown") + + def test_bad_dob_falls_back_to_age_65(self) -> None: + """If age cannot be computed, notes should use fallback age 65.""" + patient = make_synthetic_patient() + # null out dob so age computation fails + df = patient.data_source.with_columns( + pl.when(pl.col("event_type") == "patients") + .then(None) + .otherwise(pl.col("patients/dob")) + .alias("patients/dob") + ) + patient.data_source = df + result = self.task(patient) + combined = " ".join(result[0]["notes"]) + self.assertIn("65", combined) + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + unittest.main(verbosity=2) \ No newline at end of file