Skip to content

Add MixLSTM model for temporally shifting clinical time-series (Oh et al. 2020)#1127

Open
amanluth03 wants to merge 19 commits intosunlabuiuc:masterfrom
hseddis321:master
Open

Add MixLSTM model for temporally shifting clinical time-series (Oh et al. 2020)#1127
amanluth03 wants to merge 19 commits intosunlabuiuc:masterfrom
hseddis321:master

Conversation

@amanluth03
Copy link
Copy Markdown

@amanluth03 amanluth03 commented Apr 23, 2026

Contributors

  • Aman Luthra — NetID: aluth3aluth3@illinois.edu
  • Tanmay Mittal — NetID: tmitta3tmitta3@illinois.edu
  • Siddesh Vijayakumar — NetID: siddesh2siddesh2@illinois.edu

Type of Contribution

Model (Option 2)

Paper

Jeeheh Oh, Jiaxuan Wang, Shengpu Tang, Michael Sjoding, Jenna Wiens.
"Relaxed Parameter Sharing: Effectively Modeling Time-Varying Relationships in Clinical Time-Series." MLHC 2019.
https://arxiv.org/abs/1906.02898

Description

This PR implements the MixLSTM architecture from Oh et al. 2020, which addresses
temporal conditional shift in clinical time-series prediction.

Standard LSTMs share the same parameters across all time steps, which makes it
difficult to capture relationships between inputs and outcomes that change over
the course of a patient's hospital stay. MixLSTM relaxes this constraint by
maintaining K independent LSTM cells whose parameters are dynamically combined
at each time step via learned mixing coefficients constrained to the simplex.
This allows the model to smoothly transition between different temporal dynamics
without enforcing hard segment boundaries.

The implementation inherits from PyHealth's BaseModel and dynamically infers
input dimensions and sequence length from any SampleDataset passed to it, so
it can plug into existing PyHealth tasks without modification. It supports both
per-timestep regression (output shape (B, T, 1)) and classification (output
shape (B, num_classes)), auto-detected from the dataset's output schema.

File Guide

  • pyhealth/models/mixlstm.py — MixLSTM model implementation
  • pyhealth/models/__init__.py — register MixLSTM for import
  • tests/mixlstm_test.py — unit tests using synthetic data
  • examples/mimic3_synthetic_mixlstm.py — ablation study example
  • docs/api/models/pyhealth.models.MixLSTM.rst — API documentation
  • docs/api/models.rst — add MixLSTM to the models index
  • pyproject.toml — project dependencies

Ablation Study Summary

We ran two ablations on a synthetic non-stationary time-series regression task
(1,000 sequences per split, T=30 timesteps, input_dim=3, 90% sparse inputs,
lookback l=10, drift δ=0.05 per step). Each configuration is evaluated via
random search (20 runs × 30 epochs) over MixLSTM with K=2 experts and hidden
size sampled from {100, 150, 300, 500, 700, 900, 1100}.

1. Learning rate sweep (Adam): lr ∈ {0.0001, 0.0005, 0.001, 0.005, 0.01}

  • lr=0.0001 was the worst performer across all hidden sizes.
  • lr=0.001 (the paper's choice) was strongest for hidden sizes ≥ 500.
  • lr=0.005 performed best at smaller hidden sizes (100–500).
  • lr=0.01 was unstable and produced erratic loss curves.
  • Best configuration: Adam, lr=0.001, hidden=1100 → Val MSE 0.43, Test MSE 0.46.
    2. Optimizer comparison (Adam vs. SGD at lr=0.001):
Optimizer Best Val Loss Best Test Loss
Adam 0.43 0.47
SGD 16.39 16.41

Adam dramatically outperforms SGD on this task — SGD fails to converge within
the 30-epoch budget. This is consistent with the paper's use of Adam as the
default optimizer.

Outputs: The script produces six .png visualizations: loss vs. hidden
size, predictions vs. ground truth on held-out test sequences, and heatmaps of
the synthetic task's time-varying weight distribution — for both the LR sweep
and the optimizer comparison. Full runtime is approximately 30–45 minutes on
CPU, faster on GPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants