diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..505640dc8 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -174,6 +174,7 @@ API Reference models/pyhealth.models.MLP models/pyhealth.models.CNN models/pyhealth.models.RNN + models/pyhealth.models.mixlst models/pyhealth.models.GNN models/pyhealth.models.Transformer models/pyhealth.models.TransformersModel diff --git a/docs/api/models/pyhealth.models.mixlstm.rst b/docs/api/models/pyhealth.models.mixlstm.rst new file mode 100644 index 000000000..1735e2947 --- /dev/null +++ b/docs/api/models/pyhealth.models.mixlstm.rst @@ -0,0 +1,20 @@ +pyhealth.models.MixLSTM +======================= + +The MixLSTM model from Oh et al. 2020, "Relaxed Parameter Sharing: +Effectively Modeling Time-Varying Relationships in Clinical Time-Series" +(https://arxiv.org/abs/1906.02898). + +MixLSTM addresses the problem of *temporal conditional shift* in clinical +time-series, i.e., settings in which the relationship between input features +and outcomes changes over the course of a patient's hospital stay. Instead +of sharing a single set of LSTM parameters across all time steps, MixLSTM +maintains ``K`` independent LSTM cells and, at every time step, computes a +learned convex combination of their parameters using mixing coefficients. +This enables smooth transitions between different temporal dynamics without +requiring hard segment boundaries. + +.. autoclass:: pyhealth.models.MixLSTM + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/mimic3_synthetic_mixlstm.py b/examples/mimic3_synthetic_mixlstm.py new file mode 100644 index 000000000..5900f2ad7 --- /dev/null +++ b/examples/mimic3_synthetic_mixlstm.py @@ -0,0 +1,960 @@ +""" +MixLSTM Hyperparameter Search Experiment +Synthetic time-series regression task with PyHealth. + +All intermediate results (distributions, predictions, search metrics) +are kept in memory and passed directly to the visualization functions +instead of being written to / read from disk. +""" + +import os +import random +import logging +from dataclasses import dataclass, field + +import numpy as np +import pandas as pd +import torch +import torch.optim as optim +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import seaborn as sns +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.trainer import Trainer +from pyhealth.models import MixLSTM + + +# ====================================================================== +# MixLSTM Hyperparameter Search Experiment +# Synthetic time-series regression task with PyHealth +# ====================================================================== +# +# REQUIREMENTS +# Please do pip install seaborn~=0.13.2 to make sure the graphs are displyed and the abalation study +# runs smoothly. Added pip install seaborn~=0.13.2 in pyproject.toml file +# +# EXPERIMENTAL SETUP +# ------------------ +# Dataset: Synthetic non-stationary time-series regression. 1,000 sequences +# per split (train/val/test), length T=30, 3 input features. Inputs are +# 90% sparse. Targets from step l=10 onward are weighted combinations of +# prior inputs, where the weights drift by delta=0.05 per step to simulate +# distribution shift. +# +# Model: MixLSTM (PyHealth) with k=2 experts and lookback window l=10. +# Hidden size sampled from {100, 150, 300, 500, 700, 900, 1100}. +# 20 random-search runs per config, 30 epochs each, batch size 100. +# +# ABLATION STUDIES +# ---------------- +# 1) Learning rate sweep: Adam at lr in {0.0001, 0.0005, 0.001, 0.005, 0.01} +# 2) Optimizer comparison: Adam vs SGD at lr=0.001 +# 3) Every other parameter kept as default +# +# FINDINGS +# ---------------- +# 1. OPTIMIZER COMPARISON +# ---------------------------------------------------------------------------- +# Conclusion: Adam consistently outperformed SGD across training runs. +# +# | Optimizer | Lowest Val Loss (MSE) | Lowest Test Loss (MSE) | +# |-----------|-----------------------|------------------------| +# | Adam | 0.430089 | 0.467544 | +# | SGD | 16.388920 | 16.411073 | +# +# 2. LEARNING RATE VS. HIDDEN SIZE COMPARISON +# -------------------------------------------------------------------------------------------------------------------------- +# Format: (Validation Loss MSE - Test Loss MSE) +# +# | Hidden Size +# LR | 100 150 300 500 700 900 1100 +# ----------|--------------------------------------------------------------------------------------------------------------- +# 0.0001 | (-) (14.14 - 14.54) (10.76 - 11.05) (7.59 - 7.73) (5.53 - 5.44) (4.60 - 4.76) (4.31 - 4.39) +# 0.0005 | (10.78 - 11.06) (9.30 - 9.76) (5.31 - 5.87) (4.02 - 4.39) (2.81 - 3.09) (1.61 - 1.89) (1.33 - 1.52) +# 0.001 | (6.37 - 6.51) (4.49 - 4.62) (2.60 - 2.67) (1.26 - 1.31) (0.87 - 0.91) (0.69 - 0.77) (0.43 - 0.46) +# 0.005 | (2.20 - 2.28) (1.41 - 1.53) (0.68 - 0.77) (0.48 - 0.62) (0.89 - 1.01) (-) (0.68 - 0.74) +# 0.01 | (1.79 - 1.88) (1.42 - 1.47) (1.10 - 1.14) (1.03 - 1.10) (1.54 - 1.58) (0.91 - 0.98) (2.24 - 2.41) +# ========================================================================================================================== +# +# Conclution: +# LR = 0.0001 was the worst performer overall across all hidden sizes +# LR = 0.0005 was also the second word performer overall across almost all hidden states +# LR = 0.001 this was the learning rate that the paper used. LR value 0.01 and 0.005 were better in the lower hidden sizes +# eg 100, 150, 300, 500. For the reast LR 0.001 was the best choice overall +# LR = 0.05 this rate was the best overall for the lower hidden sizes from 100 to 500 but then had a spike +# at 700 but then managed to go down. Ideal for lower hidden rates +# LR = 0.01 this rate was quite spradic and unstable and it went up and down multiple times and is not recommended +# +# Overall Conclution of the entire study: +# Adam optimization gives the best results +# For learning rate 0.001 is great for hidden sizes above 500 and LR = 0.005 is the best for hidden size below 500 +# +# How to run Study +# pip install seaborn +# run the python file +# you will see 6 .png files diplaying the results as graphs +# + + + + + +# ────────────────────────────────────────────────────────────── +# In-memory result containers +# ────────────────────────────────────────────────────────────── + +@dataclass +class AblationResult: + """Container for every artefact produced by a single ablation run. + + Attributes: + learning_rate: The learning rate used for this ablation. + optimizer_name: Human-readable optimizer name (e.g. ``"Adam"``). + results_df: DataFrame with one row per random-search run. + Columns include ``Run``, ``k (experts)``, ``Hidden Size``, + ``Val Loss``, ``Test Loss``, ``num_params``, and ``epoch``. + k_dist: List of *T* numpy arrays representing the temporal + weight distribution at each time step. + d_dist: List of *T* numpy arrays representing the feature + weight distribution at each time step. + best_predictions: Dictionary with keys ``"pred"``, + ``"y_true"``, ``"k"``, ``"hidden_size"``, and ``"run"`` + for the model that achieved the lowest validation loss. + ``None`` if no valid model was produced. + best_model_state: ``state_dict`` (on CPU) of the best model. + ``None`` if no valid model was produced. + """ + learning_rate: float + optimizer_name: str + results_df: pd.DataFrame + k_dist: list[np.ndarray] + d_dist: list[np.ndarray] + best_predictions: dict | None = None + best_model_state: dict | None = None + + @property + def label(self) -> str: + """Return a human-readable label for plots and logs. + + Returns: + A string of the form ``" lr="``. + """ + + return f"{self.optimizer_name} lr={self.learning_rate}" + + +# ────────────────────────────────────────────────────────────── +# Configuration +# ────────────────────────────────────────────────────────────── + +SEED = 42 +NUM_SAMPLES = 1000 +T = 30 # sequence length +INPUT_DIM = 3 +PREV_USED_TIMESTAMPS = 10 # l +CHANGE_BETWEEN_TASKS = 0.05 # delta + +BATCH_SIZE = 100 +K_LIST = [2] +HIDDEN_SIZE_LIST = [100, 150, 300, 500, 700, 900, 1100] +NUM_RUNS = 20 # default set to 20 +MAX_EPOCHS = 30 # default set to 30 + +SAVE_DIR = "." + +# Visualization +MAX_MSE = 100 +ABLATION_LRS = [0.0001, 0.0005, 0.001, 0.005, 0.01] + + +# ────────────────────────────────────────────────────────────── +# Utility functions +# ────────────────────────────────────────────────────────────── + +def set_seed(seed: int) -> None: + """Set random seeds for Python, NumPy, and PyTorch for reproducibility. + + Args: + seed: Integer seed value applied to every RNG. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def get_device() -> torch.device: + """Detect and return the best available compute device. + + Returns: + ``torch.device("cuda")`` when a CUDA GPU is available, + otherwise ``torch.device("cpu")``. + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Running on device: {device}") + return device + + +# ────────────────────────────────────────────────────────────── +# Data generation +# ────────────────────────────────────────────────────────────── + +def convert_distb(a: np.ndarray) -> np.ndarray: + """Min-max normalize an array and rescale it to sum to one. + + The array is first shifted and scaled to the [0, 1] range via + min-max normalization, then divided by its sum so that it + forms a valid discrete probability distribution. + + Args: + a: 1-D numpy array of raw (un-normalized) weights. + + Returns: + A 1-D numpy array of the same shape whose elements are + non-negative and sum to 1. + """ + a_min = min(a) + a_max = max(a) + a = (a - a_min) / (a_max - a_min) + a_sum = sum(a) + a = a / a_sum + return a + + +def generate_distributions( + T: int, + prev_used_timestamps: int, + input_dim: int, + change_between_tasks: float, +) -> tuple[list[np.ndarray], list[np.ndarray]]: + """Generate time-varying weight distributions for synthetic targets. + + Creates ``k_dist`` (temporal) and ``d_dist`` (feature) weight + vectors that drift by a small delta at each step beyond the + lookback window, simulating non-stationary distribution shift. + + For time steps before *prev_used_timestamps*, both distributions + are uniform placeholders. At step *prev_used_timestamps* the + distributions are initialized randomly, and at each subsequent + step a uniform perturbation in ``[-change_between_tasks, + +change_between_tasks]`` is added before re-normalization. + + Args: + T: Total sequence length. + prev_used_timestamps: Lookback window size (*l*). + Distributions before this index are uniform placeholders. + input_dim: Number of input features per time step. + change_between_tasks: Maximum per-step drift (*delta*) + applied uniformly at random to each weight element. + + Returns: + A tuple ``(k_dist, d_dist)`` where: + + * ``k_dist`` is a list of *T* arrays, each of shape + ``(prev_used_timestamps,)``. + * ``d_dist`` is a list of *T* arrays, each of shape + ``(input_dim,)``. + """ + k_dist = [] + d_dist = [] + for i in range(T): + if i < prev_used_timestamps: + k_dist.append(np.ones(prev_used_timestamps)) + d_dist.append(np.ones(input_dim)) + elif i == prev_used_timestamps: + k_dist.append(convert_distb(np.random.uniform(size=(prev_used_timestamps,)))) + d_dist.append(convert_distb(np.random.uniform(size=(input_dim,)))) + else: + delta_t = np.random.uniform( + -change_between_tasks, change_between_tasks, size=(prev_used_timestamps,) + ) + delta_d = np.random.uniform( + -change_between_tasks, change_between_tasks, size=(input_dim,) + ) + k_dist.append(convert_distb(k_dist[i - 1] + delta_t)) + d_dist.append(convert_distb(d_dist[i - 1] + delta_d)) + return k_dist, d_dist + + +def generate_xy( + num_samples: int, + T: int, + input_dim: int, + prev_used_timestamps: int, + k_dist: list[np.ndarray], + d_dist: list[np.ndarray], +) -> tuple[np.ndarray, np.ndarray]: + + """Generate sparse input sequences and their regression targets. + + Inputs are 90 % sparse (zeros) with the remaining 10 % drawn + uniformly from ``[0, 100)``. For time steps ``t >= l`` the target + is ``x[t-l:t, :] @ d_dist[t] @ k_dist[t]``; earlier targets are + ones (placeholders). + + Args: + num_samples: Number of independent sequences to generate. + T: Sequence length (number of time steps). + input_dim: Dimensionality of input features. + prev_used_timestamps: Lookback window size (*l*). + k_dist: Temporal weight distributions as returned by + :func:`generate_distributions`. + d_dist: Feature weight distributions as returned by + :func:`generate_distributions`. + + Returns: + A tuple ``(x, y)`` where: + + * ``x`` has shape ``(num_samples, T, input_dim)``. + * ``y`` has shape ``(num_samples, T, 1)``. + """ + + x_size = num_samples * T * input_dim + x = np.zeros(x_size) + sparse_count = int(x_size / 10) + x[np.random.choice(x_size, size=sparse_count, replace=False)] = ( + np.random.uniform(size=sparse_count) * 100 + ) + x = np.resize(x, (num_samples, T, input_dim)) + + y = np.ones((num_samples, T, 1)) + for i in range(T): + if i >= prev_used_timestamps: + y[:, i, 0] = np.matmul( + np.matmul(x[:, i - prev_used_timestamps : i, :], d_dist[i]), + k_dist[i], + ) + return x, y + + +# ────────────────────────────────────────────────────────────── +# PyHealth dataset helpers +# ────────────────────────────────────────────────────────────── + +def make_dataset(x: np.ndarray, y: np.ndarray, split_name: str) -> "SampleDataset": + """Wrap numpy arrays into a PyHealth ``SampleDataset``. + + Each sequence is registered as a separate patient with a single + visit containing the full time-series. + + Args: + x: Input tensor of shape ``(N, T, D)``. + y: Target tensor of shape ``(N, T, 1)``. + split_name: Identifier for the split (e.g. ``"train"``, + ``"val"``, ``"test"``). Used in patient IDs and as the + PyHealth dataset name suffix. + + Returns: + A PyHealth ``SampleDataset`` ready to be passed to + ``get_dataloader``. + """ + + samples = [ + { + "patient_id": f"{split_name}-patient-{i}", + "visit_id": "visit-0", + "series": x[i].tolist(), + "y": y[i].squeeze(-1).tolist(), + } + for i in range(len(x)) + ] + return create_sample_dataset( + samples=samples, + input_schema={"series": "tensor"}, + output_schema={"y": "tensor"}, + dataset_name=f"mixlstm_{split_name}", + ) + + +def build_dataloaders( + k_dist, d_dist, num_samples, T, input_dim, prev_used_timestamps, batch_size +) -> tuple["SampleDataset", torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]: + + """Generate train / val / test splits and wrap them in DataLoaders. + + Three independent datasets are synthesized from the same + underlying distributions so that the only source of variance is + the random sparse masking and the ordering of non-zero entries. + + Args: + k_dist: Temporal weight distributions (see + :func:`generate_distributions`). + d_dist: Feature weight distributions (see + :func:`generate_distributions`). + num_samples: Number of sequences per split. + T: Sequence length. + input_dim: Number of input features. + prev_used_timestamps: Lookback window size (*l*). + batch_size: Mini-batch size for every DataLoader. + + Returns: + A tuple ``(train_dataset, train_loader, val_loader, + test_loader)``. The raw ``train_dataset`` is also returned + because ``MixLSTM.__init__`` requires it to infer schema + metadata. + """ + x_train, y_train = generate_xy( + num_samples, T, input_dim, prev_used_timestamps, k_dist, d_dist + ) + x_val, y_val = generate_xy( + num_samples, T, input_dim, prev_used_timestamps, k_dist, d_dist + ) + x_test, y_test = generate_xy( + num_samples, T, input_dim, prev_used_timestamps, k_dist, d_dist + ) + + train_data = make_dataset(x_train, y_train, "train") + val_data = make_dataset(x_val, y_val, "val") + test_data = make_dataset(x_test, y_test, "test") + + train_loader = get_dataloader(train_data, batch_size=batch_size, shuffle=True) + val_loader = get_dataloader(val_data, batch_size=batch_size, shuffle=True) + test_loader = get_dataloader(test_data, batch_size=batch_size, shuffle=True) + + return train_data, train_loader, val_loader, test_loader + + +# ────────────────────────────────────────────────────────────── +# Training & evaluation +# ────────────────────────────────────────────────────────────── + +def collect_predictions(model, test_loader, device) -> dict[str, np.ndarray]: + """Run inference on *test_loader* and collect predictions. + + The model is set to eval mode and gradients are disabled. Only + time steps from index *l* onward (the non-placeholder region) + are retained. + + Args: + model: A trained ``MixLSTM`` model instance. + test_loader: DataLoader yielding test batches. + device: Device the model resides on. + + Returns: + A dictionary with two keys: + + * ``"pred"`` — flattened 1-D numpy array of predicted values. + * ``"y_true"`` — flattened 1-D numpy array of ground-truth + values, aligned element-wise with ``"pred"``. + """ + model.eval() + l = model.prev_used_timestamps + preds, y_trues = [], [] + + with torch.no_grad(): + for batch in test_loader: + batch = { + k_: v.to(device) if isinstance(v, torch.Tensor) else v + for k_, v in batch.items() + } + output = model(**batch) + preds.append(output["y_prob"][:, l:, :].cpu().numpy()) + y_trues.append(output["y_true"][:, l:, :].cpu().numpy()) + + return { + "pred": np.concatenate(preds, axis=0).flatten(), + "y_true": np.concatenate(y_trues, axis=0).flatten(), + } + + +def run_hyperparameter_search( + train_data, + train_loader, + val_loader, + test_loader, + device, + prev_used_timestamps, + k_list, + hidden_size_list, + num_runs, + max_epochs, + learning_rate, + optimizer_class=optim.Adam, +) -> tuple[pd.DataFrame, dict | None, dict | None]: + + """Execute a random hyperparameter search over MixLSTM configs. + + Each run samples ``k`` (number of experts) and ``hidden_size`` + uniformly from the provided lists, trains for *max_epochs*, and + records validation / test loss. The model with the lowest + validation loss is retained. + + Args: + train_data: PyHealth ``SampleDataset`` used to initialize + ``MixLSTM`` (needed for schema inference). + train_loader: DataLoader for the training split. + val_loader: DataLoader for the validation split. + test_loader: DataLoader for the test split. + device: Compute device (CPU or CUDA). + prev_used_timestamps: Lookback window size (*l*) passed to + ``MixLSTM``. + k_list: Candidate values for the number of mixture experts. + hidden_size_list: Candidate values for the LSTM hidden + dimension. + num_runs: Total number of random configurations to evaluate. + max_epochs: Training epochs per run. + learning_rate: Learning rate passed to the optimizer. + optimizer_class: PyTorch optimizer class (e.g. + ``torch.optim.Adam``). + + Returns: + A tuple ``(results_df, best_predictions, best_model_state)`` + where: + + * ``results_df`` — DataFrame with columns ``Run``, + ``k (experts)``, ``Hidden Size``, ``Val Loss``, + ``Test Loss``, ``num_params``, and ``epoch``. + * ``best_predictions`` — dictionary as returned by + :func:`collect_predictions`, augmented with ``"k"``, + ``"hidden_size"``, and ``"run"`` keys. ``None`` when no + valid model was found. + * ``best_model_state`` — CPU ``state_dict`` of the + best-performing model. ``None`` when no valid model was + found. + """ + + results = [] + best_val_loss_overall = np.inf + best_predictions = None + best_model_state = None + + for run in range(num_runs): + k = random.choice(k_list) + hidden_size = random.choice(hidden_size_list) + + print(f"\n{'=' * 60}") + print(f"Run {run + 1}/{num_runs} | k (num_experts): {k} | hidden_size: {hidden_size}") + print("=" * 60) + + model = MixLSTM( + dataset=train_data, + num_experts=k, + hidden_size=hidden_size, + prev_used_timestamps=prev_used_timestamps, + ) + model = model.to(device) + + trainer = Trainer(model=model, device=device) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + optimizer_class=optimizer_class, + optimizer_params={"lr": learning_rate}, + epochs=max_epochs, + monitor="loss", + monitor_criterion="min", + ) + + print(f"\nEvaluating Best Model for Run {run + 1}...") + val_metrics = trainer.evaluate(val_loader) + test_metrics = trainer.evaluate(test_loader) + + val_loss = val_metrics.get("loss", None) + test_loss = test_metrics.get("loss", None) + + if val_loss < best_val_loss_overall: + best_val_loss_overall = val_loss + print(f" New best val loss: {val_loss:.6f}") + predictions = collect_predictions(model, test_loader, device) + predictions["k"] = k + predictions["hidden_size"] = hidden_size + predictions["run"] = run + best_predictions = predictions + best_model_state = {k_: v.cpu().clone() for k_, v in model.state_dict().items()} + + results.append({ + "Run": run + 1, + "k (experts)": k, + "Hidden Size": hidden_size, + "Val Loss": val_loss, + "Test Loss": test_loss, + "num_params": sum(p.numel() for p in model.parameters() if p.requires_grad), + "epoch": max_epochs, + }) + + return pd.DataFrame(results), best_predictions, best_model_state + + +# ────────────────────────────────────────────────────────────── +# Ablation study — learning rate sweep +# ────────────────────────────────────────────────────────────── + +def run_single_ablation( + learning_rate: float, + optimizer_class=optim.Adam, + optimizer_name: str = "Adam", +) -> AblationResult: + + """Run the full hyperparameter search for one (optimizer, lr) pair. + + This is the main entry point for a single ablation cell. It + seeds RNGs, generates data, builds data loaders, trains all + random-search runs, and packages the results into an + :class:`AblationResult`. + + Args: + learning_rate: Learning rate forwarded to the optimizer. + optimizer_class: PyTorch optimizer class to use (e.g. + ``torch.optim.Adam``, ``torch.optim.SGD``). + optimizer_name: Human-readable name stored in the result + object and used in plot labels. + + Returns: + An :class:`AblationResult` containing the results DataFrame, + weight distributions, best predictions, and best model state. + """ + + set_seed(SEED) + device = get_device() + logging.getLogger("pyhealth.trainer").setLevel(logging.WARNING) + + if device.type == "cuda": + torch.set_default_device(device) + + k_dist, d_dist = generate_distributions( + T, PREV_USED_TIMESTAMPS, INPUT_DIM, CHANGE_BETWEEN_TASKS + ) + + train_data, train_loader, val_loader, test_loader = build_dataloaders( + k_dist, d_dist, NUM_SAMPLES, T, INPUT_DIM, PREV_USED_TIMESTAMPS, BATCH_SIZE + ) + + print(f"\n{'#' * 60}") + print(f" ABLATION — optimizer = {optimizer_name}, learning_rate = {learning_rate}") + print(f"{'#' * 60}") + + results_df, best_predictions, best_model_state = run_hyperparameter_search( + train_data=train_data, + train_loader=train_loader, + val_loader=val_loader, + test_loader=test_loader, + device=device, + prev_used_timestamps=PREV_USED_TIMESTAMPS, + k_list=K_LIST, + hidden_size_list=HIDDEN_SIZE_LIST, + num_runs=NUM_RUNS, + max_epochs=MAX_EPOCHS, + learning_rate=learning_rate, + optimizer_class=optimizer_class, + ) + + best = results_df.sort_values(by="Test Loss").reset_index(drop=True) + print(f"\nTop 5 results for {optimizer_name} lr={learning_rate}:") + print(best.head(5)) + + return AblationResult( + learning_rate=learning_rate, + optimizer_name=optimizer_name, + results_df=results_df, + k_dist=k_dist, + d_dist=d_dist, + best_predictions=best_predictions, + best_model_state=best_model_state, + ) + + +def run_all_ablations() -> list[AblationResult]: + """Run the learning-rate sweep ablation using the Adam optimizer. + + Iterates over every learning rate in :data:`ABLATION_LRS`, runs + the full hyperparameter search for each, and prints a summary + table. + + Returns: + A list of :class:`AblationResult` objects, one per learning + rate, in the same order as :data:`ABLATION_LRS`. + """ + + + ablation_results = [ + run_single_ablation(lr, optim.Adam, "Adam") for lr in ABLATION_LRS + ] + _print_summary("Learning Rate Sweep (Adam)", ablation_results) + return ablation_results + + +# ────────────────────────────────────────────────────────────── +# Ablation study — optimizer comparison (Adam vs SGD) +# ────────────────────────────────────────────────────────────── + +ABLATION_OPTIMIZER_LR = 0.001 # fixed LR used for the optimizer comparison + +def ablations_optimizing_adam() -> AblationResult: + + """Run the Adam ablation at the fixed comparison learning rate. + + Returns: + An :class:`AblationResult` for Adam at + lr = :data:`ABLATION_OPTIMIZER_LR`. + """ + + + return run_single_ablation(ABLATION_OPTIMIZER_LR, optim.Adam, "Adam") + + +def ablations_optimizing_sgd() -> AblationResult: + + """Run the SGD ablation at the fixed comparison learning rate. + + Returns: + An :class:`AblationResult` for SGD at + lr = :data:`ABLATION_OPTIMIZER_LR`. + """ + + return run_single_ablation(ABLATION_OPTIMIZER_LR, optim.SGD, "SGD") + + +def run_optimizer_ablations() -> list[AblationResult]: + """Compare Adam and SGD at a fixed learning rate. + + Both optimizers are trained with + lr = :data:`ABLATION_OPTIMIZER_LR` and the results are printed + side by side. + + Returns: + A two-element list ``[adam_result, sgd_result]``. + """ + + results = [ + ablations_optimizing_adam(), + ablations_optimizing_sgd(), + ] + _print_summary("Optimizer Comparison (Adam vs SGD)", results) + return results + + +def _print_summary(title: str, ablation_results: list[AblationResult])-> None: + + """Pretty-print a summary table for a list of ablation results. + + For each :class:`AblationResult` the row with the lowest test + loss is selected and its key metrics are displayed. + + Args: + title: Header string printed above the table. + ablation_results: Results to summarize. + """ + + summary_rows = [] + for result in ablation_results: + best_row = result.results_df.sort_values(by="Test Loss").iloc[0] + summary_rows.append({ + "Optimizer": result.optimizer_name, + "Learning Rate": result.learning_rate, + "Best Val Loss": best_row["Val Loss"], + "Best Test Loss": best_row["Test Loss"], + "k (experts)": best_row["k (experts)"], + "Hidden Size": best_row["Hidden Size"], + "num_params": best_row["num_params"], + }) + + summary_df = pd.DataFrame(summary_rows) + print("\n" + "=" * 60) + print(f" {title}") + print("=" * 60) + print(summary_df.to_string(index=False)) + + +# ────────────────────────────────────────────────────────────── +# Visualization (all functions take in-memory data) +# ────────────────────────────────────────────────────────────── + +def visualize_hyperparameter_search(ablation_results: list[AblationResult], prefix: str = "") -> None: + + """Plot MSE loss vs. hidden size for every learning rate. + + Individual run results are shown as translucent scatter points + and per-hidden-size means are overlaid as solid (validation) and + dashed (test) lines. + + Args: + ablation_results: One :class:`AblationResult` per learning + rate / optimizer configuration. + prefix: String prepended to the output filename (e.g. + ``"lr_sweep_"``). + """ + + print("--- 1. Analyzing Hyperparameter Search (Ablation) ---") + + plt.figure(figsize=(12, 7)) + palette = sns.color_palette("tab10", len(ablation_results)) + + for i, result in enumerate(ablation_results): + tag = result.label + df = result.results_df.copy() + df = df[(df["Val Loss"] <= MAX_MSE) & (df["Test Loss"] <= MAX_MSE)] + + best = df.sort_values(by="Val Loss").head(1) + print( + f" {tag} Best Val Loss: {best['Val Loss'].values[0]:.6f} " + f"(Hidden Size={best['Hidden Size'].values[0]})" + ) + + color = palette[i] + + sns.scatterplot( + data=df, x="Hidden Size", y="Val Loss", + label=f"Val ({tag})", color=color, + marker="o", alpha=0.4, s=40, + ) + sns.scatterplot( + data=df, x="Hidden Size", y="Test Loss", + label=f"Test ({tag})", color=color, + marker="x", alpha=0.4, s=40, + ) + + val_mean = df.groupby("Hidden Size")["Val Loss"].mean().sort_index() + test_mean = df.groupby("Hidden Size")["Test Loss"].mean().sort_index() + plt.plot(val_mean.index, val_mean.values, color=color, linewidth=2, linestyle="-") + plt.plot(test_mean.index, test_mean.values, color=color, linewidth=2, linestyle="--") + + plt.title("Ablation: MSE Loss vs. Hidden Size by Learning Rate") + plt.xlabel("Hidden Size") + plt.ylabel("MSE Loss") + plt.legend(title="Loss Type / LR", bbox_to_anchor=(1.05, 1), loc="upper left") + plt.grid(True, linestyle="--", alpha=0.6) + plt.ylim(0, MAX_MSE) + plt.tight_layout() + + out_path = os.path.join(SAVE_DIR, f"{prefix}ablation_loss_vs_hidden_size.png") + plt.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {out_path}") + + +def visualize_predictions(ablation_results: list[AblationResult], num_samples: int = 3, prefix: str = "") -> None: + + """Plot predicted vs. true values for sample test sequences. + + One row of subplots is created per ablation result, with + *num_samples* columns each showing a randomly chosen test + sequence. + + Args: + ablation_results: Ablation results whose + ``best_predictions`` field will be visualized. Entries + with ``best_predictions is None`` are silently skipped. + num_samples: Number of randomly selected test sequences to + plot per ablation. + prefix: String prepended to the output filename. + """ + + print("\n--- 2. Analyzing Predictions (Ablation) ---") + + # Only include results that have saved predictions + valid_results = [r for r in ablation_results if r.best_predictions is not None] + if not valid_results: + print(" No predictions available to plot.") + return + + fig, axes = plt.subplots(len(valid_results), num_samples, figsize=(15, 4 * len(valid_results))) + if len(valid_results) == 1: + axes = [axes] + + for row, result in enumerate(valid_results): + tag = result.label + y_true_flat = result.best_predictions["y_true"] + pred_flat = result.best_predictions["pred"] + + l = result.best_predictions.get("k", PREV_USED_TIMESTAMPS) + eval_steps = T - l + num_test_samples = len(y_true_flat) // eval_steps + limit = num_test_samples * eval_steps + + y_true = np.reshape(y_true_flat[:limit], (num_test_samples, eval_steps)) + pred = np.reshape(pred_flat[:limit], (num_test_samples, eval_steps)) + + sample_indices = np.random.choice( + num_test_samples, min(num_samples, num_test_samples), replace=False + ) + + for col, sample_idx in enumerate(sample_indices): + ax = axes[row][col] + ax.plot(y_true[sample_idx], label="True", color="blue", marker="o", markersize=4) + ax.plot( + pred[sample_idx], label="Predicted", + color="red", linestyle="--", marker="x", markersize=4 + ) + ax.set_title(f"{tag} | Sample #{sample_idx}") + ax.set_xlabel("Time Steps") + ax.set_ylabel("Value") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + out_path = os.path.join(SAVE_DIR, f"{prefix}ablation_predictions.png") + plt.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {out_path}") + + +def visualize_synthetic_shift(ablation_results: list[AblationResult], prefix: str = "") -> None: + """Plot the ``k_dist`` heatmap for each ablation's task distributions. + + Each subplot shows how the temporal weight distribution evolves + across the *T* time steps (x-axis) over the *l* lookback + positions (y-axis). + + Args: + ablation_results: Ablation results whose ``k_dist`` fields + will be visualized. + prefix: String prepended to the output filename. + """ + + + print("\n--- 3. Analyzing Synthetic Data Shift (Ablation) ---") + + fig, axes = plt.subplots(1, len(ablation_results), figsize=(6 * len(ablation_results), 5)) + if len(ablation_results) == 1: + axes = [axes] + + for i, result in enumerate(ablation_results): + k_dist_matrix = np.stack(result.k_dist) + sns.heatmap(k_dist_matrix.T, cmap="viridis", ax=axes[i], cbar_kws={"label": "Weight"}) + axes[i].set_title(f"k_dist Shift ({result.label})") + axes[i].set_xlabel("Time Step (T)") + axes[i].set_ylabel("Lookback Step (l)") + + plt.tight_layout() + out_path = os.path.join(SAVE_DIR, f"{prefix}ablation_synthetic_shift.png") + plt.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {out_path}") + + +def run_all_visualizations(ablation_results: list[AblationResult], prefix: str = "") -> None: + """Generate all three ablation plot types from in-memory results. + + Delegates to :func:`visualize_hyperparameter_search`, + :func:`visualize_synthetic_shift`, and + :func:`visualize_predictions`. + + Args: + ablation_results: The list of :class:`AblationResult` objects + to visualize. + prefix: Filename prefix forwarded to each plotting function. + """ + + visualize_hyperparameter_search(ablation_results, prefix=prefix) + visualize_synthetic_shift(ablation_results, prefix=prefix) + visualize_predictions(ablation_results, prefix=prefix) + + +# ────────────────────────────────────────────────────────────── +# Main +# ────────────────────────────────────────────────────────────── + +def main() -> None: + # Learning-rate sweep (Adam only, multiple LRs) + lr_results = run_all_ablations() + run_all_visualizations(lr_results, prefix="lr_sweep_") + + # Optimizer comparison (Adam vs SGD at fixed LR) + optimizer_results = run_optimizer_ablations() + run_all_visualizations(optimizer_results, prefix="optim_comp_") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/mixlstm/mixlstm_test.ipynb b/examples/mixlstm/mixlstm_test.ipynb new file mode 100644 index 000000000..b46a02a84 --- /dev/null +++ b/examples/mixlstm/mixlstm_test.ipynb @@ -0,0 +1,255 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5b176866", + "metadata": {}, + "source": [ + "# mixLSTM Test Notebook\n", + "\n", + "This notebook demonstrates how to create a small sample dataset, initialize `mixLSTM`, and run a forward pass." + ] + }, + { + "cell_type": "markdown", + "id": "2cb149bb", + "metadata": {}, + "source": [ + "## 1. Environment Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "08ef7f3e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on device: cpu\n" + ] + } + ], + "source": [ + "import random\n", + "import numpy as np\n", + "import torch\n", + "\n", + "from pyhealth.datasets import create_sample_dataset, get_dataloader\n", + "from pyhealth.datasets.splitter import split_by_sample\n", + "\n", + "SEED = 42\n", + "random.seed(SEED)\n", + "np.random.seed(SEED)\n", + "torch.manual_seed(SEED)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(SEED)\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Running on device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "id": "1327ae03", + "metadata": {}, + "source": [ + "## 2. Create Sample Dataset\n", + "\n", + "We create synthetic time-series samples with shape `(T, input_dim)` stored under the input key `series`." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d3fea9bc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Label label vocab: {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Created dataset with 200 samples\n", + "Input schema: {'series': 'tensor'}\n", + "Output schema: {'label': 'multiclass'}\n" + ] + } + ], + "source": [ + "# Dataset parameters\n", + "num_samples = 200\n", + "T = 50 # sequence length\n", + "input_dim = 3\n", + "n_classes = 5\n", + "\n", + "samples = [\n", + " {\n", + " \"patient_id\": f\"patient-{i}\",\n", + " \"visit_id\": \"visit-0\",\n", + " \"series\": torch.randn(T, input_dim).numpy().tolist(),\n", + " \"label\": int(i % n_classes),\n", + " }\n", + " for i in range(num_samples)\n", + "]\n", + "\n", + "input_schema = {\"series\": \"tensor\"}\n", + "output_schema = {\"label\": \"multiclass\"}\n", + "\n", + "dataset = create_sample_dataset(\n", + " samples=samples,\n", + " input_schema=input_schema,\n", + " output_schema=output_schema,\n", + " dataset_name=\"mixlstm_demo\",\n", + ")\n", + "\n", + "print(f\"Created dataset with {len(dataset)} samples\")\n", + "print(f\"Input schema: {dataset.input_schema}\")\n", + "print(f\"Output schema: {dataset.output_schema}\")" + ] + }, + { + "cell_type": "markdown", + "id": "f2a5248a", + "metadata": {}, + "source": [ + "## 3. Split Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "168a6299", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: 140 samples\n", + "Val: 30 samples\n", + "Test: 30 samples\n" + ] + } + ], + "source": [ + "train_data, val_data, test_data = split_by_sample(dataset, [0.7, 0.15, 0.15], seed=SEED)\n", + "\n", + "print(f\"Train: {len(train_data)} samples\")\n", + "print(f\"Val: {len(val_data)} samples\")\n", + "print(f\"Test: {len(test_data)} samples\")\n", + "\n", + "train_loader = get_dataloader(train_data, batch_size=8, shuffle=True)\n", + "val_loader = get_dataloader(val_data, batch_size=8, shuffle=False)\n", + "test_loader = get_dataloader(test_data, batch_size=8, shuffle=False)" + ] + }, + { + "cell_type": "markdown", + "id": "63cad735", + "metadata": {}, + "source": [ + "## 4. Initialize `mixLSTM` Model" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "215bf9bd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model created with 180390 parameters\n" + ] + } + ], + "source": [ + "from pyhealth.models import MixLSTM\n", + "\n", + "model = MixLSTM(dataset=dataset, num_experts=10, hidden_size=64)\n", + "model = model.to(device)\n", + "print(f\"Model created with {sum(p.numel() for p in model.parameters())} parameters\")" + ] + }, + { + "cell_type": "markdown", + "id": "e978eecb", + "metadata": {}, + "source": [ + "## 5. Test Forward Pass" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "49633e2c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output keys: dict_keys(['logit', 'y_prob', 'loss', 'y_true'])\n", + "Loss: 1.6047\n", + "Logits shape: torch.Size([8, 5])\n" + ] + } + ], + "source": [ + "# Fetch a batch and run a forward pass\n", + "batch = next(iter(train_loader))\n", + "\n", + "with torch.no_grad():\n", + " outputs = model(**batch)\n", + "\n", + "print(\"Output keys:\", outputs.keys())\n", + "print(f\"Loss: {outputs['loss'].item():.4f}\")\n", + "print(f\"Logits shape: {outputs['logit'].shape}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4eaa5688", + "metadata": {}, + "source": [ + "## 7. Notes\n", + "\n", + "- Adjust `input_dim`, `T`, and model hyperparameters to match your real dataset.\n", + "- `mixLSTM` expects input tensors of shape `(batch, seq_len, input_dim)`.\n", + "- If you encounter device mismatches, ensure tensors and model are on the same `device`." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/mixlstm/mixlstm_test_ablation.ipynb b/examples/mixlstm/mixlstm_test_ablation.ipynb new file mode 100644 index 000000000..1686eb264 --- /dev/null +++ b/examples/mixlstm/mixlstm_test_ablation.ipynb @@ -0,0 +1,890 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5b176866", + "metadata": {}, + "source": [ + "# mixLSTM Test Notebook\n", + "\n", + "This notebook demonstrates how to create a small sample dataset, initialize `mixLSTM`, and run a forward pass." + ] + }, + { + "cell_type": "markdown", + "id": "2cb149bb", + "metadata": {}, + "source": [ + "## 1. Environment Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "08ef7f3e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on device: cpu\n" + ] + } + ], + "source": [ + "import random\n", + "import numpy as np\n", + "import torch\n", + "\n", + "from pyhealth.datasets import create_sample_dataset, get_dataloader\n", + "from pyhealth.datasets.splitter import split_by_sample\n", + "\n", + "SEED = 42\n", + "random.seed(SEED)\n", + "np.random.seed(SEED)\n", + "torch.manual_seed(SEED)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(SEED)\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Running on device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "id": "1327ae03", + "metadata": {}, + "source": [ + "## 2. Create Sample Dataset\n", + "\n", + "Generates the synthetic per-timestep regression task from the MLHC 2019 paper.\n", + "- **Input `x`**: Sparse random matrix of shape `(T, input_dim)` with ~10% non-zero entries.\n", + "- **Target `y`**: At each timestep `t >= l` (`l = prev_used_timestamps`), `y[t]` is a weighted sum over the previous `l` timesteps and `input_dim` features. The weight distributions (`k_dist`, `d_dist`) drift slowly by `change_between_tasks` per step. For `t < l`, `y[t] = 1` (ignored during loss)." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d3fea9bc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Created dataset with 1000 samples\n", + "Input schema: {'series': 'tensor'}\n", + "Output schema: {'y': 'tensor'}\n" + ] + } + ], + "source": [ + "# Dataset parameters (matching MLHC2019 synthetic task)\n", + "num_samples = 1000\n", + "T = 30 # sequence length\n", + "input_dim = 3\n", + "prev_used_timestamps = 10\n", + "change_between_tasks = 0.05\n", + "\n", + "def convert_distb(a):\n", + " a_min = min(a)\n", + " a_max = max(a)\n", + " a = (a-a_min)/(a_max-a_min)\n", + " a_sum = sum(a)\n", + " a = a/a_sum\n", + " return a\n", + "\n", + "\"\"\"Gen X\"\"\"\n", + "x_size = num_samples*T*input_dim\n", + "x=np.zeros(x_size)\n", + "x[np.random.choice(x_size, size=int(x_size/10), replace=False)]=np.random.uniform(size=int(x_size/10))*100\n", + "x=np.resize(x, (num_samples,T,input_dim))\n", + "\n", + "\"\"\"Gen y\"\"\"\n", + "k_dist = []\n", + "d_dist = []\n", + "for i in range(T):\n", + " if i=prev_used_timestamps:\n", + " y[:,i,0] = np.matmul(np.matmul(x[:,i-prev_used_timestamps:i,:],d_dist[i]), k_dist[i])\n", + "\n", + "# Build samples: per-timestep regression target matching original MLHC2019 synthetic setup\n", + "samples = [\n", + " {\n", + " \"patient_id\": f\"patient-{i}\",\n", + " \"visit_id\": \"visit-0\",\n", + " \"series\": x[i].tolist(),\n", + " \"y\": y[i].squeeze(-1).tolist(),\n", + " }\n", + " for i in range(num_samples)\n", + "]\n", + "\n", + "input_schema = {\"series\": \"tensor\"}\n", + "output_schema = {\"y\": \"tensor\"}\n", + "\n", + "dataset = create_sample_dataset(\n", + " samples=samples,\n", + " input_schema=input_schema,\n", + " output_schema=output_schema,\n", + " dataset_name=\"mixlstm_demo\",\n", + ")\n", + "\n", + "print(f\"Created dataset with {len(dataset)} samples\")\n", + "print(f\"Input schema: {dataset.input_schema}\")\n", + "print(f\"Output schema: {dataset.output_schema}\")" + ] + }, + { + "cell_type": "markdown", + "id": "f2a5248a", + "metadata": {}, + "source": [ + "## 3. Split Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "168a6299", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: 700 samples\n", + "Val: 150 samples\n", + "Test: 150 samples\n" + ] + } + ], + "source": [ + "train_data, val_data, test_data = split_by_sample(dataset, [0.7, 0.15, 0.15], seed=SEED)\n", + "\n", + "print(f\"Train: {len(train_data)} samples\")\n", + "print(f\"Val: {len(val_data)} samples\")\n", + "print(f\"Test: {len(test_data)} samples\")\n", + "\n", + "train_loader = get_dataloader(train_data, batch_size=8, shuffle=True)\n", + "val_loader = get_dataloader(val_data, batch_size=8, shuffle=False)\n", + "test_loader = get_dataloader(test_data, batch_size=8, shuffle=False)" + ] + }, + { + "cell_type": "markdown", + "id": "63cad735", + "metadata": {}, + "source": [ + "## 4. Initialize `mixLSTM` Model" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "215bf9bd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model created with 2021062 parameters\n", + "Mode: None, Per-timestep: True\n" + ] + } + ], + "source": [ + "from pyhealth.models import MixLSTM\n", + "\n", + "model = MixLSTM(dataset=dataset, num_experts=2, hidden_size=500,\n", + " prev_used_timestamps=prev_used_timestamps)\n", + "model = model.to(device)\n", + "print(f\"Model created with {sum(p.numel() for p in model.parameters())} parameters\")\n", + "print(f\"Mode: {model.mode}, Per-timestep: {model._per_timestep}\")" + ] + }, + { + "cell_type": "markdown", + "id": "e978eecb", + "metadata": {}, + "source": [ + "## 5. Test Forward Pass" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "49633e2c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output keys: dict_keys(['logit', 'y_prob', 'loss', 'y_true'])\n", + "Loss (MSE): 61.5612\n", + "Logit shape: torch.Size([8, 30, 1])\n" + ] + } + ], + "source": [ + "# Fetch a batch and run a forward pass\n", + "batch = next(iter(train_loader))\n", + "\n", + "with torch.no_grad():\n", + " outputs = model(**batch)\n", + "\n", + "print(\"Output keys:\", outputs.keys())\n", + "print(f\"Loss (MSE): {outputs['loss'].item():.4f}\")\n", + "print(f\"Logit shape: {outputs['logit'].shape}\") # (batch, T, 1) for regression" + ] + }, + { + "cell_type": "markdown", + "id": "f2c97123", + "metadata": {}, + "source": [ + "## 6. Optional Training (Example)\n", + "\n", + "The following is an example sketch for training using PyHealth's `Trainer`. Uncomment to run training." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b3bce217", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MixLSTM(\n", + " (model): ExampleMowLSTM(\n", + " (cells): ModuleList(\n", + " (0-29): 30 x MoW(\n", + " (experts): mowLSTM(\n", + " (dropouts): ModuleList(\n", + " (0): Dropout(p=0, inplace=False)\n", + " )\n", + " (h2o): MoO(\n", + " (experts): ModuleList(\n", + " (0-1): 2 x Linear(in_features=500, out_features=1, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " (rnns): ModuleList(\n", + " (0): mowLSTM_(\n", + " (input_weights): MoO(\n", + " (experts): ModuleList(\n", + " (0-1): 2 x Linear(in_features=3, out_features=2000, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " (hidden_weights): MoO(\n", + " (experts): ModuleList(\n", + " (0-1): 2 x Linear(in_features=500, out_features=2000, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (gate): NonAdaptiveGate()\n", + " )\n", + " )\n", + " (experts): mowLSTM(\n", + " (dropouts): ModuleList(\n", + " (0): Dropout(p=0, inplace=False)\n", + " )\n", + " (h2o): MoO(\n", + " (experts): ModuleList(\n", + " (0-1): 2 x Linear(in_features=500, out_features=1, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " (rnns): ModuleList(\n", + " (0): mowLSTM_(\n", + " (input_weights): MoO(\n", + " (experts): ModuleList(\n", + " (0-1): 2 x Linear(in_features=3, out_features=2000, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " (hidden_weights): MoO(\n", + " (experts): ModuleList(\n", + " (0-1): 2 x Linear(in_features=500, out_features=2000, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + ")\n", + "Metrics: None\n", + "Device: cpu\n", + "\n", + "Training:\n", + "Batch size: 8\n", + "Optimizer: \n", + "Optimizer params: {'lr': 0.001}\n", + "Weight decay: 0.0\n", + "Max grad norm: None\n", + "Val dataloader: \n", + "Monitor: loss\n", + "Monitor criterion: min\n", + "Epochs: 10\n", + "Patience: None\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c862d88a46c6429a997d306d1c08b33e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Epoch 0 / 10: 0%| | 0/88 [00:00= prev_used_timestamps`, matching the original repo where the first `l` targets are trivially zero.\n", + "- **Classification support**: If `output_schema` is set to a standard label type (`\"multiclass\"`, `\"binary\"`, etc.), the model automatically switches to last-timestep classification using `get_loss_function()` and `prepare_y_prob()` from `BaseModel` — no flag needed.\n", + "- `MixLSTM` expects input tensors of shape `(batch, seq_len, input_dim)`." + ] + }, + { + "cell_type": "markdown", + "id": "a28ae383", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..4089cbd24 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -44,3 +44,4 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding +from .mixlstm import MixLSTM diff --git a/pyhealth/models/mixlstm.py b/pyhealth/models/mixlstm.py new file mode 100644 index 000000000..f63b1a001 --- /dev/null +++ b/pyhealth/models/mixlstm.py @@ -0,0 +1,843 @@ +"""MixLSTM model for clinical time-series prediction. + +Implementation of the mixLSTM architecture from Oh et al. 2020, +"Relaxed Parameter Sharing: Effectively Modeling Time-Varying +Relationships in Clinical Time-Series" (https://arxiv.org/abs/1906.02898). + +The key idea is to relax the parameter-sharing constraint of a standard +LSTM by maintaining K independent LSTM cells and combining their +parameters at each time step using learned mixing coefficients. This +allows the model to capture temporal conditional shift, where the +relationship between features and outcomes changes over time. +""" + +import math +from abc import ABC +from collections import abc +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel + + +class MLP(nn.Module): + """A simple multi-layer perceptron used as a building block. + + Args: + neuron_sizes: List of layer sizes, e.g. ``[in_dim, hidden, out_dim]``. + activation: Activation function class (default: ``nn.LeakyReLU``). + bias: Whether linear layers include a bias term. + """ + + def __init__( + self, neuron_sizes: List[int], + activation: Type[nn.Module] = nn.LeakyReLU, bias: bool = True + ) -> None: + super(MLP, self).__init__() + self.neuron_sizes = neuron_sizes + + layers = [] + for s0, s1 in zip(neuron_sizes[:-1], neuron_sizes[1:]): + layers.extend([ + nn.Linear(s0, s1, bias=bias), + activation() + ]) + + self.classifier = nn.Sequential(*layers[:-1]) + + def eval_forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Run a forward pass in eval mode (ignores ``y``).""" + self.eval() + return self.forward(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Flatten the input and pass it through the MLP.""" + x = x.contiguous() + x = x.view(-1, self.neuron_sizes[0]) + return self.classifier(x) + + +############################ main models ################################## +class MoE(nn.Module): + """Abstract base class for mixture-of-experts modules. + + Supports specifying a set of experts and a gating function. Subclasses + must implement how experts are combined (see ``MoO`` and ``MoW``). + + Args: + experts: The expert modules to be combined. + gate: The gating function that produces mixing coefficients. + """ + + def __init__(self, experts: nn.Module, gate: "Gate") -> None: + super(MoE, self).__init__() + self.experts = experts + self.gate = gate + + +class MoO(MoE): + """Mixture of Outputs. + + Each expert produces an output independently, and the outputs are + combined via a weighted sum using coefficients from the gate. + + Args: + experts: The expert modules. + gate: The gating function. + bs_dim: Batch-size dimension of expert outputs (default: 1). + expert_dim: Expert dimension after stacking (default: 0). + """ + + def __init__(self, experts: nn.ModuleList, gate: "Gate", + bs_dim: int = 1, expert_dim: int = 0 + ) -> None: + super(MoO, self).__init__(experts, gate) + # this is for RNN architecture: bs_dim = 2 for RNN + self.bs_dim = bs_dim + self.expert_dim = expert_dim + + def combine( + self, o: List[torch.Tensor], coef: torch.Tensor + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + """Combine expert outputs using the mixing coefficients. + + Args: + o: List of expert output tensors. + coef: Mixing coefficient tensor of shape + ``(batch, num_experts)``. + + Returns: + Weighted sum of expert outputs, or a list of such + sums if experts return multi-output tuples. + """ + + if isinstance(o[0], abc.Sequence): # account for multi_output setting + return [self.combine(o_, coef) for o_ in zip(*o)] + else: + o = torch.stack(o) + # reshape o to (_, bs, n_expert) b/c coef is (bs, n_expert) + o = o.transpose(self.expert_dim, -1) + o = o.transpose(self.bs_dim, -2) + + # change back + res = o * coef + res = res.transpose(self.expert_dim, -1) + res = res.transpose(self.bs_dim, -2) + return res.sum(0) + + def forward( + self, x: torch.Tensor, coef: Optional[torch.Tensor] = None + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + """Compute each expert's output and combine them. + + Args: + x: Input tensor. + coef: Optional pre-computed mixing coefficients. + + Returns: + Combined expert output tensor. + """ + + coef = self.gate(x, coef) # (bs, n_expert) or n_expert + self.last_coef = coef + o = [expert(x) for expert in self.experts] + return self.combine(o, coef) + + +class MoW(MoE): + """Mixture of Weights. + + Instead of combining expert outputs, this module combines expert + parameters before the forward pass, effectively producing a single + assembled expert per time step. + """ + + def forward( + self, x: Any, coef: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + """Run the assembled expert on the input. + + Args: + x: Input tensor (or tuple of tensor and hidden state). + coef: Optional pre-computed mixing coefficients. + + Returns: + Tuple of output tensor and new hidden state. + """ + + # assume experts has already been assembled + coef = self.gate(x, coef) + self.last_coef = coef + return self.experts(x, coef) + + +################## sample gating functions for get_coefficients ########### +class Gate(ABC, nn.Module): + """Abstract base class for gating functions.""" + + def forward( + self, x: Any, coef: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Produce mixing coefficients from the input. + + Args: + x: Input data (format depends on the subclass). + coef: Optional pre-computed coefficient tensor. + + Returns: + Mixing coefficient tensor. + + Raises: + NotImplementedError: Always; subclasses must override. + """ + raise NotImplementedError() + + +class AdaptiveLSTMGate(Gate): + """A gate that computes mixing coefficients from the LSTM hidden state. + + Args: + input_size: Size of the hidden state used as input. + num_experts: Number of experts in the mixture. + normalize: If True, apply softmax to the coefficients. + """ + + def __init__( + self, input_size: int, num_experts: int, normalize: bool = False + ) -> None: + super(self.__class__, self).__init__() + self.forward_function = MLP([input_size, num_experts]) + self.normalize = normalize + + def forward( + self, x: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + coef: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Produce mixing coefficients from the hidden state. + + Args: + x: Tuple of ``(input, (h, c))`` where ``h`` is the + hidden state used to compute coefficients. + coef: Ignored; kept for interface compatibility. + + Returns: + Mixing coefficients of shape ``(batch, num_experts)``. + """ + + x, (h, c) = x # h (_, bs, d) + o = self.forward_function(h.transpose(0, 1)) # (bs, num_experts) + if self.normalize: + return nn.functional.softmax(o, 1) + else: + return o + + +class NonAdaptiveGate(Gate): + """A gate with learnable (or fixed) coefficients that do not depend on x. + + Args: + num_experts: Number of experts. + coef: Optional initial coefficient tensor. If None, randomly init. + fixed: If True, coefficients are not trainable. + normalize: If True, apply softmax to the coefficients. + """ + + def __init__( + self, num_experts: int, coef: Optional[torch.Tensor] = None, + fixed: bool = False, normalize: bool = False + ) -> None: + super(self.__class__, self).__init__() + self.normalize = normalize + if coef is None: # initialization + coef = torch.ones(num_experts) + nn.init.uniform_(coef) + if fixed: + coef = nn.Parameter(coef, requires_grad=False) + else: + coef = nn.Parameter(coef) + + self.coefficients = coef + + def forward( + self, x: Any, coef: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Return the (optionally normalized) mixing coefficients. + + Args: + x: Ignored; kept for interface compatibility. + coef: Ignored; kept for interface compatibility. + + Returns: + Mixing coefficient tensor of shape ``(num_experts,)``. + """ + + if self.normalize: + return nn.functional.softmax(self.coefficients, 0) + else: + return self.coefficients + + +class IDGate(Gate): + """Identity gate that passes through a previous coefficient unchanged.""" + + def forward(self, x: Any, coef: torch.Tensor) -> torch.Tensor: + """Return the coefficient that was passed in. + + Args: + x: Ignored. + coef: Pre-computed mixing coefficients. + + Returns: + ``coef`` unchanged. + """ + + return coef + + +################ time series example models ################ +def moo_linear( + in_features: int, out_features: int, + num_experts: int, bs_dim: int = 1, expert_dim: int = 0 + ) -> MoO: + """Create a MoO over a set of linear layers with tied shape. + + Args: + in_features: Input feature size. + out_features: Output feature size. + num_experts: Number of expert linear layers. + bs_dim: Batch-size dimension (see ``MoO``). + expert_dim: Expert dimension (see ``MoO``). + + Returns: + A ``MoO`` module wrapping ``num_experts`` linear layers with an + identity gate. + """ + # repeat a linear model for self.num_experts times + experts = nn.ModuleList() + for _ in range(num_experts): + experts.append(nn.Linear(in_features, out_features)) + + # tie weights later + return MoO(experts, IDGate(), bs_dim=bs_dim, expert_dim=expert_dim) + + +class mowLSTM_(nn.Module): + """Internal helper implementing one layer of the mixture-of-weights LSTM. + + Applies a per-time-step mixture of LSTM cells by combining the input + and hidden weight matrices across experts. + + Args: + input_size: Input feature dimension. + hidden_size: Hidden state dimension. + num_experts: Number of expert cells to mix (K). + batch_first: If True, expects input shape (batch, seq_len, dim). + """ + + def __init__(self, input_size: int, hidden_size: int, num_experts: int = 2, + batch_first: bool = False) -> None: + + super(mowLSTM_, self).__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.num_experts = num_experts + self.batch_first = batch_first + + # build cell + self.input_weights = moo_linear(input_size, 4 * hidden_size, + self.num_experts, bs_dim=2) # i,f,g,o + self.hidden_weights = moo_linear(hidden_size, 4 * hidden_size, + self.num_experts, bs_dim=2) + # init same as pytorch version + stdv = 1.0 / math.sqrt(self.hidden_size) + for m in self.input_weights.experts: + for name, weight in m.named_parameters(): + nn.init.uniform_(weight, -stdv, stdv) + for m in self.hidden_weights.experts: + for name, weight in m.named_parameters(): + nn.init.uniform_(weight, -stdv, stdv) + + + def rnn_step( + self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor], + coef: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Run a single LSTM step with mixed expert parameters. + + Args: + x: Input tensor for this time step, shape + ``(1, batch, input_size)``. + hidden: Tuple ``(h, c)`` of previous hidden and cell + states. + coef: Mixing coefficients for the experts. + + Returns: + Tuple ``(h, c)`` of updated hidden and cell states. + """ + + bs = x.shape[1] + h, c = hidden + gates = self.input_weights(x, coef) + self.hidden_weights(h, coef) + + ingate, forgetgate, cellgate, outgate = gates.view(bs, -1).chunk(4, 1) + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) + + c = forgetgate * c + ingate * cellgate + h = outgate * torch.tanh(c) # maybe use layer norm here as well + return h, c + + def forward( + self, x: torch.Tensor, + hidden: Tuple[torch.Tensor, torch.Tensor], coef: torch.Tensor + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + """Run the mixture LSTM over a full sequence. + + Args: + x: Input tensor of shape + ``(seq_len, batch, input_size)`` (or transposed + if ``batch_first``). + hidden: Tuple ``(h, c)`` of initial hidden/cell states. + coef: Mixing coefficients for the experts. + + Returns: + Tuple of ``(output, (h, c))`` where output has shape + ``(seq_len, batch, hidden_size)``. + """ + + if self.batch_first: # change to seq_len first + x = x.transpose(0, 1) + + seq_len = x.shape[0] + output = [] + for t in range(seq_len): + hidden = self.rnn_step(x[t].unsqueeze(0), hidden, coef) + output.append(hidden[0]) # seq_len x (_, bs, d) + + output = torch.cat(output, 0) + return output, hidden + + +class mowLSTM(nn.Module): + """Stacked mixture-of-weights LSTM used internally by ``MixLSTM``. + + Handles multi-layer stacking, dropout, and the final output projection. + + Args: + input_size: Input feature size. + hidden_size: Hidden state size. + num_classes: Output dimension of the final projection. + num_experts: Number of expert cells to mix (K). + num_layers: Number of stacked LSTM layers. + batch_first: If True, expects input shape (batch, seq_len, dim). + dropout: Dropout probability between layers. + bidirectional: Whether to use a bidirectional LSTM. + activation: Optional activation applied to the final output. + """ + + def __init__( + self, input_size: int, hidden_size: int, num_classes: int, + num_experts: int = 2, num_layers: int = 1, + batch_first: bool = False, dropout: float = 0, + bidirectional: bool = False, activation: Optional[nn.Module] = None + ) -> None: + + super(mowLSTM, self).__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.num_classes = num_classes + self.num_experts = num_experts + self.num_layers = num_layers + self.num_directions = 2 if bidirectional else 1 + self.batch_first = batch_first + self.dropouts = nn.ModuleList() + + self.h2o = moo_linear(self.num_directions * self.hidden_size, + self.num_classes, self.num_experts, bs_dim=2) + + if activation: + self.activation = activation + else: + self.activation = lambda x: x + + self.rnns = nn.ModuleList() + for i in range(num_layers * self.num_directions): + input_size = input_size if i == 0 else hidden_size + self.rnns.append(mowLSTM_(input_size, hidden_size, num_experts, + batch_first)) + self.dropouts.append(nn.Dropout(p=dropout)) + + def forward( + self, x: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + coef: torch.Tensor + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + """Forward pass through the stacked mixture LSTM. + + Args: + x: Tuple of ``(input_tensor, (h, c))``. + coef: Mixing coefficients for the experts. + + Returns: + Tuple of ``(output, (h, c))`` where output has shape + ``(seq_len, batch, num_classes)``. + """ + + x, hidden = x + self.last_coef = coef + + h, c = hidden + hs, cs = [], [] + for i in range(self.num_layers): + if i != 0 and i != (self.num_layers - 1): + x = self.dropouts[i](x) # waste 1 dropout but no problem + x, hidden = self.rnns[i](x, (h[i].unsqueeze(0), + c[i].unsqueeze(0)), coef) + hs.append(hidden[0]) + cs.append(hidden[1]) + + h = torch.cat(hs, 0) + c = torch.cat(cs, 0) + o = x + # run through prediction layer: o: (seq_len, bs, d) + o = self.dropouts[0](o) + o = self.h2o(o, coef) + o = self.activation(o) + + return o, (h, c) + + +class ExampleMowLSTM(nn.Module): + """Wrapper that instantiates a mixture LSTM with per-time-step gates. + + For each of the ``t`` time steps, a separate ``NonAdaptiveGate`` is + created so that the mixing coefficients can vary over time. All gates + share the same underlying experts. + + Args: + input_size: Input feature size. + hidden_size: Hidden state size. + num_classes: Output dimension. + num_layers: Number of stacked LSTM layers. + num_directions: 1 (unidirectional) or 2 (bidirectional). + dropout: Dropout probability. + activation: Optional output activation. + """ + + def __init__(self, input_size: int, hidden_size: int, num_classes: int, + num_layers: int = 1, num_directions: int = 1, dropout: float = 0, activation: Optional[nn.Module] = None) -> None: + super(ExampleMowLSTM, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_classes = num_classes + self.num_layers = num_layers + self.num_directions = num_directions + self.dropout = dropout + self.activation = activation + + def setKT(self, k: int, t: int) -> None: + """Configure the model for ``k`` experts and ``t`` time steps. + + Args: + k: Number of expert cells to mix. + t: Maximum number of time steps; one gate is created + per step. + + Raises: + ValueError: If ``k < 1`` or ``t < 1``. + """ + self.k = k + self.T = t + self.cells = nn.ModuleList() + + experts = mowLSTM(self.input_size, self.hidden_size, + self.num_classes, num_experts=self.k, + num_layers=self.num_layers, dropout=self.dropout, + bidirectional=(self.num_directions == 2), + activation=self.activation) + self.experts = experts + + for _ in range(t): + gate = NonAdaptiveGate(self.k, normalize=True) + self.cells.append(MoW(experts, gate)) + + def forward( + self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + """Run the mixture LSTM step-by-step using per-step gates. + + Args: + x: Input tensor of shape + ``(seq_len, batch, input_size)``. + hidden: Tuple ``(h, c)`` of initial hidden/cell states. + + Returns: + Tuple of ``(output, (h, c))`` where output has shape + ``(seq_len, batch, num_classes)``. + """ + + seq_len, bs, _ = x.shape + o = [] + for t in range(seq_len): + o_, hidden = self.cells[t]((x[t].view(1, bs, -1), hidden)) + o.append(o_) + + o = torch.cat(o, 0) # (seq_len, bs, d) + return o, hidden + + +def orthogonal(shape: Tuple[int, ...]) -> np.ndarray: + """Generate an orthogonal matrix of the given shape via SVD. + + Args: + shape: Target shape for the orthogonal matrix. + + Returns: + A numpy array with orthogonal rows/columns. + """ + + flat_shape = (int(shape[0]), int(np.prod(shape[1:]))) + a = np.random.normal(0.0, 1.0, flat_shape) + u, _, v = np.linalg.svd(a, full_matrices=False) + q = u if u.shape == flat_shape else v + return q.reshape(shape) + + +def lstm_ortho_initializer(shape: Tuple[int, ...], scale: float = 1.0) -> np.ndarray: + """Initialize LSTM weights with orthogonal blocks for each of the 4 gates. + + Args: + shape: Target shape where the second dimension must be divisible by 4. + scale: Scalar to multiply the orthogonal matrices by. + + Returns: + A numpy array of the requested shape. + """ + size_x = shape[0] + size_h = int(shape[1] / 4) # assumes lstm. + t = np.zeros(shape) + t[:, :size_h] = orthogonal([size_x, size_h]) * scale + t[:, size_h:size_h * 2] = orthogonal([size_x, size_h]) * scale + t[:, size_h * 2:size_h * 3] = orthogonal([size_x, size_h]) * scale + t[:, size_h * 3:] = orthogonal([size_x, size_h]) * scale + return t + + +class MixLSTM(BaseModel): + """Mixture-of-LSTMs model for clinical time-series prediction. + + Implements the mixLSTM architecture from Oh et al. 2020 for handling + temporal conditional shift: settings where the relationship between + input features and the target changes over time. Instead of sharing a + single set of LSTM parameters across all time steps, MixLSTM maintains + ``num_experts`` independent LSTM cells and, at every time step, + computes a learned convex combination of their parameters using mixing + coefficients constrained to the simplex. This enables smooth + transitions between different temporal dynamics without hard segment + boundaries. + + The model inherits from PyHealth's ``BaseModel`` and infers the input + dimension and sequence length from the ``SampleDataset`` passed at + construction time, so it can be used with any existing PyHealth task + whose input is a time-series tensor. + + The model supports two operating modes that are chosen automatically + based on the dataset's output schema: + + * Standard classification (``binary``, ``multiclass``, ``multilabel``, + or ``regression``): predictions are taken from the final time step of + the sequence and the appropriate PyHealth loss function is applied. + * Per-timestep regression (when the output schema is a raw ``tensor``): + the model outputs a value at every time step and the MSE loss is + computed over timesteps beginning at ``prev_used_timestamps``. This + reproduces the synthetic copy-memory task described in Section 4.1 + of the paper. + + Paper: + Oh et al. 2020, "Relaxed Parameter Sharing: Effectively Modeling + Time-Varying Relationships in Clinical Time-Series." + https://arxiv.org/abs/1906.02898 + + Args: + dataset: A ``SampleDataset`` used to infer input feature size and + sequence length from the first sample. + num_experts: Number of expert LSTM cells to mix (``K`` in the paper). + Higher values give the model more flexibility to vary parameters + over time at the cost of more parameters. Defaults to 2. + hidden_size: Size of the LSTM hidden state. Defaults to 100. + prev_used_timestamps: For the per-timestep regression mode, the + index of the first time step at which the loss is computed. + Earlier time steps are skipped because their targets are + trivially defined in the synthetic task. Ignored in the + standard classification mode. Defaults to 0. + + Attributes: + input_size: Inferred input feature dimension. + time_steps: Inferred sequence length. + hidden_size: LSTM hidden state size. + _per_timestep: ``True`` when the model is running in per-timestep + regression mode, ``False`` when it is running in standard + classification mode. + + Example: + >>> from pyhealth.datasets import create_sample_dataset + >>> from pyhealth.models import MixLSTM + >>> samples = [ + ... { + ... "patient_id": f"p-{i}", + ... "visit_id": "v-0", + ... "series": torch.randn(48, 76).numpy().tolist(), + ... "label": int(i % 2), + ... } + ... for i in range(100) + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"series": "tensor"}, + ... output_schema={"label": "multiclass"}, + ... dataset_name="demo", + ... ) + >>> model = MixLSTM(dataset=dataset, num_experts=2, hidden_size=64) + """ + + def __init__(self, dataset: SampleDataset, num_experts: int = 2, hidden_size: int = 100, + prev_used_timestamps: int = 0) -> None: + super(MixLSTM, self).__init__(dataset) + + # Identify primary input key and infer shape + input_keys = list(dataset.input_processors.keys()) + self.input_key = input_keys[0] + self.label_key = self.label_keys[0] if self.label_keys else None + + sample = dataset[0] + val = sample[self.input_key] + if isinstance(val, (list, tuple)): + for item in val: + if torch.is_tensor(item) or isinstance( + item, (list, tuple, np.ndarray)): + val = item + break + if torch.is_tensor(val): + input_dim = val.shape[-1] if val.dim() >= 2 else 1 + T = val.shape[0] + else: + arr = np.array(val) + input_dim = arr.shape[-1] if arr.ndim >= 2 else 1 + T = len(val) + + self.input_size = int(input_dim) + self.time_steps = int(T) + self.prev_used_timestamps = prev_used_timestamps + + # Detect per-timestep regression: output target is a tensor, not a + # standard label type. In that case self.mode is None / unrecognised. + self._per_timestep = ( + self.mode not in ("binary", "multiclass", "multilabel", "regression") + ) + + if self._per_timestep: + num_classes = 1 # predict one scalar per timestep + else: + num_classes = int(self.get_output_size()) + + self.model = ExampleMowLSTM(self.input_size, hidden_size, + num_classes, num_layers=1, + num_directions=1, dropout=0, + activation=None) + + self.num_layers = 1 + self.num_directions = 1 + self.hidden_size = hidden_size + self.model.setKT(num_experts, self.time_steps) + + def forward(self, **kwargs: Any) -> Dict[str, torch.Tensor]: + """Run a forward pass. + + Expects the input tensor under ``kwargs[self.input_key]`` with + shape ``(batch, seq_len, input_dim)``. If a label tensor is also + provided under ``self.label_key``, the appropriate loss is + computed and returned. + + Args: + **kwargs: Batch dictionary, typically produced by a PyHealth + DataLoader and passed by ``Trainer`` as ``model(**batch)``. + + Returns: + A dictionary with the following keys, where shapes depend on + which mode the model is operating in: + + * Classification mode (``_per_timestep = False``): + - ``logit``: ``(batch, num_classes)`` from the final step. + - ``y_prob``: Probabilities produced by ``prepare_y_prob``. + - ``loss`` (optional): PyHealth's standard loss for the task. + - ``y_true`` (optional): Ground-truth labels. + + * Per-timestep regression mode (``_per_timestep = True``): + - ``logit``: ``(batch, seq_len, 1)`` — one prediction per + time step. + - ``y_prob``: Same tensor as ``logit``. + - ``loss`` (optional): MSE computed over time steps from + ``prev_used_timestamps`` onward. + - ``y_true`` (optional): Ground-truth target tensor. + """ + x = kwargs.get(self.input_key) + + # (bs, seq_len, d) => (seq_len, bs, d) + x = x.permute(1, 0, 2) + batch_size = x.size(1) + device = self.device + h = torch.zeros(self.num_layers * self.num_directions, + batch_size, self.hidden_size, device=device) + c = torch.zeros(self.num_layers * self.num_directions, + batch_size, self.hidden_size, device=device) + + outputs, _ = self.model(x, (h, c)) + # (seq_len, bs, out) => (bs, seq_len, out) + logits_seq = outputs.permute(1, 0, 2) + + if self._per_timestep: + # --- Per-timestep regression mode --- + results = {"logit": logits_seq, "y_prob": logits_seq} + if self.label_key and self.label_key in kwargs: + y_true = kwargs[self.label_key].to(device) + if y_true.dim() == 2: + y_true = y_true.unsqueeze(-1) + l = self.prev_used_timestamps + pred = logits_seq[:, l:, :].contiguous() + target = y_true[:, l:, :].contiguous() + loss = F.mse_loss(pred.view(-1, pred.size(-1)), + target.view(-1, target.size(-1))) + results["loss"] = loss + results["y_true"] = y_true + return results + + logits = logits_seq[:, -1, :] + y_prob = self.prepare_y_prob(logits) + results = {"logit": logits, "y_prob": y_prob} + + if self.label_key and self.label_key in kwargs: + y_true = kwargs[self.label_key].to(device) + loss = self.get_loss_function()(logits, y_true) + results["loss"] = loss + results["y_true"] = y_true + + return results + + def after_backward(self) -> None: + """Hook called after backward(); no-op for this model.""" + pass diff --git a/pyproject.toml b/pyproject.toml index 934d4f1bb..f0da85254 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "more-itertools~=10.8.0", "einops>=0.8.0", "linear-attention-transformer>=0.19.1", + "seaborn~=0.13.2", ] license = "BSD-3-Clause" license-files = ["LICENSE.md"] diff --git a/tests/test_mixlstm.py b/tests/test_mixlstm.py new file mode 100644 index 000000000..3d9e885d0 --- /dev/null +++ b/tests/test_mixlstm.py @@ -0,0 +1,279 @@ +import unittest +import tempfile +import shutil +import numpy as np +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import MixLSTM + + +class TestMixLSTMRegression(unittest.TestCase): + """Test MixLSTM in per-timestep regression mode (MLHC2019 synthetic task).""" + + def setUp(self) -> None: + """Set up small synthetic regression dataset and model.""" + self.tmp_dir = tempfile.mkdtemp() + + T = 10 + input_dim = 2 + prev_used = 3 + n = 20 + + rng = np.random.RandomState(42) + x = np.zeros((n, T, input_dim)) + nz = int(n * T * input_dim * 0.1) + idx = rng.choice(n * T * input_dim, size=nz, replace=False) + x.flat[idx] = rng.uniform(size=nz) * 10 + y = np.zeros((n, T)) + for t in range(prev_used, T): + y[:, t] = x[:, t - prev_used:t, :].sum(axis=(1, 2)) + + self.samples = [ + { + "patient_id": f"p-{i}", + "visit_id": "v-0", + "series": x[i].tolist(), + "y": y[i].tolist(), + } + for i in range(n) + ] + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema={"series": "tensor"}, + output_schema={"y": "tensor"}, + dataset_name="test_mixlstm_reg", + ) + + self.model = MixLSTM( + dataset=self.dataset, + num_experts=2, + hidden_size=16, + prev_used_timestamps=prev_used, + ) + self.batch_size = 4 + + def tearDown(self) -> None: + shutil.rmtree(self.tmp_dir, ignore_errors=True) + + def test_instantiation(self) -> None: + """Test that model initializes with correct attributes.""" + self.assertIsInstance(self.model, MixLSTM) + self.assertTrue(self.model._per_timestep) + self.assertEqual(self.model.input_size, 2) + self.assertEqual(self.model.time_steps, 10) + self.assertEqual(self.model.hidden_size, 16) + self.assertEqual(self.model.prev_used_timestamps, 3) + + def test_forward_output_keys(self) -> None: + """Test that forward returns expected keys for regression.""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + self.assertIn("loss", ret) + self.assertIn("logit", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + + def test_forward_output_shapes(self) -> None: + """Test output tensor shapes for per-timestep regression.""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + bs = ret["logit"].shape[0] + # logit: (batch, T, 1) + self.assertEqual(ret["logit"].shape, (bs, 10, 1)) + # y_prob same as logit for regression + self.assertEqual(ret["y_prob"].shape, (bs, 10, 1)) + # y_true: (batch, T, 1) + self.assertEqual(ret["y_true"].shape, (bs, 10, 1)) + # loss is scalar + self.assertEqual(ret["loss"].dim(), 0) + + def test_forward_no_labels(self) -> None: + """Test forward without labels returns logit/y_prob but no loss.""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + # Remove the label key + del batch["y"] + + with torch.no_grad(): + ret = self.model(**batch) + + self.assertIn("logit", ret) + self.assertIn("y_prob", ret) + self.assertNotIn("loss", ret) + self.assertNotIn("y_true", ret) + + def test_backward_gradients(self) -> None: + """Test that loss.backward() produces gradients on all trainable params.""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + ret = self.model(**batch) + ret["loss"].backward() + + has_gradient = any( + p.requires_grad and p.grad is not None + for p in self.model.parameters() + ) + self.assertTrue(has_gradient, "No parameters received gradients") + + def test_loss_is_finite(self) -> None: + """Test that loss is finite (not NaN or Inf).""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + self.assertTrue(torch.isfinite(ret["loss"]).item(), "Loss is not finite") + + def test_custom_hyperparameters(self) -> None: + """Test model with different num_experts and hidden_size.""" + model = MixLSTM( + dataset=self.dataset, + num_experts=4, + hidden_size=32, + prev_used_timestamps=3, + ) + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = model(**batch) + + self.assertIn("loss", ret) + self.assertEqual(ret["logit"].shape[2], 1) + + +class TestMixLSTMClassification(unittest.TestCase): + """Test MixLSTM in classification mode (standard PyHealth label task).""" + + def setUp(self) -> None: + """Set up small synthetic classification dataset and model.""" + self.tmp_dir = tempfile.mkdtemp() + + T = 8 + input_dim = 3 + n = 16 + n_classes = 3 + + rng = np.random.RandomState(0) + self.samples = [ + { + "patient_id": f"p-{i}", + "visit_id": "v-0", + "series": rng.randn(T, input_dim).tolist(), + "label": int(i % n_classes), + } + for i in range(n) + ] + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema={"series": "tensor"}, + output_schema={"label": "multiclass"}, + dataset_name="test_mixlstm_cls", + ) + + self.model = MixLSTM( + dataset=self.dataset, + num_experts=2, + hidden_size=16, + ) + self.batch_size = 4 + + def tearDown(self) -> None: + shutil.rmtree(self.tmp_dir, ignore_errors=True) + + def test_instantiation(self) -> None: + """Test that classification model initializes correctly.""" + self.assertIsInstance(self.model, MixLSTM) + self.assertFalse(self.model._per_timestep) + self.assertEqual(self.model.mode, "multiclass") + + def test_forward_output_keys(self) -> None: + """Test that forward returns expected keys for classification.""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + self.assertIn("loss", ret) + self.assertIn("logit", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + + def test_forward_output_shapes(self) -> None: + """Test output tensor shapes for classification (last timestep).""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + bs = ret["logit"].shape[0] + n_classes = 3 + # logit: (batch, n_classes) + self.assertEqual(ret["logit"].shape, (bs, n_classes)) + # y_prob: (batch, n_classes) — softmax output + self.assertEqual(ret["y_prob"].shape, (bs, n_classes)) + # y_true: (batch,) + self.assertEqual(ret["y_true"].shape[0], bs) + # loss is scalar + self.assertEqual(ret["loss"].dim(), 0) + + def test_forward_no_labels(self) -> None: + """Test forward without labels returns logit/y_prob but no loss.""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + del batch["label"] + + with torch.no_grad(): + ret = self.model(**batch) + + self.assertIn("logit", ret) + self.assertIn("y_prob", ret) + self.assertNotIn("loss", ret) + self.assertNotIn("y_true", ret) + + def test_backward_gradients(self) -> None: + """Test that loss.backward() produces gradients.""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + ret = self.model(**batch) + ret["loss"].backward() + + has_gradient = any( + p.requires_grad and p.grad is not None + for p in self.model.parameters() + ) + self.assertTrue(has_gradient, "No parameters received gradients") + + def test_y_prob_sums_to_one(self) -> None: + """Test that y_prob (softmax) sums to ~1 for each sample.""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + prob_sums = ret["y_prob"].sum(dim=1) + self.assertTrue( + torch.allclose(prob_sums, torch.ones_like(prob_sums), atol=1e-5), + "y_prob rows do not sum to 1", + ) + + +if __name__ == "__main__": + unittest.main()