Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,4 @@ Available Tasks
Mutation Pathogenicity (COSMIC) <tasks/pyhealth.tasks.MutationPathogenicityPrediction>
Cancer Survival Prediction (TCGA) <tasks/pyhealth.tasks.CancerSurvivalPrediction>
Cancer Mutation Burden (TCGA) <tasks/pyhealth.tasks.CancerMutationBurden>
Mortality Predict </tasks/pyhealth.tasks.mortality_text_task>
7 changes: 7 additions & 0 deletions docs/api/tasks/pyhealth.tasks.mortality_text_task.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.tasks.mortality\_text\_task
=====================================

.. automodule:: pyhealth.tasks.mortality_text_task
:members:
:undoc-members:
:show-inheritance:
174 changes: 174 additions & 0 deletions examples/mimic3_mortality_text_clinicalbert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""
Example and Ablation Study: MIMIC-III Mortality Text Task
==========================================================

Reproduces part of the evaluation pipeline from:

Zhang et al. "Hurtful Words: Quantifying Biases in Clinical Contextual
Word Embeddings." ACM CHIL 2020. https://arxiv.org/abs/2003.11515

This script demonstrates:
1. How to use MortalityTextTaskMIMIC3 with MIMIC3Dataset.
2. An ablation study showing how max_notes affects sample generation.
3. Fairness gap evaluation (recall gap, parity gap) across demographic
subgroups (gender, ethnicity, insurance) as described in the paper.

Requirements:
- pyhealth
- MIMIC-III data with PATIENTS and ADMISSIONS tables
(or use the synthetic subset below for a quick demo)

Usage:
python examples/mimic3_mortality_text_clinicalbert.py
"""

from collections import Counter, defaultdict

from pyhealth.datasets import MIMIC3Dataset
from pyhealth.tasks.mortality_text_task import MortalityTextTaskMIMIC3

# ---------------------------------------------------------------------------
# 1. Load dataset
# Replace root with your local MIMIC-III path or use the synthetic subset.
# ---------------------------------------------------------------------------

MIMIC3_ROOT = "https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III_subset/"

print("Loading MIMIC-III dataset...")
dataset = MIMIC3Dataset(
root=MIMIC3_ROOT,
tables=["PATIENTS", "ADMISSIONS"],
dev=False,
)
print(f"Loaded {len(dataset.unique_patient_ids)} patients.\n")


# ---------------------------------------------------------------------------
# 2. Ablation: effect of max_notes on sample generation
#
# We test max_notes in [1, 3, 5, 8] and report:
# - Total samples generated
# - Average notes per sample
# - Label distribution (mortality rate)
#
# Finding: max_notes does not affect label distribution, only the amount
# of text available to the model. More notes give the model more context
# but also increase tokenization cost for BERT-based models.
# ---------------------------------------------------------------------------

print("=" * 60)
print("ABLATION: Effect of max_notes on sample generation")
print("=" * 60)

for max_notes in [1, 3, 5, 8]:
task = MortalityTextTaskMIMIC3(max_notes=max_notes)
all_samples = []
for pid in dataset.unique_patient_ids:
patient = dataset.get_patient(pid)
all_samples.extend(task(patient))

label_counts = Counter(s["label"] for s in all_samples)
avg_notes = sum(len(s["notes"]) for s in all_samples) / max(len(all_samples), 1)
mortality_rate = label_counts[1] / max(len(all_samples), 1) * 100

print(f"\nmax_notes={max_notes}")
print(f" Total samples : {len(all_samples)}")
print(f" Avg notes/sample: {avg_notes:.1f}")
print(f" Label dist : {dict(label_counts)}")
print(f" Mortality rate: {mortality_rate:.1f}%")


# ---------------------------------------------------------------------------
# 3. Fairness evaluation (recall gap & parity gap)
#
# We compute naive baseline fairness gaps using the label distribution
# itself (no model needed) to show demographic imbalance in the dataset.
# This mirrors Table 4 of Zhang et al. (2020).
#
# In a full pipeline, replace `predicted` with real model predictions.
# ---------------------------------------------------------------------------

print("\n" + "=" * 60)
print("FAIRNESS GAPS: Demographic subgroup analysis (label distribution)")
print("Paper reference: Zhang et al. 2020, Table 4")
print("=" * 60)

task = MortalityTextTaskMIMIC3(max_notes=5)
samples = []
for pid in dataset.unique_patient_ids:
patient = dataset.get_patient(pid)
samples.extend(task(patient))

# group samples by demographic
def parity_gap(group_samples):
"""Compute parity (positive prediction rate) for a group."""
if not group_samples:
return 0.0
return sum(s["label"] for s in group_samples) / len(group_samples)


for attr in ["gender", "ethnicity", "insurance"]:
print(f"\n--- {attr.upper()} ---")
groups = defaultdict(list)
for s in samples:
groups[s[attr]].append(s)

rates = {g: parity_gap(v) for g, v in groups.items() if len(v) >= 5}
if not rates:
print(" (insufficient data)")
continue

majority = max(rates, key=rates.get)
for group, rate in sorted(rates.items(), key=lambda x: -x[1]):
gap = rates[majority] - rate
marker = " <- majority" if group == majority else f" gap={gap:.3f}"
print(f" {group:<45} rate={rate:.3f} n={len(groups[group])}{marker}")


# ---------------------------------------------------------------------------
# 4. Template ablation: different template sets
#
# We test two subsets of templates:
# A) Chronic conditions only (heart disease, diabetes, hypertension)
# B) Social/behavioural conditions (hiv, heroin, dnr)
#
# Finding: Template choice does not affect the label, but affects what
# linguistic context BERT encodes — directly relevant to the paper's
# finding that BERT encodes gender bias differently per medical topic.
# ---------------------------------------------------------------------------

print("\n" + "=" * 60)
print("ABLATION: Template subset comparison")
print("=" * 60)

CHRONIC_TEMPLATES = [
"this is a {age} yo {gender} with a hx of heart disease",
"this is a {age} yo {gender} with a pmh of diabetes",
"this is a {age} yo {gender} with a discharge diagnosis of htn",
]

SOCIAL_TEMPLATES = [
"{gender} has a pmh of hiv",
"{gender} pt is dnr",
"this is a {age} yo {gender} with a hx of heroin addiction",
]

for name, templates in [("Chronic conditions", CHRONIC_TEMPLATES),
("Social/behavioural", SOCIAL_TEMPLATES)]:
# patch templates temporarily
import pyhealth.tasks.mortality_text_task as _mod
original = _mod.CLINICAL_NOTE_TEMPLATES
_mod.CLINICAL_NOTE_TEMPLATES = templates

task = MortalityTextTaskMIMIC3(max_notes=3)
s0 = task(dataset.get_patient(list(dataset.unique_patient_ids)[0]))

_mod.CLINICAL_NOTE_TEMPLATES = original # restore

print(f"\nTemplate set: {name}")
if s0:
print(f" Example notes: {s0[0]['notes']}")
else:
print(" (no samples generated)")

print("\nDone. See paper for full model-based fairness evaluation results.")
190 changes: 190 additions & 0 deletions pyhealth/tasks/mortality_text_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
"""In-hospital mortality prediction task using synthetic demographic note templates.

