diff --git a/docs/api/models.rst b/docs/api/models.rst index 7c3ac7c4b..1a6acf24d 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -205,4 +205,5 @@ API Reference models/pyhealth.models.TextEmbedding models/pyhealth.models.BIOT models/pyhealth.models.unified_multimodal_embedding_docs + models/pyhealth.models.SHy models/pyhealth.models.califorest diff --git a/docs/api/models/pyhealth.models.SHy.rst b/docs/api/models/pyhealth.models.SHy.rst new file mode 100644 index 000000000..8c5f6d8b5 --- /dev/null +++ b/docs/api/models/pyhealth.models.SHy.rst @@ -0,0 +1,9 @@ +pyhealth.models.SHy +=================================== + +SHy (Self-Explaining Hypergraph Neural Network) for diagnosis prediction. + +.. autoclass:: pyhealth.models.SHy + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..9fc8876a6 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -230,3 +230,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + Diagnosis Prediction diff --git a/docs/api/tasks/pyhealth.tasks.DiagnosisPrediction.rst b/docs/api/tasks/pyhealth.tasks.DiagnosisPrediction.rst new file mode 100644 index 000000000..97538a37b --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.DiagnosisPrediction.rst @@ -0,0 +1,14 @@ +pyhealth.tasks.DiagnosisPredictionMIMIC3 +========================================== + +Diagnosis prediction task for the MIMIC-III dataset. + +.. autoclass:: pyhealth.tasks.DiagnosisPredictionMIMIC3 + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.tasks.DiagnosisPredictionMIMIC4 + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/mimic3_diagnosis_prediction_shy.py b/examples/mimic3_diagnosis_prediction_shy.py new file mode 100644 index 000000000..3192a5fe8 --- /dev/null +++ b/examples/mimic3_diagnosis_prediction_shy.py @@ -0,0 +1,187 @@ +""" +Diagnosis Prediction with SHy on MIMIC-III. + +Ablation study with different configs: +1. Number of temporal phenotypes (K=1, 3, 5) +2. Number of HGNN layers (0, 1, 2) +3. Loss components (w/ and w/o each auxiliary loss) +4. Gumbel-Softmax temperature (0.5, 1.0, 2.0) — novel extension + +Paper: Leisheng Yu, Yanxiao Cai, Minxing Zhang, and Xia Hu. + Self-Explaining Hypergraph Neural Networks for Diagnosis Prediction. + Proceedings of Machine Learning Research (CHIL), 2025. + +Results (MIMIC-III dev=True, 1000 patients, 50 epochs, lr=1e-3): + + config jaccard f1 pr_auc roc_auc + ------------------------------------------------------- + K=1 0.0339 0.0652 0.1732 0.7240 + K=3 0.0401 0.0762 0.1294 0.6905 + K=5 0.0402 0.0766 0.1533 0.7126 + hgnn=0 0.0436 0.0827 0.1517 0.7067 + hgnn=1 0.0413 0.0787 0.1398 0.6997 + hgnn=2 0.0400 0.0759 0.1352 0.7142 + no auxiliary loss 0.0426 0.0808 0.1671 0.7134 + no fidelity 0.0420 0.0799 0.1422 0.6990 + no distinct 0.0390 0.0743 0.1459 0.6905 + no alpha 0.0408 0.0776 0.1429 0.6917 + full (all loss) 0.0347 0.0666 0.1389 0.6881 + temp=0.5 0.0408 0.0778 0.1265 0.7095 + temp=1.0 0.0397 0.0757 0.1354 0.6961 + temp=2.0 0.0411 0.0780 0.1431 0.6948 +""" + +import random +import numpy as np +import torch + +from pyhealth.datasets import MIMIC3Dataset +from pyhealth.datasets.splitter import split_by_patient +from pyhealth.datasets.utils import get_dataloader +from pyhealth.models import SHy +from pyhealth.tasks import DiagnosisPredictionMIMIC3 +from pyhealth.trainer import Trainer + +# seed +SEED = 123 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +if torch.cuda.is_available(): + torch.cuda.manual_seed_all(SEED) + + +def run_one(sample_dataset, train_loader, val_loader, test_loader, name, **kw): + """train + eval a single SHy config, return test metrics""" + print(f"\n{'='*55}") + print(f" {name}") + print(f"{'='*55}") + + model = SHy(dataset=sample_dataset, **kw) + + trainer = Trainer( + model=model, + metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "roc_auc_samples"], + enable_logging=False, + ) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=50, + optimizer_params={"lr": 1e-3}, + monitor="pr_auc_samples", + monitor_criterion="max", + ) + + res = trainer.evaluate(test_loader) + print(f"=> {res}") + return res + + +if __name__ == "__main__": + + # -- load mimic-iii -- + # adjust path to local one + base_dataset = MIMIC3Dataset( + root="/path/to/mimic-iii/1.4", + tables=["DIAGNOSES_ICD"], + dev=True, + ) + base_dataset.stats() + + # -- set up task + splits -- + task = DiagnosisPredictionMIMIC3() + samples = base_dataset.set_task(task) + print(f"got {len(samples)} samples total") + + train_ds, val_ds, test_ds = split_by_patient(samples, [0.8, 0.1, 0.1], seed=SEED) + print(f"split: train={len(train_ds)} val={len(val_ds)} test={len(test_ds)}") + + train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False) + test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False) + + # default hyperparams (from paper) + defaults = dict( + embedding_dim=32, + hgnn_dim=64, + hgnn_layers=2, + num_tp=5, + hidden_dim=64, + num_heads=8, + dropout=0.1, + ) + + results = {} + + # -- ablation 1: vary K (number of phenotypes) -- + for k in [1, 3, 5]: + cfg = {**defaults, "num_tp": k} + results[f"K={k}"] = run_one( + samples, + train_loader, + val_loader, + test_loader, + name=f"K={k}", + **cfg, + ) + + # -- ablation 2: vary hgnn layers -- + for n in [0, 1, 2]: + cfg = {**defaults, "hgnn_layers": n} + results[f"hgnn={n}"] = run_one( + samples, + train_loader, + val_loader, + test_loader, + name=f"HGNN layers={n}", + **cfg, + ) + + # -- ablation 3: loss components -- + loss_setups = { + "no auxiliary loss": dict(fidelity_weight=0, distinct_weight=0, alpha_weight=0), + "no fidelity": dict(fidelity_weight=0, distinct_weight=0.01, alpha_weight=0.01), + "no distinct": dict(fidelity_weight=0.1, distinct_weight=0, alpha_weight=0.01), + "no alpha": dict(fidelity_weight=0.1, distinct_weight=0.01, alpha_weight=0), + "full (all loss)": dict( + fidelity_weight=0.1, distinct_weight=0.01, alpha_weight=0.01 + ), + } + for tag, loss_kw in loss_setups.items(): + cfg = {**defaults, **loss_kw} + results[tag] = run_one( + samples, + train_loader, + val_loader, + test_loader, + name=tag, + **cfg, + ) + + # -- ablation 4 (extension): gumbel-softmax temperature -- + # lower temp = more discrete selections, higher = more exploration + for temp in [0.5, 1.0, 2.0]: + cfg = {**defaults, "temperature": temp} + results[f"temp={temp}"] = run_one( + samples, + train_loader, + val_loader, + test_loader, + name=f"temperature={temp}", + **cfg, + ) + + # -- print summary table -- + print(f"\n{'='*66}") + print("ABLATION RESULTS") + print(f"{'='*66}") + print(f"{'config':<20} {'jaccard':>10} {'f1':>10} {'pr_auc':>10} {'roc_auc':>10}") + print("-" * 76) + for tag, r in results.items(): + j = r.get("jaccard_samples", 0) + f = r.get("f1_samples", 0) + p = r.get("pr_auc_samples", 0) + a = r.get("roc_auc_samples", 0) + print(f"{tag:<20} {j:>10.4f} {f:>10.4f} {p:>10.4f} {a:>10.4f}") + print("=" * 76) diff --git a/examples/mimic4_diagnosis_prediction_shy.py b/examples/mimic4_diagnosis_prediction_shy.py new file mode 100644 index 000000000..ec35aa136 --- /dev/null +++ b/examples/mimic4_diagnosis_prediction_shy.py @@ -0,0 +1,187 @@ +""" +Diagnosis Prediction with SHy on MIMIC-IV. + +Ablation study with different configs: +1. Number of temporal phenotypes (K=1, 3, 5) +2. Number of HGNN layers (0, 1, 2) +3. Loss components (w/ and w/o each auxiliary loss) +4. Gumbel-Softmax temperature (0.5, 1.0, 2.0) — novel extension + +Paper: Leisheng Yu, Yanxiao Cai, Minxing Zhang, and Xia Hu. + Self-Explaining Hypergraph Neural Networks for Diagnosis Prediction. + Proceedings of Machine Learning Research (CHIL), 2025. + +Results (MIMIC-IV dev=True, 1000 patients, 5 epochs, lr=1e-3): + + config jaccard f1 pr_auc roc_auc + ------------------------------------------------------- + K=1 0.0083 0.0163 0.1068 0.8590 + K=3 0.0075 0.0149 0.1432 0.8694 + K=5 0.0079 0.0157 0.0989 0.8576 + hgnn=0 0.0079 0.0156 0.1277 0.8697 + hgnn=1 0.0079 0.0156 0.1323 0.8699 + hgnn=2 0.0082 0.0162 0.1081 0.8558 + no auxiliary loss 0.0081 0.0160 0.1199 0.8610 + no fidelity 0.0077 0.0153 0.1344 0.8628 + no distinct 0.0084 0.0166 0.1242 0.8583 + no alpha 0.0082 0.0162 0.1402 0.8685 + full (all loss) 0.0082 0.0162 0.1134 0.8601 + temp=0.5 0.0080 0.0159 0.1272 0.8678 + temp=1.0 0.0080 0.0158 0.1145 0.8592 + temp=2.0 0.0085 0.0168 0.1450 0.8691 +""" + +import random +import numpy as np +import torch + +from pyhealth.datasets import MIMIC4EHRDataset +from pyhealth.datasets.splitter import split_by_patient +from pyhealth.datasets.utils import get_dataloader +from pyhealth.models import SHy +from pyhealth.tasks import DiagnosisPredictionMIMIC4 +from pyhealth.trainer import Trainer + +# seed +SEED = 123 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +if torch.cuda.is_available(): + torch.cuda.manual_seed_all(SEED) + + +def run_one(sample_dataset, train_loader, val_loader, test_loader, name, **kw): + """train + eval a single SHy config, return test metrics""" + print(f"\n{'='*55}") + print(f" {name}") + print(f"{'='*55}") + + model = SHy(dataset=sample_dataset, **kw) + + trainer = Trainer( + model=model, + metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "roc_auc_samples"], + enable_logging=False, + ) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=5, + optimizer_params={"lr": 1e-3}, + monitor="pr_auc_samples", + monitor_criterion="max", + ) + + res = trainer.evaluate(test_loader) + print(f"=> {res}") + return res + + +if __name__ == "__main__": + + # -- load mimic-iv -- + # adjust path to local one + base_dataset = MIMIC4EHRDataset( + root="/path/to/mimic-iv/3.1", + tables=["diagnoses_icd"], + dev=True, + ) + base_dataset.stats() + + # -- set up task + splits -- + task = DiagnosisPredictionMIMIC4() + samples = base_dataset.set_task(task) + print(f"got {len(samples)} samples total") + + train_ds, val_ds, test_ds = split_by_patient(samples, [0.8, 0.1, 0.1], seed=SEED) + print(f"split: train={len(train_ds)} val={len(val_ds)} test={len(test_ds)}") + + train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True) + val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False) + test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False) + + # default hyperparams (from paper) + defaults = dict( + embedding_dim=32, + hgnn_dim=64, + hgnn_layers=2, + num_tp=5, + hidden_dim=64, + num_heads=8, + dropout=0.1, + ) + + results = {} + + # -- ablation 1: vary K (number of phenotypes) -- + for k in [1, 3, 5]: + cfg = {**defaults, "num_tp": k} + results[f"K={k}"] = run_one( + samples, + train_loader, + val_loader, + test_loader, + name=f"K={k}", + **cfg, + ) + + # -- ablation 2: vary hgnn layers -- + for n in [0, 1, 2]: + cfg = {**defaults, "hgnn_layers": n} + results[f"hgnn={n}"] = run_one( + samples, + train_loader, + val_loader, + test_loader, + name=f"HGNN layers={n}", + **cfg, + ) + + # -- ablation 3: loss components -- + loss_setups = { + "no auxiliary loss": dict(fidelity_weight=0, distinct_weight=0, alpha_weight=0), + "no fidelity": dict(fidelity_weight=0, distinct_weight=0.01, alpha_weight=0.01), + "no distinct": dict(fidelity_weight=0.1, distinct_weight=0, alpha_weight=0.01), + "no alpha": dict(fidelity_weight=0.1, distinct_weight=0.01, alpha_weight=0), + "full (all loss)": dict( + fidelity_weight=0.1, distinct_weight=0.01, alpha_weight=0.01 + ), + } + for tag, loss_kw in loss_setups.items(): + cfg = {**defaults, **loss_kw} + results[tag] = run_one( + samples, + train_loader, + val_loader, + test_loader, + name=tag, + **cfg, + ) + + # -- ablation 4 (extension): gumbel-softmax temperature -- + # lower temp = more discrete selections, higher = more exploration + for temp in [0.5, 1.0, 2.0]: + cfg = {**defaults, "temperature": temp} + results[f"temp={temp}"] = run_one( + samples, + train_loader, + val_loader, + test_loader, + name=f"temperature={temp}", + **cfg, + ) + + # -- print summary table -- + print(f"\n{'='*66}") + print("ABLATION RESULTS") + print(f"{'='*66}") + print(f"{'config':<20} {'jaccard':>10} {'f1':>10} {'pr_auc':>10} {'roc_auc':>10}") + print("-" * 76) + for tag, r in results.items(): + j = r.get("jaccard_samples", 0) + f = r.get("f1_samples", 0) + p = r.get("pr_auc_samples", 0) + a = r.get("roc_auc_samples", 0) + print(f"{tag:<20} {j:>10.4f} {f:>10.4f} {p:>10.4f} {a:>10.4f}") + print("=" * 76) diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..b6c57e1a4 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -45,4 +45,5 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding -from .califorest import CaliForest \ No newline at end of file +from .shy import SHy +from .califorest import CaliForest diff --git a/pyhealth/models/shy.py b/pyhealth/models/shy.py new file mode 100644 index 000000000..125f0bb1d --- /dev/null +++ b/pyhealth/models/shy.py @@ -0,0 +1,619 @@ +""" +SHy: Self-Explaining Hypergraph Neural Networks for Diagnosis Prediction. + +Paper: Leisheng Yu, Yanxiao Cai, Minxing Zhang, and Xia Hu. + Self-Explaining Hypergraph Neural Networks for Diagnosis Prediction. + Proceedings of Machine Learning Research (CHIL), 2025. +""" + +from typing import Dict, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel + + +def _scatter_mean(src: torch.Tensor, index: torch.Tensor, num: int) -> torch.Tensor: + """Average vectors by index.""" + out = torch.zeros(num, src.shape[1], device=src.device) + count = torch.zeros(num, 1, device=src.device) + idx = index.unsqueeze(1).expand_as(src) + out.scatter_add_(0, idx, src) + count.scatter_add_( + 0, index.unsqueeze(1), torch.ones(index.shape[0], 1, device=src.device) + ) + return out / count.clamp(min=1) + + +def _scatter_sum(src: torch.Tensor, index: torch.Tensor, num: int) -> torch.Tensor: + """Sum vectors by index.""" + out = torch.zeros(num, src.shape[1], device=src.device) + out.scatter_add_(0, index.unsqueeze(1).expand_as(src), src) + return out + + +class UniGINConv(nn.Module): + """ + One layer of hypergraph message passing (UniGIN). + + Two-stage aggregation: + 1. Node -> Hyperedge: mean of node features + 2. Hyperedge -> Node: sum of hyperedge features + Then a GIN-style learnable self-loop. + + Args: + in_dim: Input feature dimension. + out_dim: Output feature dimension. + """ + + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.W = nn.Linear(in_dim, out_dim) + self.eps = nn.Parameter(torch.tensor(0.1)) + + def forward( + self, X: torch.Tensor, V: torch.Tensor, E: torch.Tensor + ) -> torch.Tensor: + """ + Args: + X: Node features (num_nodes, in_dim). + V: Node indices (COO format). + E: Hyperedge indices in COO format. + """ + num_nodes = X.shape[0] + if E.numel() == 0: + return self.W((1 + self.eps) * X) + num_edges = E.max().item() + 1 + + # nodes -> hyperedges (mean) + edge_emb = _scatter_mean(X[V], E, num_edges) + + # hyperedges -> nodes (sum) + node_msg = _scatter_sum(edge_emb[E], V, num_nodes) + + return self.W((1 + self.eps) * X + node_msg) + + +class PhenotypeExtractor(nn.Module): + """ + Extracts the temporal phenotype from a patient. + + Three steps: + 1. Score each (code, visit) pair for inclusion probability + 2. Add false negatives + 3. Gumbel-Softmax + + Args: + emb_dim: Code embedding dimension. + temperature: Gumbel-Softmax temperature. + add_ratio: Fraction of false-negative connections. + """ + + def __init__(self, emb_dim: int, temperature: float = 1.0, add_ratio: float = 0.1): + super().__init__() + self.temperature = temperature + self.add_ratio = add_ratio + + # Score how likely a code belongs to a visit for this phenotype + self.scorer = nn.Sequential( + nn.Linear(emb_dim * 2, 64), + nn.ReLU(), + nn.Linear(64, 1), + ) + + def _add_false_negatives( + self, X: torch.Tensor, H: torch.Tensor, V: torch.Tensor, E: torch.Tensor + ) -> torch.Tensor: + """ + Find codes that are probably missing from visits and add them, + by using cosine similarity. + + Args: + X: Node features (num_nodes, in_dim). + H: Incidence matrix (num_nodes, num_visits) + V: Node indices (COO format). + E: Hyperedge indices in COO format. + """ + num_edges = H.shape[1] + # Mean code embeddings in a visit + visit_emb = _scatter_mean(X[V], E, num_edges) + # Cosine similarity between every (code, visit) pairs + X_norm = F.normalize(X, dim=-1) + visit_norm = F.normalize(visit_emb, dim=-1) + sim = X_norm @ visit_norm.T + sim[H > 0] = -1e16 + + # Add the top-k most similar missing (code, visit) pairs + num_to_add = max(1, int(self.add_ratio * V.shape[0])) + num_to_add = min(num_to_add, sim.numel()) + _, flat_idx = torch.topk(sim.flatten(), num_to_add) + rows = flat_idx // sim.shape[1] + cols = flat_idx % sim.shape[1] + + enriched = H.clone() + enriched[rows, cols] = 1.0 + return enriched + + def _score_pairs( + self, X: torch.Tensor, V: torch.Tensor, E: torch.Tensor, num_edges: int + ) -> torch.Tensor: + """Score every (code, visit) pair for phenotype inclusion.""" + visit_emb = _scatter_mean(X[V], E, num_edges) + + code_rep = X.unsqueeze(1).expand(-1, num_edges, -1) + visit_rep = visit_emb.unsqueeze(0).expand(X.shape[0], -1, -1) + pair_feat = torch.cat([code_rep, visit_rep], dim=-1) + + return torch.sigmoid(self.scorer(pair_feat).squeeze(-1)) + + def _gumbel_sample(self, probs: torch.Tensor) -> torch.Tensor: + """Gumbel-Softmax.""" + # Gumbel noise + u = torch.rand_like(probs).clamp(1e-16, 1 - 1e-16) + gumbel = torch.log(u) - torch.log(1 - u) + logit = torch.log(probs.clamp(1e-16) / (1 - probs).clamp(1e-16)) + soft = torch.sigmoid((logit + gumbel) / self.temperature) + hard = (soft > 0.5).float() + return hard - soft.detach() + soft + + def forward( + self, X: torch.Tensor, H: torch.Tensor, V: torch.Tensor, E: torch.Tensor + ) -> torch.Tensor: + """ + Extract one phenotype sub-hypergraph. + + Args: + X: Node features (num_nodes, in_dim). + H: Incidence matrix (num_nodes, num_visits) + V: Node indices (COO format). + E: Hyperedge indices in COO format. + + Returns: + Phenotype incidence matrix (num_codes, num_visits). + """ + # Step 1: Add potentially missing connections + enriched_H = self._add_false_negatives(X, H, V, E) + + # Step 2: Score each (code, visit) for this phenotype + probs = self._score_pairs(X, V, E, H.shape[1]) + + # Step 3: Sample binary mask + mask = self._gumbel_sample(probs) + + return enriched_H * mask + + +class PhenotypeAggregator(nn.Module): + """ + Put a phenotype sub-hypergraph into a vector. + + Args: + emb_dim: Input code embedding dimension. + hidden_dim: GRU hidden state dimension. + """ + + def __init__(self, emb_dim: int, hidden_dim: int): + super().__init__() + self.gru = nn.GRU(emb_dim, hidden_dim, batch_first=False) + self.attn = nn.Linear(hidden_dim, 1, bias=False) + + def forward(self, X: torch.Tensor, H: torch.Tensor) -> torch.Tensor: + """Args: + X: Code embeddings (num_codes, emb_dim). + H: Phenotype incidence matrix (num_codes, num_visits). + + Returns: + Phenotype embedding vector (hidden_dim,). + """ + # Weighted sum of code embeddings per visit + visit_emb = H.T.float() @ X # (num_visits, emb_dim) + + # GRU captures time patterns across visits + states, _ = self.gru(visit_emb) # (num_visits, hidden_dim) + + # Attention pooling: weight each visit's importance + weights = F.softmax(self.attn(states).squeeze(-1), dim=0) + return (weights.unsqueeze(-1) * states).sum(dim=0) # (hidden_dim,) + + +class Decoder(nn.Module): + """ + Reconstructs the original incidence matrix from phenotype embeddings. + + Args: + hidden_dim: Phenotype embedding dimension. + num_tp: Number of temporal phenotypes. + emb_dim: Code embedding dimension. + num_codes: Vocabulary size. + """ + + def __init__(self, hidden_dim: int, num_tp: int, emb_dim: int, num_codes: int): + super().__init__() + self.W_context = nn.Linear(hidden_dim * num_tp, hidden_dim) + self.gru = nn.GRU(emb_dim, hidden_dim) + self.W_output = nn.Linear(hidden_dim, num_codes) + self.num_codes = num_codes + + def forward( + self, + phenotype_embs: torch.Tensor, + num_visits: int, + H: torch.Tensor, + X: torch.Tensor, + ) -> torch.Tensor: + """Args: + phenotype_embs: Concatenated phenotype vectors (num_tp, hidden_dim). + num_visits: Number of visits to reconstruct. + H: Original incidence matrix (num_codes, num_visits). + X: Code embedding table (num_codes, emb_dim). + + Returns: + Reconstructed incidence matrix (num_codes, num_visits). + """ + hidden = self.W_context(phenotype_embs.reshape(-1)) + hidden = hidden.unsqueeze(0).unsqueeze(0) # (1, 1, hidden_dim) + + result = [] + prev_codes = torch.zeros(self.num_codes, device=X.device) + + for t in range(num_visits): + embed_input = F.relu(prev_codes @ X).unsqueeze(0).unsqueeze(0) + out, hidden = self.gru(embed_input, hidden) + pred = torch.sigmoid(self.W_output(out.squeeze(0).squeeze(0))) + result.append(pred) + # Teacher forcing during training; use own predictions at eval + if self.training: + prev_codes = H.T[t] + else: + prev_codes = pred.detach() + + return torch.stack(result, dim=1) # (num_codes, num_visits) + + +class Classifier(nn.Module): + """ + Predicts diagnoses from K phenotype embeddings. + + Args: + hidden_dim: Phenotype embedding dimension. + num_codes: Number of diagnosis codes to predict. + num_tp: Number of temporal phenotypes. + num_heads: Self-attention heads. + """ + + def __init__( + self, hidden_dim: int, num_codes: int, num_tp: int, num_heads: int = 4 + ): + super().__init__() + self.num_tp = num_tp + if num_tp > 1: + self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads) + self.W_importance = nn.Linear(hidden_dim, 1, bias=False) + self.predict = nn.Linear(hidden_dim, num_codes) + + def forward( + self, phenotype_embs: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Args: + phenotype_embs: (batch, num_tp, hidden_dim) or (batch, hidden_dim). + + Returns: + pred: Predicted probabilities (batch, num_codes). + logit: Pre-sigmoid logits (batch, num_codes). + alpha: Phenotype importance weights (batch, num_tp). + """ + if self.num_tp > 1: + # nn.MultiheadAttention expects (seq_len, batch, dim) + x = phenotype_embs.transpose(0, 1) # (num_tp, batch, dim) + attended, _ = self.self_attn(x, x, x) + attended = attended.transpose(0, 1) # (batch, num_tp, dim) + + # Each phenotype's importance weight + alpha = F.softmax(self.W_importance(attended).squeeze(-1), dim=-1) + + # Weighted combination of attended phenotype embeddings + combined = (attended * alpha.unsqueeze(-1)).sum(dim=-2) # (batch, dim) + logit = self.predict(combined) + pred = torch.sigmoid(logit) + return pred, logit, alpha + else: + logit = self.predict(phenotype_embs.squeeze(1)) + pred = torch.sigmoid(logit) + alpha = torch.ones(phenotype_embs.shape[0], 1, device=pred.device) + return pred, logit, alpha + + +class SHy(BaseModel): + """ + SHy: Self-Explaining Hypergraph Neural Network. + + Pipeline: + 1. Embed diagnosis codes + 2. HGNN message passing personalizes embeddings per patient + 3. Extract K phenotype sub-hypergraphs + 4. GRU + attention put phenotype to a vector + 5. Decoder reconstructs original hypergraph + 6. Classifier predicts next-visit diagnoses + + Note: + This implementation processes samples sequentially in forward. + For large batches or datasets, consider batching via + block-diagonal hypergraph construction. + + Args: + dataset: PyHealth SampleDataset. + embedding_dim: Code embedding dimension. Default 32. + hgnn_dim: HGNN output dimension. Default 64. + hgnn_layers: Number of HGNN layers. Default 2. + num_tp: Number of temporal phenotypes K. Default 5. + hidden_dim: GRU/aggregator hidden dimension. Default 64. + temperature: Gumbel-Softmax temperature. Default 1.0. + add_ratio: False-negative addition ratio. Default 0.1. + num_heads: Self-attention heads in classifier. Default 4. + dropout: Dropout probability. Default 0.1. + fidelity_weight: Weight for reconstruction loss. Default 0.1. + distinct_weight: Weight for phenotype overlap penalty. Default 0.01. + alpha_weight: Weight for attention diversity. Default 0.01. + + Examples: + >>> from pyhealth.datasets import create_sample_dataset, get_dataloader + >>> samples = [ + ... { + ... "patient_id": "p0", + ... "diagnoses_hist": [["d1", "d2"], ["d3", "d6"]], + ... "diagnoses": ["d1", "d2", "d3"], + ... }, + ... { + ... "patient_id": "p1", + ... "diagnoses_hist": [["d1", "d3"], ["d5"]], + ... "diagnoses": ["d2", "d3"], + ... }, + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"diagnoses_hist": "nested_sequence"}, + ... output_schema={"diagnoses": "multilabel"}, + ... dataset_name="test", + ... ) + >>> model = SHy(dataset=dataset) + >>> loader = get_dataloader(dataset, batch_size=2, shuffle=True) + >>> batch = next(iter(loader)) + >>> out = model(**batch) + >>> out["loss"].shape + torch.Size([]) + >>> out["y_prob"].shape[1] == model.output_size + True + """ + + def __init__( + self, + dataset: SampleDataset, + embedding_dim: int = 32, + hgnn_dim: int = 64, + hgnn_layers: int = 2, + num_tp: int = 5, + hidden_dim: int = 64, + temperature: float = 1.0, + add_ratio: float = 0.1, + num_heads: int = 4, + dropout: float = 0.1, + fidelity_weight: float = 0.1, + distinct_weight: float = 0.01, + alpha_weight: float = 0.01, + ): + super(SHy, self).__init__(dataset=dataset) + + if len(self.label_keys) != 1: + raise ValueError("SHy supports exactly one label key (multilabel)") + if len(self.feature_keys) != 1: + raise ValueError("SHy expects exactly one feature key (nested_sequence)") + + self.label_key = self.label_keys[0] + self.feature_key = self.feature_keys[0] + + self.num_tp = num_tp + self.hidden_dim = hidden_dim + self.fidelity_weight = fidelity_weight + self.distinct_weight = distinct_weight + self.alpha_weight = alpha_weight + + processor = dataset.input_processors[self.feature_key] + self.vocab_size = processor.vocab_size() + self.output_size = self.get_output_size() + + # 1. Code embedding + self.code_embedding = nn.Embedding(self.vocab_size, embedding_dim) + + # 2. Message Passing + self.hgnn_layers_n = hgnn_layers + if hgnn_layers > 0: + layers = [] + dims = [embedding_dim] + [hgnn_dim] * hgnn_layers + for i in range(hgnn_layers): + layers.append(UniGINConv(dims[i], dims[i + 1])) + self.hgnn_convs = nn.ModuleList(layers) + self.hgnn_out = nn.Linear(hgnn_dim, hgnn_dim) + else: + self.hgnn_out = nn.Linear(embedding_dim, hgnn_dim) + self.act = nn.LeakyReLU() + self.dropout = nn.Dropout(dropout) + + # 3. K phenotype extractors + self.extractors = nn.ModuleList( + [ + PhenotypeExtractor(hgnn_dim, temperature, add_ratio) + for _ in range(num_tp) + ] + ) + + # 4. Aggregator: sub-hypergraph -> vector + self.aggregator = PhenotypeAggregator(hgnn_dim, hidden_dim) + + # 5. Decoder: reconstruction + self.decoder = Decoder(hidden_dim, num_tp, embedding_dim, self.vocab_size) + + # 6. Classifier: phenotype embeddings -> diagnosis prediction + self.classifier = Classifier(hidden_dim, self.output_size, num_tp, num_heads) + + def _build_incidence_matrix(self, codes: torch.Tensor) -> torch.Tensor: + """ + Convert padded code tensor to binary incidence matrix. + + Args: + codes: (num_visits, max_codes_per_visit), 0 = padding. + + Returns: + H: (vocab_size, num_visits), binary. + """ + num_visits = codes.shape[0] + H = torch.zeros(self.vocab_size, num_visits, device=codes.device) + visit_idx = torch.arange(num_visits, device=codes.device) + visit_idx = visit_idx.unsqueeze(1).expand_as(codes) + mask = codes > 0 + rows = codes[mask].clamp(max=self.vocab_size - 1).long() + cols = visit_idx[mask] + H[rows, cols] = 1.0 + return H + + def _run_hgnn( + self, X: torch.Tensor, V: torch.Tensor, E: torch.Tensor + ) -> torch.Tensor: + """Run HGNN message passing to personalize code embeddings.""" + if self.hgnn_layers_n > 0: + for conv in self.hgnn_convs: + X = self.dropout(self.act(conv(X, V, E))) + return self.act(self.hgnn_out(X)) + + def _encode_patient(self, X: torch.Tensor, H: torch.Tensor): + """ + Encode one patient: HGNN -> extract K phenotypes. + + Returns: + phenotype_matrices: list of K incidence matrices (C, V). + phenotype_embs: (K, hidden_dim) or (hidden_dim,) if K=1. + """ + # COO indices from incidence matrix + nz = torch.nonzero(H) + V, E = nz[:, 0], nz[:, 1] + + # Personalize embeddings via HGNN + X_personal = self._run_hgnn(X, V, E) + + # Extract K phenotypes + tp_matrices = [ext(X_personal, H, V, E) for ext in self.extractors] + tp_embs = [self.aggregator(X_personal, tp) for tp in tp_matrices] + + if self.num_tp > 1: + return tp_matrices, torch.stack(tp_embs) + else: + return tp_matrices, tp_embs[0] + + def _compute_loss(self, pred, y_true, tp_list, recon_list, H_list, alphas): + """ + SHy loss. + L = L_pred + eps*L_fidelity + eta*L_distinct - omega*L_alpha + """ + # 1. Prediction loss: weighted BCE (upweight rare positives) + num_pos = y_true.sum(dim=1, keepdim=True).clamp(min=1) + num_neg = (y_true.shape[1] - num_pos).clamp(min=1) + pos_weight = (num_neg / num_pos).expand_as(y_true) + weight = torch.where(y_true > 0.5, pos_weight, torch.ones_like(y_true)) + loss = F.binary_cross_entropy( + pred.clamp(1e-9, 1 - 1e-9), y_true.float(), weight=weight + ) + + # 2. Fidelity loss: reconstruction (reweighted like prediction loss) + if recon_list: + fid_losses = [] + for r, h in zip(recon_list, H_list): + h_f = h.float() + n_pos = h_f.sum().clamp(min=1.0) + n_neg = (h_f.numel() - n_pos).clamp(min=1.0) + pos_w = n_neg / n_pos + w = torch.where(h_f > 0.5, pos_w, torch.ones_like(h_f)) + fid_losses.append( + F.binary_cross_entropy( + r.clamp(1e-9, 1 - 1e-9), h_f, weight=w + ) + ) + fidelity = sum(fid_losses) / len(fid_losses) + loss = loss + self.fidelity_weight * fidelity + + # 3. Distinctness loss + if self.num_tp > 1 and tp_list: + eye = torch.eye(self.num_tp, device=pred.device) + distinct = torch.tensor(0.0, device=pred.device) + for tps in tp_list: + if tps[0].dim() >= 2: + stacked = torch.stack(tps, dim=-1) # (C, V, K) + for v in range(stacked.shape[1]): + col = stacked[:, v, :] # (C, K) + distinct = distinct + torch.norm(eye - col.T @ col) + distinct = distinct / stacked.shape[1] + loss = loss + self.distinct_weight * distinct / max(len(tp_list), 1) + + # 4. Alpha diversity: phenotype balance + alpha_div = torch.mean(torch.sqrt(torch.var(alphas, dim=1).clamp(min=1e-9))) + loss = loss - self.alpha_weight * alpha_div + + return loss + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + """ + Returns: + Dict with keys: loss, y_prob, y_true, logit. + """ + feature_data = kwargs[self.feature_key] + codes_batch = ( + feature_data[0] if isinstance(feature_data, tuple) else feature_data + ) + batch_size = codes_batch.shape[0] + + X = self.code_embedding(torch.arange(self.vocab_size, device=self.device)) + + tp_list, recon_list, H_list, latent_list = [], [], [], [] + valid_mask = [] + + for i in range(batch_size): + H = self._build_incidence_matrix(codes_batch[i]).to(self.device) + + if H.sum() == 0: + zero = ( + torch.zeros(self.num_tp, self.hidden_dim, device=self.device) + if self.num_tp > 1 + else torch.zeros(self.hidden_dim, device=self.device) + ) + latent_list.append(zero) + valid_mask.append(False) + continue + + tp_mats, tp_embs = self._encode_patient(X, H) + tp_list.append(tp_mats) + latent_list.append(tp_embs) + H_list.append(H) + recon_list.append(self.decoder(tp_embs, H.shape[1], H, X)) + valid_mask.append(True) + + # Classify: phenotype embeddings -> diagnosis prediction + stacked = torch.stack(latent_list) # (batch, K, hidden) / (batch, hidden) + if self.num_tp > 1 and stacked.dim() == 2: + stacked = stacked.unsqueeze(1) + pred, logit, alphas = self.classifier(stacked) + + # Labels + y_true = kwargs[self.label_key].to(self.device).float() + + # Exclude empty-history samples from loss computation + valid = torch.tensor(valid_mask, device=self.device) + if valid.any(): + loss = self._compute_loss( + pred[valid], y_true[valid], + tp_list, recon_list, H_list, alphas[valid], + ) + else: + loss = torch.tensor(0.0, device=self.device, requires_grad=True) + + return {"loss": loss, "y_prob": pred, "y_true": y_true, "logit": logit} diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..26b7288e3 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -67,3 +67,7 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .diagnosis_prediction import ( + DiagnosisPredictionMIMIC3, + DiagnosisPredictionMIMIC4, +) diff --git a/pyhealth/tasks/diagnosis_prediction.py b/pyhealth/tasks/diagnosis_prediction.py new file mode 100644 index 000000000..180a02073 --- /dev/null +++ b/pyhealth/tasks/diagnosis_prediction.py @@ -0,0 +1,181 @@ +""" +PyHealth task for diagnosis prediction using the MIMIC-III +and MIMIC-IV datasets. The task aims to predict diagnosis +of next visit for a patient based on historical records. + +Dataset Citation: +MIMIC-III: +Johnson, Alistair, et al. "MIMIC-III Clinical Database" (version 1.4). +PhysioNet (2016). RRID:SCR_007345. https://doi.org/10.13026/C2XW26 + +MIMIC-IV: +Johnson, Alistair, et al. "MIMIC-IV" (version 3.1). +PhysioNet (2024). RRID:SCR_007345. +https://doi.org/10.13026/kpb9-mt58 + +Paper Citation: +Yu, Leisheng, et al. "Self-Explaining Hypergraph Neural Networks +for Diagnosis Prediction." arXiv preprint arXiv:2502.10689 (2025). + +""" + +from typing import Any, Dict, List + +from pyhealth.tasks import BaseTask + + +def _extract_visit_diagnoses(patient, admissions, code_attr: str): + """Extract per-visit diagnosis codes from a patient's admissions. + + Args: + patient: A Patient object that supports get_events method. + admissions: List of admission events. + code_attr: Attribute name on the event object for the ICD code + + Returns: + A list of (hadm_id, codes) tuples for visits that have at + least one diagnosis code. + """ + visits = [] + for admission in admissions: + diagnoses = patient.get_events( + event_type="diagnoses_icd", + filters=[("hadm_id", "==", admission.hadm_id)], + ) + codes = [ + getattr(event, code_attr) + for event in diagnoses + if getattr(event, code_attr, None) is not None + ] + if codes: + visits.append((admission.hadm_id, codes)) + return visits + + +class DiagnosisPredictionMIMIC3(BaseTask): + """A PyHealth task class for diagnosis prediction using the MIMIC-III dataset, + which is formulated as a multilabel classification problem. + + Attributes: + task_name (str): Name of the task. + input_schema (Dict[str, str]): The schema for model input: + - diagnoses_hist: Nested sequence of diagnosis codes across + historical visits. + output_schema (Dict[str, str]): The schema for model output: + - diagnoses: Multilabel set of diagnosis codes for next visit. + + Examples: + >>> from pyhealth.datasets import MIMIC3Dataset + >>> from pyhealth.tasks import DiagnosisPredictionMIMIC3 + >>> dataset = MIMIC3Dataset( + ... root="/path/to/mimic-iii/1.4", + ... tables=["diagnoses_icd"], + ... ) + >>> task = DiagnosisPredictionMIMIC3() + >>> samples = dataset.set_task(task) + """ + + task_name: str = "DiagnosisPredictionMIMIC3" + input_schema: Dict[str, str] = {"diagnoses_hist": "nested_sequence"} + output_schema: Dict[str, str] = {"diagnoses": "multilabel"} + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Processes a single patient for the diagnosis prediction task. + + Args: + patient: A Patient object that supports get_events method. + + Returns: + A list containing sample dictionaries, which contains patient_id, + visit_id, diagnoses_hist, and diagnoses, or an empty list if the + patient has fewer than two valid visits. + """ + admissions = patient.get_events(event_type="admissions") + if len(admissions) < 2: + return [] + + visits = _extract_visit_diagnoses(patient, admissions, "icd9_code") + + if len(visits) < 2: + return [] + + samples = [] + history = [] + for t in range(len(visits)): + visit_id, codes = visits[t] + if t > 0: + samples.append( + { + "patient_id": patient.patient_id, + "visit_id": visit_id, + "diagnoses_hist": [h for h in history], + "diagnoses": codes, + } + ) + history.append(codes) + + return samples + + +class DiagnosisPredictionMIMIC4(BaseTask): + """A PyHealth task class for diagnosis prediction using the MIMIC-IV dataset, + which is formulated as a multilabel classification problem. + + Attributes: + task_name (str): Name of the task. + input_schema (Dict[str, str]): The schema for model input: + - diagnoses_hist: Nested sequence of diagnosis codes across + historical visits. + output_schema (Dict[str, str]): The schema for model output: + - diagnoses: Multilabel set of diagnosis codes for next visit. + + Examples: + >>> from pyhealth.datasets import MIMIC4EHRDataset + >>> from pyhealth.tasks import DiagnosisPredictionMIMIC4 + >>> dataset = MIMIC4EHRDataset( + ... root="/path/to/mimic-iv/3.1", + ... tables=["diagnoses_icd"], + ... ) + >>> task = DiagnosisPredictionMIMIC4() + >>> samples = dataset.set_task(task) + """ + + task_name: str = "DiagnosisPredictionMIMIC4" + input_schema: Dict[str, str] = {"diagnoses_hist": "nested_sequence"} + output_schema: Dict[str, str] = {"diagnoses": "multilabel"} + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Processes a single patient for the diagnosis prediction task. + + Args: + patient: A Patient object that supports get_events method. + + Returns: + A list containing sample dictionaries, which contains patient_id, + visit_id, diagnoses_hist, and diagnoses, or an empty list if the + patient has fewer than two valid visits. + """ + admissions = patient.get_events(event_type="admissions") + if len(admissions) < 2: + return [] + + visits = _extract_visit_diagnoses(patient, admissions, "icd_code") + if len(visits) < 2: + return [] + + samples = [] + history = [] + for t in range(len(visits)): + visit_id, codes = visits[t] + if t > 0: + samples.append( + { + "patient_id": patient.patient_id, + "visit_id": visit_id, + "diagnoses_hist": [h for h in history], + "diagnoses": codes, + } + ) + history.append(codes) + + return samples diff --git a/tests/core/test_shy.py b/tests/core/test_shy.py new file mode 100644 index 000000000..0ff4623b3 --- /dev/null +++ b/tests/core/test_shy.py @@ -0,0 +1,293 @@ +import unittest +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import SHy + + +class TestSHyModel(unittest.TestCase): + """Tests for the SHy model.""" + + def setUp(self): + self.samples = [ + { + "patient_id": "p0", + "diagnoses_hist": [["d1", "d2", "d3"], ["d1", "d4"]], + "diagnoses": ["d1", "d2"], + }, + { + "patient_id": "p1", + "diagnoses_hist": [["d2", "d3"], ["d4", "d5"], ["d1", "d6"]], + "diagnoses": ["d3", "d4", "d5"], + }, + { + "patient_id": "p2", + "diagnoses_hist": [["d1", "d6"]], + "diagnoses": ["d2", "d6"], + }, + { + "patient_id": "p3", + "diagnoses_hist": [["d3", "d4"], ["d5", "d6"]], + "diagnoses": ["d1"], + }, + ] + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema={"diagnoses_hist": "nested_sequence"}, + output_schema={"diagnoses": "multilabel"}, + dataset_name="test_shy", + ) + + self.model = SHy( + dataset=self.dataset, + embedding_dim=16, + hgnn_dim=16, + hgnn_layers=1, + num_tp=2, + hidden_dim=16, + num_heads=2, + dropout=0.0, + ) + + def test_initialization(self): + """Check model sets up the right keys and params.""" + self.assertIsInstance(self.model, SHy) + self.assertEqual(self.model.feature_key, "diagnoses_hist") + self.assertEqual(self.model.label_key, "diagnoses") + self.assertEqual(self.model.num_tp, 2) + + def test_forward_output_keys(self): + """Forward pass should return loss, y_prob, y_true, logit.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + out = self.model(**batch) + + self.assertIn("loss", out) + self.assertIn("y_prob", out) + self.assertIn("y_true", out) + self.assertIn("logit", out) + # loss should be a scalar + self.assertEqual(out["loss"].dim(), 0) + + def test_output_shapes(self): + """y_prob and y_true should match (batch, num_labels).""" + loader = get_dataloader(self.dataset, batch_size=4, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + out = self.model(**batch) + + self.assertEqual(out["y_prob"].shape[1], self.model.output_size) + self.assertEqual(out["y_true"].shape[1], self.model.output_size) + + def test_backward(self): + """Make sure gradients flow through the model.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + out = self.model(**batch) + out["loss"].backward() + + # at least one param should have a gradient + got_grad = False + for p in self.model.parameters(): + if p.requires_grad and p.grad is not None: + got_grad = True + break + self.assertTrue(got_grad, "backward didn't produce any gradients") + + def test_probabilities_in_range(self): + """Predicted probs should all be between 0 and 1.""" + loader = get_dataloader(self.dataset, batch_size=4, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + out = self.model(**batch) + + self.assertTrue(torch.all(out["y_prob"] >= 0)) + self.assertTrue(torch.all(out["y_prob"] <= 1)) + + def test_loss_not_nan(self): + """Loss shouldn't be NaN on normal input.""" + loader = get_dataloader(self.dataset, batch_size=4, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + out = self.model(**batch) + + self.assertFalse(torch.isnan(out["loss"])) + + def test_single_phenotype(self): + """num_tp=1 should still work (no distinctness loss).""" + model = SHy( + dataset=self.dataset, + embedding_dim=16, + hgnn_dim=16, + hgnn_layers=1, + num_tp=1, + hidden_dim=16, + num_heads=2, + ) + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + out = model(**batch) + + self.assertFalse(torch.isnan(out["loss"])) + self.assertEqual(out["y_prob"].shape[0], 2) + + def test_no_hgnn(self): + """hgnn_layers=0 should fall back to linear projection.""" + model = SHy( + dataset=self.dataset, + embedding_dim=16, + hgnn_dim=16, + hgnn_layers=0, + num_tp=2, + hidden_dim=16, + num_heads=2, + ) + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + out = model(**batch) + + self.assertIn("loss", out) + self.assertFalse(torch.isnan(out["loss"])) + + def test_custom_hyperparams(self): + """Different embedding/hidden sizes should still run.""" + model = SHy( + dataset=self.dataset, + embedding_dim=8, + hgnn_dim=32, + hgnn_layers=2, + num_tp=3, + hidden_dim=32, + num_heads=4, + dropout=0.2, + ) + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + out = model(**batch) + + self.assertIn("loss", out) + self.assertIn("y_prob", out) + + def test_incidence_matrix(self): + """_build_incidence_matrix should give a valid binary matrix.""" + # fake padded codes: 2 visits, codes padded to length 3 + codes = torch.tensor([[1, 2, 0], [3, 0, 0]]) + H = self.model._build_incidence_matrix(codes) + + # rows = vocab size, cols = num visits + self.assertEqual(H.shape[0], self.model.vocab_size) + self.assertEqual(H.shape[1], 2) + + # should only contain 0s and 1s + self.assertTrue(torch.all((H == 0) | (H == 1))) + + # code 1 should be in visit 0, code 3 should be in visit 1 + self.assertEqual(H[1, 0].item(), 1.0) + self.assertEqual(H[3, 1].item(), 1.0) + # padding index 0 should not appear + self.assertEqual(H[0, 0].item(), 0.0) + + def test_phenotype_extractor_produces_k_subhypergraphs(self): + """For num_tp=K, the model should produce K phenotype matrices.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + feature_data = batch[self.model.feature_key] + codes_batch = ( + feature_data[0] if isinstance(feature_data, tuple) else feature_data + ) + X = self.model.code_embedding(torch.arange(self.model.vocab_size)) + H = self.model._build_incidence_matrix(codes_batch[0]) + tp_mats, tp_embs = self.model._encode_patient(X, H) + self.assertEqual(len(tp_mats), self.model.num_tp) + for mat in tp_mats: + self.assertEqual(mat.shape, H.shape) + + def test_different_num_tp_gives_different_outputs(self): + """num_tp=1 and num_tp=3 should produce different losses.""" + torch.manual_seed(123) + model1 = SHy( + dataset=self.dataset, + embedding_dim=16, + hgnn_dim=16, + hgnn_layers=1, + num_tp=1, + hidden_dim=16, + num_heads=2, + dropout=0.0, + ) + torch.manual_seed(123) + model3 = SHy( + dataset=self.dataset, + embedding_dim=16, + hgnn_dim=16, + hgnn_layers=1, + num_tp=3, + hidden_dim=16, + num_heads=2, + dropout=0.0, + ) + loader = get_dataloader(self.dataset, batch_size=4, shuffle=False) + batch = next(iter(loader)) + with torch.no_grad(): + out1 = model1(**batch) + out3 = model3(**batch) + self.assertNotEqual(out1["loss"].item(), out3["loss"].item()) + + def test_add_false_negatives_changes_incidence(self): + """add_ratio > 0 should add entries to the incidence matrix.""" + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + feature_data = batch[self.model.feature_key] + codes_batch = ( + feature_data[0] if isinstance(feature_data, tuple) else feature_data + ) + X = self.model.code_embedding(torch.arange(self.model.vocab_size)) + H = self.model._build_incidence_matrix(codes_batch[0]) + nz = torch.nonzero(H) + V, E = nz[:, 0], nz[:, 1] + X_personal = self.model._run_hgnn(X, V, E) + ext = self.model.extractors[0] + enriched = ext._add_false_negatives(X_personal, H, V, E) + # Enriched should have at least as many nonzeros as original + self.assertGreaterEqual(enriched.sum().item(), H.sum().item()) + + +class TestDiagnosisPredictionTask(unittest.TestCase): + """Tests for the diagnosis prediction task classes.""" + + def test_mimic3_task_schema(self): + from pyhealth.tasks import DiagnosisPredictionMIMIC3 + + task = DiagnosisPredictionMIMIC3() + self.assertEqual(task.task_name, "DiagnosisPredictionMIMIC3") + self.assertIn("diagnoses_hist", task.input_schema) + self.assertEqual(task.input_schema["diagnoses_hist"], "nested_sequence") + self.assertIn("diagnoses", task.output_schema) + self.assertEqual(task.output_schema["diagnoses"], "multilabel") + + def test_mimic4_task_schema(self): + from pyhealth.tasks import DiagnosisPredictionMIMIC4 + + task = DiagnosisPredictionMIMIC4() + self.assertEqual(task.task_name, "DiagnosisPredictionMIMIC4") + self.assertIn("diagnoses_hist", task.input_schema) + self.assertEqual(task.input_schema["diagnoses_hist"], "nested_sequence") + self.assertIn("diagnoses", task.output_schema) + self.assertEqual(task.output_schema["diagnoses"], "multilabel") + + +if __name__ == "__main__": + unittest.main()