Add Deep Cox Mixtures survival model#1055
Open
sirlittle wants to merge 1 commit intosunlabuiuc:masterfrom
Open
Add Deep Cox Mixtures survival model#1055sirlittle wants to merge 1 commit intosunlabuiuc:masterfrom
sirlittle wants to merge 1 commit intosunlabuiuc:masterfrom
Conversation
cfd7ef0 to
99e6cad
Compare
Implements a simplified, PyHealth-native port of Nagpal et al. (2021,
MLHC) "Deep Cox Mixtures for Survival Regression": a shared MLP
embedding feeding a softmax gate and K Cox experts, each with a
nonparametric Breslow baseline cumulative hazard smoothed via a scipy
spline. Training alternates hard-assignment E-steps with gradient-
descent M-steps on the per-component Cox partial likelihood.
Files:
- pyhealth/models/deep_cox_mixtures.py (new model)
- pyhealth/models/__init__.py (export)
- docs/api/models/pyhealth.models.DeepCoxMixtures.rst (autodoc page)
- docs/api/models.rst (toctree entry)
- tests/core/test_deep_cox_mixtures.py (9 synthetic-data tests, <5s)
- examples/synthetic_survival_deep_cox_mixtures.py (k in {1,2,3}
ablation on 400/100 Weibull synthetic data; k=2 wins at C-index
0.6499 vs 0.6209 for k=1 and 0.6365 for k=3)
Paper: https://proceedings.mlr.press/v149/nagpal21a/nagpal21a.pdf
99e6cad to
a646b80
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Contributors
Type of contribution
Model —
DeepCoxMixtures, a new PyHealth model inheritingBaseModel.Paper
📄 Deep Cox Mixtures for Survival Regression (Nagpal, Yadlowsky, Rostamzadeh, Heller; MLHC 2021)
High-level description
DCM extends Cox proportional-hazards regression by modelling the patient population as a mixture of
Klatent subgroups, each with its own Cox expert. A shared MLP embedding feeds a softmax gate (mixture weights) and a clamped expert head (per-component log-hazard-ratios). Each component's baseline cumulative hazard is a non-parametric Breslow estimator smoothed with ascipy.interpolate.UnivariateSpline. Training alternates hard-assignment E-steps with gradient-descent M-steps on the per-component Cox partial likelihood; Breslow baselines are refit on the full training set between epochs.The model consumes a
SampleDatasetwhoseoutput_schemacarries two keys —{time: regression, event: binary}— so no new processor is introduced. Prediction helperspredict_survival_curve,predict_latent_z, andpredict_riskcover the standard survival-analysis evaluation surface.File guide
pyhealth/models/deep_cox_mixtures.py— core model.pyhealth/models/__init__.py— add the export.docs/api/models/pyhealth.models.DeepCoxMixtures.rst— autodoc page.docs/api/models.rst— toctree entry.tests/core/test_deep_cox_mixtures.py— 9 unit tests on synthetic data covering forward contract, gradient flow, gamma-clamp, survival-curve monotonicity, latent-z normalisation, thek=1degenerate-Cox case, missing-label rejection, and the Breslow refit fallback. Full suite runs in ~4s.examples/synthetic_survival_deep_cox_mixtures.py— deterministic ablation overk in {1, 2, 3}on 400/100 Weibull-censored synthetic samples with two latent subgroups; reproduces the paper's hypothesis thatk=2beatsk=1(under-specified) andk=3(over-parameterised):