Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified .gitignore
Binary file not shown.
5 changes: 4 additions & 1 deletion docs/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,11 @@ API Reference
-------------

.. toctree::
:maxdepth: 3
:maxdepth: 4

models/pyhealth.models.rnn
models/pyhealth.models.transformer
models/pyhealth.models.dila
models/pyhealth.models.BaseModel
models/pyhealth.models.LogisticRegression
models/pyhealth.models.MLP
Expand Down
17 changes: 17 additions & 0 deletions docs/api/models/pyhealth.models.dila.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
pyhealth.models.dila
====================

.. automodule:: pyhealth.models.dila

Overview
--------
This module implements the Dictionary Label Attention (DILA) model, utilizing a sparse
autoencoder to disentangle dense pre-trained language model embeddings into distinct,
interpretable dictionary features for extreme multi-label medical coding.

API Reference
-------------
.. autoclass:: pyhealth.models.DILA
:members:
:undoc-members:
:show-inheritance:
176 changes: 176 additions & 0 deletions examples/mimic3_icd9_dila.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Contributor: Nikhil Ajit
# NetID/Email: najit2@illinois.edu
# Paper Title: DILA: Dictionary Label Attention for Mechanistic Interpretability in High-dimensional Multi-label Medical Coding Prediction
# Paper Link: https://arxiv.org/abs/2409.10504
# Description: Implementation of the DILA model utilizing a sparse autoencoder
# and a globally interpretable dictionary projection matrix for medical coding.

"""Ablation Study: DILA (Dictionary Label Attention) on MIMIC-III.

Paper: DILA: Dictionary Label Attention for Mechanistic Interpretability
in High-dimensional Multi-label Medical Coding Prediction.

Experimental Setup:
This script performs an ablation study to quantify the trade-off between strict
interpretability (enforced by sparsity) and downstream predictive accuracy. We
evaluate the model's performance on the medical coding task by varying two key
hyperparameters:

1. Dictionary Size (m): Controls the total number of sparse features allowed.
- Values tested: [1000, 3000, 6088]
2. Sparsity Penalty (lambda_saenc): Controls the L1/L2 penalty threshold.
- Values tested: [1e-5, 1e-6]

Metrics Tracked:
- Micro F1
- Macro F1
- ROC-AUC

Actual Findings:
When trained on the restricted MIMIC-III demo dataset, modifying the dictionary
size (m) and sparsity penalty (lambda_saenc) yielded negligible differences in
performance. Micro F1 remained heavily bounded around ~0.056 across all
configurations, and Macro F1 hovered around ~0.047. The ROC AUC metric
evaluated to NaN due to the extremely small test split lacking positive samples
for the vast majority of the 581 extracted ICD-9 classes. Running this on the
full MIMIC-III dataset is required to observe the true sparsity/accuracy tradeoff.
"""

import torch
import pandas as pd
from typing import List, Dict, Any
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.datasets.splitter import split_by_patient
from pyhealth.datasets import get_dataloader
from pyhealth.models import DILA
from pyhealth.trainer import Trainer


class ICD9CodingTask:
"""Medical coding task definition for extracting diagnosis codes.

Attributes:
task_name (str): Identifier for the task.
input_schema (Dict[str, str]): Schema definition for input features.
output_schema (Dict[str, str]): Schema definition for output labels.
"""

task_name = "ICD9_Coding_Task"
input_schema = {"conditions": "sequence"}
output_schema = {"label": "multilabel"}

def pre_filter(self, global_event_df: pd.DataFrame) -> pd.DataFrame:
"""Applies filtering to the global event dataframe before patient parsing.

Args:
global_event_df (pd.DataFrame): Raw event dataframe.

Returns:
pd.DataFrame: Unmodified event dataframe.
"""
return global_event_df

def __call__(self, patient: Any) -> List[Dict[str, Any]]:
"""Extracts ICD-9 diagnosis codes for each hospital admission.

Args:
patient (Any): Patient object containing historical hospital visits.

Returns:
List[Dict[str, Any]]: A list of parsed samples ready for processing.
"""
samples = []
admissions = patient.get_events(event_type="admissions")

for admission in admissions:
diagnoses_events = patient.get_events(
event_type="diagnoses_icd",
filters=[("hadm_id", "==", admission.hadm_id)]
)

diagnoses = [event.icd9_code for event in diagnoses_events]

if not diagnoses:
continue

samples.append({
"visit_id": admission.hadm_id,
"patient_id": patient.patient_id,
"conditions": diagnoses,
"label": diagnoses
})

return samples


if __name__ == "__main__":
print("Loading MIMIC-III Demo Dataset...")
dataset = MIMIC3Dataset(
root="./data/mimic-iii-clinical-database-demo-1.4/",
tables=["DIAGNOSES_ICD", "ADMISSIONS", "PATIENTS"]
)

dataset = dataset.set_task(ICD9CodingTask())

train_dataset, val_dataset, test_dataset = split_by_patient(
dataset,
[0.8, 0.1, 0.1],
seed=42
)

train_loader = get_dataloader(train_dataset, batch_size=8, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=8, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=8, shuffle=False)

dictionary_sizes = [1000, 3000, 6088]
sparsity_penalties = [1e-5, 1e-6]
results_log = {}

print("\nStarting DILA Ablation Study...")
print("=" * 50)

for m in dictionary_sizes:
for penalty in sparsity_penalties:
config_name = f"DictSize_{m}_Penalty_{penalty}"
print(f"\nTraining configuration: {config_name}")

model = DILA(
dataset=dataset,
feature_keys=["conditions"],
label_key="label",
mode="multilabel",
embedding_dim=768,
dictionary_size=m,
sparsity_penalty=penalty
)

trainer = Trainer(
model=model,
metrics=["roc_auc_macro", "f1_macro", "f1_micro"]
)

trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=3,
optimizer_class=torch.optim.AdamW,
optimizer_params={"lr": 5e-5},
weight_decay=0.01
)

print(f"Evaluating {config_name} on Test Set...")
eval_results = trainer.evaluate(dataloader=test_loader)
results_log[config_name] = eval_results

print("\n" + "=" * 50)
print("ABLATION STUDY RESULTS SUMMARY")
print("=" * 50)
print(f"{'Configuration':<30} | {'Micro F1':<10} | {'Macro F1':<10} | {'ROC AUC (Macro)':<15}")
print("-" * 75)

for config, metrics in results_log.items():
micro_f1 = metrics.get('f1_micro', 0.0)
macro_f1 = metrics.get('f1_macro', 0.0)
roc_auc = metrics.get('roc_auc_macro', 0.0)

print(f"{config:<30} | {micro_f1:<10.4f} | {macro_f1:<10.4f} | {roc_auc:<15.4f}")
3 changes: 2 additions & 1 deletion pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,5 @@
from .sdoh import SdohClassifier
from .medlink import MedLink
from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding
from .califorest import CaliForest
from .califorest import CaliForest
from .dila import DILA
137 changes: 137 additions & 0 deletions pyhealth/models/dila.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Contributor: Nikhil Ajit
# NetID/Email: najit2@illinois.edu
# Paper Title: DILA: Dictionary Label Attention for Mechanistic Interpretability in High-dimensional Multi-label Medical Coding Prediction
# Paper Link: https://arxiv.org/abs/2409.10504
# Description: Implementation of the DILA model utilizing a sparse autoencoder
# and a globally interpretable dictionary projection matrix for medical coding.

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Tuple, Optional, Any
from pyhealth.models.base_model import BaseModel
from pyhealth.datasets import SampleEHRDataset


class DILA(BaseModel):
"""Dictionary Label Attention (DILA) Model for Medical Coding.

This model implements the Dictionary Label Attention mechanism to predict
medical codes from clinical sequences. It uses a sparse autoencoder to
disentangle dense embeddings into sparse, interpretable dictionary features,
which are then projected to the label space.

Attributes:
feature_keys (List[str]): Keys to access input features in the dataset.
label_key (str): Key to access the ground truth labels.
mode (str): Mode of the task, e.g., "multilabel".
embedding_dim (int): Dimension of the token embeddings.
dictionary_size (int): Number of sparse dictionary features (m).
sparsity_penalty (float): Penalty weight for the L1/L2 regularization.
simulated_plm (nn.Embedding): Embedding layer simulating PLM token features.
encoder_weight (nn.Linear): Linear projection for the sparse encoder.
decoder_weight (nn.Linear): Linear projection for the sparse decoder.
decoder_bias (nn.Parameter): Bias term for the sparse autoencoder.
sparse_projection (nn.Parameter): Globally interpretable projection matrix.
fc (nn.Linear): Final linear decision layer for predictions.

Example:
>>> from pyhealth.models import DILA
>>> model = DILA(
... dataset=dataset,
... feature_keys=["conditions"],
... label_key="label",
... mode="multilabel"
... )
>>> # kwargs must include the features and labels
>>> outputs = model(conditions=torch.randn(4, 128, 768), label=torch.empty(4, 50).random_(2))
>>> loss = outputs["loss"]
"""

def __init__(
self,
dataset: SampleEHRDataset,
feature_keys: Optional[List[str]] = None,
label_key: Optional[str] = None,
mode: Optional[str] = None,
embedding_dim: int = 768,
dictionary_size: int = 6088,
sparsity_penalty: float = 1e-6,
**kwargs: Any
) -> None:
super(DILA, self).__init__(dataset=dataset, **kwargs)

self.feature_keys = feature_keys or ["conditions"]
self.label_key = label_key or "label"
self.mode = mode or "multilabel"

self.embedding_dim = embedding_dim
self.dictionary_size = dictionary_size
self.sparsity_penalty = sparsity_penalty

self.simulated_plm = nn.Embedding(5000, embedding_dim, padding_idx=0)

self.encoder_weight = nn.Linear(embedding_dim, dictionary_size)
self.decoder_weight = nn.Linear(dictionary_size, embedding_dim)
self.decoder_bias = nn.Parameter(torch.zeros(embedding_dim))

num_labels = self.get_output_size()
self.sparse_projection = nn.Parameter(torch.randn(dictionary_size, num_labels))

self.fc = nn.Linear(embedding_dim, num_labels)

def sparse_autoencoder(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Disentangles dense embeddings into sparse features.

Args:
x (torch.Tensor): Dense token embeddings of shape (batch, seq_len, dim).

Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the sparse
dictionary features and the reconstructed dense embeddings.
"""
x_bar = x - self.decoder_bias
f = torch.relu(self.encoder_weight(x_bar))
x_hat = self.decoder_weight(f) + self.decoder_bias
return f, x_hat

def forward(self, **kwargs: Any) -> Dict[str, torch.Tensor]:
"""Forward pass for the DILA model.

Args:
**kwargs: Keyword arguments containing the input features and labels.

Returns:
Dict[str, torch.Tensor]: A dictionary containing the total loss,
predicted probabilities, and true labels.
"""
input_sequence = kwargs[self.feature_keys[0]]

if input_sequence.dim() == 3:
input_sequence = input_sequence.squeeze(-1)

x = self.simulated_plm(input_sequence)
f, x_hat = self.sparse_autoencoder(x)

attention_scores = torch.matmul(f, self.sparse_projection)
a_laat = torch.softmax(attention_scores, dim=1)

x_att = torch.matmul(a_laat.transpose(-2, -1), x)
x_att_pooled = x_att.mean(dim=1)
logits = self.fc(x_att_pooled)

mse_loss = nn.MSELoss()(x_hat, x)
l1_loss = torch.norm(f, p=1)
l2_loss = torch.norm(f, p=2) ** 2
sae_loss = mse_loss + self.sparsity_penalty * (l1_loss + l2_loss)

y_true = kwargs[self.label_key].float()
bce_loss = F.binary_cross_entropy_with_logits(logits, y_true)

total_loss = bce_loss + sae_loss

return {
"loss": total_loss,
"y_prob": torch.sigmoid(logits),
"y_true": y_true
}
Loading