Skip to content

Add Deep Cox Mixtures survival model#1055

Open
sirlittle wants to merge 1 commit intosunlabuiuc:masterfrom
sirlittle:feat/deep-cox-mixtures
Open

Add Deep Cox Mixtures survival model#1055
sirlittle wants to merge 1 commit intosunlabuiuc:masterfrom
sirlittle:feat/deep-cox-mixtures

Conversation

@sirlittle
Copy link
Copy Markdown

@sirlittle sirlittle commented Apr 21, 2026

Contributors

Type of contribution

ModelDeepCoxMixtures, a new PyHealth model inheriting BaseModel.

Paper

📄 Deep Cox Mixtures for Survival Regression (Nagpal, Yadlowsky, Rostamzadeh, Heller; MLHC 2021)

High-level description

DCM extends Cox proportional-hazards regression by modelling the patient population as a mixture of K latent subgroups, each with its own Cox expert. A shared MLP embedding feeds a softmax gate (mixture weights) and a clamped expert head (per-component log-hazard-ratios). Each component's baseline cumulative hazard is a non-parametric Breslow estimator smoothed with a scipy.interpolate.UnivariateSpline. Training alternates hard-assignment E-steps with gradient-descent M-steps on the per-component Cox partial likelihood; Breslow baselines are refit on the full training set between epochs.

The model consumes a SampleDataset whose output_schema carries two keys — {time: regression, event: binary} — so no new processor is introduced. Prediction helpers predict_survival_curve, predict_latent_z, and predict_risk cover the standard survival-analysis evaluation surface.

File guide

  • pyhealth/models/deep_cox_mixtures.py — core model.
  • pyhealth/models/__init__.py — add the export.
  • docs/api/models/pyhealth.models.DeepCoxMixtures.rst — autodoc page.
  • docs/api/models.rst — toctree entry.
  • tests/core/test_deep_cox_mixtures.py — 9 unit tests on synthetic data covering forward contract, gradient flow, gamma-clamp, survival-curve monotonicity, latent-z normalisation, the k=1 degenerate-Cox case, missing-label rejection, and the Breslow refit fallback. Full suite runs in ~4s.
  • examples/synthetic_survival_deep_cox_mixtures.py — deterministic ablation over k in {1, 2, 3} on 400/100 Weibull-censored synthetic samples with two latent subgroups; reproduces the paper's hypothesis that k=2 beats k=1 (under-specified) and k=3 (over-parameterised):
k C-index
1 0.6209
2 0.6499
3 0.6365

@sirlittle sirlittle marked this pull request as ready for review April 21, 2026 04:21
@sirlittle sirlittle force-pushed the feat/deep-cox-mixtures branch from cfd7ef0 to 99e6cad Compare April 21, 2026 04:23
Implements a simplified, PyHealth-native port of Nagpal et al. (2021,
MLHC) "Deep Cox Mixtures for Survival Regression": a shared MLP
embedding feeding a softmax gate and K Cox experts, each with a
nonparametric Breslow baseline cumulative hazard smoothed via a scipy
spline. Training alternates hard-assignment E-steps with gradient-
descent M-steps on the per-component Cox partial likelihood.

Files:
- pyhealth/models/deep_cox_mixtures.py (new model)
- pyhealth/models/__init__.py           (export)
- docs/api/models/pyhealth.models.DeepCoxMixtures.rst (autodoc page)
- docs/api/models.rst                   (toctree entry)
- tests/core/test_deep_cox_mixtures.py  (9 synthetic-data tests, <5s)
- examples/synthetic_survival_deep_cox_mixtures.py (k in {1,2,3}
    ablation on 400/100 Weibull synthetic data; k=2 wins at C-index
    0.6499 vs 0.6209 for k=1 and 0.6365 for k=3)

Paper: https://proceedings.mlr.press/v149/nagpal21a/nagpal21a.pdf
@sirlittle sirlittle force-pushed the feat/deep-cox-mixtures branch from 99e6cad to a646b80 Compare April 21, 2026 04:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant