Add DynamicSurvivalModel and DecompensationDSA task#1126
Open
hards2 wants to merge 1 commit intosunlabuiuc:masterfrom
Open
Add DynamicSurvivalModel and DecompensationDSA task#1126hards2 wants to merge 1 commit intosunlabuiuc:masterfrom
hards2 wants to merge 1 commit intosunlabuiuc:masterfrom
Conversation
Implements the Dynamic Survival Analysis pipeline from Yeche et al. (CHIL 2024). Includes: - DynamicSurvivalModel (pyhealth/models/) with GRU, LSTM, and causal Transformer backbones, L1-regularised embedding, and hazard head with bias initialisation from empirical mean hazard rates. - DecompensationDSA task (pyhealth/tasks/) with a synthetic data factory for credential-free reproduction. - End-to-end example (examples/) with two ablations. - 33 unit tests (tests/core/) using synthetic data only. - Sphinx RST docs plus updates to models.rst and tasks.rst toctrees. Paper: https://proceedings.mlr.press/v248/yeche24a.html Made-with: Cursor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Contributor
Type of contribution
This PR adds:
DynamicSurvivalModel) — Option 2 of the contribution rubric.DecompensationDSA) — Option 3 of the contribution rubric.Both the model and task are novel within PyHealth (no existing
DynamicSurvivalModelorDecompensationDSAsymbols before this PR).Paper being reproduced
Yèche, H.; Burger, M.; Veshchezerova, D.; Rätsch, G. (2024). Dynamic Survival Analysis for Early Event Prediction. In Proceedings of the Conference on Health, Inference, and Learning (CHIL), PMLR 248:540–557.
Paper link: https://proceedings.mlr.press/v248/yeche24a.html
Authors' reference implementation: https://github.com/ratschlab/dsa-for-eep
Video walkthrough (4–7 min): https://drive.google.com/drive/folders/1mF3sfgjo-lG299ORsvddwiaEhPVikS0s?usp=share_link
High-level description
Early Event Prediction (EEP) in the ICU asks whether a clinical event — circulatory failure, decompensation, mortality — will occur within a fixed horizon h. The dominant approach trains a neural network to estimate the cumulative failure function
F_θ(h | X_t). While this works for a single scalar risk score, it collapses the entire horizon into one number: a patient whose event is one hour away and another whose event is eleven hours away receive the same score if their 24-hour probabilities happen to align. As a result, simple threshold-based alarm policies are suboptimal.Yèche et al. (2024) propose Dynamic Survival Analysis (DSA): instead of estimating the scalar
F, estimate the full hazard functionλ_θ(k | X_t)for every horizonk ∈ {1, …, h}. From the hazard one recovers both the original EEP risk score and a full probability mass function over time-to-event, which supports the imminent-prioritisation alarm policy that up-weights near-term predictions and improves Alarm/Event AuPRC by up to 11% over standard EEP baselines.This PR implements the DSA pipeline as:
DynamicSurvivalModel(pyhealth/models/dynamic_survival_model.py) — linear embedding with L1 regularisation → temporal backbone (GRU, LSTM, or causal Transformer) → hazard head with bias-initialisation from empirical mean hazard rates. Inherits frompyhealth.models.BaseModel; returns the PyHealth-Trainer-compatible dict{loss, y_prob, y_true, logit}fromforward(**batch).DecompensationDSA(pyhealth/tasks/decompensation_dsa.py) — subclass ofpyhealth.tasks.BaseTaskwithinput_schema = {"timeseries": "tensor"}andoutput_schema = {"label": "binary"}. Includes amake_synthetic_dsa_samples()factory so the whole pipeline is reproducible without a credentialed dataset.examples/synthetic_decompensation_dsa_model.py) — trains the model on synthetic data for three epochs and runs two ablations (hidden-dim sweep and horizon sweep). Runs in under 30 seconds on CPU with no PyHealth-install prerequisites.Ablation study: the example script sweeps
hidden_dim ∈ {64, 128, 256}andhorizon ∈ {6, 12, 24}, reporting the final training loss and wall-clock time per configuration.File guide
Exactly 10 files change or are added. Review order (suggested):
pyhealth/models/dynamic_survival_model.pyclass DynamicSurvivalModel(BaseModel)+ internal_LinearEmbedding,_CausalTransformerEncoder,_HazardHead. Forward pass starts at line 369.pyhealth/tasks/decompensation_dsa.pyclass DecompensationDSA(BaseTask)+make_synthetic_dsa_samples()factory.examples/synthetic_decompensation_dsa_model.pytests/core/test_dynamic_survival_model.pytests/core/test_decompensation_dsa.pymake_synthetic_dsa_samples(11 tests) andDecompensationDSAschema (4 tests).docs/api/models/pyhealth.models.DynamicSurvivalModel.rstautoclass.docs/api/tasks/pyhealth.tasks.DecompensationDSA.rstautoclass+autofunction.docs/api/models.rstmodels/pyhealth.models.DynamicSurvivalModelto the toctree (line 209).docs/api/tasks.rstICU Decompensation DSA <tasks/pyhealth.tasks.DecompensationDSA>to the toctree (line 233).pyhealth/models/__init__.pyandpyhealth/tasks/__init__.pyDynamicSurvivalModelandDecompensationDSArespectively.Test plan
All tests use synthetic data (no credentialed datasets, no MIMIC demo), complete in milliseconds per test, and seconds for the entire suite:
Expected:
Timing breakdown (verified via
pytest --durations=0):Reproducibility
No external dataset is required. The example and the tests rely only on
make_synthetic_dsa_samples(seeded random generator) andtorch.randnfixtures.test_reproducibilityinsidetest_decompensation_dsa.pyasserts byte-identical output for the same seed across invocations.Checklist
upstream/masterbefore opening this PRexamples/demonstrates both the model and the taskBaseModel,BaseTask)docs/api/models.rstanddocs/api/tasks.rstare updatedThanks for reviewing — happy to make any changes you'd like.