Skip to content

Add DynamicSurvivalModel and DecompensationDSA task#1126

Open
hards2 wants to merge 1 commit intosunlabuiuc:masterfrom
hards2:dsa-dynamic-survival-model
Open

Add DynamicSurvivalModel and DecompensationDSA task#1126
hards2 wants to merge 1 commit intosunlabuiuc:masterfrom
hards2:dsa-dynamic-survival-model

Conversation

@hards2
Copy link
Copy Markdown

@hards2 hards2 commented Apr 23, 2026

Contributor

  • Name: Hard Shah
  • NetID: hards2
  • Email: hards2@illinois.edu (project registration used the UIUC email per the course spreadsheet)
  • Course: CS 598 DL4H — University of Illinois Urbana-Champaign, Spring 2026

Type of contribution

This PR adds:

  • A new model (DynamicSurvivalModel) — Option 2 of the contribution rubric.
  • A new task (DecompensationDSA) — Option 3 of the contribution rubric.
  • A full-pipeline example showing the two working together end-to-end on synthetic data.

Both the model and task are novel within PyHealth (no existing DynamicSurvivalModel or DecompensationDSA symbols before this PR).

Paper being reproduced

Yèche, H.; Burger, M.; Veshchezerova, D.; Rätsch, G. (2024). Dynamic Survival Analysis for Early Event Prediction. In Proceedings of the Conference on Health, Inference, and Learning (CHIL), PMLR 248:540–557.

Paper link: https://proceedings.mlr.press/v248/yeche24a.html

Authors' reference implementation: https://github.com/ratschlab/dsa-for-eep

Video walkthrough (4–7 min): https://drive.google.com/drive/folders/1mF3sfgjo-lG299ORsvddwiaEhPVikS0s?usp=share_link

High-level description

Early Event Prediction (EEP) in the ICU asks whether a clinical event — circulatory failure, decompensation, mortality — will occur within a fixed horizon h. The dominant approach trains a neural network to estimate the cumulative failure function F_θ(h | X_t). While this works for a single scalar risk score, it collapses the entire horizon into one number: a patient whose event is one hour away and another whose event is eleven hours away receive the same score if their 24-hour probabilities happen to align. As a result, simple threshold-based alarm policies are suboptimal.

Yèche et al. (2024) propose Dynamic Survival Analysis (DSA): instead of estimating the scalar F, estimate the full hazard function λ_θ(k | X_t) for every horizon k ∈ {1, …, h}. From the hazard one recovers both the original EEP risk score and a full probability mass function over time-to-event, which supports the imminent-prioritisation alarm policy that up-weights near-term predictions and improves Alarm/Event AuPRC by up to 11% over standard EEP baselines.

This PR implements the DSA pipeline as:

  1. DynamicSurvivalModel (pyhealth/models/dynamic_survival_model.py) — linear embedding with L1 regularisation → temporal backbone (GRU, LSTM, or causal Transformer) → hazard head with bias-initialisation from empirical mean hazard rates. Inherits from pyhealth.models.BaseModel; returns the PyHealth-Trainer-compatible dict {loss, y_prob, y_true, logit} from forward(**batch).
  2. DecompensationDSA (pyhealth/tasks/decompensation_dsa.py) — subclass of pyhealth.tasks.BaseTask with input_schema = {"timeseries": "tensor"} and output_schema = {"label": "binary"}. Includes a make_synthetic_dsa_samples() factory so the whole pipeline is reproducible without a credentialed dataset.
  3. End-to-end example (examples/synthetic_decompensation_dsa_model.py) — trains the model on synthetic data for three epochs and runs two ablations (hidden-dim sweep and horizon sweep). Runs in under 30 seconds on CPU with no PyHealth-install prerequisites.

Ablation study: the example script sweeps hidden_dim ∈ {64, 128, 256} and horizon ∈ {6, 12, 24}, reporting the final training loss and wall-clock time per configuration.

File guide

Exactly 10 files change or are added. Review order (suggested):

# Path Change What to look at
1 pyhealth/models/dynamic_survival_model.py new (418 LOC) class DynamicSurvivalModel(BaseModel) + internal _LinearEmbedding, _CausalTransformerEncoder, _HazardHead. Forward pass starts at line 369.
2 pyhealth/tasks/decompensation_dsa.py new (212 LOC) class DecompensationDSA(BaseTask) + make_synthetic_dsa_samples() factory.
3 examples/synthetic_decompensation_dsa_model.py new (248 LOC) Runnable end-to-end example with two ablations.
4 tests/core/test_dynamic_survival_model.py new (278 LOC, 18 tests) Covers instantiation, forward shapes, gradient flow, three encoder types, bias init.
5 tests/core/test_decompensation_dsa.py new (145 LOC, 15 tests) Covers make_synthetic_dsa_samples (11 tests) and DecompensationDSA schema (4 tests).
6 docs/api/models/pyhealth.models.DynamicSurvivalModel.rst new Sphinx docs — Quick Start + autoclass.
7 docs/api/tasks/pyhealth.tasks.DecompensationDSA.rst new Sphinx docs — Quick Start + schema table + autoclass + autofunction.
8 docs/api/models.rst modified Added models/pyhealth.models.DynamicSurvivalModel to the toctree (line 209).
9 docs/api/tasks.rst modified Added ICU Decompensation DSA <tasks/pyhealth.tasks.DecompensationDSA> to the toctree (line 233).
10 pyhealth/models/__init__.py and pyhealth/tasks/__init__.py modified Registered DynamicSurvivalModel and DecompensationDSA respectively.

Test plan

All tests use synthetic data (no credentialed datasets, no MIMIC demo), complete in milliseconds per test, and seconds for the entire suite:

pytest tests/core/test_dynamic_survival_model.py tests/core/test_decompensation_dsa.py -v

Expected:

tests/core/test_dynamic_survival_model.py ..................      [ 54%]
tests/core/test_decompensation_dsa.py ...............             [100%]
================== 33 passed in ~2s ==================

Timing breakdown (verified via pytest --durations=0):

  • Slowest single test: 0.05 s
  • Full suite: ≈ 2 seconds
  • Zero files written to disk; no temp-dir cleanup needed.

Reproducibility

No external dataset is required. The example and the tests rely only on make_synthetic_dsa_samples (seeded random generator) and torch.randn fixtures. test_reproducibility inside test_decompensation_dsa.py asserts byte-identical output for the same seed across invocations.

Checklist

  • Rebased against upstream/master before opening this PR
  • "Allow edits by maintainers" enabled on this PR
  • Link to the original paper is included above
  • Example script under examples/ demonstrates both the model and the task
  • Tests use synthetic data only (no MIMIC demo, no real HiRID/eICU)
  • All new classes inherit from PyHealth base classes (BaseModel, BaseTask)
  • Google-style docstrings with type hints on every public method
  • Sphinx RST files added for both the model and the task, and the toctrees in docs/api/models.rst and docs/api/tasks.rst are updated

Thanks for reviewing — happy to make any changes you'd like.

Implements the Dynamic Survival Analysis pipeline from Yeche et al.
(CHIL 2024). Includes:

- DynamicSurvivalModel (pyhealth/models/) with GRU, LSTM, and causal
  Transformer backbones, L1-regularised embedding, and hazard head
  with bias initialisation from empirical mean hazard rates.
- DecompensationDSA task (pyhealth/tasks/) with a synthetic data
  factory for credential-free reproduction.
- End-to-end example (examples/) with two ablations.
- 33 unit tests (tests/core/) using synthetic data only.
- Sphinx RST docs plus updates to models.rst and tasks.rst toctrees.

Paper: https://proceedings.mlr.press/v248/yeche24a.html
Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants