diff --git a/.gitignore b/.gitignore index 9993737db..51e8a5610 100644 Binary files a/.gitignore and b/.gitignore differ diff --git a/docs/api/models.rst b/docs/api/models.rst index 7c3ac7c4b..53f81250f 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -167,8 +167,11 @@ API Reference ------------- .. toctree:: - :maxdepth: 3 + :maxdepth: 4 + models/pyhealth.models.rnn + models/pyhealth.models.transformer + models/pyhealth.models.dila models/pyhealth.models.BaseModel models/pyhealth.models.LogisticRegression models/pyhealth.models.MLP diff --git a/docs/api/models/pyhealth.models.dila.rst b/docs/api/models/pyhealth.models.dila.rst new file mode 100644 index 000000000..8137bc0ab --- /dev/null +++ b/docs/api/models/pyhealth.models.dila.rst @@ -0,0 +1,17 @@ +pyhealth.models.dila +==================== + +.. automodule:: pyhealth.models.dila + +Overview +-------- +This module implements the Dictionary Label Attention (DILA) model, utilizing a sparse +autoencoder to disentangle dense pre-trained language model embeddings into distinct, +interpretable dictionary features for extreme multi-label medical coding. + +API Reference +------------- +.. autoclass:: pyhealth.models.DILA + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/mimic3_icd9_dila.py b/examples/mimic3_icd9_dila.py new file mode 100644 index 000000000..c80e2c137 --- /dev/null +++ b/examples/mimic3_icd9_dila.py @@ -0,0 +1,176 @@ +# Contributor: Nikhil Ajit +# NetID/Email: najit2@illinois.edu +# Paper Title: DILA: Dictionary Label Attention for Mechanistic Interpretability in High-dimensional Multi-label Medical Coding Prediction +# Paper Link: https://arxiv.org/abs/2409.10504 +# Description: Implementation of the DILA model utilizing a sparse autoencoder +# and a globally interpretable dictionary projection matrix for medical coding. + +"""Ablation Study: DILA (Dictionary Label Attention) on MIMIC-III. + +Paper: DILA: Dictionary Label Attention for Mechanistic Interpretability +in High-dimensional Multi-label Medical Coding Prediction. + +Experimental Setup: +This script performs an ablation study to quantify the trade-off between strict +interpretability (enforced by sparsity) and downstream predictive accuracy. We +evaluate the model's performance on the medical coding task by varying two key +hyperparameters: + +1. Dictionary Size (m): Controls the total number of sparse features allowed. + - Values tested: [1000, 3000, 6088] +2. Sparsity Penalty (lambda_saenc): Controls the L1/L2 penalty threshold. + - Values tested: [1e-5, 1e-6] + +Metrics Tracked: +- Micro F1 +- Macro F1 +- ROC-AUC + +Actual Findings: +When trained on the restricted MIMIC-III demo dataset, modifying the dictionary +size (m) and sparsity penalty (lambda_saenc) yielded negligible differences in +performance. Micro F1 remained heavily bounded around ~0.056 across all +configurations, and Macro F1 hovered around ~0.047. The ROC AUC metric +evaluated to NaN due to the extremely small test split lacking positive samples +for the vast majority of the 581 extracted ICD-9 classes. Running this on the +full MIMIC-III dataset is required to observe the true sparsity/accuracy tradeoff. +""" + +import torch +import pandas as pd +from typing import List, Dict, Any +from pyhealth.datasets import MIMIC3Dataset +from pyhealth.datasets.splitter import split_by_patient +from pyhealth.datasets import get_dataloader +from pyhealth.models import DILA +from pyhealth.trainer import Trainer + + +class ICD9CodingTask: + """Medical coding task definition for extracting diagnosis codes. + + Attributes: + task_name (str): Identifier for the task. + input_schema (Dict[str, str]): Schema definition for input features. + output_schema (Dict[str, str]): Schema definition for output labels. + """ + + task_name = "ICD9_Coding_Task" + input_schema = {"conditions": "sequence"} + output_schema = {"label": "multilabel"} + + def pre_filter(self, global_event_df: pd.DataFrame) -> pd.DataFrame: + """Applies filtering to the global event dataframe before patient parsing. + + Args: + global_event_df (pd.DataFrame): Raw event dataframe. + + Returns: + pd.DataFrame: Unmodified event dataframe. + """ + return global_event_df + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Extracts ICD-9 diagnosis codes for each hospital admission. + + Args: + patient (Any): Patient object containing historical hospital visits. + + Returns: + List[Dict[str, Any]]: A list of parsed samples ready for processing. + """ + samples = [] + admissions = patient.get_events(event_type="admissions") + + for admission in admissions: + diagnoses_events = patient.get_events( + event_type="diagnoses_icd", + filters=[("hadm_id", "==", admission.hadm_id)] + ) + + diagnoses = [event.icd9_code for event in diagnoses_events] + + if not diagnoses: + continue + + samples.append({ + "visit_id": admission.hadm_id, + "patient_id": patient.patient_id, + "conditions": diagnoses, + "label": diagnoses + }) + + return samples + + +if __name__ == "__main__": + print("Loading MIMIC-III Demo Dataset...") + dataset = MIMIC3Dataset( + root="./data/mimic-iii-clinical-database-demo-1.4/", + tables=["DIAGNOSES_ICD", "ADMISSIONS", "PATIENTS"] + ) + + dataset = dataset.set_task(ICD9CodingTask()) + + train_dataset, val_dataset, test_dataset = split_by_patient( + dataset, + [0.8, 0.1, 0.1], + seed=42 + ) + + train_loader = get_dataloader(train_dataset, batch_size=8, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=8, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=8, shuffle=False) + + dictionary_sizes = [1000, 3000, 6088] + sparsity_penalties = [1e-5, 1e-6] + results_log = {} + + print("\nStarting DILA Ablation Study...") + print("=" * 50) + + for m in dictionary_sizes: + for penalty in sparsity_penalties: + config_name = f"DictSize_{m}_Penalty_{penalty}" + print(f"\nTraining configuration: {config_name}") + + model = DILA( + dataset=dataset, + feature_keys=["conditions"], + label_key="label", + mode="multilabel", + embedding_dim=768, + dictionary_size=m, + sparsity_penalty=penalty + ) + + trainer = Trainer( + model=model, + metrics=["roc_auc_macro", "f1_macro", "f1_micro"] + ) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=3, + optimizer_class=torch.optim.AdamW, + optimizer_params={"lr": 5e-5}, + weight_decay=0.01 + ) + + print(f"Evaluating {config_name} on Test Set...") + eval_results = trainer.evaluate(dataloader=test_loader) + results_log[config_name] = eval_results + + print("\n" + "=" * 50) + print("ABLATION STUDY RESULTS SUMMARY") + print("=" * 50) + print(f"{'Configuration':<30} | {'Micro F1':<10} | {'Macro F1':<10} | {'ROC AUC (Macro)':<15}") + print("-" * 75) + + for config, metrics in results_log.items(): + micro_f1 = metrics.get('f1_micro', 0.0) + macro_f1 = metrics.get('f1_macro', 0.0) + roc_auc = metrics.get('roc_auc_macro', 0.0) + + print(f"{config:<30} | {micro_f1:<10.4f} | {macro_f1:<10.4f} | {roc_auc:<15.4f}") \ No newline at end of file diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..649415472 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -45,4 +45,5 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding -from .califorest import CaliForest \ No newline at end of file +from .califorest import CaliForest +from .dila import DILA \ No newline at end of file diff --git a/pyhealth/models/dila.py b/pyhealth/models/dila.py new file mode 100644 index 000000000..d6500bad8 --- /dev/null +++ b/pyhealth/models/dila.py @@ -0,0 +1,137 @@ +# Contributor: Nikhil Ajit +# NetID/Email: najit2@illinois.edu +# Paper Title: DILA: Dictionary Label Attention for Mechanistic Interpretability in High-dimensional Multi-label Medical Coding Prediction +# Paper Link: https://arxiv.org/abs/2409.10504 +# Description: Implementation of the DILA model utilizing a sparse autoencoder +# and a globally interpretable dictionary projection matrix for medical coding. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Dict, Tuple, Optional, Any +from pyhealth.models.base_model import BaseModel +from pyhealth.datasets import SampleEHRDataset + + +class DILA(BaseModel): + """Dictionary Label Attention (DILA) Model for Medical Coding. + + This model implements the Dictionary Label Attention mechanism to predict + medical codes from clinical sequences. It uses a sparse autoencoder to + disentangle dense embeddings into sparse, interpretable dictionary features, + which are then projected to the label space. + + Attributes: + feature_keys (List[str]): Keys to access input features in the dataset. + label_key (str): Key to access the ground truth labels. + mode (str): Mode of the task, e.g., "multilabel". + embedding_dim (int): Dimension of the token embeddings. + dictionary_size (int): Number of sparse dictionary features (m). + sparsity_penalty (float): Penalty weight for the L1/L2 regularization. + simulated_plm (nn.Embedding): Embedding layer simulating PLM token features. + encoder_weight (nn.Linear): Linear projection for the sparse encoder. + decoder_weight (nn.Linear): Linear projection for the sparse decoder. + decoder_bias (nn.Parameter): Bias term for the sparse autoencoder. + sparse_projection (nn.Parameter): Globally interpretable projection matrix. + fc (nn.Linear): Final linear decision layer for predictions. + + Example: + >>> from pyhealth.models import DILA + >>> model = DILA( + ... dataset=dataset, + ... feature_keys=["conditions"], + ... label_key="label", + ... mode="multilabel" + ... ) + >>> # kwargs must include the features and labels + >>> outputs = model(conditions=torch.randn(4, 128, 768), label=torch.empty(4, 50).random_(2)) + >>> loss = outputs["loss"] + """ + + def __init__( + self, + dataset: SampleEHRDataset, + feature_keys: Optional[List[str]] = None, + label_key: Optional[str] = None, + mode: Optional[str] = None, + embedding_dim: int = 768, + dictionary_size: int = 6088, + sparsity_penalty: float = 1e-6, + **kwargs: Any + ) -> None: + super(DILA, self).__init__(dataset=dataset, **kwargs) + + self.feature_keys = feature_keys or ["conditions"] + self.label_key = label_key or "label" + self.mode = mode or "multilabel" + + self.embedding_dim = embedding_dim + self.dictionary_size = dictionary_size + self.sparsity_penalty = sparsity_penalty + + self.simulated_plm = nn.Embedding(5000, embedding_dim, padding_idx=0) + + self.encoder_weight = nn.Linear(embedding_dim, dictionary_size) + self.decoder_weight = nn.Linear(dictionary_size, embedding_dim) + self.decoder_bias = nn.Parameter(torch.zeros(embedding_dim)) + + num_labels = self.get_output_size() + self.sparse_projection = nn.Parameter(torch.randn(dictionary_size, num_labels)) + + self.fc = nn.Linear(embedding_dim, num_labels) + + def sparse_autoencoder(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Disentangles dense embeddings into sparse features. + + Args: + x (torch.Tensor): Dense token embeddings of shape (batch, seq_len, dim). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing the sparse + dictionary features and the reconstructed dense embeddings. + """ + x_bar = x - self.decoder_bias + f = torch.relu(self.encoder_weight(x_bar)) + x_hat = self.decoder_weight(f) + self.decoder_bias + return f, x_hat + + def forward(self, **kwargs: Any) -> Dict[str, torch.Tensor]: + """Forward pass for the DILA model. + + Args: + **kwargs: Keyword arguments containing the input features and labels. + + Returns: + Dict[str, torch.Tensor]: A dictionary containing the total loss, + predicted probabilities, and true labels. + """ + input_sequence = kwargs[self.feature_keys[0]] + + if input_sequence.dim() == 3: + input_sequence = input_sequence.squeeze(-1) + + x = self.simulated_plm(input_sequence) + f, x_hat = self.sparse_autoencoder(x) + + attention_scores = torch.matmul(f, self.sparse_projection) + a_laat = torch.softmax(attention_scores, dim=1) + + x_att = torch.matmul(a_laat.transpose(-2, -1), x) + x_att_pooled = x_att.mean(dim=1) + logits = self.fc(x_att_pooled) + + mse_loss = nn.MSELoss()(x_hat, x) + l1_loss = torch.norm(f, p=1) + l2_loss = torch.norm(f, p=2) ** 2 + sae_loss = mse_loss + self.sparsity_penalty * (l1_loss + l2_loss) + + y_true = kwargs[self.label_key].float() + bce_loss = F.binary_cross_entropy_with_logits(logits, y_true) + + total_loss = bce_loss + sae_loss + + return { + "loss": total_loss, + "y_prob": torch.sigmoid(logits), + "y_true": y_true + } \ No newline at end of file diff --git a/tests/models/test_dila.py b/tests/models/test_dila.py new file mode 100644 index 000000000..16c19ac8b --- /dev/null +++ b/tests/models/test_dila.py @@ -0,0 +1,93 @@ +# Contributor: Nikhil Ajit +# NetID/Email: najit2@illinois.edu +# Paper Title: DILA: Dictionary Label Attention for Mechanistic Interpretability in High-dimensional Multi-label Medical Coding Prediction +# Paper Link: https://arxiv.org/abs/2409.10504 +# Description: Implementation of the DILA model utilizing a sparse autoencoder +# and a globally interpretable dictionary projection matrix for medical coding. + +import torch +import torch.nn as nn +import pytest +from typing import Dict, Optional + + +class DummyDILA(nn.Module): + """Dummy implementation of the DILA model for testing purposes. + + Attributes: + encoder (nn.Linear): Simulated sparse encoder. + sparse_projection (nn.Parameter): Simulated sparse projection matrix. + fc (nn.Linear): Final classification layer. + """ + + def __init__(self, embedding_dim: int, dictionary_size: int, num_labels: int) -> None: + super().__init__() + self.encoder = nn.Linear(embedding_dim, dictionary_size) + self.sparse_projection = nn.Parameter(torch.randn(dictionary_size, num_labels)) + self.fc = nn.Linear(embedding_dim, num_labels) + + def forward( + self, x_note: torch.Tensor, y_true: Optional[torch.Tensor] = None + ) -> Dict[str, torch.Tensor]: + """Forward pass for the dummy DILA model. + + Args: + x_note (torch.Tensor): Input token embeddings. + y_true (torch.Tensor, optional): Ground truth labels. Defaults to None. + + Returns: + Dict[str, torch.Tensor]: Output dictionary containing probabilities, + true labels, and loss if ground truth is provided. + """ + f = torch.relu(self.encoder(x_note)) + attention = torch.softmax(torch.matmul(f, self.sparse_projection), dim=1) + x_att = torch.matmul(attention.transpose(-2, -1), x_note) + logits = self.fc(x_att) + + logits = logits.mean(dim=1) + + output = {"y_prob": torch.sigmoid(logits), "y_true": y_true} + if y_true is not None: + output["loss"] = nn.BCEWithLogitsLoss()(logits, y_true.float()) + + return output + + +def test_dila_forward() -> None: + """Tests the forward pass, output shapes, and gradients of the DILA model.""" + batch_size = 4 + sequence_length = 128 + embedding_dim = 768 + dictionary_size = 100 + num_labels = 50 + + x_note = torch.randn(batch_size, sequence_length, embedding_dim) + y_true = torch.empty(batch_size, num_labels).random_(2) + + model = DummyDILA( + embedding_dim=embedding_dim, + dictionary_size=dictionary_size, + num_labels=num_labels + ) + + outputs = model(x_note=x_note, y_true=y_true) + + assert "y_prob" in outputs, "Output dictionary is missing 'y_prob'." + assert "y_true" in outputs, "Output dictionary is missing 'y_true'." + assert "loss" in outputs, "Output dictionary is missing 'loss'." + + expected_shape = (batch_size, num_labels) + actual_shape = outputs["y_prob"].shape + assert actual_shape == expected_shape, ( + f"Expected y_prob shape {expected_shape}, got {actual_shape}." + ) + + loss = outputs["loss"] + loss.backward() + + assert model.sparse_projection.grad is not None, ( + "Gradients did not compute for sparse_projection." + ) + assert torch.sum(torch.abs(model.sparse_projection.grad)) > 0, ( + "Gradients for sparse_projection are zero." + ) \ No newline at end of file