From e77e542bd6ba031ffde3de4a811d9e3a84548ae5 Mon Sep 17 00:00:00 2001 From: siddesh2-sys Date: Mon, 13 Apr 2026 20:55:38 -0400 Subject: [PATCH 01/19] initial implementation --- pyhealth/models/mixlstm.py | 415 +++++++++++++++++++++++++++++++++++++ 1 file changed, 415 insertions(+) create mode 100644 pyhealth/models/mixlstm.py diff --git a/pyhealth/models/mixlstm.py b/pyhealth/models/mixlstm.py new file mode 100644 index 000000000..846ff2569 --- /dev/null +++ b/pyhealth/models/mixlstm.py @@ -0,0 +1,415 @@ +import torch, math +import torch.nn as nn +from torch.autograd import Variable +import torch.nn.functional as F +import numpy as np +from collections import abc +from abc import ABC + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel + +class MLP(nn.Module): + + def __init__(self, neuron_sizes, activation=nn.LeakyReLU, bias=True): + super(MLP, self).__init__() + self.neuron_sizes = neuron_sizes + + layers = [] + for s0, s1 in zip(neuron_sizes[:-1], neuron_sizes[1:]): + layers.extend([ + nn.Linear(s0, s1, bias=bias), + activation() + ]) + + self.classifier = nn.Sequential(*layers[:-1]) + + def eval_forward(self, x, y): + self.eval() + return self.forward(x) + + def forward(self, x): + x = x.contiguous() + x = x.view(-1, self.neuron_sizes[0]) + return self.classifier(x) + +############################ main models ################################## +class MoE(nn.Module): + + ''' + This is a abstract base class for mixture of experts + + it supports: + a) specifiying experts + b) specifying the gating function (having parameter or not) + + it needs combining functions (either MoO or MoE) + ''' + + def __init__(self, experts, gate): + super(MoE, self).__init__() + self.experts = experts + self.gate = gate + +class MoO(MoE): + + ''' + mixture of outputs + ''' + def __init__(self, experts, gate, bs_dim=1, expert_dim=0): + super(MoO, self).__init__(experts, gate) + # this is for RNN architecture: bs_dim = 2 for RNN + self.bs_dim = bs_dim + self.expert_dim = expert_dim + + def combine(self, o, coef): + + if isinstance(o[0], abc.Sequence): # account for multi_output setting + return [self.combine(o_, coef) for o_ in zip(*o)] + else: + o = torch.stack(o) + # reshape o to (_, bs, n_expert) b/c coef is (bs, n_expert) + o = o.transpose(self.expert_dim, -1) + o = o.transpose(self.bs_dim, -2) + + # change back + res = o * coef + res = res.transpose(self.expert_dim, -1) + res = res.transpose(self.bs_dim, -2) + return res.sum(0) + + def forward(self, x, coef=None): # coef is previous coefficient: for IDGate + coef = self.gate(x, coef) # (bs, n_expert) or n_expert + self.last_coef = coef + o = [expert(x) for expert in self.experts] + return self.combine(o, coef) + +class MoW(MoE): + + def forward(self, x, coef=None): + # assume experts has already been assembled + coef = self.gate(x, coef) + self.last_coef = coef + return self.experts(x, coef) + + +################## sample gating functions for get_coefficients ########### +class Gate(ABC, nn.Module): + + ''' + gate function + ''' + + def forward(self, x, coef=None): + raise NotImplementedError() + +class AdaptiveLSTMGate(Gate): + + def __init__(self, input_size, num_experts, normalize=False): + super(self.__class__, self).__init__() + self.forward_function = MLP([input_size, num_experts]) + self.normalize = normalize + + def forward(self, x, coef=None): + x, (h, c) = x # h (_, bs, d) + o = self.forward_function(h.transpose(0,1)) # (bs, num_experts) + if self.normalize: + return nn.functional.softmax(o, 1) + else: + return o + +class NonAdaptiveGate(Gate): + + def __init__(self, num_experts, coef=None, fixed=False, normalize=False): + ''' + fixed coefficient: resnet like with predefined not learnable gate values + normalize: take softmax of the parameters + ''' + super(self.__class__, self).__init__() + self.normalize = normalize + if coef is None: # initialization + coef = torch.ones(num_experts) + nn.init.uniform_(coef) + if fixed: + coef = nn.Parameter(coef, requires_grad=False) + else: + coef = nn.Parameter(coef) + + self.coefficients = coef + + def forward(self, x, coef=None): + if self.normalize: + return nn.functional.softmax(self.coefficients, 0) + else: + return self.coefficients + +class IDGate(Gate): # identity gate + + def forward(self, x, coef): # coef is previous coefficient + return coef + + +################ time series example models ################ +def moo_linear(in_features, out_features, num_experts, bs_dim=1, expert_dim=0): + # repeat a linear model for self.num_experts times + experts = nn.ModuleList() + for _ in range(num_experts): + experts.append(nn.Linear(in_features, out_features)) + + # tie weights later + return MoO(experts, IDGate(), bs_dim=bs_dim, expert_dim=expert_dim) + +class mowLSTM_(nn.Module): + + ''' + helper module for mowLSTM + ''' + def __init__(self, input_size, hidden_size, num_experts=2, batch_first=False): + + super(mowLSTM_, self).__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.num_experts = num_experts + self.batch_first = batch_first + + # build cell + self.input_weights = moo_linear(input_size, 4 * hidden_size, + self.num_experts, bs_dim=2) # i,f,g,o + self.hidden_weights = moo_linear(hidden_size, 4 * hidden_size, + self.num_experts, bs_dim=2) + # init same as pytorch version + stdv = 1.0 / math.sqrt(self.hidden_size) + for m in self.input_weights.experts: + for name, weight in m.named_parameters(): + nn.init.uniform_(weight, -stdv, stdv) + # if 'weight' in name: + # nn.init.uniform_(weight) + for m in self.hidden_weights.experts: + for name, weight in m.named_parameters(): + nn.init.uniform_(weight, -stdv, stdv) + # if 'weight' in name: + # nn.init.orthogonal_(weight) + + # maybe: layer normalization: see jeeheh's code + # maybe: orthogonal initialization: see jeeheh's code + # note: pytorch implementation does neither + + def rnn_step(self, x, hidden, coef): # one step of rnn + bs = x.shape[1] + h, c = hidden + gates = self.input_weights(x, coef) + self.hidden_weights(h, coef) + # maybe: layer normalization: see jeeheh's code + + ingate, forgetgate, cellgate, outgate = gates.view(bs, -1).chunk(4, 1) + ingate = torch.sigmoid(ingate) + forgetgate = torch.sigmoid(forgetgate) + cellgate = torch.tanh(cellgate) + outgate = torch.sigmoid(outgate) + + c = forgetgate * c + ingate * cellgate + h = outgate * torch.tanh(c) # maybe use layer norm here as well + return h, c + + def forward(self, x, hidden, coef): + if self.batch_first: # change to seq_len first + x = x.transpose(0, 1) + + seq_len = x.shape[0] + output = [] + for t in range(seq_len): + hidden = self.rnn_step(x[t].unsqueeze(0), hidden, coef) + output.append(hidden[0]) # seq_len x (_, bs, d) + + output = torch.cat(output, 0) + return output, hidden + +class mowLSTM(nn.Module): + + ''' + helper for mowLSTM, + responsible for stacking and bidirectional LSTM + stack according to + https://stackoverflow.com/questions/49224413/difference-between-1-lstm-with-num-layers-2-and-2-lstms-in-pytorch + + ''' + def __init__(self, input_size, hidden_size, num_classes, num_experts=2, + num_layers=1, batch_first=False, dropout=0, bidirectional=False, + activation=None): + + super(mowLSTM, self).__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.num_classes = num_classes + self.num_experts = num_experts + self.num_layers = num_layers + self.num_directions = 2 if bidirectional else 1 + self.batch_first = batch_first + self.dropouts = nn.ModuleList() + + self.h2o = moo_linear(self.num_directions * self.hidden_size, + self.num_classes, self.num_experts, bs_dim=2) + + if activation: + self.activation = activation + else: + self.activation = lambda x: x + + self.rnns = nn.ModuleList() + for i in range(num_layers * self.num_directions): + input_size = input_size if i == 0 else hidden_size + self.rnns.append(mowLSTM_(input_size, hidden_size, num_experts, batch_first)) + self.dropouts.append(nn.Dropout(p=dropout)) + + def forward(self, x, coef): + x, hidden = x + self.last_coef = coef + + h, c = hidden + hs, cs = [], [] + for i in range(self.num_layers): + if i != 0 and i != (self.num_layers - 1): + x = self.dropouts[i](x) # waste 1 droput out but no problem + x, hidden = self.rnns[i](x, (h[i].unsqueeze(0), c[i].unsqueeze(0)), coef) + hs.append(hidden[0]) + cs.append(hidden[1]) + + # todo: bidirectional stacked LSTM, see reference here + # https://github.com/allenai/allennlp/blob/master/allennlp/modules/stacked_bidirectional_lstm.py; it basically concat layer output + + h = torch.cat(hs, 0) + c = torch.cat(cs, 0) + o = x + # run through prediction layer: o: (seq_len, bs, d) + o = self.dropouts[0](o) + o = self.h2o(o, coef) + o = self.activation(o) + + return o, (h, c) + + +class ExampleMowLSTM(nn.Module): + + ''' + recreate LSTM architectre + then stack them according to + + ''' + def __init__(self, input_size, hidden_size, num_classes, + num_layers=1, num_directions=1, dropout=0, activation=None): + super(ExampleMowLSTM, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_classes = num_classes + self.num_layers = num_layers + self.num_directions = num_directions + self.dropout = dropout + self.activation = activation + + def setKT(self, k, t): # k models t steps + '''k experts with maximum of t time steps''' + self.k = k + self.T = t + self.cells = nn.ModuleList() + + experts = mowLSTM(self.input_size, self.hidden_size, + self.num_classes, num_experts=self.k, + num_layers=self.num_layers, dropout=self.dropout, + bidirectional = (self.num_directions==2), + activation=self.activation) + self.experts = experts + + for _ in range(t): + gate = NonAdaptiveGate(self.k, normalize=True) + # gate = AdaptiveLSTMGate(self.hidden_size *\ + # self.num_layers *\ + # self.num_directions, + # self.k, + # normalize=True) + self.cells.append(MoW(experts, gate)) + + def forward(self, x, hidden): + seq_len, bs, _ = x.shape + o = [] + for t in range(seq_len): + o_, hidden = self.cells[t]((x[t].view(1, bs, -1), hidden)) + o.append(o_) + + o = torch.cat(o, 0) # (seq_len, bs, d) + return o, hidden + + + +def orthogonal(shape): + flat_shape = (int(shape[0]), int(np.prod(shape[1:]))) + a = np.random.normal(0.0, 1.0, flat_shape) + u, _, v = np.linalg.svd(a, full_matrices=False) + q = u if u.shape == flat_shape else v + return q.reshape(shape) + +def lstm_ortho_initializer(shape, scale=1.0): + size_x = shape[0] + size_h = int(shape[1]/4) # assumes lstm. + t = np.zeros(shape) + t[:, :size_h] = orthogonal([size_x, size_h])*scale + t[:, size_h:size_h*2] = orthogonal([size_x, size_h])*scale + t[:, size_h*2:size_h*3] = orthogonal([size_x, size_h])*scale + t[:, size_h*3:] = orthogonal([size_x, size_h])*scale + return t + +class mixLSTM(BaseModel): + + def __init__(self, dataset: SampleDataset, num_experts=2, hidden_size=100): + super(mixLSTM, self).__init__(dataset) + + #Process dataset to get input dimension and time steps + input_keys = list(dataset.input_processors.keys()) + sample = dataset[0] + val = sample[input_keys[0]] + if isinstance(val, (list, tuple)): + for item in val: + if torch.is_tensor(item) or isinstance(item, (list, tuple, np.ndarray)): + val = item + break + if torch.is_tensor(val): + input_dim = val.shape[-1] if val.dim() >= 2 else 1 + T = val.shape[0] + else: + arr = np.array(val) + input_dim = arr.shape[-1] if arr.ndim >= 2 else 1 + T = len(val) + + self.input_size = int(input_dim) + self.time_steps = int(T) + num_classes = int(self.get_output_size()) + + self.model = ExampleMowLSTM(self.input_size, hidden_size, + num_classes, num_layers=1, + num_directions=1, dropout=0, + activation=nn.LogSoftmax(dim=-1)) + + self.num_layers = 1 + self.num_directions = 1 + self.hidden_size = hidden_size + self.model.setKT(num_experts, self.time_steps) + + def forward(self, x): + # change x from (bs, seq_len, d) => (seq_len, bs, d) + x = x.permute(1, 0, 2) + batch_size = x.size(1) + # set initial hidden and cell states on the model device + device = self.device + h = torch.zeros(self.num_layers * self.num_directions, + batch_size, self.hidden_size, device=device) + c = torch.zeros(self.num_layers * self.num_directions, + batch_size, self.hidden_size, device=device) + + states = (h, c) + outputs, states = self.model(x, states) + + return outputs.permute(1, 0, 2) + + def after_backward(self): + return + \ No newline at end of file From b88a78cab9dd9ad9b193a0efb136b62bb1b6de63 Mon Sep 17 00:00:00 2001 From: siddesh2-sys Date: Sat, 18 Apr 2026 22:04:30 -0400 Subject: [PATCH 02/19] Added test and fixed minor bugs --- examples/mixlstm/mixlstm_test.ipynb | 895 ++++++++++++++++++++++++++++ pyhealth/models/__init__.py | 1 + pyhealth/models/mixlstm.py | 34 +- 3 files changed, 925 insertions(+), 5 deletions(-) create mode 100644 examples/mixlstm/mixlstm_test.ipynb diff --git a/examples/mixlstm/mixlstm_test.ipynb b/examples/mixlstm/mixlstm_test.ipynb new file mode 100644 index 000000000..2c3fdb980 --- /dev/null +++ b/examples/mixlstm/mixlstm_test.ipynb @@ -0,0 +1,895 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5b176866", + "metadata": {}, + "source": [ + "# mixLSTM Test Notebook\n", + "\n", + "This notebook demonstrates how to create a small sample dataset, initialize `mixLSTM`, and run a forward pass." + ] + }, + { + "cell_type": "markdown", + "id": "2cb149bb", + "metadata": {}, + "source": [ + "## 1. Environment Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "08ef7f3e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on device: cpu\n" + ] + } + ], + "source": [ + "import random\n", + "import numpy as np\n", + "import torch\n", + "\n", + "from pyhealth.datasets import create_sample_dataset, get_dataloader\n", + "from pyhealth.datasets.splitter import split_by_sample\n", + "\n", + "SEED = 42\n", + "random.seed(SEED)\n", + "np.random.seed(SEED)\n", + "torch.manual_seed(SEED)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(SEED)\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Running on device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "id": "1327ae03", + "metadata": {}, + "source": [ + "## 2. Create Sample Dataset\n", + "\n", + "We create synthetic time-series samples with shape `(T, input_dim)` stored under the input key `series`." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d3fea9bc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Label label vocab: {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}\n", + "Created dataset with 200 samples\n", + "Input schema: {'series': 'tensor'}\n", + "Output schema: {'label': 'multiclass'}\n" + ] + } + ], + "source": [ + "# Dataset parameters\n", + "num_samples = 200\n", + "T = 50 # sequence length\n", + "input_dim = 3\n", + "n_classes = 5\n", + "\n", + "samples = [\n", + " {\n", + " \"patient_id\": f\"patient-{i}\",\n", + " \"visit_id\": \"visit-0\",\n", + " \"series\": torch.randn(T, input_dim).numpy().tolist(),\n", + " \"label\": int(i % n_classes),\n", + " }\n", + " for i in range(num_samples)\n", + "]\n", + "\n", + "input_schema = {\"series\": \"tensor\"}\n", + "output_schema = {\"label\": \"multiclass\"}\n", + "\n", + "dataset = create_sample_dataset(\n", + " samples=samples,\n", + " input_schema=input_schema,\n", + " output_schema=output_schema,\n", + " dataset_name=\"mixlstm_demo\",\n", + ")\n", + "\n", + "print(f\"Created dataset with {len(dataset)} samples\")\n", + "print(f\"Input schema: {dataset.input_schema}\")\n", + "print(f\"Output schema: {dataset.output_schema}\")" + ] + }, + { + "cell_type": "markdown", + "id": "f2a5248a", + "metadata": {}, + "source": [ + "## 3. Split Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "168a6299", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: 140 samples\n", + "Val: 30 samples\n", + "Test: 30 samples\n" + ] + } + ], + "source": [ + "train_data, val_data, test_data = split_by_sample(dataset, [0.7, 0.15, 0.15], seed=SEED)\n", + "\n", + "print(f\"Train: {len(train_data)} samples\")\n", + "print(f\"Val: {len(val_data)} samples\")\n", + "print(f\"Test: {len(test_data)} samples\")\n", + "\n", + "train_loader = get_dataloader(train_data, batch_size=8, shuffle=True)\n", + "val_loader = get_dataloader(val_data, batch_size=8, shuffle=False)\n", + "test_loader = get_dataloader(test_data, batch_size=8, shuffle=False)" + ] + }, + { + "cell_type": "markdown", + "id": "63cad735", + "metadata": {}, + "source": [ + "## 4. Initialize `mixLSTM` Model" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "215bf9bd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model created with 180390 parameters\n" + ] + } + ], + "source": [ + "from pyhealth.models import MixLSTM\n", + "\n", + "model = MixLSTM(dataset=dataset, num_experts=10, hidden_size=64)\n", + "model = model.to(device)\n", + "print(f\"Model created with {sum(p.numel() for p in model.parameters())} parameters\")" + ] + }, + { + "cell_type": "markdown", + "id": "e978eecb", + "metadata": {}, + "source": [ + "## 5. Test Forward Pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49633e2c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output keys: dict_keys(['logit', 'y_prob', 'loss', 'y_true'])\n", + "Loss: 1.1942\n", + "Logits shape: torch.Size([8, 5])\n", + "Predictions: tensor([[-1.6454, -1.5354, -1.8028, -1.0847, -2.4207],\n", + " [-2.9820, -3.0563, -2.2066, -1.8453, -0.4554],\n", + " [-1.8053, -2.3038, -2.2325, -0.6283, -2.3546],\n", + " [-1.4553, -1.7886, -1.6593, -1.3155, -1.9599],\n", + " [-1.1825, -1.4733, -2.4646, -2.1457, -1.3383],\n", + " [-1.3650, -1.9290, -1.5502, -1.5415, -1.7541],\n", + " [-1.6411, -1.6955, -1.3702, -2.0516, -1.4266],\n", + " [-1.5647, -2.0988, -1.5074, -1.5615, -1.4399]])\n" + ] + } + ], + "source": [ + "# Fetch a batch and run a forward pass\n", + "batch = next(iter(train_loader))\n", + "\n", + "with torch.no_grad():\n", + " outputs = model(**batch)\n", + "\n", + "print(\"Output keys:\", outputs.keys())\n", + "print(f\"Loss: {outputs['loss'].item():.4f}\")\n", + "print(f\"Logits shape: {outputs['logit'].shape}\")" + ] + }, + { + "cell_type": "markdown", + "id": "f2c97123", + "metadata": {}, + "source": [ + "## 6. Optional Training (Example)\n", + "\n", + "The following is an example sketch for training using PyHealth's `Trainer`. Uncomment to run training." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "b3bce217", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MixLSTM(\n", + " (model): ExampleMowLSTM(\n", + " (activation): LogSoftmax(dim=-1)\n", + " (cells): ModuleList(\n", + " (0-49): 50 x MoW(\n", + " (experts): mowLSTM(\n", + " (dropouts): ModuleList(\n", + " (0): Dropout(p=0, inplace=False)\n", + " )\n", + " (h2o): MoO(\n", + " (experts): ModuleList(\n", + " (0-9): 10 x Linear(in_features=64, out_features=5, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " (activation): LogSoftmax(dim=-1)\n", + " (rnns): ModuleList(\n", + " (0): mowLSTM_(\n", + " (input_weights): MoO(\n", + " (experts): ModuleList(\n", + " (0-9): 10 x Linear(in_features=3, out_features=256, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " (hidden_weights): MoO(\n", + " (experts): ModuleList(\n", + " (0-9): 10 x Linear(in_features=64, out_features=256, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (gate): NonAdaptiveGate()\n", + " )\n", + " )\n", + " (experts): mowLSTM(\n", + " (dropouts): ModuleList(\n", + " (0): Dropout(p=0, inplace=False)\n", + " )\n", + " (h2o): MoO(\n", + " (experts): ModuleList(\n", + " (0-9): 10 x Linear(in_features=64, out_features=5, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " (activation): LogSoftmax(dim=-1)\n", + " (rnns): ModuleList(\n", + " (0): mowLSTM_(\n", + " (input_weights): MoO(\n", + " (experts): ModuleList(\n", + " (0-9): 10 x Linear(in_features=3, out_features=256, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " (hidden_weights): MoO(\n", + " (experts): ModuleList(\n", + " (0-9): 10 x Linear(in_features=64, out_features=256, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + ")\n", + "Metrics: None\n", + "Device: cpu\n", + "\n", + "Training:\n", + "Batch size: 8\n", + "Optimizer: \n", + "Optimizer params: {'lr': 0.1}\n", + "Weight decay: 0.0\n", + "Max grad norm: None\n", + "Val dataloader: \n", + "Monitor: accuracy\n", + "Monitor criterion: max\n", + "Epochs: 10\n", + "Patience: None\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4b86a8a571724d679d5deb755da2d2b8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Epoch 0 / 10: 0%| | 0/18 [00:00 (seq_len, bs, d) x = x.permute(1, 0, 2) batch_size = x.size(1) @@ -407,8 +412,27 @@ def forward(self, x): states = (h, c) outputs, states = self.model(x, states) - - return outputs.permute(1, 0, 2) + + # outputs: (seq_len, batch, num_classes) -> (batch, seq_len, num_classes) + logits_seq = outputs.permute(1, 0, 2) + + # For sequence models used for classification tasks, provide a + # per-sample logit by selecting the last timestep. + logits = logits_seq[:, -1, :] + + results = {} + results["logit"] = logits + results["y_prob"] = self.prepare_y_prob(logits) + + # If labels were provided in kwargs (Trainer passes them), compute loss + if hasattr(self, "label_keys") and len(self.label_keys) > 0 and self.label_keys[0] in kwargs: + y_true = kwargs[self.label_keys[0]].to(self.device) + loss_fn = self.get_loss_function() + loss = loss_fn(logits, y_true) + results["loss"] = loss + results["y_true"] = y_true + + return results def after_backward(self): return From faaad49abaaa2eddba06bf2f0faea05b3e366759 Mon Sep 17 00:00:00 2001 From: siddesh2-sys Date: Sun, 19 Apr 2026 21:28:24 -0400 Subject: [PATCH 03/19] Added tests, ablation, and support for regression to mixlstm --- examples/mixlstm/mixlstm_test.ipynb | 662 +------------- examples/mixlstm/mixlstm_test_ablation.ipynb | 890 +++++++++++++++++++ pyhealth/models/mixlstm.py | 78 +- tests/test_mixlstm.py | 279 ++++++ 4 files changed, 1227 insertions(+), 682 deletions(-) create mode 100644 examples/mixlstm/mixlstm_test_ablation.ipynb create mode 100644 tests/test_mixlstm.py diff --git a/examples/mixlstm/mixlstm_test.ipynb b/examples/mixlstm/mixlstm_test.ipynb index 2c3fdb980..b46a02a84 100644 --- a/examples/mixlstm/mixlstm_test.ipynb +++ b/examples/mixlstm/mixlstm_test.ipynb @@ -71,7 +71,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "Label label vocab: {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}\n", + "Label label vocab: {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "Created dataset with 200 samples\n", "Input schema: {'series': 'tensor'}\n", "Output schema: {'label': 'multiclass'}\n" @@ -156,7 +162,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 4, "id": "215bf9bd", "metadata": {}, "outputs": [ @@ -186,7 +192,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "49633e2c", "metadata": {}, "outputs": [ @@ -195,16 +201,8 @@ "output_type": "stream", "text": [ "Output keys: dict_keys(['logit', 'y_prob', 'loss', 'y_true'])\n", - "Loss: 1.1942\n", - "Logits shape: torch.Size([8, 5])\n", - "Predictions: tensor([[-1.6454, -1.5354, -1.8028, -1.0847, -2.4207],\n", - " [-2.9820, -3.0563, -2.2066, -1.8453, -0.4554],\n", - " [-1.8053, -2.3038, -2.2325, -0.6283, -2.3546],\n", - " [-1.4553, -1.7886, -1.6593, -1.3155, -1.9599],\n", - " [-1.1825, -1.4733, -2.4646, -2.1457, -1.3383],\n", - " [-1.3650, -1.9290, -1.5502, -1.5415, -1.7541],\n", - " [-1.6411, -1.6955, -1.3702, -2.0516, -1.4266],\n", - " [-1.5647, -2.0988, -1.5074, -1.5615, -1.4399]])\n" + "Loss: 1.6047\n", + "Logits shape: torch.Size([8, 5])\n" ] } ], @@ -220,644 +218,6 @@ "print(f\"Logits shape: {outputs['logit'].shape}\")" ] }, - { - "cell_type": "markdown", - "id": "f2c97123", - "metadata": {}, - "source": [ - "## 6. Optional Training (Example)\n", - "\n", - "The following is an example sketch for training using PyHealth's `Trainer`. Uncomment to run training." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "b3bce217", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "MixLSTM(\n", - " (model): ExampleMowLSTM(\n", - " (activation): LogSoftmax(dim=-1)\n", - " (cells): ModuleList(\n", - " (0-49): 50 x MoW(\n", - " (experts): mowLSTM(\n", - " (dropouts): ModuleList(\n", - " (0): Dropout(p=0, inplace=False)\n", - " )\n", - " (h2o): MoO(\n", - " (experts): ModuleList(\n", - " (0-9): 10 x Linear(in_features=64, out_features=5, bias=True)\n", - " )\n", - " (gate): IDGate()\n", - " )\n", - " (activation): LogSoftmax(dim=-1)\n", - " (rnns): ModuleList(\n", - " (0): mowLSTM_(\n", - " (input_weights): MoO(\n", - " (experts): ModuleList(\n", - " (0-9): 10 x Linear(in_features=3, out_features=256, bias=True)\n", - " )\n", - " (gate): IDGate()\n", - " )\n", - " (hidden_weights): MoO(\n", - " (experts): ModuleList(\n", - " (0-9): 10 x Linear(in_features=64, out_features=256, bias=True)\n", - " )\n", - " (gate): IDGate()\n", - " )\n", - " )\n", - " )\n", - " )\n", - " (gate): NonAdaptiveGate()\n", - " )\n", - " )\n", - " (experts): mowLSTM(\n", - " (dropouts): ModuleList(\n", - " (0): Dropout(p=0, inplace=False)\n", - " )\n", - " (h2o): MoO(\n", - " (experts): ModuleList(\n", - " (0-9): 10 x Linear(in_features=64, out_features=5, bias=True)\n", - " )\n", - " (gate): IDGate()\n", - " )\n", - " (activation): LogSoftmax(dim=-1)\n", - " (rnns): ModuleList(\n", - " (0): mowLSTM_(\n", - " (input_weights): MoO(\n", - " (experts): ModuleList(\n", - " (0-9): 10 x Linear(in_features=3, out_features=256, bias=True)\n", - " )\n", - " (gate): IDGate()\n", - " )\n", - " (hidden_weights): MoO(\n", - " (experts): ModuleList(\n", - " (0-9): 10 x Linear(in_features=64, out_features=256, bias=True)\n", - " )\n", - " (gate): IDGate()\n", - " )\n", - " )\n", - " )\n", - " )\n", - " )\n", - ")\n", - "Metrics: None\n", - "Device: cpu\n", - "\n", - "Training:\n", - "Batch size: 8\n", - "Optimizer: \n", - "Optimizer params: {'lr': 0.1}\n", - "Weight decay: 0.0\n", - "Max grad norm: None\n", - "Val dataloader: \n", - "Monitor: accuracy\n", - "Monitor criterion: max\n", - "Epochs: 10\n", - "Patience: None\n", - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4b86a8a571724d679d5deb755da2d2b8", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Epoch 0 / 10: 0%| | 0/18 [00:00= l` (`l = prev_used_timestamps`), `y[t]` is a weighted sum over the previous `l` timesteps and `input_dim` features. The weight distributions (`k_dist`, `d_dist`) drift slowly by `change_between_tasks` per step. For `t < l`, `y[t] = 1` (ignored during loss)." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d3fea9bc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Created dataset with 1000 samples\n", + "Input schema: {'series': 'tensor'}\n", + "Output schema: {'y': 'tensor'}\n" + ] + } + ], + "source": [ + "# Dataset parameters (matching MLHC2019 synthetic task)\n", + "num_samples = 1000\n", + "T = 30 # sequence length\n", + "input_dim = 3\n", + "prev_used_timestamps = 10\n", + "change_between_tasks = 0.05\n", + "\n", + "def convert_distb(a):\n", + " a_min = min(a)\n", + " a_max = max(a)\n", + " a = (a-a_min)/(a_max-a_min)\n", + " a_sum = sum(a)\n", + " a = a/a_sum\n", + " return a\n", + "\n", + "\"\"\"Gen X\"\"\"\n", + "x_size = num_samples*T*input_dim\n", + "x=np.zeros(x_size)\n", + "x[np.random.choice(x_size, size=int(x_size/10), replace=False)]=np.random.uniform(size=int(x_size/10))*100\n", + "x=np.resize(x, (num_samples,T,input_dim))\n", + "\n", + "\"\"\"Gen y\"\"\"\n", + "k_dist = []\n", + "d_dist = []\n", + "for i in range(T):\n", + " if i=prev_used_timestamps:\n", + " y[:,i,0] = np.matmul(np.matmul(x[:,i-prev_used_timestamps:i,:],d_dist[i]), k_dist[i])\n", + "\n", + "# Build samples: per-timestep regression target matching original MLHC2019 synthetic setup\n", + "samples = [\n", + " {\n", + " \"patient_id\": f\"patient-{i}\",\n", + " \"visit_id\": \"visit-0\",\n", + " \"series\": x[i].tolist(),\n", + " \"y\": y[i].squeeze(-1).tolist(),\n", + " }\n", + " for i in range(num_samples)\n", + "]\n", + "\n", + "input_schema = {\"series\": \"tensor\"}\n", + "output_schema = {\"y\": \"tensor\"}\n", + "\n", + "dataset = create_sample_dataset(\n", + " samples=samples,\n", + " input_schema=input_schema,\n", + " output_schema=output_schema,\n", + " dataset_name=\"mixlstm_demo\",\n", + ")\n", + "\n", + "print(f\"Created dataset with {len(dataset)} samples\")\n", + "print(f\"Input schema: {dataset.input_schema}\")\n", + "print(f\"Output schema: {dataset.output_schema}\")" + ] + }, + { + "cell_type": "markdown", + "id": "f2a5248a", + "metadata": {}, + "source": [ + "## 3. Split Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "168a6299", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: 700 samples\n", + "Val: 150 samples\n", + "Test: 150 samples\n" + ] + } + ], + "source": [ + "train_data, val_data, test_data = split_by_sample(dataset, [0.7, 0.15, 0.15], seed=SEED)\n", + "\n", + "print(f\"Train: {len(train_data)} samples\")\n", + "print(f\"Val: {len(val_data)} samples\")\n", + "print(f\"Test: {len(test_data)} samples\")\n", + "\n", + "train_loader = get_dataloader(train_data, batch_size=8, shuffle=True)\n", + "val_loader = get_dataloader(val_data, batch_size=8, shuffle=False)\n", + "test_loader = get_dataloader(test_data, batch_size=8, shuffle=False)" + ] + }, + { + "cell_type": "markdown", + "id": "63cad735", + "metadata": {}, + "source": [ + "## 4. Initialize `mixLSTM` Model" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "215bf9bd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model created with 2021062 parameters\n", + "Mode: None, Per-timestep: True\n" + ] + } + ], + "source": [ + "from pyhealth.models import MixLSTM\n", + "\n", + "model = MixLSTM(dataset=dataset, num_experts=2, hidden_size=500,\n", + " prev_used_timestamps=prev_used_timestamps)\n", + "model = model.to(device)\n", + "print(f\"Model created with {sum(p.numel() for p in model.parameters())} parameters\")\n", + "print(f\"Mode: {model.mode}, Per-timestep: {model._per_timestep}\")" + ] + }, + { + "cell_type": "markdown", + "id": "e978eecb", + "metadata": {}, + "source": [ + "## 5. Test Forward Pass" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "49633e2c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output keys: dict_keys(['logit', 'y_prob', 'loss', 'y_true'])\n", + "Loss (MSE): 61.5612\n", + "Logit shape: torch.Size([8, 30, 1])\n" + ] + } + ], + "source": [ + "# Fetch a batch and run a forward pass\n", + "batch = next(iter(train_loader))\n", + "\n", + "with torch.no_grad():\n", + " outputs = model(**batch)\n", + "\n", + "print(\"Output keys:\", outputs.keys())\n", + "print(f\"Loss (MSE): {outputs['loss'].item():.4f}\")\n", + "print(f\"Logit shape: {outputs['logit'].shape}\") # (batch, T, 1) for regression" + ] + }, + { + "cell_type": "markdown", + "id": "f2c97123", + "metadata": {}, + "source": [ + "## 6. Optional Training (Example)\n", + "\n", + "The following is an example sketch for training using PyHealth's `Trainer`. Uncomment to run training." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b3bce217", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MixLSTM(\n", + " (model): ExampleMowLSTM(\n", + " (cells): ModuleList(\n", + " (0-29): 30 x MoW(\n", + " (experts): mowLSTM(\n", + " (dropouts): ModuleList(\n", + " (0): Dropout(p=0, inplace=False)\n", + " )\n", + " (h2o): MoO(\n", + " (experts): ModuleList(\n", + " (0-1): 2 x Linear(in_features=500, out_features=1, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " (rnns): ModuleList(\n", + " (0): mowLSTM_(\n", + " (input_weights): MoO(\n", + " (experts): ModuleList(\n", + " (0-1): 2 x Linear(in_features=3, out_features=2000, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " (hidden_weights): MoO(\n", + " (experts): ModuleList(\n", + " (0-1): 2 x Linear(in_features=500, out_features=2000, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (gate): NonAdaptiveGate()\n", + " )\n", + " )\n", + " (experts): mowLSTM(\n", + " (dropouts): ModuleList(\n", + " (0): Dropout(p=0, inplace=False)\n", + " )\n", + " (h2o): MoO(\n", + " (experts): ModuleList(\n", + " (0-1): 2 x Linear(in_features=500, out_features=1, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " (rnns): ModuleList(\n", + " (0): mowLSTM_(\n", + " (input_weights): MoO(\n", + " (experts): ModuleList(\n", + " (0-1): 2 x Linear(in_features=3, out_features=2000, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " (hidden_weights): MoO(\n", + " (experts): ModuleList(\n", + " (0-1): 2 x Linear(in_features=500, out_features=2000, bias=True)\n", + " )\n", + " (gate): IDGate()\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + ")\n", + "Metrics: None\n", + "Device: cpu\n", + "\n", + "Training:\n", + "Batch size: 8\n", + "Optimizer: \n", + "Optimizer params: {'lr': 0.001}\n", + "Weight decay: 0.0\n", + "Max grad norm: None\n", + "Val dataloader: \n", + "Monitor: loss\n", + "Monitor criterion: min\n", + "Epochs: 10\n", + "Patience: None\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c862d88a46c6429a997d306d1c08b33e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Epoch 0 / 10: 0%| | 0/88 [00:00= prev_used_timestamps`, matching the original repo where the first `l` targets are trivially zero.\n", + "- **Classification support**: If `output_schema` is set to a standard label type (`\"multiclass\"`, `\"binary\"`, etc.), the model automatically switches to last-timestep classification using `get_loss_function()` and `prepare_y_prob()` from `BaseModel` — no flag needed.\n", + "- `MixLSTM` expects input tensors of shape `(batch, seq_len, input_dim)`." + ] + }, + { + "cell_type": "markdown", + "id": "a28ae383", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyhealth/models/mixlstm.py b/pyhealth/models/mixlstm.py index 2a0cf51ef..6c2cd0400 100644 --- a/pyhealth/models/mixlstm.py +++ b/pyhealth/models/mixlstm.py @@ -275,9 +275,6 @@ def forward(self, x, coef): hs.append(hidden[0]) cs.append(hidden[1]) - # todo: bidirectional stacked LSTM, see reference here - # https://github.com/allenai/allennlp/blob/master/allennlp/modules/stacked_bidirectional_lstm.py; it basically concat layer output - h = torch.cat(hs, 0) c = torch.cat(cs, 0) o = x @@ -360,15 +357,17 @@ def lstm_ortho_initializer(shape, scale=1.0): class MixLSTM(BaseModel): - def __init__(self, dataset: SampleDataset, num_experts=2, hidden_size=100): + def __init__(self, dataset: SampleDataset, num_experts=2, hidden_size=100, + prev_used_timestamps=0): super(MixLSTM, self).__init__(dataset) - #Process dataset to get input dimension and time steps + # Identify primary input key and infer shape input_keys = list(dataset.input_processors.keys()) - # remember the primary input key so Trainer can call model(**batch) self.input_key = input_keys[0] + self.label_key = self.label_keys[0] if self.label_keys else None + sample = dataset[0] - val = sample[input_keys[0]] + val = sample[self.input_key] if isinstance(val, (list, tuple)): for item in val: if torch.is_tensor(item) or isinstance(item, (list, tuple, np.ndarray)): @@ -384,56 +383,73 @@ def __init__(self, dataset: SampleDataset, num_experts=2, hidden_size=100): self.input_size = int(input_dim) self.time_steps = int(T) - num_classes = int(self.get_output_size()) + self.prev_used_timestamps = prev_used_timestamps + + # Detect per-timestep regression: output target is a tensor, not a + # standard label type. In that case self.mode is None / unrecognised. + self._per_timestep = ( + self.mode not in ("binary", "multiclass", "multilabel", "regression") + ) + + if self._per_timestep: + num_classes = 1 # predict one scalar per timestep + else: + num_classes = int(self.get_output_size()) self.model = ExampleMowLSTM(self.input_size, hidden_size, - num_classes, num_layers=1, - num_directions=1, dropout=0, - activation=nn.LogSoftmax(dim=-1)) + num_classes, num_layers=1, + num_directions=1, dropout=0, + activation=None) self.num_layers = 1 self.num_directions = 1 self.hidden_size = hidden_size self.model.setKT(num_experts, self.time_steps) - + def forward(self, **kwargs): - # Extract input tensor when called as `model(**batch)` by Trainer. x = kwargs.get(self.input_key) - # change x from (bs, seq_len, d) => (seq_len, bs, d) + # (bs, seq_len, d) => (seq_len, bs, d) x = x.permute(1, 0, 2) batch_size = x.size(1) - # set initial hidden and cell states on the model device device = self.device h = torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size, device=device) c = torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size, device=device) - - states = (h, c) - outputs, states = self.model(x, states) - # outputs: (seq_len, batch, num_classes) -> (batch, seq_len, num_classes) + outputs, _ = self.model(x, (h, c)) + # (seq_len, bs, out) => (bs, seq_len, out) logits_seq = outputs.permute(1, 0, 2) - # For sequence models used for classification tasks, provide a - # per-sample logit by selecting the last timestep. - logits = logits_seq[:, -1, :] + if self._per_timestep: + # --- Per-timestep regression (original MLHC2019 synthetic task) --- + results = {"logit": logits_seq, "y_prob": logits_seq} + if self.label_key and self.label_key in kwargs: + y_true = kwargs[self.label_key].to(device) + if y_true.dim() == 2: + y_true = y_true.unsqueeze(-1) + l = self.prev_used_timestamps + pred = logits_seq[:, l:, :].contiguous() + target = y_true[:, l:, :].contiguous() + loss = F.mse_loss(pred.view(-1, pred.size(-1)), + target.view(-1, target.size(-1))) + results["loss"] = loss + results["y_true"] = y_true + return results - results = {} - results["logit"] = logits - results["y_prob"] = self.prepare_y_prob(logits) + logits = logits_seq[:, -1, :] + y_prob = self.prepare_y_prob(logits) + results = {"logit": logits, "y_prob": y_prob} - # If labels were provided in kwargs (Trainer passes them), compute loss - if hasattr(self, "label_keys") and len(self.label_keys) > 0 and self.label_keys[0] in kwargs: - y_true = kwargs[self.label_keys[0]].to(self.device) - loss_fn = self.get_loss_function() - loss = loss_fn(logits, y_true) + if self.label_key and self.label_key in kwargs: + y_true = kwargs[self.label_key].to(device) + loss = self.get_loss_function()(logits, y_true) results["loss"] = loss results["y_true"] = y_true return results def after_backward(self): - return + pass \ No newline at end of file diff --git a/tests/test_mixlstm.py b/tests/test_mixlstm.py new file mode 100644 index 000000000..9adf68e2a --- /dev/null +++ b/tests/test_mixlstm.py @@ -0,0 +1,279 @@ +import unittest +import tempfile +import shutil +import numpy as np +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import MixLSTM + + +class TestMixLSTMRegression(unittest.TestCase): + """Test MixLSTM in per-timestep regression mode (MLHC2019 synthetic task).""" + + def setUp(self): + """Set up small synthetic regression dataset and model.""" + self.tmp_dir = tempfile.mkdtemp() + + T = 10 + input_dim = 2 + prev_used = 3 + n = 20 + + rng = np.random.RandomState(42) + x = np.zeros((n, T, input_dim)) + nz = int(n * T * input_dim * 0.1) + idx = rng.choice(n * T * input_dim, size=nz, replace=False) + x.flat[idx] = rng.uniform(size=nz) * 10 + y = np.zeros((n, T)) + for t in range(prev_used, T): + y[:, t] = x[:, t - prev_used:t, :].sum(axis=(1, 2)) + + self.samples = [ + { + "patient_id": f"p-{i}", + "visit_id": "v-0", + "series": x[i].tolist(), + "y": y[i].tolist(), + } + for i in range(n) + ] + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema={"series": "tensor"}, + output_schema={"y": "tensor"}, + dataset_name="test_mixlstm_reg", + ) + + self.model = MixLSTM( + dataset=self.dataset, + num_experts=2, + hidden_size=16, + prev_used_timestamps=prev_used, + ) + self.batch_size = 4 + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + + def test_instantiation(self): + """Test that model initializes with correct attributes.""" + self.assertIsInstance(self.model, MixLSTM) + self.assertTrue(self.model._per_timestep) + self.assertEqual(self.model.input_size, 2) + self.assertEqual(self.model.time_steps, 10) + self.assertEqual(self.model.hidden_size, 16) + self.assertEqual(self.model.prev_used_timestamps, 3) + + def test_forward_output_keys(self): + """Test that forward returns expected keys for regression.""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + self.assertIn("loss", ret) + self.assertIn("logit", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + + def test_forward_output_shapes(self): + """Test output tensor shapes for per-timestep regression.""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + bs = ret["logit"].shape[0] + # logit: (batch, T, 1) + self.assertEqual(ret["logit"].shape, (bs, 10, 1)) + # y_prob same as logit for regression + self.assertEqual(ret["y_prob"].shape, (bs, 10, 1)) + # y_true: (batch, T, 1) + self.assertEqual(ret["y_true"].shape, (bs, 10, 1)) + # loss is scalar + self.assertEqual(ret["loss"].dim(), 0) + + def test_forward_no_labels(self): + """Test forward without labels returns logit/y_prob but no loss.""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + # Remove the label key + del batch["y"] + + with torch.no_grad(): + ret = self.model(**batch) + + self.assertIn("logit", ret) + self.assertIn("y_prob", ret) + self.assertNotIn("loss", ret) + self.assertNotIn("y_true", ret) + + def test_backward_gradients(self): + """Test that loss.backward() produces gradients on all trainable params.""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + ret = self.model(**batch) + ret["loss"].backward() + + has_gradient = any( + p.requires_grad and p.grad is not None + for p in self.model.parameters() + ) + self.assertTrue(has_gradient, "No parameters received gradients") + + def test_loss_is_finite(self): + """Test that loss is finite (not NaN or Inf).""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + self.assertTrue(torch.isfinite(ret["loss"]).item(), "Loss is not finite") + + def test_custom_hyperparameters(self): + """Test model with different num_experts and hidden_size.""" + model = MixLSTM( + dataset=self.dataset, + num_experts=4, + hidden_size=32, + prev_used_timestamps=3, + ) + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = model(**batch) + + self.assertIn("loss", ret) + self.assertEqual(ret["logit"].shape[2], 1) + + +class TestMixLSTMClassification(unittest.TestCase): + """Test MixLSTM in classification mode (standard PyHealth label task).""" + + def setUp(self): + """Set up small synthetic classification dataset and model.""" + self.tmp_dir = tempfile.mkdtemp() + + T = 8 + input_dim = 3 + n = 16 + n_classes = 3 + + rng = np.random.RandomState(0) + self.samples = [ + { + "patient_id": f"p-{i}", + "visit_id": "v-0", + "series": rng.randn(T, input_dim).tolist(), + "label": int(i % n_classes), + } + for i in range(n) + ] + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema={"series": "tensor"}, + output_schema={"label": "multiclass"}, + dataset_name="test_mixlstm_cls", + ) + + self.model = MixLSTM( + dataset=self.dataset, + num_experts=2, + hidden_size=16, + ) + self.batch_size = 4 + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + + def test_instantiation(self): + """Test that classification model initializes correctly.""" + self.assertIsInstance(self.model, MixLSTM) + self.assertFalse(self.model._per_timestep) + self.assertEqual(self.model.mode, "multiclass") + + def test_forward_output_keys(self): + """Test that forward returns expected keys for classification.""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + self.assertIn("loss", ret) + self.assertIn("logit", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + + def test_forward_output_shapes(self): + """Test output tensor shapes for classification (last timestep).""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + bs = ret["logit"].shape[0] + n_classes = 3 + # logit: (batch, n_classes) + self.assertEqual(ret["logit"].shape, (bs, n_classes)) + # y_prob: (batch, n_classes) — softmax output + self.assertEqual(ret["y_prob"].shape, (bs, n_classes)) + # y_true: (batch,) + self.assertEqual(ret["y_true"].shape[0], bs) + # loss is scalar + self.assertEqual(ret["loss"].dim(), 0) + + def test_forward_no_labels(self): + """Test forward without labels returns logit/y_prob but no loss.""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + del batch["label"] + + with torch.no_grad(): + ret = self.model(**batch) + + self.assertIn("logit", ret) + self.assertIn("y_prob", ret) + self.assertNotIn("loss", ret) + self.assertNotIn("y_true", ret) + + def test_backward_gradients(self): + """Test that loss.backward() produces gradients.""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + ret = self.model(**batch) + ret["loss"].backward() + + has_gradient = any( + p.requires_grad and p.grad is not None + for p in self.model.parameters() + ) + self.assertTrue(has_gradient, "No parameters received gradients") + + def test_y_prob_sums_to_one(self): + """Test that y_prob (softmax) sums to ~1 for each sample.""" + loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + ret = self.model(**batch) + + prob_sums = ret["y_prob"].sum(dim=1) + self.assertTrue( + torch.allclose(prob_sums, torch.ones_like(prob_sums), atol=1e-5), + "y_prob rows do not sum to 1", + ) + + +if __name__ == "__main__": + unittest.main() From d4f7653e4f81b95ed92a55c9a4f001f6250a89c4 Mon Sep 17 00:00:00 2001 From: amanluth03 Date: Mon, 20 Apr 2026 16:45:41 -0500 Subject: [PATCH 04/19] rst file and docs added --- docs/api/models/pyhealth.models.mixlstm.rst | 20 + pyhealth/models/mixlstm.py | 435 +++++++++++++++----- 2 files changed, 349 insertions(+), 106 deletions(-) create mode 100644 docs/api/models/pyhealth.models.mixlstm.rst diff --git a/docs/api/models/pyhealth.models.mixlstm.rst b/docs/api/models/pyhealth.models.mixlstm.rst new file mode 100644 index 000000000..1735e2947 --- /dev/null +++ b/docs/api/models/pyhealth.models.mixlstm.rst @@ -0,0 +1,20 @@ +pyhealth.models.MixLSTM +======================= + +The MixLSTM model from Oh et al. 2020, "Relaxed Parameter Sharing: +Effectively Modeling Time-Varying Relationships in Clinical Time-Series" +(https://arxiv.org/abs/1906.02898). + +MixLSTM addresses the problem of *temporal conditional shift* in clinical +time-series, i.e., settings in which the relationship between input features +and outcomes changes over the course of a patient's hospital stay. Instead +of sharing a single set of LSTM parameters across all time steps, MixLSTM +maintains ``K`` independent LSTM cells and, at every time step, computes a +learned convex combination of their parameters using mixing coefficients. +This enables smooth transitions between different temporal dynamics without +requiring hard segment boundaries. + +.. autoclass:: pyhealth.models.MixLSTM + :members: + :undoc-members: + :show-inheritance: diff --git a/pyhealth/models/mixlstm.py b/pyhealth/models/mixlstm.py index 6c2cd0400..38acd7510 100644 --- a/pyhealth/models/mixlstm.py +++ b/pyhealth/models/mixlstm.py @@ -1,93 +1,136 @@ -import torch, math +"""MixLSTM model for clinical time-series prediction. + +Implementation of the mixLSTM architecture from Oh et al. 2020, +"Relaxed Parameter Sharing: Effectively Modeling Time-Varying +Relationships in Clinical Time-Series" (https://arxiv.org/abs/1906.02898). + +The key idea is to relax the parameter-sharing constraint of a standard +LSTM by maintaining K independent LSTM cells and combining their +parameters at each time step using learned mixing coefficients. This +allows the model to capture temporal conditional shift, where the +relationship between features and outcomes changes over time. +""" + +import math +from abc import ABC +from collections import abc + +import numpy as np +import torch import torch.nn as nn -from torch.autograd import Variable import torch.nn.functional as F -import numpy as np -from collections import abc -from abc import ABC +from torch.autograd import Variable from pyhealth.datasets import SampleDataset from pyhealth.models import BaseModel + class MLP(nn.Module): + """A simple multi-layer perceptron used as a building block. + + Args: + neuron_sizes: List of layer sizes, e.g. ``[in_dim, hidden, out_dim]``. + activation: Activation function class (default: ``nn.LeakyReLU``). + bias: Whether linear layers include a bias term. + """ - def __init__(self, neuron_sizes, activation=nn.LeakyReLU, bias=True): + def __init__(self, neuron_sizes, activation=nn.LeakyReLU, bias=True): super(MLP, self).__init__() self.neuron_sizes = neuron_sizes - + layers = [] for s0, s1 in zip(neuron_sizes[:-1], neuron_sizes[1:]): layers.extend([ nn.Linear(s0, s1, bias=bias), activation() ]) - + self.classifier = nn.Sequential(*layers[:-1]) def eval_forward(self, x, y): + """Run a forward pass in eval mode (ignores ``y``).""" self.eval() return self.forward(x) - + def forward(self, x): + """Flatten the input and pass it through the MLP.""" x = x.contiguous() x = x.view(-1, self.neuron_sizes[0]) return self.classifier(x) + ############################ main models ################################## class MoE(nn.Module): + """Abstract base class for mixture-of-experts modules. - ''' - This is a abstract base class for mixture of experts - - it supports: - a) specifiying experts - b) specifying the gating function (having parameter or not) + Supports specifying a set of experts and a gating function. Subclasses + must implement how experts are combined (see ``MoO`` and ``MoW``). - it needs combining functions (either MoO or MoE) - ''' + Args: + experts: The expert modules to be combined. + gate: The gating function that produces mixing coefficients. + """ def __init__(self, experts, gate): - super(MoE, self).__init__() + super(MoE, self).__init__() self.experts = experts self.gate = gate + class MoO(MoE): + """Mixture of Outputs. + + Each expert produces an output independently, and the outputs are + combined via a weighted sum using coefficients from the gate. + + Args: + experts: The expert modules. + gate: The gating function. + bs_dim: Batch-size dimension of expert outputs (default: 1). + expert_dim: Expert dimension after stacking (default: 0). + """ - ''' - mixture of outputs - ''' def __init__(self, experts, gate, bs_dim=1, expert_dim=0): - super(MoO, self).__init__(experts, gate) + super(MoO, self).__init__(experts, gate) # this is for RNN architecture: bs_dim = 2 for RNN self.bs_dim = bs_dim self.expert_dim = expert_dim def combine(self, o, coef): - - if isinstance(o[0], abc.Sequence): # account for multi_output setting + """Combine expert outputs using the mixing coefficients.""" + if isinstance(o[0], abc.Sequence): # account for multi_output setting return [self.combine(o_, coef) for o_ in zip(*o)] else: o = torch.stack(o) # reshape o to (_, bs, n_expert) b/c coef is (bs, n_expert) o = o.transpose(self.expert_dim, -1) - o = o.transpose(self.bs_dim, -2) + o = o.transpose(self.bs_dim, -2) # change back res = o * coef res = res.transpose(self.expert_dim, -1) res = res.transpose(self.bs_dim, -2) return res.sum(0) - - def forward(self, x, coef=None): # coef is previous coefficient: for IDGate - coef = self.gate(x, coef) # (bs, n_expert) or n_expert + + def forward(self, x, coef=None): + """Compute each expert's output and combine them.""" + coef = self.gate(x, coef) # (bs, n_expert) or n_expert self.last_coef = coef o = [expert(x) for expert in self.experts] return self.combine(o, coef) + class MoW(MoE): + """Mixture of Weights. + + Instead of combining expert outputs, this module combines expert + parameters before the forward pass, effectively producing a single + assembled expert per time step. + """ def forward(self, x, coef=None): - # assume experts has already been assembled + """Run the assembled expert on the input.""" + # assume experts has already been assembled coef = self.gate(x, coef) self.last_coef = coef return self.experts(x, coef) @@ -95,39 +138,50 @@ def forward(self, x, coef=None): ################## sample gating functions for get_coefficients ########### class Gate(ABC, nn.Module): - - ''' - gate function - ''' + """Abstract base class for gating functions.""" def forward(self, x, coef=None): raise NotImplementedError() + class AdaptiveLSTMGate(Gate): + """A gate that computes mixing coefficients from the LSTM hidden state. + + Args: + input_size: Size of the hidden state used as input. + num_experts: Number of experts in the mixture. + normalize: If True, apply softmax to the coefficients. + """ def __init__(self, input_size, num_experts, normalize=False): super(self.__class__, self).__init__() self.forward_function = MLP([input_size, num_experts]) self.normalize = normalize - + def forward(self, x, coef=None): - x, (h, c) = x # h (_, bs, d) - o = self.forward_function(h.transpose(0,1)) # (bs, num_experts) - if self.normalize: + """Produce mixing coefficients from the hidden state.""" + x, (h, c) = x # h (_, bs, d) + o = self.forward_function(h.transpose(0, 1)) # (bs, num_experts) + if self.normalize: return nn.functional.softmax(o, 1) else: return o - + + class NonAdaptiveGate(Gate): + """A gate with learnable (or fixed) coefficients that do not depend on x. + + Args: + num_experts: Number of experts. + coef: Optional initial coefficient tensor. If None, randomly init. + fixed: If True, coefficients are not trainable. + normalize: If True, apply softmax to the coefficients. + """ def __init__(self, num_experts, coef=None, fixed=False, normalize=False): - ''' - fixed coefficient: resnet like with predefined not learnable gate values - normalize: take softmax of the parameters - ''' super(self.__class__, self).__init__() self.normalize = normalize - if coef is None: # initialization + if coef is None: # initialization coef = torch.ones(num_experts) nn.init.uniform_(coef) if fixed: @@ -138,36 +192,63 @@ def __init__(self, num_experts, coef=None, fixed=False, normalize=False): self.coefficients = coef def forward(self, x, coef=None): + """Return the (optionally normalized) mixing coefficients.""" if self.normalize: return nn.functional.softmax(self.coefficients, 0) else: return self.coefficients -class IDGate(Gate): # identity gate - def forward(self, x, coef): # coef is previous coefficient +class IDGate(Gate): + """Identity gate that passes through a previous coefficient unchanged.""" + + def forward(self, x, coef): + """Return the coefficient that was passed in.""" return coef ################ time series example models ################ def moo_linear(in_features, out_features, num_experts, bs_dim=1, expert_dim=0): + """Create a MoO over a set of linear layers with tied shape. + + Args: + in_features: Input feature size. + out_features: Output feature size. + num_experts: Number of expert linear layers. + bs_dim: Batch-size dimension (see ``MoO``). + expert_dim: Expert dimension (see ``MoO``). + + Returns: + A ``MoO`` module wrapping ``num_experts`` linear layers with an + identity gate. + """ # repeat a linear model for self.num_experts times experts = nn.ModuleList() for _ in range(num_experts): experts.append(nn.Linear(in_features, out_features)) - # tie weights later - return MoO(experts, IDGate(), bs_dim=bs_dim, expert_dim=expert_dim) + # tie weights later + return MoO(experts, IDGate(), bs_dim=bs_dim, expert_dim=expert_dim) + class mowLSTM_(nn.Module): + """Internal helper implementing one layer of the mixture-of-weights LSTM. - ''' - helper module for mowLSTM - ''' - def __init__(self, input_size, hidden_size, num_experts=2, batch_first=False): + Applies a per-time-step mixture of LSTM cells by combining the input + and hidden weight matrices across experts. + + Args: + input_size: Input feature dimension. + hidden_size: Hidden state dimension. + num_experts: Number of expert cells to mix (K). + batch_first: If True, expects input shape (batch, seq_len, dim). + """ + + def __init__(self, input_size, hidden_size, num_experts=2, + batch_first=False): super(mowLSTM_, self).__init__() - + self.input_size = input_size self.hidden_size = hidden_size self.num_experts = num_experts @@ -175,7 +256,7 @@ def __init__(self, input_size, hidden_size, num_experts=2, batch_first=False): # build cell self.input_weights = moo_linear(input_size, 4 * hidden_size, - self.num_experts, bs_dim=2) # i,f,g,o + self.num_experts, bs_dim=2) # i,f,g,o self.hidden_weights = moo_linear(hidden_size, 4 * hidden_size, self.num_experts, bs_dim=2) # init same as pytorch version @@ -183,20 +264,17 @@ def __init__(self, input_size, hidden_size, num_experts=2, batch_first=False): for m in self.input_weights.experts: for name, weight in m.named_parameters(): nn.init.uniform_(weight, -stdv, stdv) - # if 'weight' in name: - # nn.init.uniform_(weight) for m in self.hidden_weights.experts: for name, weight in m.named_parameters(): - nn.init.uniform_(weight, -stdv, stdv) - # if 'weight' in name: - # nn.init.orthogonal_(weight) - + nn.init.uniform_(weight, -stdv, stdv) + # maybe: layer normalization: see jeeheh's code # maybe: orthogonal initialization: see jeeheh's code # note: pytorch implementation does neither - def rnn_step(self, x, hidden, coef): # one step of rnn - bs = x.shape[1] + def rnn_step(self, x, hidden, coef): + """Run a single LSTM step with mixed expert parameters.""" + bs = x.shape[1] h, c = hidden gates = self.input_weights(x, coef) + self.hidden_weights(h, coef) # maybe: layer normalization: see jeeheh's code @@ -208,37 +286,47 @@ def rnn_step(self, x, hidden, coef): # one step of rnn outgate = torch.sigmoid(outgate) c = forgetgate * c + ingate * cellgate - h = outgate * torch.tanh(c) # maybe use layer norm here as well + h = outgate * torch.tanh(c) # maybe use layer norm here as well return h, c - + def forward(self, x, hidden, coef): - if self.batch_first: # change to seq_len first + """Run the mixture LSTM over a full sequence.""" + if self.batch_first: # change to seq_len first x = x.transpose(0, 1) seq_len = x.shape[0] output = [] for t in range(seq_len): hidden = self.rnn_step(x[t].unsqueeze(0), hidden, coef) - output.append(hidden[0]) # seq_len x (_, bs, d) + output.append(hidden[0]) # seq_len x (_, bs, d) output = torch.cat(output, 0) return output, hidden -class mowLSTM(nn.Module): - ''' - helper for mowLSTM, - responsible for stacking and bidirectional LSTM - stack according to - https://stackoverflow.com/questions/49224413/difference-between-1-lstm-with-num-layers-2-and-2-lstms-in-pytorch +class mowLSTM(nn.Module): + """Stacked mixture-of-weights LSTM used internally by ``MixLSTM``. + + Handles multi-layer stacking, dropout, and the final output projection. + + Args: + input_size: Input feature size. + hidden_size: Hidden state size. + num_classes: Output dimension of the final projection. + num_experts: Number of expert cells to mix (K). + num_layers: Number of stacked LSTM layers. + batch_first: If True, expects input shape (batch, seq_len, dim). + dropout: Dropout probability between layers. + bidirectional: Whether to use a bidirectional LSTM. + activation: Optional activation applied to the final output. + """ - ''' def __init__(self, input_size, hidden_size, num_classes, num_experts=2, - num_layers=1, batch_first=False, dropout=0, bidirectional=False, - activation=None): + num_layers=1, batch_first=False, dropout=0, + bidirectional=False, activation=None): super(mowLSTM, self).__init__() - + self.input_size = input_size self.hidden_size = hidden_size self.num_classes = num_classes @@ -250,30 +338,33 @@ def __init__(self, input_size, hidden_size, num_classes, num_experts=2, self.h2o = moo_linear(self.num_directions * self.hidden_size, self.num_classes, self.num_experts, bs_dim=2) - + if activation: self.activation = activation else: self.activation = lambda x: x - + self.rnns = nn.ModuleList() for i in range(num_layers * self.num_directions): input_size = input_size if i == 0 else hidden_size - self.rnns.append(mowLSTM_(input_size, hidden_size, num_experts, batch_first)) + self.rnns.append(mowLSTM_(input_size, hidden_size, num_experts, + batch_first)) self.dropouts.append(nn.Dropout(p=dropout)) def forward(self, x, coef): + """Forward pass through the stacked mixture LSTM.""" x, hidden = x self.last_coef = coef - + h, c = hidden hs, cs = [], [] for i in range(self.num_layers): if i != 0 and i != (self.num_layers - 1): - x = self.dropouts[i](x) # waste 1 droput out but no problem - x, hidden = self.rnns[i](x, (h[i].unsqueeze(0), c[i].unsqueeze(0)), coef) + x = self.dropouts[i](x) # waste 1 dropout but no problem + x, hidden = self.rnns[i](x, (h[i].unsqueeze(0), + c[i].unsqueeze(0)), coef) hs.append(hidden[0]) - cs.append(hidden[1]) + cs.append(hidden[1]) h = torch.cat(hs, 0) c = torch.cat(cs, 0) @@ -284,15 +375,25 @@ def forward(self, x, coef): o = self.activation(o) return o, (h, c) - -class ExampleMowLSTM(nn.Module): - ''' - recreate LSTM architectre - then stack them according to +class ExampleMowLSTM(nn.Module): + """Wrapper that instantiates a mixture LSTM with per-time-step gates. + + For each of the ``t`` time steps, a separate ``NonAdaptiveGate`` is + created so that the mixing coefficients can vary over time. All gates + share the same underlying experts. + + Args: + input_size: Input feature size. + hidden_size: Hidden state size. + num_classes: Output dimension. + num_layers: Number of stacked LSTM layers. + num_directions: 1 (unidirectional) or 2 (bidirectional). + dropout: Dropout probability. + activation: Optional output activation. + """ - ''' def __init__(self, input_size, hidden_size, num_classes, num_layers=1, num_directions=1, dropout=0, activation=None): super(ExampleMowLSTM, self).__init__() @@ -304,8 +405,13 @@ def __init__(self, input_size, hidden_size, num_classes, self.dropout = dropout self.activation = activation - def setKT(self, k, t): # k models t steps - '''k experts with maximum of t time steps''' + def setKT(self, k, t): + """Configure the model for ``k`` experts and ``t`` time steps. + + Args: + k: Number of expert cells to mix. + t: Maximum number of time steps; one gate is created per step. + """ self.k = k self.T = t self.cells = nn.ModuleList() @@ -313,10 +419,10 @@ def setKT(self, k, t): # k models t steps experts = mowLSTM(self.input_size, self.hidden_size, self.num_classes, num_experts=self.k, num_layers=self.num_layers, dropout=self.dropout, - bidirectional = (self.num_directions==2), + bidirectional=(self.num_directions == 2), activation=self.activation) self.experts = experts - + for _ in range(t): gate = NonAdaptiveGate(self.k, normalize=True) # gate = AdaptiveLSTMGate(self.hidden_size *\ @@ -327,35 +433,122 @@ def setKT(self, k, t): # k models t steps self.cells.append(MoW(experts, gate)) def forward(self, x, hidden): + """Run the mixture LSTM step-by-step using the per-step gates.""" seq_len, bs, _ = x.shape o = [] for t in range(seq_len): o_, hidden = self.cells[t]((x[t].view(1, bs, -1), hidden)) o.append(o_) - - o = torch.cat(o, 0) # (seq_len, bs, d) + + o = torch.cat(o, 0) # (seq_len, bs, d) return o, hidden - def orthogonal(shape): + """Generate an orthogonal matrix of the given shape via SVD.""" flat_shape = (int(shape[0]), int(np.prod(shape[1:]))) a = np.random.normal(0.0, 1.0, flat_shape) u, _, v = np.linalg.svd(a, full_matrices=False) q = u if u.shape == flat_shape else v return q.reshape(shape) + def lstm_ortho_initializer(shape, scale=1.0): + """Initialize LSTM weights with orthogonal blocks for each of the 4 gates. + + Args: + shape: Target shape where the second dimension must be divisible by 4. + scale: Scalar to multiply the orthogonal matrices by. + + Returns: + A numpy array of the requested shape. + """ size_x = shape[0] - size_h = int(shape[1]/4) # assumes lstm. + size_h = int(shape[1] / 4) # assumes lstm. t = np.zeros(shape) - t[:, :size_h] = orthogonal([size_x, size_h])*scale - t[:, size_h:size_h*2] = orthogonal([size_x, size_h])*scale - t[:, size_h*2:size_h*3] = orthogonal([size_x, size_h])*scale - t[:, size_h*3:] = orthogonal([size_x, size_h])*scale + t[:, :size_h] = orthogonal([size_x, size_h]) * scale + t[:, size_h:size_h * 2] = orthogonal([size_x, size_h]) * scale + t[:, size_h * 2:size_h * 3] = orthogonal([size_x, size_h]) * scale + t[:, size_h * 3:] = orthogonal([size_x, size_h]) * scale return t + class MixLSTM(BaseModel): + """Mixture-of-LSTMs model for clinical time-series prediction. + + Implements the mixLSTM architecture from Oh et al. 2020 for handling + temporal conditional shift: settings where the relationship between + input features and the target changes over time. Instead of sharing a + single set of LSTM parameters across all time steps, MixLSTM maintains + ``num_experts`` independent LSTM cells and, at every time step, + computes a learned convex combination of their parameters using mixing + coefficients constrained to the simplex. This enables smooth + transitions between different temporal dynamics without hard segment + boundaries. + + The model inherits from PyHealth's ``BaseModel`` and infers the input + dimension and sequence length from the ``SampleDataset`` passed at + construction time, so it can be used with any existing PyHealth task + whose input is a time-series tensor. + + The model supports two operating modes that are chosen automatically + based on the dataset's output schema: + + * Standard classification (``binary``, ``multiclass``, ``multilabel``, + or ``regression``): predictions are taken from the final time step of + the sequence and the appropriate PyHealth loss function is applied. + * Per-timestep regression (when the output schema is a raw ``tensor``): + the model outputs a value at every time step and the MSE loss is + computed over timesteps beginning at ``prev_used_timestamps``. This + reproduces the synthetic copy-memory task described in Section 4.1 + of the paper. + + Paper: + Oh et al. 2020, "Relaxed Parameter Sharing: Effectively Modeling + Time-Varying Relationships in Clinical Time-Series." + https://arxiv.org/abs/1906.02898 + + Args: + dataset: A ``SampleDataset`` used to infer input feature size and + sequence length from the first sample. + num_experts: Number of expert LSTM cells to mix (``K`` in the paper). + Higher values give the model more flexibility to vary parameters + over time at the cost of more parameters. Defaults to 2. + hidden_size: Size of the LSTM hidden state. Defaults to 100. + prev_used_timestamps: For the per-timestep regression mode, the + index of the first time step at which the loss is computed. + Earlier time steps are skipped because their targets are + trivially defined in the synthetic task. Ignored in the + standard classification mode. Defaults to 0. + + Attributes: + input_size: Inferred input feature dimension. + time_steps: Inferred sequence length. + hidden_size: LSTM hidden state size. + _per_timestep: ``True`` when the model is running in per-timestep + regression mode, ``False`` when it is running in standard + classification mode. + + Example: + >>> from pyhealth.datasets import create_sample_dataset + >>> from pyhealth.models import MixLSTM + >>> samples = [ + ... { + ... "patient_id": f"p-{i}", + ... "visit_id": "v-0", + ... "series": torch.randn(48, 76).numpy().tolist(), + ... "label": int(i % 2), + ... } + ... for i in range(100) + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"series": "tensor"}, + ... output_schema={"label": "multiclass"}, + ... dataset_name="demo", + ... ) + >>> model = MixLSTM(dataset=dataset, num_experts=2, hidden_size=64) + """ def __init__(self, dataset: SampleDataset, num_experts=2, hidden_size=100, prev_used_timestamps=0): @@ -370,7 +563,8 @@ def __init__(self, dataset: SampleDataset, num_experts=2, hidden_size=100, val = sample[self.input_key] if isinstance(val, (list, tuple)): for item in val: - if torch.is_tensor(item) or isinstance(item, (list, tuple, np.ndarray)): + if torch.is_tensor(item) or isinstance( + item, (list, tuple, np.ndarray)): val = item break if torch.is_tensor(val): @@ -397,9 +591,9 @@ def __init__(self, dataset: SampleDataset, num_experts=2, hidden_size=100, num_classes = int(self.get_output_size()) self.model = ExampleMowLSTM(self.input_size, hidden_size, - num_classes, num_layers=1, - num_directions=1, dropout=0, - activation=None) + num_classes, num_layers=1, + num_directions=1, dropout=0, + activation=None) self.num_layers = 1 self.num_directions = 1 @@ -407,6 +601,35 @@ def __init__(self, dataset: SampleDataset, num_experts=2, hidden_size=100, self.model.setKT(num_experts, self.time_steps) def forward(self, **kwargs): + """Run a forward pass. + + Expects the input tensor under ``kwargs[self.input_key]`` with + shape ``(batch, seq_len, input_dim)``. If a label tensor is also + provided under ``self.label_key``, the appropriate loss is + computed and returned. + + Args: + **kwargs: Batch dictionary, typically produced by a PyHealth + DataLoader and passed by ``Trainer`` as ``model(**batch)``. + + Returns: + A dictionary with the following keys, where shapes depend on + which mode the model is operating in: + + * Classification mode (``_per_timestep = False``): + - ``logit``: ``(batch, num_classes)`` from the final step. + - ``y_prob``: Probabilities produced by ``prepare_y_prob``. + - ``loss`` (optional): PyHealth's standard loss for the task. + - ``y_true`` (optional): Ground-truth labels. + + * Per-timestep regression mode (``_per_timestep = True``): + - ``logit``: ``(batch, seq_len, 1)`` — one prediction per + time step. + - ``y_prob``: Same tensor as ``logit``. + - ``loss`` (optional): MSE computed over time steps from + ``prev_used_timestamps`` onward. + - ``y_true`` (optional): Ground-truth target tensor. + """ x = kwargs.get(self.input_key) # (bs, seq_len, d) => (seq_len, bs, d) @@ -423,7 +646,7 @@ def forward(self, **kwargs): logits_seq = outputs.permute(1, 0, 2) if self._per_timestep: - # --- Per-timestep regression (original MLHC2019 synthetic task) --- + # --- Per-timestep regression (original MLHC2019 synthetic task) results = {"logit": logits_seq, "y_prob": logits_seq} if self.label_key and self.label_key in kwargs: y_true = kwargs[self.label_key].to(device) @@ -451,5 +674,5 @@ def forward(self, **kwargs): return results def after_backward(self): + """Hook called after backward(); no-op for this model.""" pass - \ No newline at end of file From f19b6fed4e34c5ce25ea92e76debcaa39ddd9e7b Mon Sep 17 00:00:00 2001 From: amanluth03 Date: Mon, 20 Apr 2026 17:24:21 -0500 Subject: [PATCH 05/19] updated index --- docs/api/models.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..505640dc8 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -174,6 +174,7 @@ API Reference models/pyhealth.models.MLP models/pyhealth.models.CNN models/pyhealth.models.RNN + models/pyhealth.models.mixlst models/pyhealth.models.GNN models/pyhealth.models.Transformer models/pyhealth.models.TransformersModel From 634a80edc8836cc10d7b12f377a24d0db41cd432 Mon Sep 17 00:00:00 2001 From: Tanmay Mittal Date: Wed, 22 Apr 2026 05:11:51 -0500 Subject: [PATCH 06/19] abalation study --- examples/mimic3_synthetic_mixlstm.py | 562 +++++++++++++++++++++++++++ 1 file changed, 562 insertions(+) create mode 100644 examples/mimic3_synthetic_mixlstm.py diff --git a/examples/mimic3_synthetic_mixlstm.py b/examples/mimic3_synthetic_mixlstm.py new file mode 100644 index 000000000..49c487db6 --- /dev/null +++ b/examples/mimic3_synthetic_mixlstm.py @@ -0,0 +1,562 @@ +""" +MixLSTM Hyperparameter Search Experiment +Synthetic time-series regression task with PyHealth. + +All intermediate results (distributions, predictions, search metrics) +are kept in memory and passed directly to the visualization functions +instead of being written to / read from disk. +""" + +import os +import random +import logging +from dataclasses import dataclass, field + +import numpy as np +import pandas as pd +import torch +import torch.optim as optim +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import seaborn as sns +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.trainer import Trainer +from pyhealth.models import MixLSTM + + +# ────────────────────────────────────────────────────────────── +# In-memory result containers +# ────────────────────────────────────────────────────────────── + +@dataclass +class AblationResult: + """Everything produced by a single ablation run.""" + learning_rate: float + optimizer_name: str + results_df: pd.DataFrame + k_dist: list[np.ndarray] + d_dist: list[np.ndarray] + best_predictions: dict | None = None # {"pred": ..., "y_true": ..., "k": ..., "hidden_size": ...} + best_model_state: dict | None = None + + @property + def label(self) -> str: + """Human-readable label for plots and logs.""" + return f"{self.optimizer_name} lr={self.learning_rate}" + + +# ────────────────────────────────────────────────────────────── +# Configuration +# ────────────────────────────────────────────────────────────── + +SEED = 42 +NUM_SAMPLES = 1000 +T = 30 # sequence length +INPUT_DIM = 3 +PREV_USED_TIMESTAMPS = 10 # l +CHANGE_BETWEEN_TASKS = 0.05 # delta + +BATCH_SIZE = 100 +K_LIST = [2] +HIDDEN_SIZE_LIST = [100, 150, 300, 500, 700, 900, 1100] +NUM_RUNS = 1 # 20 +MAX_EPOCHS = 2 # 30 + +SAVE_DIR = os.path.dirname(os.path.abspath(__file__)) + +# Visualization +MAX_MSE = 100 +ABLATION_LRS = [0.0001, 0.0005, 0.001, 0.005, 0.01] + + +# ────────────────────────────────────────────────────────────── +# Utility functions +# ────────────────────────────────────────────────────────────── + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def get_device() -> torch.device: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Running on device: {device}") + return device + + +# ────────────────────────────────────────────────────────────── +# Data generation +# ────────────────────────────────────────────────────────────── + +def convert_distb(a: np.ndarray) -> np.ndarray: + a_min = min(a) + a_max = max(a) + a = (a - a_min) / (a_max - a_min) + a_sum = sum(a) + a = a / a_sum + return a + + +def generate_distributions( + T: int, + prev_used_timestamps: int, + input_dim: int, + change_between_tasks: float, +) -> tuple[list[np.ndarray], list[np.ndarray]]: + k_dist = [] + d_dist = [] + for i in range(T): + if i < prev_used_timestamps: + k_dist.append(np.ones(prev_used_timestamps)) + d_dist.append(np.ones(input_dim)) + elif i == prev_used_timestamps: + k_dist.append(convert_distb(np.random.uniform(size=(prev_used_timestamps,)))) + d_dist.append(convert_distb(np.random.uniform(size=(input_dim,)))) + else: + delta_t = np.random.uniform( + -change_between_tasks, change_between_tasks, size=(prev_used_timestamps,) + ) + delta_d = np.random.uniform( + -change_between_tasks, change_between_tasks, size=(input_dim,) + ) + k_dist.append(convert_distb(k_dist[i - 1] + delta_t)) + d_dist.append(convert_distb(d_dist[i - 1] + delta_d)) + return k_dist, d_dist + + +def generate_xy( + num_samples: int, + T: int, + input_dim: int, + prev_used_timestamps: int, + k_dist: list[np.ndarray], + d_dist: list[np.ndarray], +) -> tuple[np.ndarray, np.ndarray]: + x_size = num_samples * T * input_dim + x = np.zeros(x_size) + sparse_count = int(x_size / 10) + x[np.random.choice(x_size, size=sparse_count, replace=False)] = ( + np.random.uniform(size=sparse_count) * 100 + ) + x = np.resize(x, (num_samples, T, input_dim)) + + y = np.ones((num_samples, T, 1)) + for i in range(T): + if i >= prev_used_timestamps: + y[:, i, 0] = np.matmul( + np.matmul(x[:, i - prev_used_timestamps : i, :], d_dist[i]), + k_dist[i], + ) + return x, y + + +# ────────────────────────────────────────────────────────────── +# PyHealth dataset helpers +# ────────────────────────────────────────────────────────────── + +def make_dataset(x: np.ndarray, y: np.ndarray, split_name: str): + samples = [ + { + "patient_id": f"{split_name}-patient-{i}", + "visit_id": "visit-0", + "series": x[i].tolist(), + "y": y[i].squeeze(-1).tolist(), + } + for i in range(len(x)) + ] + return create_sample_dataset( + samples=samples, + input_schema={"series": "tensor"}, + output_schema={"y": "tensor"}, + dataset_name=f"mixlstm_{split_name}", + ) + + +def build_dataloaders( + k_dist, d_dist, num_samples, T, input_dim, prev_used_timestamps, batch_size +): + x_train, y_train = generate_xy(num_samples, T, input_dim, prev_used_timestamps, k_dist, d_dist) + x_val, y_val = generate_xy(num_samples, T, input_dim, prev_used_timestamps, k_dist, d_dist) + x_test, y_test = generate_xy(num_samples, T, input_dim, prev_used_timestamps, k_dist, d_dist) + + train_data = make_dataset(x_train, y_train, "train") + val_data = make_dataset(x_val, y_val, "val") + test_data = make_dataset(x_test, y_test, "test") + + train_loader = get_dataloader(train_data, batch_size=batch_size, shuffle=True) + val_loader = get_dataloader(val_data, batch_size=batch_size, shuffle=True) + test_loader = get_dataloader(test_data, batch_size=batch_size, shuffle=True) + + return train_data, train_loader, val_loader, test_loader + + +# ────────────────────────────────────────────────────────────── +# Training & evaluation +# ────────────────────────────────────────────────────────────── + +def collect_predictions(model, test_loader, device): + """Run inference and return predictions + ground truth as numpy arrays.""" + model.eval() + l = model.prev_used_timestamps + preds, y_trues = [], [] + + with torch.no_grad(): + for batch in test_loader: + batch = { + k_: v.to(device) if isinstance(v, torch.Tensor) else v + for k_, v in batch.items() + } + output = model(**batch) + preds.append(output["y_prob"][:, l:, :].cpu().numpy()) + y_trues.append(output["y_true"][:, l:, :].cpu().numpy()) + + return { + "pred": np.concatenate(preds, axis=0).flatten(), + "y_true": np.concatenate(y_trues, axis=0).flatten(), + } + + +def run_hyperparameter_search( + train_data, + train_loader, + val_loader, + test_loader, + device, + prev_used_timestamps, + k_list, + hidden_size_list, + num_runs, + max_epochs, + learning_rate, + optimizer_class=optim.Adam, +): + """Execute the hyperparameter search. Returns (results_df, best_predictions, best_model_state).""" + results = [] + best_val_loss_overall = np.inf + best_predictions = None + best_model_state = None + + for run in range(num_runs): + k = random.choice(k_list) + hidden_size = random.choice(hidden_size_list) + + print(f"\n{'=' * 60}") + print(f"Run {run + 1}/{num_runs} | k (num_experts): {k} | hidden_size: {hidden_size}") + print("=" * 60) + + model = MixLSTM( + dataset=train_data, + num_experts=k, + hidden_size=hidden_size, + prev_used_timestamps=prev_used_timestamps, + ) + model = model.to(device) + + trainer = Trainer(model=model, device=device) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + optimizer_class=optimizer_class, + optimizer_params={"lr": learning_rate}, + epochs=max_epochs, + monitor="loss", + monitor_criterion="min", + ) + + print(f"\nEvaluating Best Model for Run {run + 1}...") + val_metrics = trainer.evaluate(val_loader) + test_metrics = trainer.evaluate(test_loader) + + val_loss = val_metrics.get("loss", None) + test_loss = test_metrics.get("loss", None) + + if val_loss < best_val_loss_overall: + best_val_loss_overall = val_loss + print(f" New best val loss: {val_loss:.6f}") + predictions = collect_predictions(model, test_loader, device) + predictions["k"] = k + predictions["hidden_size"] = hidden_size + predictions["run"] = run + best_predictions = predictions + best_model_state = {k_: v.cpu().clone() for k_, v in model.state_dict().items()} + + results.append({ + "Run": run + 1, + "k (experts)": k, + "Hidden Size": hidden_size, + "Val Loss": val_loss, + "Test Loss": test_loss, + "num_params": sum(p.numel() for p in model.parameters() if p.requires_grad), + "epoch": max_epochs, + }) + + return pd.DataFrame(results), best_predictions, best_model_state + + +# ────────────────────────────────────────────────────────────── +# Ablation study — learning rate sweep +# ────────────────────────────────────────────────────────────── + +def run_single_ablation( + learning_rate: float, + optimizer_class=optim.Adam, + optimizer_name: str = "Adam", +) -> AblationResult: + """Run the full search for one (optimizer, lr) combination and return all results in memory.""" + set_seed(SEED) + device = get_device() + logging.getLogger("pyhealth.trainer").setLevel(logging.WARNING) + + if device.type == "cuda": + torch.set_default_device(device) + + k_dist, d_dist = generate_distributions( + T, PREV_USED_TIMESTAMPS, INPUT_DIM, CHANGE_BETWEEN_TASKS + ) + + train_data, train_loader, val_loader, test_loader = build_dataloaders( + k_dist, d_dist, NUM_SAMPLES, T, INPUT_DIM, PREV_USED_TIMESTAMPS, BATCH_SIZE + ) + + print(f"\n{'#' * 60}") + print(f" ABLATION — optimizer = {optimizer_name}, learning_rate = {learning_rate}") + print(f"{'#' * 60}") + + results_df, best_predictions, best_model_state = run_hyperparameter_search( + train_data=train_data, + train_loader=train_loader, + val_loader=val_loader, + test_loader=test_loader, + device=device, + prev_used_timestamps=PREV_USED_TIMESTAMPS, + k_list=K_LIST, + hidden_size_list=HIDDEN_SIZE_LIST, + num_runs=NUM_RUNS, + max_epochs=MAX_EPOCHS, + learning_rate=learning_rate, + optimizer_class=optimizer_class, + ) + + best = results_df.sort_values(by="Test Loss").reset_index(drop=True) + print(f"\nTop 5 results for {optimizer_name} lr={learning_rate}:") + print(best.head(5)) + + return AblationResult( + learning_rate=learning_rate, + optimizer_name=optimizer_name, + results_df=results_df, + k_dist=k_dist, + d_dist=d_dist, + best_predictions=best_predictions, + best_model_state=best_model_state, + ) + + +def run_all_ablations() -> list[AblationResult]: + """Run learning-rate ablations with Adam (original behaviour).""" + ablation_results = [ + run_single_ablation(lr, optim.Adam, "Adam") for lr in ABLATION_LRS + ] + _print_summary("Learning Rate Sweep (Adam)", ablation_results) + return ablation_results + + +# ────────────────────────────────────────────────────────────── +# Ablation study — optimizer comparison (Adam vs SGD) +# ────────────────────────────────────────────────────────────── + +ABLATION_OPTIMIZER_LR = 0.001 # fixed LR used for the optimizer comparison + +def ablations_optimizing_adam() -> AblationResult: + """Ablation: Adam optimizer at lr=0.001.""" + return run_single_ablation(ABLATION_OPTIMIZER_LR, optim.Adam, "Adam") + + +def ablations_optimizing_sgd() -> AblationResult: + """Ablation: SGD optimizer at lr=0.001.""" + return run_single_ablation(ABLATION_OPTIMIZER_LR, optim.SGD, "SGD") + + +def run_optimizer_ablations() -> list[AblationResult]: + """Run Adam vs SGD at a fixed learning rate and print a comparison.""" + results = [ + ablations_optimizing_adam(), + ablations_optimizing_sgd(), + ] + _print_summary("Optimizer Comparison (Adam vs SGD)", results) + return results + + +def _print_summary(title: str, ablation_results: list[AblationResult]): + """Pretty-print a summary table for a list of ablation results.""" + summary_rows = [] + for result in ablation_results: + best_row = result.results_df.sort_values(by="Test Loss").iloc[0] + summary_rows.append({ + "Optimizer": result.optimizer_name, + "Learning Rate": result.learning_rate, + "Best Val Loss": best_row["Val Loss"], + "Best Test Loss": best_row["Test Loss"], + "k (experts)": best_row["k (experts)"], + "Hidden Size": best_row["Hidden Size"], + "num_params": best_row["num_params"], + }) + + summary_df = pd.DataFrame(summary_rows) + print("\n" + "=" * 60) + print(f" {title}") + print("=" * 60) + print(summary_df.to_string(index=False)) + + +# ────────────────────────────────────────────────────────────── +# Visualization (all functions take in-memory data) +# ────────────────────────────────────────────────────────────── + +def visualize_hyperparameter_search(ablation_results: list[AblationResult], prefix: str = ""): + """Plot MSE loss vs. hidden size for every learning rate.""" + print("--- 1. Analyzing Hyperparameter Search (Ablation) ---") + + plt.figure(figsize=(12, 7)) + palette = sns.color_palette("tab10", len(ablation_results)) + + for i, result in enumerate(ablation_results): + tag = result.label + df = result.results_df.copy() + df = df[(df["Val Loss"] <= MAX_MSE) & (df["Test Loss"] <= MAX_MSE)] + + best = df.sort_values(by="Val Loss").head(1) + print( + f" {tag} Best Val Loss: {best['Val Loss'].values[0]:.6f} " + f"(Hidden Size={best['Hidden Size'].values[0]})" + ) + + color = palette[i] + + sns.scatterplot( + data=df, x="Hidden Size", y="Val Loss", + label=f"Val ({tag})", color=color, + marker="o", alpha=0.4, s=40, + ) + sns.scatterplot( + data=df, x="Hidden Size", y="Test Loss", + label=f"Test ({tag})", color=color, + marker="x", alpha=0.4, s=40, + ) + + val_mean = df.groupby("Hidden Size")["Val Loss"].mean().sort_index() + test_mean = df.groupby("Hidden Size")["Test Loss"].mean().sort_index() + plt.plot(val_mean.index, val_mean.values, color=color, linewidth=2, linestyle="-") + plt.plot(test_mean.index, test_mean.values, color=color, linewidth=2, linestyle="--") + + plt.title("Ablation: MSE Loss vs. Hidden Size by Learning Rate") + plt.xlabel("Hidden Size") + plt.ylabel("MSE Loss") + plt.legend(title="Loss Type / LR", bbox_to_anchor=(1.05, 1), loc="upper left") + plt.grid(True, linestyle="--", alpha=0.6) + plt.ylim(0, MAX_MSE) + plt.tight_layout() + + out_path = os.path.join(SAVE_DIR, f"{prefix}ablation_loss_vs_hidden_size.png") + plt.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {out_path}") + + +def visualize_predictions(ablation_results: list[AblationResult], num_samples: int = 3, prefix: str = ""): + """Plot predicted vs. true values for a few test samples per learning rate.""" + print("\n--- 2. Analyzing Predictions (Ablation) ---") + + # Only include results that have saved predictions + valid_results = [r for r in ablation_results if r.best_predictions is not None] + if not valid_results: + print(" No predictions available to plot.") + return + + fig, axes = plt.subplots(len(valid_results), num_samples, figsize=(15, 4 * len(valid_results))) + if len(valid_results) == 1: + axes = [axes] + + for row, result in enumerate(valid_results): + tag = result.label + y_true_flat = result.best_predictions["y_true"] + pred_flat = result.best_predictions["pred"] + + l = result.best_predictions.get("k", PREV_USED_TIMESTAMPS) + eval_steps = T - l + num_test_samples = len(y_true_flat) // eval_steps + limit = num_test_samples * eval_steps + + y_true = np.reshape(y_true_flat[:limit], (num_test_samples, eval_steps)) + pred = np.reshape(pred_flat[:limit], (num_test_samples, eval_steps)) + + sample_indices = np.random.choice( + num_test_samples, min(num_samples, num_test_samples), replace=False + ) + + for col, sample_idx in enumerate(sample_indices): + ax = axes[row][col] + ax.plot(y_true[sample_idx], label="True", color="blue", marker="o", markersize=4) + ax.plot(pred[sample_idx], label="Predicted", color="red", linestyle="--", marker="x", markersize=4) + ax.set_title(f"{tag} | Sample #{sample_idx}") + ax.set_xlabel("Time Steps") + ax.set_ylabel("Value") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + out_path = os.path.join(SAVE_DIR, f"{prefix}ablation_predictions.png") + plt.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {out_path}") + + +def visualize_synthetic_shift(ablation_results: list[AblationResult], prefix: str = ""): + """Plot the k_dist heatmap for each learning rate's task distributions.""" + print("\n--- 3. Analyzing Synthetic Data Shift (Ablation) ---") + + fig, axes = plt.subplots(1, len(ablation_results), figsize=(6 * len(ablation_results), 5)) + if len(ablation_results) == 1: + axes = [axes] + + for i, result in enumerate(ablation_results): + k_dist_matrix = np.stack(result.k_dist) + sns.heatmap(k_dist_matrix.T, cmap="viridis", ax=axes[i], cbar_kws={"label": "Weight"}) + axes[i].set_title(f"k_dist Shift ({result.label})") + axes[i].set_xlabel("Time Step (T)") + axes[i].set_ylabel("Lookback Step (l)") + + plt.tight_layout() + out_path = os.path.join(SAVE_DIR, f"{prefix}ablation_synthetic_shift.png") + plt.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close() + print(f" Saved: {out_path}") + + +def run_all_visualizations(ablation_results: list[AblationResult], prefix: str = ""): + """Generate all three ablation plots from in-memory results.""" + visualize_hyperparameter_search(ablation_results, prefix=prefix) + visualize_synthetic_shift(ablation_results, prefix=prefix) + visualize_predictions(ablation_results, prefix=prefix) + + +# ────────────────────────────────────────────────────────────── +# Main +# ────────────────────────────────────────────────────────────── + +def main(): + # Learning-rate sweep (Adam only, multiple LRs) + lr_results = run_all_ablations() + run_all_visualizations(lr_results, prefix="lr_sweep_") + + # Optimizer comparison (Adam vs SGD at fixed LR) + optimizer_results = run_optimizer_ablations() + run_all_visualizations(optimizer_results, prefix="optim_comp_") + + +if __name__ == "__main__": + main() \ No newline at end of file From a298239ab7247b56901a40ea95ee4a6e045029e2 Mon Sep 17 00:00:00 2001 From: Tanmay Mittal Date: Wed, 22 Apr 2026 05:14:39 -0500 Subject: [PATCH 07/19] fixed abalation study --- examples/mimic3_synthetic_mixlstm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/mimic3_synthetic_mixlstm.py b/examples/mimic3_synthetic_mixlstm.py index 49c487db6..94c949dbc 100644 --- a/examples/mimic3_synthetic_mixlstm.py +++ b/examples/mimic3_synthetic_mixlstm.py @@ -60,8 +60,8 @@ def label(self) -> str: BATCH_SIZE = 100 K_LIST = [2] HIDDEN_SIZE_LIST = [100, 150, 300, 500, 700, 900, 1100] -NUM_RUNS = 1 # 20 -MAX_EPOCHS = 2 # 30 +NUM_RUNS = 20 # 20 +MAX_EPOCHS = 30 # 30 SAVE_DIR = os.path.dirname(os.path.abspath(__file__)) From e567d0ac9b70513729405bc4e1ae2cdb42c80149 Mon Sep 17 00:00:00 2001 From: Tanmay Mittal Date: Wed, 22 Apr 2026 05:46:01 -0500 Subject: [PATCH 08/19] comments draft 1 --- examples/mimic3_synthetic_mixlstm.py | 47 ++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/examples/mimic3_synthetic_mixlstm.py b/examples/mimic3_synthetic_mixlstm.py index 94c949dbc..e7f9ad570 100644 --- a/examples/mimic3_synthetic_mixlstm.py +++ b/examples/mimic3_synthetic_mixlstm.py @@ -25,6 +25,53 @@ from pyhealth.models import MixLSTM +# ====================================================================== +# MixLSTM Hyperparameter Search Experiment +# Synthetic time-series regression task with PyHealth +# ====================================================================== +# +# EXPERIMENTAL SETUP +# ------------------ +# Dataset: Synthetic non-stationary time-series regression. 1,000 sequences +# per split (train/val/test), length T=30, 3 input features. Inputs are +# 90% sparse. Targets from step l=10 onward are weighted combinations of +# prior inputs, where the weights drift by delta=0.05 per step to simulate +# distribution shift. +# +# Model: MixLSTM (PyHealth) with k=2 experts and lookback window l=10. +# Hidden size sampled from {100, 150, 300, 500, 700, 900, 1100}. +# 20 random-search runs per config, 30 epochs each, batch size 100. +# +# ABLATION STUDIES +# ---------------- +# 1) Learning rate sweep: Adam at lr in {0.0001, 0.0005, 0.001, 0.005, 0.01} +# 2) Optimizer comparison: Adam vs SGD at lr=0.001 +# 3) Every other parameter kept as default +# +# FINDINGS +# ---------------- +# Adam consistently better than SGD +# Adam lowest val loss MSE = 0.430089, test loss MSE = 0.467544 +# SGD loweest val loss MSE = 16.388920, test loss MSE = 16.411073 +# +# Learning Rate comparison +# learnng rate hidden-parameter 100, 150, 300, 500, 700, 900, 1100 MSE val loss MSE test loss +# 0.0001 +# 0.0005 +# 0.001 +# 0.005 +# 0.01 +# +# +# +# +# +# +# +# +# + + # ────────────────────────────────────────────────────────────── # In-memory result containers # ────────────────────────────────────────────────────────────── From 1178691863c9f4236d001e7fe0cf58c4dc3cb934 Mon Sep 17 00:00:00 2001 From: Tanmay Mittal Date: Wed, 22 Apr 2026 15:56:47 -0500 Subject: [PATCH 09/19] abalation explanation done --- examples/mimic3_synthetic_mixlstm.py | 51 ++++++++++++++++++---------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/examples/mimic3_synthetic_mixlstm.py b/examples/mimic3_synthetic_mixlstm.py index e7f9ad570..b1fc6997f 100644 --- a/examples/mimic3_synthetic_mixlstm.py +++ b/examples/mimic3_synthetic_mixlstm.py @@ -50,25 +50,42 @@ # # FINDINGS # ---------------- -# Adam consistently better than SGD -# Adam lowest val loss MSE = 0.430089, test loss MSE = 0.467544 -# SGD loweest val loss MSE = 16.388920, test loss MSE = 16.411073 -# -# Learning Rate comparison -# learnng rate hidden-parameter 100, 150, 300, 500, 700, 900, 1100 MSE val loss MSE test loss -# 0.0001 -# 0.0005 -# 0.001 -# 0.005 -# 0.01 -# -# -# -# -# -# +# 1. OPTIMIZER COMPARISON +# ---------------------------------------------------------------------------- +# Conclusion: Adam consistently outperformed SGD across training runs. +# +# | Optimizer | Lowest Val Loss (MSE) | Lowest Test Loss (MSE) | +# |-----------|-----------------------|------------------------| +# | Adam | 0.430089 | 0.467544 | +# | SGD | 16.388920 | 16.411073 | # +# 2. LEARNING RATE VS. HIDDEN SIZE COMPARISON +# -------------------------------------------------------------------------------------------------------------------------- +# Format: (Validation Loss MSE - Test Loss MSE) +# +# | Hidden Size +# LR | 100 150 300 500 700 900 1100 +# ----------|--------------------------------------------------------------------------------------------------------------- +# 0.0001 | (-) (14.14 - 14.54) (10.76 - 11.05) (7.59 - 7.73) (5.53 - 5.44) (4.60 - 4.76) (4.31 - 4.39) +# 0.0005 | (10.78 - 11.06) (9.30 - 9.76) (5.31 - 5.87) (4.02 - 4.39) (2.81 - 3.09) (1.61 - 1.89) (1.33 - 1.52) +# 0.001 | (6.37 - 6.51) (4.49 - 4.62) (2.60 - 2.67) (1.26 - 1.31) (0.87 - 0.91) (0.69 - 0.77) (0.43 - 0.46) +# 0.005 | (2.20 - 2.28) (1.41 - 1.53) (0.68 - 0.77) (0.48 - 0.62) (0.89 - 1.01) (-) (0.68 - 0.74) +# 0.01 | (1.79 - 1.88) (1.42 - 1.47) (1.10 - 1.14) (1.03 - 1.10) (1.54 - 1.58) (0.91 - 0.98) (2.24 - 2.41) +# ========================================================================================================================== +# +# Conclution: +# LR = 0.0001 was the worst performer overall across all hidden sizes +# LR = 0.0005 was also the second word performer overall across almost all hidden states +# LR = 0.001 this was the learning rate that the paper used. LR value 0.01 and 0.005 were better in the lower hidden sizes +# eg 100, 150, 300, 500. For the reast LR 0.001 was the best choice overall +# LR = 0.05 this rate was the best overall for the lower hidden sizes from 100 to 500 but then had a spike +# at 700 but then managed to go down. Ideal for lower hidden rates +# LR = 0.01 this rate was quite spradic and unstable and it went up and down multiple times and is not recommended # +# Overall Conclution of the entire study: +# Adam optimization gives the best results +# For learning rate 0.001 is great for hidden sizes above 500 and LR = 0.005 is the best for hidden size below 500 +# # From fd026449c0823984434bf420ff202a2faa896d89 Mon Sep 17 00:00:00 2001 From: Tanmay Mittal Date: Wed, 22 Apr 2026 16:48:32 -0500 Subject: [PATCH 10/19] comments --- examples/mimic3_synthetic_mixlstm.py | 322 ++++++++++++++++++++++++++- 1 file changed, 310 insertions(+), 12 deletions(-) diff --git a/examples/mimic3_synthetic_mixlstm.py b/examples/mimic3_synthetic_mixlstm.py index b1fc6997f..91bdebda0 100644 --- a/examples/mimic3_synthetic_mixlstm.py +++ b/examples/mimic3_synthetic_mixlstm.py @@ -86,9 +86,16 @@ # Adam optimization gives the best results # For learning rate 0.001 is great for hidden sizes above 500 and LR = 0.005 is the best for hidden size below 500 # +# How to run Study +# pip install seaborn +# run the python file +# you will see 6 .png files diplaying the results as graphs # + + + # ────────────────────────────────────────────────────────────── # In-memory result containers # ────────────────────────────────────────────────────────────── @@ -139,6 +146,11 @@ def label(self) -> str: # ────────────────────────────────────────────────────────────── def set_seed(seed: int) -> None: + """Set random seeds for Python, NumPy, and PyTorch for reproducibility. + + Args: + seed: Integer seed value applied to every RNG. + """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) @@ -147,6 +159,12 @@ def set_seed(seed: int) -> None: def get_device() -> torch.device: + """Detect and return the best available compute device. + + Returns: + ``torch.device("cuda")`` when a CUDA GPU is available, + otherwise ``torch.device("cpu")``. + """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Running on device: {device}") return device @@ -157,6 +175,19 @@ def get_device() -> torch.device: # ────────────────────────────────────────────────────────────── def convert_distb(a: np.ndarray) -> np.ndarray: + """Min-max normalize an array and rescale it to sum to one. + + The array is first shifted and scaled to the [0, 1] range via + min-max normalization, then divided by its sum so that it + forms a valid discrete probability distribution. + + Args: + a: 1-D numpy array of raw (un-normalized) weights. + + Returns: + A 1-D numpy array of the same shape whose elements are + non-negative and sum to 1. + """ a_min = min(a) a_max = max(a) a = (a - a_min) / (a_max - a_min) @@ -171,6 +202,34 @@ def generate_distributions( input_dim: int, change_between_tasks: float, ) -> tuple[list[np.ndarray], list[np.ndarray]]: + """Generate time-varying weight distributions for synthetic targets. + + Creates ``k_dist`` (temporal) and ``d_dist`` (feature) weight + vectors that drift by a small delta at each step beyond the + lookback window, simulating non-stationary distribution shift. + + For time steps before *prev_used_timestamps*, both distributions + are uniform placeholders. At step *prev_used_timestamps* the + distributions are initialized randomly, and at each subsequent + step a uniform perturbation in ``[-change_between_tasks, + +change_between_tasks]`` is added before re-normalization. + + Args: + T: Total sequence length. + prev_used_timestamps: Lookback window size (*l*). + Distributions before this index are uniform placeholders. + input_dim: Number of input features per time step. + change_between_tasks: Maximum per-step drift (*delta*) + applied uniformly at random to each weight element. + + Returns: + A tuple ``(k_dist, d_dist)`` where: + + * ``k_dist`` is a list of *T* arrays, each of shape + ``(prev_used_timestamps,)``. + * ``d_dist`` is a list of *T* arrays, each of shape + ``(input_dim,)``. + """ k_dist = [] d_dist = [] for i in range(T): @@ -200,6 +259,31 @@ def generate_xy( k_dist: list[np.ndarray], d_dist: list[np.ndarray], ) -> tuple[np.ndarray, np.ndarray]: + + """Generate sparse input sequences and their regression targets. + + Inputs are 90 % sparse (zeros) with the remaining 10 % drawn + uniformly from ``[0, 100)``. For time steps ``t >= l`` the target + is ``x[t-l:t, :] @ d_dist[t] @ k_dist[t]``; earlier targets are + ones (placeholders). + + Args: + num_samples: Number of independent sequences to generate. + T: Sequence length (number of time steps). + input_dim: Dimensionality of input features. + prev_used_timestamps: Lookback window size (*l*). + k_dist: Temporal weight distributions as returned by + :func:`generate_distributions`. + d_dist: Feature weight distributions as returned by + :func:`generate_distributions`. + + Returns: + A tuple ``(x, y)`` where: + + * ``x`` has shape ``(num_samples, T, input_dim)``. + * ``y`` has shape ``(num_samples, T, 1)``. + """ + x_size = num_samples * T * input_dim x = np.zeros(x_size) sparse_count = int(x_size / 10) @@ -223,6 +307,23 @@ def generate_xy( # ────────────────────────────────────────────────────────────── def make_dataset(x: np.ndarray, y: np.ndarray, split_name: str): + """Wrap numpy arrays into a PyHealth ``SampleDataset``. + + Each sequence is registered as a separate patient with a single + visit containing the full time-series. + + Args: + x: Input tensor of shape ``(N, T, D)``. + y: Target tensor of shape ``(N, T, 1)``. + split_name: Identifier for the split (e.g. ``"train"``, + ``"val"``, ``"test"``). Used in patient IDs and as the + PyHealth dataset name suffix. + + Returns: + A PyHealth ``SampleDataset`` ready to be passed to + ``get_dataloader``. + """ + samples = [ { "patient_id": f"{split_name}-patient-{i}", @@ -243,6 +344,30 @@ def make_dataset(x: np.ndarray, y: np.ndarray, split_name: str): def build_dataloaders( k_dist, d_dist, num_samples, T, input_dim, prev_used_timestamps, batch_size ): + + """Generate train / val / test splits and wrap them in DataLoaders. + + Three independent datasets are synthesized from the same + underlying distributions so that the only source of variance is + the random sparse masking and the ordering of non-zero entries. + + Args: + k_dist: Temporal weight distributions (see + :func:`generate_distributions`). + d_dist: Feature weight distributions (see + :func:`generate_distributions`). + num_samples: Number of sequences per split. + T: Sequence length. + input_dim: Number of input features. + prev_used_timestamps: Lookback window size (*l*). + batch_size: Mini-batch size for every DataLoader. + + Returns: + A tuple ``(train_dataset, train_loader, val_loader, + test_loader)``. The raw ``train_dataset`` is also returned + because ``MixLSTM.__init__`` requires it to infer schema + metadata. + """ x_train, y_train = generate_xy(num_samples, T, input_dim, prev_used_timestamps, k_dist, d_dist) x_val, y_val = generate_xy(num_samples, T, input_dim, prev_used_timestamps, k_dist, d_dist) x_test, y_test = generate_xy(num_samples, T, input_dim, prev_used_timestamps, k_dist, d_dist) @@ -263,7 +388,24 @@ def build_dataloaders( # ────────────────────────────────────────────────────────────── def collect_predictions(model, test_loader, device): - """Run inference and return predictions + ground truth as numpy arrays.""" + """Run inference on *test_loader* and collect predictions. + + The model is set to eval mode and gradients are disabled. Only + time steps from index *l* onward (the non-placeholder region) + are retained. + + Args: + model: A trained ``MixLSTM`` model instance. + test_loader: DataLoader yielding test batches. + device: Device the model resides on. + + Returns: + A dictionary with two keys: + + * ``"pred"`` — flattened 1-D numpy array of predicted values. + * ``"y_true"`` — flattened 1-D numpy array of ground-truth + values, aligned element-wise with ``"pred"``. + """ model.eval() l = model.prev_used_timestamps preds, y_trues = [], [] @@ -298,7 +440,48 @@ def run_hyperparameter_search( learning_rate, optimizer_class=optim.Adam, ): - """Execute the hyperparameter search. Returns (results_df, best_predictions, best_model_state).""" + + """Execute a random hyperparameter search over MixLSTM configs. + + Each run samples ``k`` (number of experts) and ``hidden_size`` + uniformly from the provided lists, trains for *max_epochs*, and + records validation / test loss. The model with the lowest + validation loss is retained. + + Args: + train_data: PyHealth ``SampleDataset`` used to initialize + ``MixLSTM`` (needed for schema inference). + train_loader: DataLoader for the training split. + val_loader: DataLoader for the validation split. + test_loader: DataLoader for the test split. + device: Compute device (CPU or CUDA). + prev_used_timestamps: Lookback window size (*l*) passed to + ``MixLSTM``. + k_list: Candidate values for the number of mixture experts. + hidden_size_list: Candidate values for the LSTM hidden + dimension. + num_runs: Total number of random configurations to evaluate. + max_epochs: Training epochs per run. + learning_rate: Learning rate passed to the optimizer. + optimizer_class: PyTorch optimizer class (e.g. + ``torch.optim.Adam``). + + Returns: + A tuple ``(results_df, best_predictions, best_model_state)`` + where: + + * ``results_df`` — DataFrame with columns ``Run``, + ``k (experts)``, ``Hidden Size``, ``Val Loss``, + ``Test Loss``, ``num_params``, and ``epoch``. + * ``best_predictions`` — dictionary as returned by + :func:`collect_predictions`, augmented with ``"k"``, + ``"hidden_size"``, and ``"run"`` keys. ``None`` when no + valid model was found. + * ``best_model_state`` — CPU ``state_dict`` of the + best-performing model. ``None`` when no valid model was + found. + """ + results = [] best_val_loss_overall = np.inf best_predictions = None @@ -370,7 +553,26 @@ def run_single_ablation( optimizer_class=optim.Adam, optimizer_name: str = "Adam", ) -> AblationResult: - """Run the full search for one (optimizer, lr) combination and return all results in memory.""" + + """Run the full hyperparameter search for one (optimizer, lr) pair. + + This is the main entry point for a single ablation cell. It + seeds RNGs, generates data, builds data loaders, trains all + random-search runs, and packages the results into an + :class:`AblationResult`. + + Args: + learning_rate: Learning rate forwarded to the optimizer. + optimizer_class: PyTorch optimizer class to use (e.g. + ``torch.optim.Adam``, ``torch.optim.SGD``). + optimizer_name: Human-readable name stored in the result + object and used in plot labels. + + Returns: + An :class:`AblationResult` containing the results DataFrame, + weight distributions, best predictions, and best model state. + """ + set_seed(SEED) device = get_device() logging.getLogger("pyhealth.trainer").setLevel(logging.WARNING) @@ -421,7 +623,18 @@ def run_single_ablation( def run_all_ablations() -> list[AblationResult]: - """Run learning-rate ablations with Adam (original behaviour).""" + """Run the learning-rate sweep ablation using the Adam optimizer. + + Iterates over every learning rate in :data:`ABLATION_LRS`, runs + the full hyperparameter search for each, and prints a summary + table. + + Returns: + A list of :class:`AblationResult` objects, one per learning + rate, in the same order as :data:`ABLATION_LRS`. + """ + + ablation_results = [ run_single_ablation(lr, optim.Adam, "Adam") for lr in ABLATION_LRS ] @@ -436,17 +649,41 @@ def run_all_ablations() -> list[AblationResult]: ABLATION_OPTIMIZER_LR = 0.001 # fixed LR used for the optimizer comparison def ablations_optimizing_adam() -> AblationResult: - """Ablation: Adam optimizer at lr=0.001.""" + + """Run the Adam ablation at the fixed comparison learning rate. + + Returns: + An :class:`AblationResult` for Adam at + lr = :data:`ABLATION_OPTIMIZER_LR`. + """ + + return run_single_ablation(ABLATION_OPTIMIZER_LR, optim.Adam, "Adam") def ablations_optimizing_sgd() -> AblationResult: - """Ablation: SGD optimizer at lr=0.001.""" + + """Run the SGD ablation at the fixed comparison learning rate. + + Returns: + An :class:`AblationResult` for SGD at + lr = :data:`ABLATION_OPTIMIZER_LR`. + """ + return run_single_ablation(ABLATION_OPTIMIZER_LR, optim.SGD, "SGD") def run_optimizer_ablations() -> list[AblationResult]: - """Run Adam vs SGD at a fixed learning rate and print a comparison.""" + """Compare Adam and SGD at a fixed learning rate. + + Both optimizers are trained with + lr = :data:`ABLATION_OPTIMIZER_LR` and the results are printed + side by side. + + Returns: + A two-element list ``[adam_result, sgd_result]``. + """ + results = [ ablations_optimizing_adam(), ablations_optimizing_sgd(), @@ -456,7 +693,17 @@ def run_optimizer_ablations() -> list[AblationResult]: def _print_summary(title: str, ablation_results: list[AblationResult]): - """Pretty-print a summary table for a list of ablation results.""" + + """Pretty-print a summary table for a list of ablation results. + + For each :class:`AblationResult` the row with the lowest test + loss is selected and its key metrics are displayed. + + Args: + title: Header string printed above the table. + ablation_results: Results to summarize. + """ + summary_rows = [] for result in ablation_results: best_row = result.results_df.sort_values(by="Test Loss").iloc[0] @@ -482,7 +729,20 @@ def _print_summary(title: str, ablation_results: list[AblationResult]): # ────────────────────────────────────────────────────────────── def visualize_hyperparameter_search(ablation_results: list[AblationResult], prefix: str = ""): - """Plot MSE loss vs. hidden size for every learning rate.""" + + """Plot MSE loss vs. hidden size for every learning rate. + + Individual run results are shown as translucent scatter points + and per-hidden-size means are overlaid as solid (validation) and + dashed (test) lines. + + Args: + ablation_results: One :class:`AblationResult` per learning + rate / optimizer configuration. + prefix: String prepended to the output filename (e.g. + ``"lr_sweep_"``). + """ + print("--- 1. Analyzing Hyperparameter Search (Ablation) ---") plt.figure(figsize=(12, 7)) @@ -532,7 +792,22 @@ def visualize_hyperparameter_search(ablation_results: list[AblationResult], pref def visualize_predictions(ablation_results: list[AblationResult], num_samples: int = 3, prefix: str = ""): - """Plot predicted vs. true values for a few test samples per learning rate.""" + + """Plot predicted vs. true values for sample test sequences. + + One row of subplots is created per ablation result, with + *num_samples* columns each showing a randomly chosen test + sequence. + + Args: + ablation_results: Ablation results whose + ``best_predictions`` field will be visualized. Entries + with ``best_predictions is None`` are silently skipped. + num_samples: Number of randomly selected test sequences to + plot per ablation. + prefix: String prepended to the output filename. + """ + print("\n--- 2. Analyzing Predictions (Ablation) ---") # Only include results that have saved predictions @@ -580,7 +855,19 @@ def visualize_predictions(ablation_results: list[AblationResult], num_samples: i def visualize_synthetic_shift(ablation_results: list[AblationResult], prefix: str = ""): - """Plot the k_dist heatmap for each learning rate's task distributions.""" + """Plot the ``k_dist`` heatmap for each ablation's task distributions. + + Each subplot shows how the temporal weight distribution evolves + across the *T* time steps (x-axis) over the *l* lookback + positions (y-axis). + + Args: + ablation_results: Ablation results whose ``k_dist`` fields + will be visualized. + prefix: String prepended to the output filename. + """ + + print("\n--- 3. Analyzing Synthetic Data Shift (Ablation) ---") fig, axes = plt.subplots(1, len(ablation_results), figsize=(6 * len(ablation_results), 5)) @@ -602,7 +889,18 @@ def visualize_synthetic_shift(ablation_results: list[AblationResult], prefix: st def run_all_visualizations(ablation_results: list[AblationResult], prefix: str = ""): - """Generate all three ablation plots from in-memory results.""" + """Generate all three ablation plot types from in-memory results. + + Delegates to :func:`visualize_hyperparameter_search`, + :func:`visualize_synthetic_shift`, and + :func:`visualize_predictions`. + + Args: + ablation_results: The list of :class:`AblationResult` objects + to visualize. + prefix: Filename prefix forwarded to each plotting function. + """ + visualize_hyperparameter_search(ablation_results, prefix=prefix) visualize_synthetic_shift(ablation_results, prefix=prefix) visualize_predictions(ablation_results, prefix=prefix) From da77126fb2a66f92c0fc5bf2ea4bc1c35820cd95 Mon Sep 17 00:00:00 2001 From: Tanmay Mittal Date: Wed, 22 Apr 2026 17:01:01 -0500 Subject: [PATCH 11/19] more comments --- examples/mimic3_synthetic_mixlstm.py | 50 +++++++++++++++++++++++----- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/examples/mimic3_synthetic_mixlstm.py b/examples/mimic3_synthetic_mixlstm.py index 91bdebda0..99f3b0c06 100644 --- a/examples/mimic3_synthetic_mixlstm.py +++ b/examples/mimic3_synthetic_mixlstm.py @@ -102,18 +102,41 @@ @dataclass class AblationResult: - """Everything produced by a single ablation run.""" + """Container for every artefact produced by a single ablation run. + + Attributes: + learning_rate: The learning rate used for this ablation. + optimizer_name: Human-readable optimizer name (e.g. ``"Adam"``). + results_df: DataFrame with one row per random-search run. + Columns include ``Run``, ``k (experts)``, ``Hidden Size``, + ``Val Loss``, ``Test Loss``, ``num_params``, and ``epoch``. + k_dist: List of *T* numpy arrays representing the temporal + weight distribution at each time step. + d_dist: List of *T* numpy arrays representing the feature + weight distribution at each time step. + best_predictions: Dictionary with keys ``"pred"``, + ``"y_true"``, ``"k"``, ``"hidden_size"``, and ``"run"`` + for the model that achieved the lowest validation loss. + ``None`` if no valid model was produced. + best_model_state: ``state_dict`` (on CPU) of the best model. + ``None`` if no valid model was produced. + """ learning_rate: float optimizer_name: str results_df: pd.DataFrame k_dist: list[np.ndarray] d_dist: list[np.ndarray] - best_predictions: dict | None = None # {"pred": ..., "y_true": ..., "k": ..., "hidden_size": ...} + best_predictions: dict | None = None best_model_state: dict | None = None @property def label(self) -> str: - """Human-readable label for plots and logs.""" + """Return a human-readable label for plots and logs. + + Returns: + A string of the form ``" lr="``. + """ + return f"{self.optimizer_name} lr={self.learning_rate}" @@ -131,8 +154,8 @@ def label(self) -> str: BATCH_SIZE = 100 K_LIST = [2] HIDDEN_SIZE_LIST = [100, 150, 300, 500, 700, 900, 1100] -NUM_RUNS = 20 # 20 -MAX_EPOCHS = 30 # 30 +NUM_RUNS = 1 # 20 +MAX_EPOCHS = 2 # 30 SAVE_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -368,9 +391,15 @@ def build_dataloaders( because ``MixLSTM.__init__`` requires it to infer schema metadata. """ - x_train, y_train = generate_xy(num_samples, T, input_dim, prev_used_timestamps, k_dist, d_dist) - x_val, y_val = generate_xy(num_samples, T, input_dim, prev_used_timestamps, k_dist, d_dist) - x_test, y_test = generate_xy(num_samples, T, input_dim, prev_used_timestamps, k_dist, d_dist) + x_train, y_train = generate_xy( + num_samples, T, input_dim, prev_used_timestamps, k_dist, d_dist + ) + x_val, y_val = generate_xy( + num_samples, T, input_dim, prev_used_timestamps, k_dist, d_dist + ) + x_test, y_test = generate_xy( + num_samples, T, input_dim, prev_used_timestamps, k_dist, d_dist + ) train_data = make_dataset(x_train, y_train, "train") val_data = make_dataset(x_val, y_val, "val") @@ -840,7 +869,10 @@ def visualize_predictions(ablation_results: list[AblationResult], num_samples: i for col, sample_idx in enumerate(sample_indices): ax = axes[row][col] ax.plot(y_true[sample_idx], label="True", color="blue", marker="o", markersize=4) - ax.plot(pred[sample_idx], label="Predicted", color="red", linestyle="--", marker="x", markersize=4) + ax.plot( + pred[sample_idx], label="Predicted", + color="red", linestyle="--", marker="x", markersize=4 + ) ax.set_title(f"{tag} | Sample #{sample_idx}") ax.set_xlabel("Time Steps") ax.set_ylabel("Value") From 7d6e082f67363b549306fb334087a8b4b3ec9342 Mon Sep 17 00:00:00 2001 From: siddesh2-sys Date: Wed, 22 Apr 2026 18:26:19 -0400 Subject: [PATCH 12/19] Added type hints --- pyhealth/models/mixlstm.py | 67 +++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/pyhealth/models/mixlstm.py b/pyhealth/models/mixlstm.py index 38acd7510..0fc2479a9 100644 --- a/pyhealth/models/mixlstm.py +++ b/pyhealth/models/mixlstm.py @@ -14,6 +14,7 @@ import math from abc import ABC from collections import abc +from typing import Any, Dict, List, Optional, Tuple, Type, Union import numpy as np import torch @@ -34,7 +35,7 @@ class MLP(nn.Module): bias: Whether linear layers include a bias term. """ - def __init__(self, neuron_sizes, activation=nn.LeakyReLU, bias=True): + def __init__(self, neuron_sizes: List[int], activation: Type[nn.Module] = nn.LeakyReLU, bias: bool = True) -> None: super(MLP, self).__init__() self.neuron_sizes = neuron_sizes @@ -47,12 +48,12 @@ def __init__(self, neuron_sizes, activation=nn.LeakyReLU, bias=True): self.classifier = nn.Sequential(*layers[:-1]) - def eval_forward(self, x, y): + def eval_forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Run a forward pass in eval mode (ignores ``y``).""" self.eval() return self.forward(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """Flatten the input and pass it through the MLP.""" x = x.contiguous() x = x.view(-1, self.neuron_sizes[0]) @@ -71,7 +72,7 @@ class MoE(nn.Module): gate: The gating function that produces mixing coefficients. """ - def __init__(self, experts, gate): + def __init__(self, experts: nn.Module, gate: "Gate") -> None: super(MoE, self).__init__() self.experts = experts self.gate = gate @@ -90,13 +91,13 @@ class MoO(MoE): expert_dim: Expert dimension after stacking (default: 0). """ - def __init__(self, experts, gate, bs_dim=1, expert_dim=0): + def __init__(self, experts: nn.ModuleList, gate: "Gate", bs_dim: int = 1, expert_dim: int = 0) -> None: super(MoO, self).__init__(experts, gate) # this is for RNN architecture: bs_dim = 2 for RNN self.bs_dim = bs_dim self.expert_dim = expert_dim - def combine(self, o, coef): + def combine(self, o: List[torch.Tensor], coef: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]: """Combine expert outputs using the mixing coefficients.""" if isinstance(o[0], abc.Sequence): # account for multi_output setting return [self.combine(o_, coef) for o_ in zip(*o)] @@ -112,7 +113,7 @@ def combine(self, o, coef): res = res.transpose(self.bs_dim, -2) return res.sum(0) - def forward(self, x, coef=None): + def forward(self, x: torch.Tensor, coef: Optional[torch.Tensor] = None) -> Union[torch.Tensor, List[torch.Tensor]]: """Compute each expert's output and combine them.""" coef = self.gate(x, coef) # (bs, n_expert) or n_expert self.last_coef = coef @@ -128,7 +129,7 @@ class MoW(MoE): assembled expert per time step. """ - def forward(self, x, coef=None): + def forward(self, x: Any, coef: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Run the assembled expert on the input.""" # assume experts has already been assembled coef = self.gate(x, coef) @@ -140,7 +141,7 @@ def forward(self, x, coef=None): class Gate(ABC, nn.Module): """Abstract base class for gating functions.""" - def forward(self, x, coef=None): + def forward(self, x: Any, coef: Optional[torch.Tensor] = None) -> torch.Tensor: raise NotImplementedError() @@ -153,12 +154,12 @@ class AdaptiveLSTMGate(Gate): normalize: If True, apply softmax to the coefficients. """ - def __init__(self, input_size, num_experts, normalize=False): + def __init__(self, input_size: int, num_experts: int, normalize: bool = False) -> None: super(self.__class__, self).__init__() self.forward_function = MLP([input_size, num_experts]) self.normalize = normalize - def forward(self, x, coef=None): + def forward(self, x: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], coef: Optional[torch.Tensor] = None) -> torch.Tensor: """Produce mixing coefficients from the hidden state.""" x, (h, c) = x # h (_, bs, d) o = self.forward_function(h.transpose(0, 1)) # (bs, num_experts) @@ -178,7 +179,7 @@ class NonAdaptiveGate(Gate): normalize: If True, apply softmax to the coefficients. """ - def __init__(self, num_experts, coef=None, fixed=False, normalize=False): + def __init__(self, num_experts: int, coef: Optional[torch.Tensor] = None, fixed: bool = False, normalize: bool = False) -> None: super(self.__class__, self).__init__() self.normalize = normalize if coef is None: # initialization @@ -191,7 +192,7 @@ def __init__(self, num_experts, coef=None, fixed=False, normalize=False): self.coefficients = coef - def forward(self, x, coef=None): + def forward(self, x: Any, coef: Optional[torch.Tensor] = None) -> torch.Tensor: """Return the (optionally normalized) mixing coefficients.""" if self.normalize: return nn.functional.softmax(self.coefficients, 0) @@ -202,13 +203,13 @@ def forward(self, x, coef=None): class IDGate(Gate): """Identity gate that passes through a previous coefficient unchanged.""" - def forward(self, x, coef): + def forward(self, x: Any, coef: torch.Tensor) -> torch.Tensor: """Return the coefficient that was passed in.""" return coef ################ time series example models ################ -def moo_linear(in_features, out_features, num_experts, bs_dim=1, expert_dim=0): +def moo_linear(in_features: int, out_features: int, num_experts: int, bs_dim: int = 1, expert_dim: int = 0) -> MoO: """Create a MoO over a set of linear layers with tied shape. Args: @@ -244,8 +245,8 @@ class mowLSTM_(nn.Module): batch_first: If True, expects input shape (batch, seq_len, dim). """ - def __init__(self, input_size, hidden_size, num_experts=2, - batch_first=False): + def __init__(self, input_size: int, hidden_size: int, num_experts: int = 2, + batch_first: bool = False) -> None: super(mowLSTM_, self).__init__() @@ -272,7 +273,7 @@ def __init__(self, input_size, hidden_size, num_experts=2, # maybe: orthogonal initialization: see jeeheh's code # note: pytorch implementation does neither - def rnn_step(self, x, hidden, coef): + def rnn_step(self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor], coef: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Run a single LSTM step with mixed expert parameters.""" bs = x.shape[1] h, c = hidden @@ -289,7 +290,7 @@ def rnn_step(self, x, hidden, coef): h = outgate * torch.tanh(c) # maybe use layer norm here as well return h, c - def forward(self, x, hidden, coef): + def forward(self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor], coef: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Run the mixture LSTM over a full sequence.""" if self.batch_first: # change to seq_len first x = x.transpose(0, 1) @@ -321,9 +322,9 @@ class mowLSTM(nn.Module): activation: Optional activation applied to the final output. """ - def __init__(self, input_size, hidden_size, num_classes, num_experts=2, - num_layers=1, batch_first=False, dropout=0, - bidirectional=False, activation=None): + def __init__(self, input_size: int, hidden_size: int, num_classes: int, num_experts: int = 2, + num_layers: int = 1, batch_first: bool = False, dropout: float = 0, + bidirectional: bool = False, activation: Optional[nn.Module] = None) -> None: super(mowLSTM, self).__init__() @@ -351,7 +352,7 @@ def __init__(self, input_size, hidden_size, num_classes, num_experts=2, batch_first)) self.dropouts.append(nn.Dropout(p=dropout)) - def forward(self, x, coef): + def forward(self, x: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], coef: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Forward pass through the stacked mixture LSTM.""" x, hidden = x self.last_coef = coef @@ -394,8 +395,8 @@ class ExampleMowLSTM(nn.Module): activation: Optional output activation. """ - def __init__(self, input_size, hidden_size, num_classes, - num_layers=1, num_directions=1, dropout=0, activation=None): + def __init__(self, input_size: int, hidden_size: int, num_classes: int, + num_layers: int = 1, num_directions: int = 1, dropout: float = 0, activation: Optional[nn.Module] = None) -> None: super(ExampleMowLSTM, self).__init__() self.input_size = input_size self.hidden_size = hidden_size @@ -405,7 +406,7 @@ def __init__(self, input_size, hidden_size, num_classes, self.dropout = dropout self.activation = activation - def setKT(self, k, t): + def setKT(self, k: int, t: int) -> None: """Configure the model for ``k`` experts and ``t`` time steps. Args: @@ -432,7 +433,7 @@ def setKT(self, k, t): # normalize=True) self.cells.append(MoW(experts, gate)) - def forward(self, x, hidden): + def forward(self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Run the mixture LSTM step-by-step using the per-step gates.""" seq_len, bs, _ = x.shape o = [] @@ -444,7 +445,7 @@ def forward(self, x, hidden): return o, hidden -def orthogonal(shape): +def orthogonal(shape: Tuple[int, ...]) -> np.ndarray: """Generate an orthogonal matrix of the given shape via SVD.""" flat_shape = (int(shape[0]), int(np.prod(shape[1:]))) a = np.random.normal(0.0, 1.0, flat_shape) @@ -453,7 +454,7 @@ def orthogonal(shape): return q.reshape(shape) -def lstm_ortho_initializer(shape, scale=1.0): +def lstm_ortho_initializer(shape: Tuple[int, ...], scale: float = 1.0) -> np.ndarray: """Initialize LSTM weights with orthogonal blocks for each of the 4 gates. Args: @@ -550,8 +551,8 @@ class MixLSTM(BaseModel): >>> model = MixLSTM(dataset=dataset, num_experts=2, hidden_size=64) """ - def __init__(self, dataset: SampleDataset, num_experts=2, hidden_size=100, - prev_used_timestamps=0): + def __init__(self, dataset: SampleDataset, num_experts: int = 2, hidden_size: int = 100, + prev_used_timestamps: int = 0) -> None: super(MixLSTM, self).__init__(dataset) # Identify primary input key and infer shape @@ -600,7 +601,7 @@ def __init__(self, dataset: SampleDataset, num_experts=2, hidden_size=100, self.hidden_size = hidden_size self.model.setKT(num_experts, self.time_steps) - def forward(self, **kwargs): + def forward(self, **kwargs: Any) -> Dict[str, torch.Tensor]: """Run a forward pass. Expects the input tensor under ``kwargs[self.input_key]`` with @@ -673,6 +674,6 @@ def forward(self, **kwargs): return results - def after_backward(self): + def after_backward(self) -> None: """Hook called after backward(); no-op for this model.""" pass From 271176dd9d94efd2fd0d6dcf40352270dd899b05 Mon Sep 17 00:00:00 2001 From: Tanmay Mittal Date: Wed, 22 Apr 2026 19:12:26 -0500 Subject: [PATCH 13/19] type hints --- examples/mimic3_synthetic_mixlstm.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/mimic3_synthetic_mixlstm.py b/examples/mimic3_synthetic_mixlstm.py index 99f3b0c06..1f3859d13 100644 --- a/examples/mimic3_synthetic_mixlstm.py +++ b/examples/mimic3_synthetic_mixlstm.py @@ -154,8 +154,8 @@ def label(self) -> str: BATCH_SIZE = 100 K_LIST = [2] HIDDEN_SIZE_LIST = [100, 150, 300, 500, 700, 900, 1100] -NUM_RUNS = 1 # 20 -MAX_EPOCHS = 2 # 30 +NUM_RUNS = 20 # 20 +MAX_EPOCHS = 30 # 30 SAVE_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -329,7 +329,7 @@ def generate_xy( # PyHealth dataset helpers # ────────────────────────────────────────────────────────────── -def make_dataset(x: np.ndarray, y: np.ndarray, split_name: str): +def make_dataset(x: np.ndarray, y: np.ndarray, split_name: str) -> "SampleDataset": """Wrap numpy arrays into a PyHealth ``SampleDataset``. Each sequence is registered as a separate patient with a single @@ -366,7 +366,7 @@ def make_dataset(x: np.ndarray, y: np.ndarray, split_name: str): def build_dataloaders( k_dist, d_dist, num_samples, T, input_dim, prev_used_timestamps, batch_size -): +) -> tuple["SampleDataset", torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]: """Generate train / val / test splits and wrap them in DataLoaders. @@ -416,7 +416,7 @@ def build_dataloaders( # Training & evaluation # ────────────────────────────────────────────────────────────── -def collect_predictions(model, test_loader, device): +def collect_predictions(model, test_loader, device) -> dict[str, np.ndarray]: """Run inference on *test_loader* and collect predictions. The model is set to eval mode and gradients are disabled. Only @@ -468,7 +468,7 @@ def run_hyperparameter_search( max_epochs, learning_rate, optimizer_class=optim.Adam, -): +) -> tuple[pd.DataFrame, dict | None, dict | None]: """Execute a random hyperparameter search over MixLSTM configs. @@ -721,7 +721,7 @@ def run_optimizer_ablations() -> list[AblationResult]: return results -def _print_summary(title: str, ablation_results: list[AblationResult]): +def _print_summary(title: str, ablation_results: list[AblationResult])-> None: """Pretty-print a summary table for a list of ablation results. @@ -757,7 +757,7 @@ def _print_summary(title: str, ablation_results: list[AblationResult]): # Visualization (all functions take in-memory data) # ────────────────────────────────────────────────────────────── -def visualize_hyperparameter_search(ablation_results: list[AblationResult], prefix: str = ""): +def visualize_hyperparameter_search(ablation_results: list[AblationResult], prefix: str = "") -> None: """Plot MSE loss vs. hidden size for every learning rate. @@ -820,7 +820,7 @@ def visualize_hyperparameter_search(ablation_results: list[AblationResult], pref print(f" Saved: {out_path}") -def visualize_predictions(ablation_results: list[AblationResult], num_samples: int = 3, prefix: str = ""): +def visualize_predictions(ablation_results: list[AblationResult], num_samples: int = 3, prefix: str = "") -> None: """Plot predicted vs. true values for sample test sequences. @@ -886,7 +886,7 @@ def visualize_predictions(ablation_results: list[AblationResult], num_samples: i print(f" Saved: {out_path}") -def visualize_synthetic_shift(ablation_results: list[AblationResult], prefix: str = ""): +def visualize_synthetic_shift(ablation_results: list[AblationResult], prefix: str = "") -> None: """Plot the ``k_dist`` heatmap for each ablation's task distributions. Each subplot shows how the temporal weight distribution evolves @@ -920,7 +920,7 @@ def visualize_synthetic_shift(ablation_results: list[AblationResult], prefix: st print(f" Saved: {out_path}") -def run_all_visualizations(ablation_results: list[AblationResult], prefix: str = ""): +def run_all_visualizations(ablation_results: list[AblationResult], prefix: str = "") -> None: """Generate all three ablation plot types from in-memory results. Delegates to :func:`visualize_hyperparameter_search`, @@ -942,7 +942,7 @@ def run_all_visualizations(ablation_results: list[AblationResult], prefix: str = # Main # ────────────────────────────────────────────────────────────── -def main(): +def main() -> None: # Learning-rate sweep (Adam only, multiple LRs) lr_results = run_all_ablations() run_all_visualizations(lr_results, prefix="lr_sweep_") From 00769daf118f3a65dbd2fb5263f4f2782d30ba1e Mon Sep 17 00:00:00 2001 From: Tanmay Mittal Date: Wed, 22 Apr 2026 19:39:12 -0500 Subject: [PATCH 14/19] file path fixed --- examples/mimic3_synthetic_mixlstm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/mimic3_synthetic_mixlstm.py b/examples/mimic3_synthetic_mixlstm.py index 1f3859d13..24d9909f0 100644 --- a/examples/mimic3_synthetic_mixlstm.py +++ b/examples/mimic3_synthetic_mixlstm.py @@ -154,10 +154,10 @@ def label(self) -> str: BATCH_SIZE = 100 K_LIST = [2] HIDDEN_SIZE_LIST = [100, 150, 300, 500, 700, 900, 1100] -NUM_RUNS = 20 # 20 -MAX_EPOCHS = 30 # 30 +NUM_RUNS = 20 # default set to 20 +MAX_EPOCHS = 30 # default set to 30 -SAVE_DIR = os.path.dirname(os.path.abspath(__file__)) +SAVE_DIR = "." # Visualization MAX_MSE = 100 From dbb2160465bf27694e0c159c9c94d716aafa5330 Mon Sep 17 00:00:00 2001 From: Tanmay Mittal Date: Wed, 22 Apr 2026 19:40:14 -0500 Subject: [PATCH 15/19] do a pip install seaborn=0.13.2 --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 934d4f1bb..f0da85254 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "more-itertools~=10.8.0", "einops>=0.8.0", "linear-attention-transformer>=0.19.1", + "seaborn~=0.13.2", ] license = "BSD-3-Clause" license-files = ["LICENSE.md"] From 39efecc62c60e07dc88813b0e7c2fc6a4437fb36 Mon Sep 17 00:00:00 2001 From: siddesh2-sys Date: Wed, 22 Apr 2026 20:48:46 -0400 Subject: [PATCH 16/19] Cleaned up comments --- pyhealth/models/mixlstm.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/pyhealth/models/mixlstm.py b/pyhealth/models/mixlstm.py index 0fc2479a9..5289e97ff 100644 --- a/pyhealth/models/mixlstm.py +++ b/pyhealth/models/mixlstm.py @@ -269,16 +269,12 @@ def __init__(self, input_size: int, hidden_size: int, num_experts: int = 2, for name, weight in m.named_parameters(): nn.init.uniform_(weight, -stdv, stdv) - # maybe: layer normalization: see jeeheh's code - # maybe: orthogonal initialization: see jeeheh's code - # note: pytorch implementation does neither def rnn_step(self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor], coef: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Run a single LSTM step with mixed expert parameters.""" bs = x.shape[1] h, c = hidden gates = self.input_weights(x, coef) + self.hidden_weights(h, coef) - # maybe: layer normalization: see jeeheh's code ingate, forgetgate, cellgate, outgate = gates.view(bs, -1).chunk(4, 1) ingate = torch.sigmoid(ingate) @@ -426,11 +422,6 @@ def setKT(self, k: int, t: int) -> None: for _ in range(t): gate = NonAdaptiveGate(self.k, normalize=True) - # gate = AdaptiveLSTMGate(self.hidden_size *\ - # self.num_layers *\ - # self.num_directions, - # self.k, - # normalize=True) self.cells.append(MoW(experts, gate)) def forward(self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: @@ -647,7 +638,7 @@ def forward(self, **kwargs: Any) -> Dict[str, torch.Tensor]: logits_seq = outputs.permute(1, 0, 2) if self._per_timestep: - # --- Per-timestep regression (original MLHC2019 synthetic task) + # --- Per-timestep regression mode --- results = {"logit": logits_seq, "y_prob": logits_seq} if self.label_key and self.label_key in kwargs: y_true = kwargs[self.label_key].to(device) From d7acb61d653950106822114e29926bdcc745f600 Mon Sep 17 00:00:00 2001 From: Tanmay Mittal Date: Wed, 22 Apr 2026 20:11:41 -0500 Subject: [PATCH 17/19] added type hints --- tests/test_mixlstm.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/test_mixlstm.py b/tests/test_mixlstm.py index 9adf68e2a..5b9b083ab 100644 --- a/tests/test_mixlstm.py +++ b/tests/test_mixlstm.py @@ -11,7 +11,7 @@ class TestMixLSTMRegression(unittest.TestCase): """Test MixLSTM in per-timestep regression mode (MLHC2019 synthetic task).""" - def setUp(self): + def setUp(self) -> None: """Set up small synthetic regression dataset and model.""" self.tmp_dir = tempfile.mkdtemp() @@ -57,7 +57,7 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.tmp_dir, ignore_errors=True) - def test_instantiation(self): + def test_instantiation(self) -> None: """Test that model initializes with correct attributes.""" self.assertIsInstance(self.model, MixLSTM) self.assertTrue(self.model._per_timestep) @@ -66,7 +66,7 @@ def test_instantiation(self): self.assertEqual(self.model.hidden_size, 16) self.assertEqual(self.model.prev_used_timestamps, 3) - def test_forward_output_keys(self): + def test_forward_output_keys(self) -> None: """Test that forward returns expected keys for regression.""" loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) batch = next(iter(loader)) @@ -79,7 +79,7 @@ def test_forward_output_keys(self): self.assertIn("y_prob", ret) self.assertIn("y_true", ret) - def test_forward_output_shapes(self): + def test_forward_output_shapes(self) -> None: """Test output tensor shapes for per-timestep regression.""" loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) batch = next(iter(loader)) @@ -97,7 +97,7 @@ def test_forward_output_shapes(self): # loss is scalar self.assertEqual(ret["loss"].dim(), 0) - def test_forward_no_labels(self): + def test_forward_no_labels(self) -> None: """Test forward without labels returns logit/y_prob but no loss.""" loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) batch = next(iter(loader)) @@ -112,7 +112,7 @@ def test_forward_no_labels(self): self.assertNotIn("loss", ret) self.assertNotIn("y_true", ret) - def test_backward_gradients(self): + def test_backward_gradients(self) -> None: """Test that loss.backward() produces gradients on all trainable params.""" loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) batch = next(iter(loader)) @@ -126,7 +126,7 @@ def test_backward_gradients(self): ) self.assertTrue(has_gradient, "No parameters received gradients") - def test_loss_is_finite(self): + def test_loss_is_finite(self) -> None: """Test that loss is finite (not NaN or Inf).""" loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) batch = next(iter(loader)) @@ -136,7 +136,7 @@ def test_loss_is_finite(self): self.assertTrue(torch.isfinite(ret["loss"]).item(), "Loss is not finite") - def test_custom_hyperparameters(self): + def test_custom_hyperparameters(self) -> None: """Test model with different num_experts and hidden_size.""" model = MixLSTM( dataset=self.dataset, @@ -157,7 +157,7 @@ def test_custom_hyperparameters(self): class TestMixLSTMClassification(unittest.TestCase): """Test MixLSTM in classification mode (standard PyHealth label task).""" - def setUp(self): + def setUp(self) -> None: """Set up small synthetic classification dataset and model.""" self.tmp_dir = tempfile.mkdtemp() @@ -194,13 +194,13 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.tmp_dir, ignore_errors=True) - def test_instantiation(self): + def test_instantiation(self) -> None: """Test that classification model initializes correctly.""" self.assertIsInstance(self.model, MixLSTM) self.assertFalse(self.model._per_timestep) self.assertEqual(self.model.mode, "multiclass") - def test_forward_output_keys(self): + def test_forward_output_keys(self) -> None: """Test that forward returns expected keys for classification.""" loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) batch = next(iter(loader)) @@ -213,7 +213,7 @@ def test_forward_output_keys(self): self.assertIn("y_prob", ret) self.assertIn("y_true", ret) - def test_forward_output_shapes(self): + def test_forward_output_shapes(self) -> None: """Test output tensor shapes for classification (last timestep).""" loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) batch = next(iter(loader)) @@ -232,7 +232,7 @@ def test_forward_output_shapes(self): # loss is scalar self.assertEqual(ret["loss"].dim(), 0) - def test_forward_no_labels(self): + def test_forward_no_labels(self) -> None: """Test forward without labels returns logit/y_prob but no loss.""" loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) batch = next(iter(loader)) @@ -246,7 +246,7 @@ def test_forward_no_labels(self): self.assertNotIn("loss", ret) self.assertNotIn("y_true", ret) - def test_backward_gradients(self): + def test_backward_gradients(self) -> None: """Test that loss.backward() produces gradients.""" loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) batch = next(iter(loader)) @@ -260,7 +260,7 @@ def test_backward_gradients(self): ) self.assertTrue(has_gradient, "No parameters received gradients") - def test_y_prob_sums_to_one(self): + def test_y_prob_sums_to_one(self) -> None: """Test that y_prob (softmax) sums to ~1 for each sample.""" loader = get_dataloader(self.dataset, batch_size=self.batch_size, shuffle=False) batch = next(iter(loader)) From 8562c1f1bb006f1bae44aef13847bdadede97b45 Mon Sep 17 00:00:00 2001 From: Tanmay Mittal Date: Wed, 22 Apr 2026 20:18:42 -0500 Subject: [PATCH 18/19] type hints 2 --- tests/test_mixlstm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_mixlstm.py b/tests/test_mixlstm.py index 5b9b083ab..3d9e885d0 100644 --- a/tests/test_mixlstm.py +++ b/tests/test_mixlstm.py @@ -54,7 +54,7 @@ def setUp(self) -> None: ) self.batch_size = 4 - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.tmp_dir, ignore_errors=True) def test_instantiation(self) -> None: @@ -191,7 +191,7 @@ def setUp(self) -> None: ) self.batch_size = 4 - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.tmp_dir, ignore_errors=True) def test_instantiation(self) -> None: From 6e83f5d270f48319e32f29841f57be4a87c98ef0 Mon Sep 17 00:00:00 2001 From: Tanmay Mittal Date: Wed, 22 Apr 2026 20:48:08 -0500 Subject: [PATCH 19/19] comments --- examples/mimic3_synthetic_mixlstm.py | 4 + pyhealth/models/mixlstm.py | 235 +++++++++++++++++++++++---- 2 files changed, 208 insertions(+), 31 deletions(-) diff --git a/examples/mimic3_synthetic_mixlstm.py b/examples/mimic3_synthetic_mixlstm.py index 24d9909f0..5900f2ad7 100644 --- a/examples/mimic3_synthetic_mixlstm.py +++ b/examples/mimic3_synthetic_mixlstm.py @@ -30,6 +30,10 @@ # Synthetic time-series regression task with PyHealth # ====================================================================== # +# REQUIREMENTS +# Please do pip install seaborn~=0.13.2 to make sure the graphs are displyed and the abalation study +# runs smoothly. Added pip install seaborn~=0.13.2 in pyproject.toml file +# # EXPERIMENTAL SETUP # ------------------ # Dataset: Synthetic non-stationary time-series regression. 1,000 sequences diff --git a/pyhealth/models/mixlstm.py b/pyhealth/models/mixlstm.py index 5289e97ff..f63b1a001 100644 --- a/pyhealth/models/mixlstm.py +++ b/pyhealth/models/mixlstm.py @@ -35,7 +35,10 @@ class MLP(nn.Module): bias: Whether linear layers include a bias term. """ - def __init__(self, neuron_sizes: List[int], activation: Type[nn.Module] = nn.LeakyReLU, bias: bool = True) -> None: + def __init__( + self, neuron_sizes: List[int], + activation: Type[nn.Module] = nn.LeakyReLU, bias: bool = True + ) -> None: super(MLP, self).__init__() self.neuron_sizes = neuron_sizes @@ -91,14 +94,30 @@ class MoO(MoE): expert_dim: Expert dimension after stacking (default: 0). """ - def __init__(self, experts: nn.ModuleList, gate: "Gate", bs_dim: int = 1, expert_dim: int = 0) -> None: + def __init__(self, experts: nn.ModuleList, gate: "Gate", + bs_dim: int = 1, expert_dim: int = 0 + ) -> None: super(MoO, self).__init__(experts, gate) # this is for RNN architecture: bs_dim = 2 for RNN self.bs_dim = bs_dim self.expert_dim = expert_dim - def combine(self, o: List[torch.Tensor], coef: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]: - """Combine expert outputs using the mixing coefficients.""" + def combine( + self, o: List[torch.Tensor], coef: torch.Tensor + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + """Combine expert outputs using the mixing coefficients. + + Args: + o: List of expert output tensors. + coef: Mixing coefficient tensor of shape + ``(batch, num_experts)``. + + Returns: + Weighted sum of expert outputs, or a list of such + sums if experts return multi-output tuples. + """ + if isinstance(o[0], abc.Sequence): # account for multi_output setting return [self.combine(o_, coef) for o_ in zip(*o)] else: @@ -113,8 +132,20 @@ def combine(self, o: List[torch.Tensor], coef: torch.Tensor) -> Union[torch.Tens res = res.transpose(self.bs_dim, -2) return res.sum(0) - def forward(self, x: torch.Tensor, coef: Optional[torch.Tensor] = None) -> Union[torch.Tensor, List[torch.Tensor]]: - """Compute each expert's output and combine them.""" + def forward( + self, x: torch.Tensor, coef: Optional[torch.Tensor] = None + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + """Compute each expert's output and combine them. + + Args: + x: Input tensor. + coef: Optional pre-computed mixing coefficients. + + Returns: + Combined expert output tensor. + """ + coef = self.gate(x, coef) # (bs, n_expert) or n_expert self.last_coef = coef o = [expert(x) for expert in self.experts] @@ -129,8 +160,20 @@ class MoW(MoE): assembled expert per time step. """ - def forward(self, x: Any, coef: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Run the assembled expert on the input.""" + def forward( + self, x: Any, coef: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + """Run the assembled expert on the input. + + Args: + x: Input tensor (or tuple of tensor and hidden state). + coef: Optional pre-computed mixing coefficients. + + Returns: + Tuple of output tensor and new hidden state. + """ + # assume experts has already been assembled coef = self.gate(x, coef) self.last_coef = coef @@ -141,7 +184,21 @@ def forward(self, x: Any, coef: Optional[torch.Tensor] = None) -> Tuple[torch.Te class Gate(ABC, nn.Module): """Abstract base class for gating functions.""" - def forward(self, x: Any, coef: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, x: Any, coef: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Produce mixing coefficients from the input. + + Args: + x: Input data (format depends on the subclass). + coef: Optional pre-computed coefficient tensor. + + Returns: + Mixing coefficient tensor. + + Raises: + NotImplementedError: Always; subclasses must override. + """ raise NotImplementedError() @@ -154,13 +211,28 @@ class AdaptiveLSTMGate(Gate): normalize: If True, apply softmax to the coefficients. """ - def __init__(self, input_size: int, num_experts: int, normalize: bool = False) -> None: + def __init__( + self, input_size: int, num_experts: int, normalize: bool = False + ) -> None: super(self.__class__, self).__init__() self.forward_function = MLP([input_size, num_experts]) self.normalize = normalize - def forward(self, x: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], coef: Optional[torch.Tensor] = None) -> torch.Tensor: - """Produce mixing coefficients from the hidden state.""" + def forward( + self, x: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + coef: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Produce mixing coefficients from the hidden state. + + Args: + x: Tuple of ``(input, (h, c))`` where ``h`` is the + hidden state used to compute coefficients. + coef: Ignored; kept for interface compatibility. + + Returns: + Mixing coefficients of shape ``(batch, num_experts)``. + """ + x, (h, c) = x # h (_, bs, d) o = self.forward_function(h.transpose(0, 1)) # (bs, num_experts) if self.normalize: @@ -179,7 +251,10 @@ class NonAdaptiveGate(Gate): normalize: If True, apply softmax to the coefficients. """ - def __init__(self, num_experts: int, coef: Optional[torch.Tensor] = None, fixed: bool = False, normalize: bool = False) -> None: + def __init__( + self, num_experts: int, coef: Optional[torch.Tensor] = None, + fixed: bool = False, normalize: bool = False + ) -> None: super(self.__class__, self).__init__() self.normalize = normalize if coef is None: # initialization @@ -192,8 +267,19 @@ def __init__(self, num_experts: int, coef: Optional[torch.Tensor] = None, fixed: self.coefficients = coef - def forward(self, x: Any, coef: Optional[torch.Tensor] = None) -> torch.Tensor: - """Return the (optionally normalized) mixing coefficients.""" + def forward( + self, x: Any, coef: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Return the (optionally normalized) mixing coefficients. + + Args: + x: Ignored; kept for interface compatibility. + coef: Ignored; kept for interface compatibility. + + Returns: + Mixing coefficient tensor of shape ``(num_experts,)``. + """ + if self.normalize: return nn.functional.softmax(self.coefficients, 0) else: @@ -204,12 +290,24 @@ class IDGate(Gate): """Identity gate that passes through a previous coefficient unchanged.""" def forward(self, x: Any, coef: torch.Tensor) -> torch.Tensor: - """Return the coefficient that was passed in.""" + """Return the coefficient that was passed in. + + Args: + x: Ignored. + coef: Pre-computed mixing coefficients. + + Returns: + ``coef`` unchanged. + """ + return coef ################ time series example models ################ -def moo_linear(in_features: int, out_features: int, num_experts: int, bs_dim: int = 1, expert_dim: int = 0) -> MoO: +def moo_linear( + in_features: int, out_features: int, + num_experts: int, bs_dim: int = 1, expert_dim: int = 0 + ) -> MoO: """Create a MoO over a set of linear layers with tied shape. Args: @@ -270,8 +368,23 @@ def __init__(self, input_size: int, hidden_size: int, num_experts: int = 2, nn.init.uniform_(weight, -stdv, stdv) - def rnn_step(self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor], coef: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """Run a single LSTM step with mixed expert parameters.""" + def rnn_step( + self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor], + coef: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Run a single LSTM step with mixed expert parameters. + + Args: + x: Input tensor for this time step, shape + ``(1, batch, input_size)``. + hidden: Tuple ``(h, c)`` of previous hidden and cell + states. + coef: Mixing coefficients for the experts. + + Returns: + Tuple ``(h, c)`` of updated hidden and cell states. + """ + bs = x.shape[1] h, c = hidden gates = self.input_weights(x, coef) + self.hidden_weights(h, coef) @@ -286,8 +399,25 @@ def rnn_step(self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor], c h = outgate * torch.tanh(c) # maybe use layer norm here as well return h, c - def forward(self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor], coef: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Run the mixture LSTM over a full sequence.""" + def forward( + self, x: torch.Tensor, + hidden: Tuple[torch.Tensor, torch.Tensor], coef: torch.Tensor + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + """Run the mixture LSTM over a full sequence. + + Args: + x: Input tensor of shape + ``(seq_len, batch, input_size)`` (or transposed + if ``batch_first``). + hidden: Tuple ``(h, c)`` of initial hidden/cell states. + coef: Mixing coefficients for the experts. + + Returns: + Tuple of ``(output, (h, c))`` where output has shape + ``(seq_len, batch, hidden_size)``. + """ + if self.batch_first: # change to seq_len first x = x.transpose(0, 1) @@ -318,9 +448,12 @@ class mowLSTM(nn.Module): activation: Optional activation applied to the final output. """ - def __init__(self, input_size: int, hidden_size: int, num_classes: int, num_experts: int = 2, - num_layers: int = 1, batch_first: bool = False, dropout: float = 0, - bidirectional: bool = False, activation: Optional[nn.Module] = None) -> None: + def __init__( + self, input_size: int, hidden_size: int, num_classes: int, + num_experts: int = 2, num_layers: int = 1, + batch_first: bool = False, dropout: float = 0, + bidirectional: bool = False, activation: Optional[nn.Module] = None + ) -> None: super(mowLSTM, self).__init__() @@ -348,8 +481,22 @@ def __init__(self, input_size: int, hidden_size: int, num_classes: int, num_expe batch_first)) self.dropouts.append(nn.Dropout(p=dropout)) - def forward(self, x: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], coef: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Forward pass through the stacked mixture LSTM.""" + def forward( + self, x: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + coef: torch.Tensor + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + """Forward pass through the stacked mixture LSTM. + + Args: + x: Tuple of ``(input_tensor, (h, c))``. + coef: Mixing coefficients for the experts. + + Returns: + Tuple of ``(output, (h, c))`` where output has shape + ``(seq_len, batch, num_classes)``. + """ + x, hidden = x self.last_coef = coef @@ -404,10 +551,14 @@ def __init__(self, input_size: int, hidden_size: int, num_classes: int, def setKT(self, k: int, t: int) -> None: """Configure the model for ``k`` experts and ``t`` time steps. - + Args: k: Number of expert cells to mix. - t: Maximum number of time steps; one gate is created per step. + t: Maximum number of time steps; one gate is created + per step. + + Raises: + ValueError: If ``k < 1`` or ``t < 1``. """ self.k = k self.T = t @@ -424,8 +575,22 @@ def setKT(self, k: int, t: int) -> None: gate = NonAdaptiveGate(self.k, normalize=True) self.cells.append(MoW(experts, gate)) - def forward(self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Run the mixture LSTM step-by-step using the per-step gates.""" + def forward( + self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + """Run the mixture LSTM step-by-step using per-step gates. + + Args: + x: Input tensor of shape + ``(seq_len, batch, input_size)``. + hidden: Tuple ``(h, c)`` of initial hidden/cell states. + + Returns: + Tuple of ``(output, (h, c))`` where output has shape + ``(seq_len, batch, num_classes)``. + """ + seq_len, bs, _ = x.shape o = [] for t in range(seq_len): @@ -437,7 +602,15 @@ def forward(self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor]) -> def orthogonal(shape: Tuple[int, ...]) -> np.ndarray: - """Generate an orthogonal matrix of the given shape via SVD.""" + """Generate an orthogonal matrix of the given shape via SVD. + + Args: + shape: Target shape for the orthogonal matrix. + + Returns: + A numpy array with orthogonal rows/columns. + """ + flat_shape = (int(shape[0]), int(np.prod(shape[1:]))) a = np.random.normal(0.0, 1.0, flat_shape) u, _, v = np.linalg.svd(a, full_matrices=False)