Add SHy model and DiagnosisPrediction task#991
Add SHy model and DiagnosisPrediction task#991Lilin-Huang wants to merge 3 commits intosunlabuiuc:masterfrom
Conversation
Implement Self-Explaining Hypergraph Neural Network (SHy) for diagnosis prediction, with MIMIC-III and MIMIC-IV task classes. New files: - pyhealth/models/shy.py: SHy model implementation - pyhealth/tasks/diagnosis_prediction.py: DiagnosisPrediction tasks - tests/core/test_shy.py: 12 unit tests with synthetic data - examples/mimic4_diagnosis_prediction_shy.py: ablation study script - docs/api/models/pyhealth.models.SHy.rst: model documentation - docs/api/tasks/pyhealth.tasks.DiagnosisPrediction.rst: task documentation Paper: Yu et al., "Self-Explaining Hypergraph Neural Networks for Diagnosis Prediction", CHIL 2025.
|
Hi @Lilin-Huang. SHy is a non-trivial model to implement (hypergraph message passing, Gumbel-Softmax phenotype extraction, multi-objective loss), and you have it all in working form with real ablation numbers from MIMIC-IV. The Scalability: forward is sample-by-sample. for i in range(batch_size):
H = self._build_incidence_matrix(codes_batch[i]).to(self.device)
...
tp_mats, tp_embs = self._encode_patient(X, H)Each sample gets its own incidence matrix, HGNN pass, K extractors, and decoder call inside a Python loop. This will not scale beyond the 1000-patient dev subset used in your ablations. Per-patient hypergraphs make full batching nontrivial, but block-diagonal stacking (a standard trick from PyTorch Geometric) would let you batch the HGNN and phenotype extraction. Not a merge blocker, but please add a note in the class docstring documenting the expected scalability (e.g., "This implementation processes samples sequentially in forward. For large batches or datasets, consider batching via block-diagonal hypergraph construction").
pred = torch.sigmoid(self.predict(combined))
...
return {"loss": loss, "y_prob": pred, "y_true": y_true, "logit": pred}
Use assert len(self.label_keys) == 1, "SHy supports exactly one label key (multilabel)"
assert len(self.feature_keys) == 1, "SHy expects exactly one feature key (nested_sequence)"Asserts are stripped when Python is run with Empty-history samples contribute to prediction loss from zero-vector input. if H.sum() == 0:
zero = torch.zeros(self.num_tp, self.hidden_dim, device=self.device)
latent_list.append(zero)
continueWhen a patient has an all-zero incidence matrix (possible if all codes got clipped or visits are empty), the classifier still runs on the zero vector and the loss still counts. The task filters to Fidelity loss does not handle class imbalance. Your prediction loss correctly reweights positives: num_pos = y_true.sum(dim=1, keepdim=True).clamp(min=1)
num_neg = (y_true.shape[1] - num_pos).clamp(min=1)
pos_weight = (num_neg / num_pos).expand_as(y_true)But the fidelity loss is plain BCE on the incidence matrix: F.binary_cross_entropy(r.clamp(1e-9, 1 - 1e-9), h.float())The incidence matrix is typically very sparse (most codes are not in a visit). Without reweighting, fidelity encourages predicting all zeros, which is the uninformative solution. Your ablation shows fidelity has small positive weight (0.1) so this may be contained, but worth applying the same pos_weight treatment or commenting on why raw BCE is intended. Tests verify it runs but do not verify the core model behavior. The 12 tests cover shapes, output keys, backward, probability range, and a few hyperparameter variations. None of them check:
These would strengthen the PR by making it harder for future refactors to silently break the paper's core mechanism. A test that inspects the shape and sparsity of what Teacher forcing in the decoder during evaluation. # Teacher forcing
prev_codes = H.T[t]The decoder uses ground-truth codes from visit Smaller items
>>> out = model(**batch)
>>> out["loss"]Shows no expected output. Consider either removing the last line or adding a
|
Contributor
Type of Contribution
Model + Task (Option 4)
Paper
Leisheng Yu, Yanxiao Cai, Minxing Zhang, and Xia Hu.
Self-Explaining Hypergraph Neural Networks for Diagnosis Prediction.
Proceedings of Machine Learning Research (CHIL), 2025.
https://arxiv.org/abs/2502.10689
Description
Implement SHy (Self-Explaining Hypergraph Neural Network) for diagnosis prediction. The model builds a patient hypergraph from diagnosis codes, runs UniGIN message passing, extracts K temporal phenotype sub-hypergraphs via Gumbel-Softmax sampling, aggregates each phenotype with a GRU + attention, and predicts next-visit diagnoses. A multi-objective loss combines prediction BCE, fidelity (reconstruction), distinctness (phenotype overlap penalty), and alpha diversity.
Also adds
DiagnosisPredictionMIMIC3andDiagnosisPredictionMIMIC4standalone task classes that extract per-visit diagnosis histories from MIMIC-III/IV.The example scripts run ablation studies over 4 axes: number of temporal phenotypes (K), HGNN layers, loss components, and Gumbel-Softmax temperature (novel extension).
File Guide
pyhealth/models/shy.pypyhealth/tasks/diagnosis_prediction.pypyhealth/models/__init__.pypyhealth/tasks/__init__.pytests/core/test_shy.pyexamples/mimic3_diagnosis_prediction_shy.pyexamples/mimic4_diagnosis_prediction_shy.pydocs/api/models/pyhealth.models.SHy.rstdocs/api/tasks/pyhealth.tasks.DiagnosisPrediction.rstdocs/api/models.rstdocs/api/tasks.rstAblation Results (MIMIC-III dev=True, 1000 patients, 50 epochs)
Ablation Results (MIMIC-IV dev=True, 1000 patients, 5 epochs)