This module implements a clinical text-based mortality prediction task for
MIMIC-III, reproducing the downstream evaluation setup from:

Zhang et al. "Hurtful Words: Quantifying Biases in Clinical Contextual
Word Embeddings." ACM CHIL 2020. https://arxiv.org/abs/2003.11515

Since clinical notes may be unavailable or restricted, this task generates
synthetic note templates populated with real patient demographics (gender, age)
drawn from the PATIENTS and ADMISSIONS tables. This mirrors the fill-in-the-blank
template methodology described in Section 3.4 of the paper, enabling fairness
evaluation of clinical language models across gender, ethnicity, insurance status,
and language subgroups.
"""

from typing import Any, Dict, List, Optional

from pyhealth.tasks import BaseTask


# ---------------------------------------------------------------------------
# Note templates adapted from Section 3.4 of Zhang et al. (2020).
# Each template is populated with real patient demographics at runtime.
# ---------------------------------------------------------------------------

CLINICAL_NOTE_TEMPLATES: List[str] = [
"this is a {age} yo {gender} with a hx of heart disease",
"this is a {age} yo {gender} with a pmh of diabetes",
"{gender} pt is dnr",
"{gender} has a pmh of hiv",
"this is a {age} yo {gender} with a discharge diagnosis of htn",
"this is a {age} yo {gender} with a hx of schizophrenia",
"this is a {age} yo {gender} with a hx of heroin addiction",
"this is a {age} yo {gender} with a hx of hypertension",
]


class MortalityTextTaskMIMIC3(BaseTask):
"""In-hospital mortality prediction from synthetic clinical note templates.

Reproduces the in-hospital mortality clinical prediction task from:
Zhang et al. "Hurtful Words: Quantifying Biases in Clinical
Contextual Word Embeddings." ACM CHIL 2020.
https://arxiv.org/abs/2003.11515

For each patient admission, this task generates synthetic clinical note
templates (see CLINICAL_NOTE_TEMPLATES) populated with real patient
demographics extracted from the MIMIC-III PATIENTS and ADMISSIONS tables.
The binary mortality label is derived from the hospital_expire_flag field.

Demographic fields (gender, ethnicity, insurance, language) are preserved
in each sample to support downstream fairness evaluation — specifically the
recall gap, parity gap, and specificity gap metrics described in the paper.

This task is designed for use with MIMIC3Dataset loaded with at minimum
the PATIENTS and ADMISSIONS tables.

Args:
max_notes (int): Maximum number of synthetic note templates to include
per sample. Must be >= 1. Defaults to 5.

Attributes:
task_name (str): Unique identifier for this task.
input_schema (Dict[str, str]): Maps feature name to processor type.
output_schema (Dict[str, str]): Maps label name to processor type.

Example:
>>> from pyhealth.datasets import MIMIC3Dataset
>>> from pyhealth.tasks import MortalityTextTaskMIMIC3
>>> dataset = MIMIC3Dataset(
... root="/path/to/mimic3",
... tables=["PATIENTS", "ADMISSIONS"],
... )
>>> task = MortalityTextTaskMIMIC3(max_notes=5)
>>> task_dataset = dataset.set_task(task)
>>> print(task_dataset[0])
{
'visit_id': '142345',
'patient_id': '10006',
'notes': ['this is a 65 yo female with a hx of heart disease', ...],
'label': 0,
'gender': 'female',
'ethnicity': 'WHITE',
'insurance': 'Medicare',
'language': 'ENGL',
}
"""

task_name: str = "mortality_text"
input_schema: Dict[str, str] = {"notes": "sequence"}
output_schema: Dict[str, str] = {"label": "binary"}

def __init__(self, max_notes: int = 5) -> None:
"""Initialise the task.

Args:
max_notes (int): Maximum number of synthetic note templates per
sample. Defaults to 5.
"""
self.max_notes = max_notes

def __call__(self, patient: Any) -> List[Dict[str, Any]]:
"""Process a single patient into mortality prediction samples.

Extracts gender from the PATIENTS partition and iterates over each
admission row to generate one sample per hospital stay. Synthetic
clinical notes are generated from CLINICAL_NOTE_TEMPLATES using the
patient's gender and computed age.

Args:
patient: A PyHealth Patient object whose data_source attribute is
a Polars DataFrame with event rows for the 'patients' and
'admissions' event types, and the following columns:
- event_type (str)
- timestamp (datetime | None)
- patients/gender (str | None): 'F' or 'M'
- patients/dob (datetime | None)
- admissions/hadm_id (str | None)
- admissions/hospital_expire_flag (int | None)
- admissions/ethnicity (str | None)
- admissions/insurance (str | None)
- admissions/language (str | None)

Returns:
List[Dict[str, Any]]: One sample dict per admission. Each dict
contains the following keys:
- visit_id (str): Hospital admission ID (hadm_id).
- patient_id (str): Patient identifier.
- notes (List[str]): Synthetic clinical note strings.
- label (int): 1 if patient died, 0 if survived.
- gender (str): 'male' or 'female'.
- ethnicity (str): Ethnicity string, or 'unknown'.
- insurance (str): Insurance type string, or 'unknown'.
- language (str): Language string, or 'unknown'.

Returns an empty list if the patient has no 'patients' or
'admissions' event rows.
"""
samples: List[Dict[str, Any]] = []
df = patient.data_source

# -- gender from patients partition -----------------------------------
patients_df = df.filter(df["event_type"] == "patients")
if patients_df.is_empty():
return samples

gender_raw: Optional[str] = patients_df["patients/gender"][0]
gender: str = "female" if gender_raw == "F" else "male"

# -- one sample per admission row -------------------------------------
admissions_df = df.filter(df["event_type"] == "admissions")
if admissions_df.is_empty():
return samples

for row in admissions_df.iter_rows(named=True):
ethnicity: str = row.get("admissions/ethnicity") or "unknown"
insurance: str = row.get("admissions/insurance") or "unknown"
language: str = row.get("admissions/language") or "unknown"

expire_flag = row.get("admissions/hospital_expire_flag", 0)
label: int = int(expire_flag == 1) if expire_flag is not None else 0

# compute age from date of birth; fall back to 65 if unavailable
dob = patients_df["patients/dob"][0]
admit_time = row.get("timestamp")
try:
age: int = int((admit_time - dob).days / 365)
except Exception:
age = 65

fake_notes: List[str] = [
t.format(gender=gender, age=age)
for t in CLINICAL_NOTE_TEMPLATES
][: self.max_notes]

samples.append(
{
"visit_id": str(row.get("admissions/hadm_id", "")),
"patient_id": patient.patient_id,
"notes": fake_notes,
"label": label,
"gender": gender,
"ethnicity": ethnicity,
"insurance": insurance,
"language": language,
}
)

return samples
Loading