Skip to content

Add TransEHR model for clinical EHR time series (CS598 DL4H)#1130

Open
pintard2 wants to merge 1 commit intosunlabuiuc:masterfrom
pintard2:transehr-model
Open

Add TransEHR model for clinical EHR time series (CS598 DL4H)#1130
pintard2 wants to merge 1 commit intosunlabuiuc:masterfrom
pintard2:transehr-model

Conversation

@pintard2
Copy link
Copy Markdown

@pintard2 pintard2 commented Apr 23, 2026

Summary

  • Student: Reshaun Pintard; NetID: pintard2; CS598 Deep Learning for Healthcare
  • Contribution type: Option 2: Model
  • Paper: Xu et al., "TransEHR: Self-Supervised Transformer for Clinical Time Series Data", PMLR 2023. https://proceedings.mlr.press/v209/xu23a.html

What this PR adds

TransEHR is a transformer encoder designed for longitudinal EHR data. Its key novelty over the existing PyHealth Transformer model is the use of nested_sequence inputs that preserve visit-level temporal structure (patient → visits → codes), rather than a flat list of codes.

Architecture:

  1. Clinical codes are embedded and mean-pooled within each visit to produce a single visit vector.
  2. Sinusoidal positional encoding is applied over the visit sequence.
  3. A standard TransformerEncoder attends over visits (not individual codes).
  4. Valid visit representations are mean-pooled to a single patient vector.
  5. A linear layer produces the final prediction.

Files

File Description
pyhealth/models/trans_ehr.py TransEHR model implementation
tests/test_trans_ehr.py 22 unit tests using synthetic data (no dataset download needed)
examples/mimic4_mortality_trans_ehr.py Ablation study: num_layers, embedding_dim, num_heads
docs/api/models/pyhealth.models.TransEHR.rst Sphinx documentation
pyhealth/models/__init__.py Added TransEHR export
docs/api/models.rst Added TransEHR to toctree

Tests

All 22 unit tests pass on synthetic data with no real dataset required:

pytest tests/test_trans_ehr.py -v

Implements TransEHR (Xu et al., PMLR 2023) as a PyHealth model
contribution for CS598 DL4H (pintard2). TransEHR uses nested_sequence
inputs to preserve visit-level temporal structure — each patient's
visit sequence is encoded with sinusoidal positional encoding and
processed by a transformer encoder over visits, not individual codes.

Files added:
- pyhealth/models/trans_ehr.py     — TransEHR model implementation
- tests/test_trans_ehr.py          — 22 unit tests (synthetic data)
- examples/mimic4_mortality_trans_ehr.py — Ablation study script
- docs/api/models/pyhealth.models.TransEHR.rst — Sphinx documentation

Files updated:
- pyhealth/models/__init__.py      — export TransEHR
- docs/api/models.rst              — add TransEHR to toctree

Paper: https://proceedings.mlr.press/v209/xu23a.html
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant