Skip to content

feat: Add ShiftLSTM model contribution with tests, docs, and examples#1032

Open
mamahhh wants to merge 3 commits intosunlabuiuc:masterfrom
mamahhh:project/shift-lstm
Open

feat: Add ShiftLSTM model contribution with tests, docs, and examples#1032
mamahhh wants to merge 3 commits intosunlabuiuc:masterfrom
mamahhh:project/shift-lstm

Conversation

@mamahhh
Copy link
Copy Markdown

@mamahhh mamahhh commented Apr 20, 2026

PR Draft: ShiftLSTM Reproduction in PyHealth

Contributor Information

  • Name: Sijia Ma
  • NetID / Email: sijiam2 / sijiam2@illinois.edu
  • Contribution Type: Model contribution

Original Paper

High-Level Description

This pull request adds ShiftLSTM to PyHealth as a new model contribution.
ShiftLSTM relaxes recurrent parameter sharing over time by dividing a
sequence into K temporal segments. Each segment uses its own LSTMCell,
while hidden state and cell state still propagate through the full sequence.

In this implementation:

  • num_segments = 1 acts as the shared-parameter baseline
  • num_segments > 1 relaxes parameter sharing over time

This PR also adds:

  • fast model tests using synthetic data only
  • a synthetic data generator based on Section 4.1 of the paper
  • an end-to-end ablation/example script
  • API documentation and model index updates

Why This PR Fits the Paper Reproduction

The original paper studies time-varying input-output relationships in
sequential prediction. This PR directly implements one of the paper's core
methods, shiftLSTM, and evaluates it with synthetic ablation experiments
aligned with the paper's synthetic setup.

File Guide

Core implementation

  • pyhealth/models/shift_lstm.py
    • Implements ShiftLSTMLayer and ShiftLSTM
  • pyhealth/models/__init__.py
    • Exports ShiftLSTM and ShiftLSTMLayer

Tests

  • tests/core/test_shift_lstm.py
    • Tests model initialization
    • Tests forward pass
    • Tests output shapes
    • Tests gradient computation
    • Tests synthetic generator behavior
    • Tests temporary directory save/cleanup

Example / Ablation

  • examples/synthetic/shift_lstm_synthetic_data.py
    • Synthetic data generator
  • examples/synthetic_sequence_classification_shift_lstm.py
    • End-to-end ablation script for K and delta

Documentation

  • docs/api/models/pyhealth.models.ShiftLSTM.rst
    • API docs page for the new model
  • docs/api/models.rst
    • Adds ShiftLSTM to the models index

Testing

Run the model tests with:

python -m unittest tests.core.test_shift_lstm -v

These tests use only synthetic or pseudo data and complete quickly.

Example Usage

Run the ablation/example script with:

python examples/synthetic_sequence_classification_shift_lstm.py \
  --num-samples 3000 \
  --seq-len 30 \
  --num-features 3 \
  --lookback 10 \
  --delta 0.2 \
  --embedding-dim 32 \
  --hidden-dim 32 \
  --batch-size 64 \
  --epochs 5 \
  --segments 1 2 4

Ablation Summary

We ran three kinds of synthetic experiments:

  1. Main comparison with K = 1, 2, 4
  2. K sweep with K = 1, 2, 4, 8
  3. Shift-strength sweep with delta = 0.0, 0.1, 0.2

Main results at delta = 0.2, averaged over 3 seeds:

Model K Test AUROC (mean ± std) Test AUPRC (mean ± std) Test Accuracy (mean ± std)
LSTM baseline 1 0.8870 ± 0.0185 0.8801 ± 0.0312 0.8111 ± 0.0236
ShiftLSTM 2 0.8951 ± 0.0308 0.8874 ± 0.0329 0.8244 ± 0.0323
ShiftLSTM 4 0.9194 ± 0.0313 0.9210 ± 0.0320 0.8415 ± 0.0461

These results support the paper's main qualitative claim that relaxing
parameter sharing can help when relationships vary over time.

Notes

  • This PR focuses on reproducing shiftLSTM, not mixLSTM
  • The experiments are synthetic and controlled
  • The clinical benchmark experiments from the paper were not fully reproduced in this PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant