Add MixLSTM model for temporally shifting clinical time-series (Oh et al. 2020)#1127
Open
amanluth03 wants to merge 19 commits intosunlabuiuc:masterfrom
Open
Add MixLSTM model for temporally shifting clinical time-series (Oh et al. 2020)#1127amanluth03 wants to merge 19 commits intosunlabuiuc:masterfrom
amanluth03 wants to merge 19 commits intosunlabuiuc:masterfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Contributors
aluth3—aluth3@illinois.edutmitta3—tmitta3@illinois.edusiddesh2—siddesh2@illinois.eduType of Contribution
Model (Option 2)
Paper
Jeeheh Oh, Jiaxuan Wang, Shengpu Tang, Michael Sjoding, Jenna Wiens.
"Relaxed Parameter Sharing: Effectively Modeling Time-Varying Relationships in Clinical Time-Series." MLHC 2019.
https://arxiv.org/abs/1906.02898
Description
This PR implements the
MixLSTMarchitecture from Oh et al. 2020, which addressestemporal conditional shift in clinical time-series prediction.
Standard LSTMs share the same parameters across all time steps, which makes it
difficult to capture relationships between inputs and outcomes that change over
the course of a patient's hospital stay. MixLSTM relaxes this constraint by
maintaining
Kindependent LSTM cells whose parameters are dynamically combinedat each time step via learned mixing coefficients constrained to the simplex.
This allows the model to smoothly transition between different temporal dynamics
without enforcing hard segment boundaries.
The implementation inherits from PyHealth's
BaseModeland dynamically infersinput dimensions and sequence length from any
SampleDatasetpassed to it, soit can plug into existing PyHealth tasks without modification. It supports both
per-timestep regression (output shape
(B, T, 1)) and classification (outputshape
(B, num_classes)), auto-detected from the dataset's output schema.File Guide
pyhealth/models/mixlstm.py— MixLSTM model implementationpyhealth/models/__init__.py— register MixLSTM for importtests/mixlstm_test.py— unit tests using synthetic dataexamples/mimic3_synthetic_mixlstm.py— ablation study exampledocs/api/models/pyhealth.models.MixLSTM.rst— API documentationdocs/api/models.rst— add MixLSTM to the models indexpyproject.toml— project dependenciesAblation Study Summary
We ran two ablations on a synthetic non-stationary time-series regression task
(1,000 sequences per split, T=30 timesteps, input_dim=3, 90% sparse inputs,
lookback l=10, drift δ=0.05 per step). Each configuration is evaluated via
random search (20 runs × 30 epochs) over MixLSTM with K=2 experts and hidden
size sampled from
{100, 150, 300, 500, 700, 900, 1100}.1. Learning rate sweep (Adam):
lr ∈ {0.0001, 0.0005, 0.001, 0.005, 0.01}lr=0.0001was the worst performer across all hidden sizes.lr=0.001(the paper's choice) was strongest for hidden sizes ≥ 500.lr=0.005performed best at smaller hidden sizes (100–500).lr=0.01was unstable and produced erratic loss curves.2. Optimizer comparison (Adam vs. SGD at lr=0.001):
Adam dramatically outperforms SGD on this task — SGD fails to converge within
the 30-epoch budget. This is consistent with the paper's use of Adam as the
default optimizer.
Outputs: The script produces six
.pngvisualizations: loss vs. hiddensize, predictions vs. ground truth on held-out test sequences, and heatmaps of
the synthetic task's time-varying weight distribution — for both the LR sweep
and the optimizer comparison. Full runtime is approximately 30–45 minutes on
CPU, faster on GPU.