diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..918461103 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -30,6 +30,9 @@ routes each feature type automatically. * - :doc:`models/pyhealth.models.RNN` - Your features are sequences of medical codes (diagnoses, procedures, drugs) across visits - One RNN per feature, hidden states concatenated; ``rnn_type`` can be ``"GRU"`` (default) or ``"LSTM"`` + * - :doc:`models/pyhealth.models.ShiftLSTM` + - Your time-series task may exhibit phase-specific or time-varying relationships + - Uses segment-specific LSTM parameters; ``num_segments=1`` acts as the shared-parameter baseline * - :doc:`models/pyhealth.models.Transformer` - You have longer code histories and want attention to capture long-range dependencies - Self-attention across the sequence; tends to work well when visit order matters @@ -174,6 +177,7 @@ API Reference models/pyhealth.models.MLP models/pyhealth.models.CNN models/pyhealth.models.RNN + models/pyhealth.models.ShiftLSTM models/pyhealth.models.GNN models/pyhealth.models.Transformer models/pyhealth.models.TransformersModel diff --git a/docs/api/models/pyhealth.models.ShiftLSTM.rst b/docs/api/models/pyhealth.models.ShiftLSTM.rst new file mode 100644 index 000000000..ce908415a --- /dev/null +++ b/docs/api/models/pyhealth.models.ShiftLSTM.rst @@ -0,0 +1,26 @@ +pyhealth.models.ShiftLSTM +======================================== + + +The segment-wise recurrent layer and the complete ShiftLSTM model. + +``ShiftLSTM`` relaxes parameter sharing over time by dividing the sequence +into ``K`` temporal segments. Each segment uses its own ``LSTMCell`` while the +hidden and cell states continue flowing through the full sequence. When +``num_segments=1``, the model reduces to the shared-parameter baseline. + +This implementation is inspired by: + + Oh, J., Wang, J., Wiens, J. (2019). + "Relaxed Parameter Sharing: Effectively Modeling Time-Varying Relationships + in Clinical Time-Series." + +.. autoclass:: pyhealth.models.ShiftLSTMLayer + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.ShiftLSTM + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/synthetic/shift_lstm_synthetic_data.py b/examples/synthetic/shift_lstm_synthetic_data.py new file mode 100644 index 000000000..f1d944305 --- /dev/null +++ b/examples/synthetic/shift_lstm_synthetic_data.py @@ -0,0 +1,361 @@ +"""Synthetic data generation for shiftLSTM ablation experiments. + +This script follows Section 4.1 of: +"Relaxed Parameter Sharing: Effectively Modeling Time-Varying Relationships +in Clinical Time-Series" (Oh et al., 2019). + +It also borrows the high-level structure of the authors' released synthetic +data generator that maintains: + - sparse input sequences X in R^{N x T x d} + - time-varying temporal weights over the previous l timesteps + - time-varying feature weights over d dimensions + - smooth drift controlled by delta + +Compared with the original prototype-style script, this version: + - is CPU-friendly + - is reproducible through a random seed + - can export NumPy arrays + - can convert synthetic arrays into PyHealth SampleDataset-ready samples +""" + +from __future__ import annotations + +import argparse +import json +from dataclasses import asdict, dataclass +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Iterable, Optional + +import numpy as np + + +@dataclass +class SyntheticConfig: + """Configuration for the synthetic generator. + + Args: + N: Number of samples. + T: Sequence length. + d: Number of input features. + l: Lookback window used to generate targets. The paper uses l=10. + delta: Drift magnitude controlling how much the temporal/feature + distributions can change between adjacent timesteps. + sparsity: Probability that an input entry is non-zero. + value_low: Lower bound for non-zero values. + value_high: Upper bound for non-zero values. + seed: Random seed for reproducibility. + """ + + N: int = 1000 + T: int = 30 + d: int = 3 + l: int = 10 + delta: float = 0.2 + sparsity: float = 0.1 + value_low: float = 0.0 + value_high: float = 100.0 + seed: int = 42 + + def validate(self) -> None: + if self.T <= self.l: + raise ValueError("T must be greater than l for target generation.") + if not 0.0 < self.sparsity <= 1.0: + raise ValueError("sparsity must be in (0, 1].") + if self.delta < 0: + raise ValueError("delta must be non-negative.") + if self.value_high <= self.value_low: + raise ValueError("value_high must be greater than value_low.") + + +def normalize_to_distribution(values: np.ndarray) -> np.ndarray: + """Projects a vector to a probability distribution. + + The authors' script first min-max normalizes, then renormalizes to sum to 1. + We retain that behavior but add guards for constant vectors. + """ + + values = np.asarray(values, dtype=float) + vmin = float(np.min(values)) + vmax = float(np.max(values)) + if np.isclose(vmax, vmin): + return np.full_like(values, fill_value=1.0 / len(values), dtype=float) + + scaled = (values - vmin) / (vmax - vmin) + total = float(np.sum(scaled)) + if np.isclose(total, 0.0): + return np.full_like(values, fill_value=1.0 / len(values), dtype=float) + return scaled / total + + +def generate_sparse_inputs(config: SyntheticConfig, rng: np.random.Generator) -> np.ndarray: + """Generates sparse inputs X of shape (N, T, d). + + Following the paper, each entry is active with probability 0.1 by default, + and active values are sampled uniformly on [0, 100]. + """ + + active = rng.binomial(1, config.sparsity, size=(config.N, config.T, config.d)) + values = rng.uniform( + low=config.value_low, + high=config.value_high, + size=(config.N, config.T, config.d), + ) + return active * values + + +def generate_weight_trajectories( + config: SyntheticConfig, + rng: np.random.Generator, + k_dist: Optional[list[np.ndarray]] = None, + d_dist: Optional[list[np.ndarray]] = None, +) -> tuple[list[np.ndarray], list[np.ndarray]]: + """Generates time-varying temporal and feature distributions. + + At timestep t >= l: + - k_dist[t] is a distribution over the previous l timesteps + - d_dist[t] is a distribution over the d feature dimensions + + Distances evolve smoothly with additive perturbations bounded by delta. + """ + + if k_dist is not None and d_dist is not None: + return k_dist, d_dist + + temporal_weights: list[np.ndarray] = [] + feature_weights: list[np.ndarray] = [] + + for t in range(config.T): + if t < config.l: + temporal_weights.append(np.ones(config.l, dtype=float) / config.l) + feature_weights.append(np.ones(config.d, dtype=float) / config.d) + continue + + if t == config.l: + temporal_weights.append( + normalize_to_distribution(rng.uniform(size=config.l)) + ) + feature_weights.append( + normalize_to_distribution(rng.uniform(size=config.d)) + ) + continue + + delta_t = rng.uniform(-config.delta, config.delta, size=config.l) + delta_d = rng.uniform(-config.delta, config.delta, size=config.d) + temporal_weights.append( + normalize_to_distribution(temporal_weights[t - 1] + delta_t) + ) + feature_weights.append( + normalize_to_distribution(feature_weights[t - 1] + delta_d) + ) + + return temporal_weights, feature_weights + + +def generate_targets( + x: np.ndarray, + config: SyntheticConfig, + k_dist: list[np.ndarray], + d_dist: list[np.ndarray], +) -> np.ndarray: + """Generates regression targets Y of shape (N, T, 1). + + For timestep t >= l, the target is formed by: + 1. combining feature dimensions at each previous timestep using d_dist[t] + 2. combining the resulting l-length history using k_dist[t] + """ + + y = np.ones((config.N, config.T, 1), dtype=float) + for t in range(config.l, config.T): + history = x[:, t - config.l : t, :] # (N, l, d) + # Use explicit weighted sums instead of np.matmul to avoid spurious + # BLAS-level overflow warnings on some local NumPy builds. + feature_agg = np.einsum("nld,d->nl", history, d_dist[t]) # (N, l) + y[:, t, 0] = np.einsum("nl,l->n", feature_agg, k_dist[t]) # (N,) + return y + + +def generate_synthetic_arrays( + config: SyntheticConfig, + k_dist: Optional[list[np.ndarray]] = None, + d_dist: Optional[list[np.ndarray]] = None, +) -> dict[str, Any]: + """Generates a full synthetic dataset bundle. + + Returns: + A dictionary containing X, Y, k_dist, d_dist, and config. + """ + + config.validate() + rng = np.random.default_rng(config.seed) + + x = generate_sparse_inputs(config, rng) + k_dist, d_dist = generate_weight_trajectories(config, rng, k_dist, d_dist) + y = generate_targets(x, config, k_dist, d_dist) + + return { + "x": x.astype(np.float32), + "y": y.astype(np.float32), + "k_dist": np.asarray(k_dist, dtype=np.float32), + "d_dist": np.asarray(d_dist, dtype=np.float32), + "config": asdict(config), + } + + +def save_synthetic_bundle(bundle: dict[str, Any], output_path: str | Path) -> Path: + """Saves the generated arrays and metadata to a compressed NPZ file.""" + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + np.savez_compressed( + output_path, + x=bundle["x"], + y=bundle["y"], + k_dist=bundle["k_dist"], + d_dist=bundle["d_dist"], + config_json=json.dumps(bundle["config"]), + ) + return output_path + + +def to_pyhealth_samples( + x: np.ndarray, + y: np.ndarray, + task: str = "binary_final_step", + start_time: Optional[datetime] = None, + threshold: Optional[float] = None, +) -> list[dict[str, Any]]: + """Converts arrays to raw Python samples for ``create_sample_dataset``. + + The default output is a sequence-level binary label derived from the final + timestep target, which is convenient for shiftLSTM ablation studies. + + Each sample contains one timeseries feature: + - "signal": (timestamps, values) + - "label": scalar + """ + + if start_time is None: + start_time = datetime(2020, 1, 1, 0, 0, 0) + + final_targets = y[:, -1, 0] + if threshold is None: + threshold = float(np.median(final_targets)) + + timestamps = None + samples: list[dict[str, Any]] = [] + for idx in range(x.shape[0]): + if timestamps is None: + timestamps = [start_time + timedelta(hours=t) for t in range(x.shape[1])] + + if task == "binary_final_step": + label = int(final_targets[idx] > threshold) + elif task == "regression_final_step": + label = float(final_targets[idx]) + else: + raise ValueError(f"Unsupported task: {task}") + + samples.append( + { + "patient_id": f"synthetic-patient-{idx}", + "visit_id": f"synthetic-visit-{idx}", + "signal": (timestamps, x[idx]), + "label": label, + } + ) + return samples + + +def create_repeated_bundles( + config: SyntheticConfig, + synth_num: int, + output_dir: str | Path, + run_prefix: str = "synthetic_shift", +) -> list[Path]: + """Creates multiple synthetic datasets with different random seeds.""" + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + saved_paths: list[Path] = [] + for run in range(synth_num): + run_config = SyntheticConfig(**{**asdict(config), "seed": config.seed + run}) + bundle = generate_synthetic_arrays(run_config) + out_path = output_dir / f"{run_prefix}_model{run}.npz" + save_synthetic_bundle(bundle, out_path) + saved_paths.append(out_path) + return saved_paths + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Generate synthetic data for shiftLSTM ablation experiments." + ) + parser.add_argument("--N", type=int, default=1000, help="Number of samples.") + parser.add_argument("--T", type=int, default=30, help="Sequence length.") + parser.add_argument("--d", type=int, default=3, help="Number of features.") + parser.add_argument( + "--l", type=int, default=10, help="Lookback window used for target generation." + ) + parser.add_argument( + "--delta", + type=float, + default=0.2, + help="Maximum drift per step for temporal/feature distributions.", + ) + parser.add_argument( + "--sparsity", + type=float, + default=0.1, + help="Probability that an input entry is non-zero.", + ) + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for reproducibility." + ) + parser.add_argument( + "--synth-num", + type=int, + default=1, + help="How many independent synthetic datasets to generate.", + ) + parser.add_argument( + "--savedir", + type=str, + default="examples/synthetic/generated", + help="Directory for generated NPZ files.", + ) + parser.add_argument( + "--runname", + type=str, + default="synthetic_shift", + help="Filename prefix for generated datasets.", + ) + return parser + + +def main() -> None: + parser = _build_arg_parser() + args = parser.parse_args() + + config = SyntheticConfig( + N=args.N, + T=args.T, + d=args.d, + l=args.l, + delta=args.delta, + sparsity=args.sparsity, + seed=args.seed, + ) + saved = create_repeated_bundles( + config=config, + synth_num=args.synth_num, + output_dir=args.savedir, + run_prefix=args.runname, + ) + print(f"Generated {len(saved)} synthetic dataset(s):") + for path in saved: + print(path) + + +if __name__ == "__main__": + main() diff --git a/examples/synthetic_sequence_classification_shift_lstm.py b/examples/synthetic_sequence_classification_shift_lstm.py new file mode 100644 index 000000000..90845b1b4 --- /dev/null +++ b/examples/synthetic_sequence_classification_shift_lstm.py @@ -0,0 +1,219 @@ +"""Ablation script for ShiftLSTM on synthetic time-varying data. + +This example is designed to satisfy the course project's +"Ablation Study / Example Usage" requirement while staying lightweight enough +for local smoke runs. + +It compares ShiftLSTM with different segment counts K on synthetic data +generated following Section 4.1 of Oh et al. (2019): + + - K = 1 acts as the shared-parameter LSTM baseline + - K > 1 relaxes parameter sharing over time +""" + +from __future__ import annotations + +import argparse +import importlib.util +import json +import sys +import tempfile +from pathlib import Path + +from pyhealth.datasets import create_sample_dataset, get_dataloader, split_by_patient +from pyhealth.models import ShiftLSTM +from pyhealth.trainer import Trainer + + +THIS_DIR = Path(__file__).resolve().parent +SYNTHETIC_MODULE_PATH = THIS_DIR / "synthetic" / "shift_lstm_synthetic_data.py" + + +def load_synthetic_module(): + """Loads the synthetic generator module from examples/synthetic.""" + + spec = importlib.util.spec_from_file_location( + "shift_lstm_synthetic_data", SYNTHETIC_MODULE_PATH + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def build_sample_dataset( + num_samples: int, + seq_len: int, + num_features: int, + lookback: int, + delta: float, + seed: int, +): + """Builds a PyHealth SampleDataset from synthetic arrays.""" + + synth = load_synthetic_module() + config = synth.SyntheticConfig( + N=num_samples, + T=seq_len, + d=num_features, + l=lookback, + delta=delta, + seed=seed, + ) + bundle = synth.generate_synthetic_arrays(config) + samples = synth.to_pyhealth_samples(bundle["x"], bundle["y"]) + dataset = create_sample_dataset( + samples=samples, + input_schema={"signal": "timeseries"}, + output_schema={"label": "binary"}, + dataset_name="synthetic_shift_lstm", + ) + return dataset + + +def run_single_experiment( + dataset, + num_segments: int, + embedding_dim: int, + hidden_dim: int, + dropout: float, + batch_size: int, + epochs: int, + learning_rate: float, + seed: int, +): + """Runs one ShiftLSTM configuration and returns validation/test metrics.""" + + train_dataset, val_dataset, test_dataset = split_by_patient( + dataset, [0.7, 0.15, 0.15], seed=seed + ) + train_loader = get_dataloader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=batch_size, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=batch_size, shuffle=False) + + model = ShiftLSTM( + dataset=dataset, + embedding_dim=embedding_dim, + hidden_dim=hidden_dim, + num_segments=num_segments, + dropout=dropout, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + trainer = Trainer( + model=model, + metrics=["accuracy", "roc_auc", "pr_auc", "f1"], + enable_logging=False, + output_path=tmpdir, + exp_name=f"shift_lstm_k{num_segments}", + ) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=epochs, + optimizer_params={"lr": learning_rate}, + monitor="roc_auc", + monitor_criterion="max", + load_best_model_at_last=False, + ) + val_scores = trainer.evaluate(val_loader) + test_scores = trainer.evaluate(test_loader) + + return { + "num_segments": num_segments, + "val": val_scores, + "test": test_scores, + } + + +def format_results_table(results: list[dict]) -> str: + """Formats a compact human-readable table.""" + + header = ( + f"{'Model':<18} {'K':<4} {'Val AUROC':<10} " + f"{'Test AUROC':<11} {'Test AUPRC':<11} {'Test Acc':<9}" + ) + rows = [header, "-" * len(header)] + for result in results: + k = result["num_segments"] + name = "LSTM baseline" if k == 1 else "ShiftLSTM" + val = result["val"] + test = result["test"] + rows.append( + f"{name:<18} {k:<4} {val['roc_auc']:<10.4f} " + f"{test['roc_auc']:<11.4f} {test['pr_auc']:<11.4f} {test['accuracy']:<9.4f}" + ) + return "\n".join(rows) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run ShiftLSTM ablations on synthetic sequence classification." + ) + parser.add_argument("--num-samples", type=int, default=3000) + parser.add_argument("--seq-len", type=int, default=30) + parser.add_argument("--num-features", type=int, default=3) + parser.add_argument("--lookback", type=int, default=10) + parser.add_argument("--delta", type=float, default=0.2) + parser.add_argument("--embedding-dim", type=int, default=32) + parser.add_argument("--hidden-dim", type=int, default=32) + parser.add_argument("--dropout", type=float, default=0.0) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--epochs", type=int, default=5) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--segments", + type=int, + nargs="+", + default=[1, 2, 4], + help="Segment counts K to compare.", + ) + parser.add_argument( + "--save-json", + type=str, + default=None, + help="Optional path to save the full metrics as JSON.", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + dataset = build_sample_dataset( + num_samples=args.num_samples, + seq_len=args.seq_len, + num_features=args.num_features, + lookback=args.lookback, + delta=args.delta, + seed=args.seed, + ) + + results = [] + for num_segments in args.segments: + result = run_single_experiment( + dataset=dataset, + num_segments=num_segments, + embedding_dim=args.embedding_dim, + hidden_dim=args.hidden_dim, + dropout=args.dropout, + batch_size=args.batch_size, + epochs=args.epochs, + learning_rate=args.lr, + seed=args.seed, + ) + results.append(result) + + print(format_results_table(results)) + + if args.save_json is not None: + save_path = Path(args.save_json) + save_path.parent.mkdir(parents=True, exist_ok=True) + with open(save_path, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2) + print(f"\nSaved metrics to: {save_path}") + + +if __name__ == "__main__": + main() diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..a5ebd27f3 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -21,6 +21,7 @@ from .molerec import MoleRec, MoleRecLayer from .retain import MultimodalRETAIN, RETAIN, RETAINLayer from .rnn import MultimodalRNN, RNN, RNNLayer +from .shift_lstm import ShiftLSTM, ShiftLSTMLayer from .safedrug import SafeDrug, SafeDrugLayer from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer from .stagenet import StageNet, StageNetLayer diff --git a/pyhealth/models/shift_lstm.py b/pyhealth/models/shift_lstm.py new file mode 100644 index 000000000..14b915fdc --- /dev/null +++ b/pyhealth/models/shift_lstm.py @@ -0,0 +1,267 @@ +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn + +from pyhealth.datasets import SampleDataset + +from .base_model import BaseModel +from .embedding import EmbeddingModel + + +class ShiftLSTMLayer(nn.Module): + """Segment-wise LSTM layer with relaxed parameter sharing over time. + + The layer divides each sequence into ``num_segments`` temporal chunks. A + dedicated ``nn.LSTMCell`` is used within each chunk, while the hidden and + cell states are propagated through the entire sequence. + + This implements the core idea of shiftLSTM from + "Relaxed Parameter Sharing: Effectively Modeling Time-Varying Relationships + in Clinical Time-Series". + + Args: + input_size: Input feature size. + hidden_size: Hidden state size. + num_segments: Number of temporal segments. ``1`` reduces to a standard + shared-parameter LSTM. + dropout: Dropout applied to the input sequence before recurrence. + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + num_segments: int = 1, + dropout: float = 0.5, + ): + super().__init__() + if num_segments < 1: + raise ValueError("num_segments must be >= 1") + + self.input_size = input_size + self.hidden_size = hidden_size + self.num_segments = num_segments + self.dropout = dropout + + self.dropout_layer = nn.Dropout(dropout) + self.cells = nn.ModuleList( + [nn.LSTMCell(input_size, hidden_size) for _ in range(num_segments)] + ) + + def _compute_lengths( + self, x: torch.Tensor, mask: Optional[torch.Tensor] + ) -> torch.Tensor: + batch_size = x.size(0) + if mask is None: + lengths = torch.full( + (batch_size,), + fill_value=x.size(1), + dtype=torch.long, + device=x.device, + ) + else: + lengths = mask.long().sum(dim=-1).clamp_min(1) + return lengths + + def _segment_index(self, step: int, lengths: torch.Tensor) -> torch.Tensor: + # Relative segment assignment per sample: + # floor(step / length * K), clamped to [0, K - 1]. + seg = torch.div( + step * self.num_segments, + lengths, + rounding_mode="floor", + ) + return seg.clamp(max=self.num_segments - 1) + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward propagation. + + Args: + x: Tensor of shape ``[batch_size, seq_len, input_size]``. + mask: Optional tensor of shape ``[batch_size, seq_len]`` with 1 for + valid steps and 0 for padding. + + Returns: + outputs: Tensor of shape ``[batch_size, seq_len, hidden_size]``. + last_outputs: Tensor of shape ``[batch_size, hidden_size]``. + """ + x = self.dropout_layer(x) + batch_size, seq_len, _ = x.shape + + if mask is not None: + mask = mask.to(x.device).bool() + + lengths = self._compute_lengths(x, mask) + outputs = torch.zeros( + batch_size, seq_len, self.hidden_size, device=x.device, dtype=x.dtype + ) + last_outputs = torch.zeros( + batch_size, self.hidden_size, device=x.device, dtype=x.dtype + ) + + h_t = torch.zeros( + batch_size, self.hidden_size, device=x.device, dtype=x.dtype + ) + c_t = torch.zeros( + batch_size, self.hidden_size, device=x.device, dtype=x.dtype + ) + + for step in range(seq_len): + if mask is None: + valid = step < lengths + else: + valid = mask[:, step] + + if not valid.any(): + continue + + step_segment = self._segment_index(step, lengths) + + next_h = h_t.clone() + next_c = c_t.clone() + + for segment_idx, cell in enumerate(self.cells): + select = valid & (step_segment == segment_idx) + if not select.any(): + continue + h_sel, c_sel = cell(x[select, step, :], (h_t[select], c_t[select])) + next_h[select] = h_sel + next_c[select] = c_sel + + h_t = next_h + c_t = next_c + outputs[valid, step, :] = h_t[valid] + last_outputs[valid] = h_t[valid] + + return outputs, last_outputs + + +class ShiftLSTM(BaseModel): + """PyHealth model wrapper for shiftLSTM. + + This model mirrors the high-level structure of :class:`pyhealth.models.RNN`, + but replaces the shared-parameter recurrent layer with ``ShiftLSTMLayer`` to + better model time-varying input-label relationships. + + Args: + dataset: Dataset used to infer feature and label schemas. + embedding_dim: Shared embedding size for all input features. + hidden_dim: Hidden state size of each shiftLSTM layer. + num_segments: Number of temporal segments per feature sequence. + dropout: Input dropout applied before recurrence. + """ + + def __init__( + self, + dataset: SampleDataset, + embedding_dim: int = 128, + hidden_dim: int = 128, + num_segments: int = 2, + dropout: float = 0.5, + ): + super().__init__(dataset=dataset) + assert len(self.label_keys) == 1, ( + "Only one label key is supported if ShiftLSTM is initialized" + ) + + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim + self.num_segments = num_segments + self.dropout = dropout + self.label_key = self.label_keys[0] + self.mode = self.dataset.output_schema[self.label_key] + + self.embedding_model = EmbeddingModel(dataset, embedding_dim) + self.shift_lstm = nn.ModuleDict() + for feature_key in self.dataset.input_processors.keys(): + self.shift_lstm[feature_key] = ShiftLSTMLayer( + input_size=embedding_dim, + hidden_size=hidden_dim, + num_segments=num_segments, + dropout=dropout, + ) + + output_size = self.get_output_size() + self.fc = nn.Linear(len(self.feature_keys) * self.hidden_dim, output_size) + + def _extract_inputs_and_masks(self, **kwargs): + inputs = {} + masks = {} + + for feature_key in self.feature_keys: + feature = kwargs[feature_key] + if isinstance(feature, torch.Tensor): + feature = (feature,) + + schema = self.dataset.input_processors[feature_key].schema() + value = feature[schema.index("value")] if "value" in schema else None + mask = feature[schema.index("mask")] if "mask" in schema else None + + if value is None: + raise ValueError( + f"Feature '{feature_key}' must contain 'value' in the schema." + ) + + inputs[feature_key] = value + if mask is not None: + masks[feature_key] = mask + + return inputs, masks + + def _prepare_sequence_feature( + self, + feature_key: str, + embedded_feature: torch.Tensor, + masks: Dict[str, torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = embedded_feature + x_dim_orig = x.dim() + + if x_dim_orig == 4: + x = x.sum(dim=2) + if feature_key in masks: + mask = (masks[feature_key].to(self.device).sum(dim=-1) > 0).int() + else: + mask = (torch.abs(x).sum(dim=-1) != 0).int() + elif x_dim_orig == 2: + x = x.unsqueeze(1) + mask = None + else: + if feature_key in masks: + mask = masks[feature_key].to(self.device).int() + if mask.dim() == 3: + mask = (mask.sum(dim=-1) > 0).int() + else: + mask = (torch.abs(x).sum(dim=-1) != 0).int() + + return x, mask + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """Forward propagation.""" + inputs, masks = self._extract_inputs_and_masks(**kwargs) + embedded = self.embedding_model(inputs, masks=masks) + + patient_emb = [] + for feature_key in self.feature_keys: + x, mask = self._prepare_sequence_feature( + feature_key, embedded[feature_key], masks + ) + _, last_output = self.shift_lstm[feature_key](x, mask) + patient_emb.append(last_output) + + patient_emb = torch.cat(patient_emb, dim=1) + logits = self.fc(patient_emb) + + y_true = kwargs[self.label_key].to(self.device) + loss = self.get_loss_function()(logits, y_true) + y_prob = self.prepare_y_prob(logits) + + results = {"loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits} + if kwargs.get("embed", False): + results["embed"] = patient_emb + return results diff --git a/tests/core/test_shift_lstm.py b/tests/core/test_shift_lstm.py new file mode 100644 index 000000000..d975786f5 --- /dev/null +++ b/tests/core/test_shift_lstm.py @@ -0,0 +1,170 @@ +import importlib.util +import sys +import tempfile +import unittest +from datetime import datetime, timedelta +from pathlib import Path + +import numpy as np +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import ShiftLSTM +from pyhealth.models.shift_lstm import ShiftLSTMLayer + + +class TestShiftLSTM(unittest.TestCase): + """Test cases for the ShiftLSTM model and synthetic generator.""" + + def setUp(self): + base = datetime(2020, 1, 1) + self.samples = [] + for i in range(4): + timestamps = [base + timedelta(hours=t) for t in range(6)] + values = np.array( + [ + [float(i + t), float((i + 1) * (t + 1) % 5), float(t % 2)] + for t in range(6) + ], + dtype=float, + ) + self.samples.append( + { + "patient_id": f"patient-{i}", + "visit_id": f"visit-{i}", + "signal": (timestamps, values), + "label": i % 2, + } + ) + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema={"signal": "timeseries"}, + output_schema={"label": "binary"}, + dataset_name="synthetic_shift_test", + ) + self.model = ShiftLSTM( + dataset=self.dataset, + embedding_dim=16, + hidden_dim=8, + num_segments=2, + dropout=0.0, + ) + + def _load_synthetic_module(self): + module_path = ( + Path(__file__).resolve().parents[2] + / "examples" + / "synthetic" + / "shift_lstm_synthetic_data.py" + ) + spec = importlib.util.spec_from_file_location( + "shift_lstm_synthetic_data", module_path + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + def test_model_initialization(self): + """Test that the ShiftLSTM model initializes correctly.""" + self.assertIsInstance(self.model, ShiftLSTM) + self.assertEqual(self.model.embedding_dim, 16) + self.assertEqual(self.model.hidden_dim, 8) + self.assertEqual(self.model.num_segments, 2) + self.assertEqual(self.model.label_key, "label") + self.assertIn("signal", self.model.shift_lstm) + self.assertEqual(len(self.model.shift_lstm["signal"].cells), 2) + + def test_model_forward(self): + """Test that the ShiftLSTM forward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + ret = self.model(**data_batch) + + self.assertIn("loss", ret) + self.assertIn("y_prob", ret) + self.assertIn("y_true", ret) + self.assertIn("logit", ret) + self.assertEqual(ret["y_prob"].shape[0], 2) + self.assertEqual(ret["y_true"].shape[0], 2) + self.assertEqual(ret["logit"].shape[0], 2) + self.assertEqual(ret["loss"].dim(), 0) + + def test_model_backward(self): + """Test that the ShiftLSTM backward pass works correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(train_loader)) + ret = self.model(**data_batch) + ret["loss"].backward() + + has_gradient = False + for param in self.model.parameters(): + if param.requires_grad and param.grad is not None: + has_gradient = True + break + self.assertTrue(has_gradient, "No gradients found after backward pass.") + + def test_shift_lstm_layer_shapes(self): + """Test the low-level ShiftLSTMLayer output shapes and segment count.""" + layer = ShiftLSTMLayer( + input_size=4, hidden_size=3, num_segments=3, dropout=0.0 + ) + x = torch.randn(2, 5, 4) + mask = torch.tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]], dtype=torch.int64) + + outputs, last_outputs = layer(x, mask) + + self.assertEqual(outputs.shape, (2, 5, 3)) + self.assertEqual(last_outputs.shape, (2, 3)) + self.assertEqual(len(layer.cells), 3) + + def test_synthetic_generator_outputs(self): + """Test synthetic data generator shapes and distribution normalization.""" + synth = self._load_synthetic_module() + config = synth.SyntheticConfig(N=8, T=12, d=3, l=4, delta=0.2, seed=7) + bundle = synth.generate_synthetic_arrays(config) + + self.assertEqual(bundle["x"].shape, (8, 12, 3)) + self.assertEqual(bundle["y"].shape, (8, 12, 1)) + self.assertEqual(bundle["k_dist"].shape, (12, 4)) + self.assertEqual(bundle["d_dist"].shape, (12, 3)) + + # For valid prediction timesteps, temporal and feature weights should sum to 1. + self.assertTrue( + np.allclose(bundle["k_dist"][config.l :].sum(axis=1), 1.0, atol=1e-5) + ) + self.assertTrue( + np.allclose(bundle["d_dist"][config.l :].sum(axis=1), 1.0, atol=1e-5) + ) + + samples = synth.to_pyhealth_samples(bundle["x"], bundle["y"]) + self.assertEqual(len(samples), 8) + self.assertIn("signal", samples[0]) + self.assertIn("label", samples[0]) + + def test_synthetic_bundle_save_with_tempdir(self): + """Test saving synthetic bundles via TemporaryDirectory with cleanup.""" + synth = self._load_synthetic_module() + config = synth.SyntheticConfig(N=6, T=10, d=3, l=4, delta=0.1, seed=11) + bundle = synth.generate_synthetic_arrays(config) + + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) / "synthetic_bundle.npz" + saved_path = synth.save_synthetic_bundle(bundle, output_path) + + self.assertEqual(saved_path, output_path) + self.assertTrue(saved_path.exists()) + + loaded = np.load(saved_path, allow_pickle=False) + self.assertEqual(tuple(loaded["x"].shape), (6, 10, 3)) + self.assertEqual(tuple(loaded["y"].shape), (6, 10, 1)) + self.assertEqual(tuple(loaded["k_dist"].shape), (10, 4)) + self.assertEqual(tuple(loaded["d_dist"].shape), (10, 3)) + + +if __name__ == "__main__": + unittest.main()