diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8d9a59d21..315ad5f8d 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -244,5 +244,7 @@ Available Datasets datasets/pyhealth.datasets.ClinVarDataset datasets/pyhealth.datasets.COSMICDataset datasets/pyhealth.datasets.TCGAPRADDataset + datasets/pyhealth.datasets.GDSCDataset + datasets/pyhealth.datasets.CCLEDataset datasets/pyhealth.datasets.splitter datasets/pyhealth.datasets.utils diff --git a/docs/api/datasets/pyhealth.datasets.CCLEDataset.rst b/docs/api/datasets/pyhealth.datasets.CCLEDataset.rst new file mode 100644 index 000000000..29c53f442 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.CCLEDataset.rst @@ -0,0 +1,20 @@ +pyhealth.datasets.CCLEDataset +============================== + +The Cancer Cell Line Encyclopedia (CCLE) dataset provides binary gene expression +profiles and drug sensitivity labels for hundreds of cancer cell lines across a +panel of anti-cancer compounds. It is used as a cross-dataset evaluation target +after training on GDSC. + + Barretina, J. et al. (2012). *The Cancer Cell Line Encyclopedia enables + predictive modelling of anticancer drug sensitivity.* + Nature, 483(7391), 603-607. + +Data is available at `https://portals.broadinstitute.org/ccle `_. +The pre-processed version used by this module follows the format from the +`original CADRE repository `_. + +.. autoclass:: pyhealth.datasets.CCLEDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/datasets/pyhealth.datasets.GDSCDataset.rst b/docs/api/datasets/pyhealth.datasets.GDSCDataset.rst new file mode 100644 index 000000000..c300a8d69 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.GDSCDataset.rst @@ -0,0 +1,20 @@ +pyhealth.datasets.GDSCDataset +============================= + +The Genomics of Drug Sensitivity in Cancer (GDSC) dataset provides binary gene +expression profiles and drug sensitivity labels for hundreds of cancer cell lines +tested against a panel of anti-cancer compounds. It is used for the drug +sensitivity prediction task described in: + + Tao, Y. et al. (2020). *Predicting Drug Sensitivity of Cancer Cell Lines via + Collaborative Filtering with Contextual Attention.* + Proceedings of Machine Learning Research, 126, 456-477. PMLR (MLHC 2020). + +Data is available at `https://www.cancerrxgene.org/ `_. +The pre-processed version used by this module follows the format from the +`original CADRE repository `_. + +.. autoclass:: pyhealth.datasets.GDSCDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/models.rst b/docs/api/models.rst index 7c3ac7c4b..7b763d427 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -205,4 +205,5 @@ API Reference models/pyhealth.models.TextEmbedding models/pyhealth.models.BIOT models/pyhealth.models.unified_multimodal_embedding_docs + models/pyhealth.models.CADRE models/pyhealth.models.califorest diff --git a/docs/api/models/pyhealth.models.CADRE.rst b/docs/api/models/pyhealth.models.CADRE.rst new file mode 100644 index 000000000..2761fca80 --- /dev/null +++ b/docs/api/models/pyhealth.models.CADRE.rst @@ -0,0 +1,36 @@ +pyhealth.models.CADRE +===================== + +CADRE (Contextual Attention-based Drug REsponse prediction) is a collaborative +filtering model for multi-drug binary sensitivity prediction. It encodes cancer +cell-line genomic profiles using frozen Gene2Vec embeddings conditioned on +drug target pathway context (contextual attention), then decodes per-drug +predictions via a dot-product collaborative filter. + +Implementation of: + + Tao, Y. et al. (2020). *Predicting Drug Sensitivity of Cancer Cell Lines via + Collaborative Filtering with Contextual Attention.* + Proceedings of Machine Learning Research, 126, 456-477. PMLR (MLHC 2020). + +Original code: https://github.com/yifengtao/CADRE + +See also :class:`~pyhealth.models.CADREDotAttn` for the Transformer-style +scaled dot-product attention extension. + +.. autoclass:: pyhealth.models.ExpEncoder + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.DrugDecoder + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.CADRE + :members: + :undoc-members: + :show-inheritance: + +.. autofunction:: pyhealth.models.cadre_collate_fn diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..6064e9576 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -230,3 +230,5 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + Drug Sensitivity Prediction (GDSC) + Drug Sensitivity Prediction (CCLE) diff --git a/docs/api/tasks/pyhealth.tasks.DrugSensitivityPredictionCCLE.rst b/docs/api/tasks/pyhealth.tasks.DrugSensitivityPredictionCCLE.rst new file mode 100644 index 000000000..7fceb118f --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.DrugSensitivityPredictionCCLE.rst @@ -0,0 +1,23 @@ +pyhealth.tasks.DrugSensitivityPredictionCCLE +============================================= + +Task definition for multi-drug binary sensitivity prediction on the CCLE dataset. +Follows the same interface as +:class:`~pyhealth.tasks.DrugSensitivityPredictionGDSC`, enabling cross-dataset +evaluation: train on GDSC, evaluate on CCLE using overlapping drugs identified +via :meth:`~pyhealth.datasets.GDSCDataset.get_overlap_drugs`. + +Each sample represents one cancer cell line; the label vector encodes whether the +cell line is sensitive (1) or resistant (0) to each of the tested drugs. +Missing (untested) drug/cell-line pairs are indicated by a companion mask vector. + +Reference dataset: + + Barretina, J. et al. (2012). *The Cancer Cell Line Encyclopedia enables + predictive modelling of anticancer drug sensitivity.* + Nature, 483(7391), 603-607. + +.. autoclass:: pyhealth.tasks.DrugSensitivityPredictionCCLE + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks/pyhealth.tasks.DrugSensitivityPredictionGDSC.rst b/docs/api/tasks/pyhealth.tasks.DrugSensitivityPredictionGDSC.rst new file mode 100644 index 000000000..e115bcd8c --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.DrugSensitivityPredictionGDSC.rst @@ -0,0 +1,19 @@ +pyhealth.tasks.DrugSensitivityPredictionGDSC +============================================= + +Task definition for multi-drug binary sensitivity prediction on the GDSC dataset. +Each sample represents one cancer cell line; the label vector encodes whether the +cell line is sensitive (1) or resistant (0) to each of the 260 screened drugs. +Missing (untested) drug/cell-line pairs are indicated by a companion mask vector. + +This task is used with :class:`~pyhealth.datasets.GDSCDataset` and the +:class:`~pyhealth.models.CADRE` model to reproduce the results from: + + Tao, Y. et al. (2020). *Predicting Drug Sensitivity of Cancer Cell Lines via + Collaborative Filtering with Contextual Attention.* + Proceedings of Machine Learning Research, 126, 456-477. PMLR (MLHC 2020). + +.. autoclass:: pyhealth.tasks.DrugSensitivityPredictionGDSC + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/gdsc_drug_sensitivity_prediction_cadre.py b/examples/gdsc_drug_sensitivity_prediction_cadre.py new file mode 100644 index 000000000..b3a273e74 --- /dev/null +++ b/examples/gdsc_drug_sensitivity_prediction_cadre.py @@ -0,0 +1,966 @@ +"""Full pipeline: GDSC → DrugSensitivityPredictionGDSC → CADRE. + +Reproduces the CADRE training procedure and ablation study from: + Tao, Y. et al. (2020). Predicting Drug Sensitivity of Cancer Cell Lines + via Collaborative Filtering with Contextual Attention. MLHC 2020. + +Expected results on GDSC test set (paper Table 1): + F1: 64.3 ± 0.22 + Acc: 78.6 ± 0.34 + AUROC: 83.4 ± 0.19 + AUPR: 70.6 ± 1.30 + +Ablation study (paper Table 2): + CADRE (full model) F1 ~64.3 AUROC ~83.4 + SADRE (no pathway context) F1 ~62.1 AUROC ~81.9 + ADRE (mean pooling, no attn) F1 ~60.8 AUROC ~80.5 + CADRE-100 (embedding_dim=100) F1 ~63.1 AUROC ~82.6 + CADRE-free (trainable gene emb) F1 ~63.8 AUROC ~82.8 + +Usage — full paper replication: + python examples/gdsc_drug_sensitivity_prediction_cadre.py \\ + --data_dir /path/to/originalData --output_dir ./outputs + +Usage — ablation study (trains all 5 variants and compares): + python examples/gdsc_drug_sensitivity_prediction_cadre.py \\ + --data_dir /path/to/originalData --ablation + +Usage — quick demo with synthetic data (no real data required): + python examples/gdsc_drug_sensitivity_prediction_cadre.py --demo +""" + +import argparse +import os +import pickle +import random +import tempfile +import time +from typing import Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +import torch +import torch.optim as optim +from sklearn.metrics import ( + accuracy_score, + auc, + f1_score, + precision_recall_curve, + precision_score, + recall_score, + roc_auc_score, +) +from torch.utils.data import DataLoader, Subset + +from pyhealth.datasets import GDSCDataset +from pyhealth.tasks import DrugSensitivityPredictionGDSC +from pyhealth.models import CADRE, cadre_collate_fn + + +# --------------------------------------------------------------------------- +# Step 1: Load dataset and apply task +# --------------------------------------------------------------------------- + + +def load_data( + data_dir: str, seed: int = 2019 +) -> Tuple[GDSCDataset, Subset, Subset, Subset]: + """Load GDSC, apply the drug-sensitivity task, split 60/20/20. + + Args: + data_dir: Path to directory containing GDSC CSV files. + seed: Random seed for reproducible split. + + Returns: + Tuple of (GDSCDataset, train_subset, val_subset, test_subset). + """ + print("Loading dataset...") + dataset = GDSCDataset(data_dir=data_dir) + dataset.summary() + + sample_ds = dataset.set_task(DrugSensitivityPredictionGDSC()) + + n = len(sample_ds) + rng = np.random.RandomState(seed) + indices = rng.permutation(n) + + n_train = int(n * 0.6) + n_val = int(n * 0.8) + + train_ds = Subset(sample_ds, indices[:n_train].tolist()) + val_ds = Subset(sample_ds, indices[n_train:n_val].tolist()) + test_ds = Subset(sample_ds, indices[n_val:].tolist()) + + print(f"Split: train={len(train_ds)}, val={len(val_ds)}, test={len(test_ds)}") + return dataset, train_ds, val_ds, test_ds + + +# --------------------------------------------------------------------------- +# Step 2: Missing value imputation (Section 4.2) +# --------------------------------------------------------------------------- + + +def fill_mask_training(train_ds: Subset) -> None: + """Fill missing drug labels with per-drug mode in training set. + + Paper Section 4.2: 'if the sensitivity of a cell line to a drug was + missing, we filled the missing value with the mode of the available + sensitivities to this specific drug.' + + Modifies samples in-place: sets mask to all 1s and fills labels. + + Args: + train_ds: Training split as a ``torch.utils.data.Subset``. + """ + # Access the underlying samples via the Subset indices + samples = [train_ds.dataset[i] for i in train_ds.indices] + num_drugs = len(samples[0]["labels"]) + num_samples = len(samples) + + # Collect labels and masks into arrays for vectorized computation + labels = np.array([s["labels"] for s in samples], dtype=np.float32) + masks = np.array([s["mask"] for s in samples], dtype=np.float32) + + for d in range(num_drugs): + tested = masks[:, d] == 1 + if tested.sum() == 0: + continue + # Mode = 1 if more positives than negatives, else 0 + pos_count = labels[tested, d].sum() + neg_count = tested.sum() - pos_count + fill_val = 1 if pos_count > neg_count else 0 + + # Fill untested entries + untested = masks[:, d] == 0 + labels[untested, d] = fill_val + + # Write back to the underlying dataset samples + for i, idx in enumerate(train_ds.indices): + train_ds.dataset.samples[idx]["labels"] = labels[i].astype(int).tolist() + train_ds.dataset.samples[idx]["mask"] = [1] * num_drugs + + +# --------------------------------------------------------------------------- +# Step 3: OneCycle LR/Momentum scheduler (Section 3.5) +# --------------------------------------------------------------------------- + + +class OneCycle: + """1-Cycle policy for learning rate and momentum scheduling. + + Phase 1 (warm-up, 45%): LR η/10 → η, momentum 0.95 → 0.85 + Phase 2 (cool-down, 45%): LR η → η/10, momentum 0.85 → 0.95 + Phase 3 (annihilation, 10%): LR η/10 → η/100, momentum 0.95 + + Args: + total_steps: Total number of optimiser steps. + max_lr: Peak learning rate (η). + div: Divisor for initial and final LR. Default: ``10``. + prcnt: Percentage of steps used for the annihilation phase. + Default: ``10``. + momentum_vals: (high, low) momentum bounds. Default: ``(0.95, 0.85)``. + """ + + def __init__( + self, + total_steps: int, + max_lr: float, + div: int = 10, + prcnt: int = 10, + momentum_vals: Tuple[float, float] = (0.95, 0.85), + ) -> None: + self.total_steps = total_steps + self.max_lr = max_lr + self.div = div + self.step_len = int(total_steps * (1 - prcnt / 100) / 2) + self.high_mom = momentum_vals[0] + self.low_mom = momentum_vals[1] + self.iteration = 0 + + def step(self) -> Tuple[float, float]: + """Return (lr, momentum) for current step, then advance. + + Returns: + Tuple of (learning_rate, momentum) for this step. + """ + self.iteration += 1 + lr = self._calc_lr() + mom = self._calc_mom() + return lr, mom + + def _calc_lr(self) -> float: + it = self.iteration + if it > 2 * self.step_len: # annihilation phase + ratio = (it - 2 * self.step_len) / ( + self.total_steps - 2 * self.step_len + ) + return self.max_lr / self.div * (1 - ratio * (1 - 1 / self.div)) + elif it > self.step_len: # cool-down phase + ratio = 1 - (it - self.step_len) / self.step_len + return self.max_lr * (1 + ratio * (self.div - 1)) / self.div + else: # warm-up phase + ratio = it / self.step_len + return self.max_lr * (1 + ratio * (self.div - 1)) / self.div + + def _calc_mom(self) -> float: + it = self.iteration + if it > 2 * self.step_len: # annihilation + return self.high_mom + elif it > self.step_len: # cool-down + ratio = (it - self.step_len) / self.step_len + return self.low_mom + ratio * (self.high_mom - self.low_mom) + else: # warm-up + ratio = it / self.step_len + return self.high_mom - ratio * (self.high_mom - self.low_mom) + + +# --------------------------------------------------------------------------- +# Step 4: Evaluation +# --------------------------------------------------------------------------- + + +def evaluate( + model: CADRE, + dataloader: DataLoader, + device: torch.device, +) -> Dict: + """Evaluate model on a dataloader; returns metrics dict. + + Args: + model: Trained CADRE model. + dataloader: DataLoader for the split to evaluate. + device: Compute device. + + Returns: + Dict with keys ``f1``, ``accuracy``, ``precision``, ``recall``, + ``auroc``, ``aupr`` (all in [0, 1]) plus raw ``labels``, + ``probs``, and ``masks`` arrays. + """ + model.eval() + all_labels, all_probs, all_masks = [], [], [] + + with torch.no_grad(): + for batch in dataloader: + gene_indices = batch["gene_indices"].to(device) + result = model(gene_indices) + all_probs.append(result["probs"].cpu().numpy()) + all_labels.append(batch["labels"].numpy()) + all_masks.append(batch["mask"].numpy()) + + labels = np.concatenate(all_labels, axis=0) + probs = np.concatenate(all_probs, axis=0) + masks = np.concatenate(all_masks, axis=0) + + # Flatten and apply mask + flat_labels = labels.flatten() + flat_probs = probs.flatten() + flat_masks = masks.flatten() + + idx = flat_masks == 1 + y_true = flat_labels[idx] + y_prob = flat_probs[idx] + y_pred = (y_prob >= 0.5).astype(float) + + eps = 1e-5 # noqa: F841 — kept for numerical-stability parity with reCADRE + + acc = accuracy_score(y_true, y_pred) + prec = precision_score(y_true, y_pred, zero_division=0) + rec = recall_score(y_true, y_pred, zero_division=0) + f1 = f1_score(y_true, y_pred, zero_division=0) + + try: + auroc = roc_auc_score(y_true, y_prob) + except ValueError: + auroc = 0.5 + + try: + prec_curve, rec_curve, _ = precision_recall_curve(y_true, y_prob) + aupr = auc(rec_curve, prec_curve) + except ValueError: + aupr = 0.0 + + return { + "precision": prec, + "recall": rec, + "f1": f1, + "accuracy": acc, + "auroc": auroc, + "aupr": aupr, + "labels": labels, + "probs": probs, + "masks": masks, + } + + +# --------------------------------------------------------------------------- +# Step 5: Training loop +# --------------------------------------------------------------------------- + + +def run_one( + dataset: GDSCDataset, + train_ds: Subset, + val_ds: Subset, + test_ds: Subset, + device: torch.device, + args: argparse.Namespace, + label: str = "CADRE", + embedding_dim: Optional[int] = None, + use_attention: Optional[bool] = None, + use_cntx_attn: Optional[bool] = None, + freeze_gene_emb: Optional[bool] = None, +) -> Dict: + """Train one model configuration; returns test metrics. + + Keyword overrides (``embedding_dim``, ``use_attention``, etc.) are used + by the ablation study to vary a single hyperparameter per run while + inheriting all other settings from ``args``. + + Args: + dataset: Loaded GDSCDataset (supplies embeddings and pathway info). + train_ds: Training split (after fill_mask_training has been called). + val_ds: Validation split. + test_ds: Test split. + device: Compute device. + args: Parsed CLI arguments supplying default hyperparameters. + label: Display name for progress output. + embedding_dim: Override for ``args.embedding_dim``. + use_attention: Override for ``args.use_attention``. + use_cntx_attn: Override for ``args.use_cntx_attn``. + freeze_gene_emb: Override for ``not args.train_gene_emb``. + + Returns: + Dict with test-set metrics from :func:`evaluate`. + """ + # Apply per-run overrides (ablation study only) + emb_dim = embedding_dim if embedding_dim is not None else args.embedding_dim + attn = use_attention if use_attention is not None else args.use_attention + cntx = use_cntx_attn if use_cntx_attn is not None else args.use_cntx_attn + freeze = ( + freeze_gene_emb + if freeze_gene_emb is not None + else not args.train_gene_emb + ) + + train_loader = DataLoader( + train_ds, + batch_size=args.batch_size, + shuffle=True, + collate_fn=cadre_collate_fn, + ) + val_loader = DataLoader( + val_ds, + batch_size=args.batch_size, + shuffle=False, + collate_fn=cadre_collate_fn, + ) + test_loader = DataLoader( + test_ds, + batch_size=args.batch_size, + shuffle=False, + collate_fn=cadre_collate_fn, + ) + + gene_emb = dataset.get_gene_embeddings() + pw_info = dataset.get_pathway_info() + + model = CADRE( + gene_embeddings=gene_emb, + num_drugs=len(dataset.drug_ids), + num_pathways=pw_info["num_pathways"], + drug_pathway_ids=pw_info["drug_pathway_ids"], + embedding_dim=emb_dim, + attention_size=args.attention_size, + attention_head=args.attention_head, + dropout_rate=args.dropout_rate, + use_attention=attn, + use_cntx_attn=cntx, + freeze_gene_emb=freeze, + ).to(device) + + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + total = sum(p.numel() for p in model.parameters()) + print(f"\n[{label}] Parameters: {trainable:,} trainable / {total:,} total") + + # SGD with OneCycle LR/momentum (paper: lr=0.3, wd=3e-4, 48k steps) + optimizer = optim.SGD( + model.parameters(), + lr=args.learning_rate, + momentum=0.95, + weight_decay=args.weight_decay, + ) + + # Paper: batch_size=8, max_iter=48000 → 6000 optimizer steps + steps_per_epoch = len(train_loader) + total_optimizer_steps = args.max_iter // args.batch_size + total_epochs = ( + total_optimizer_steps + steps_per_epoch - 1 + ) // steps_per_epoch + scheduler = OneCycle(total_optimizer_steps, args.learning_rate) + + print( + f" Training for {total_optimizer_steps} steps ({total_epochs} epochs), " + f"{steps_per_epoch} steps/epoch" + ) + + logs: Dict = { + "args": vars(args), + "epoch": [], + "step": [], + "train_loss": [], + "train_f1": [], + "train_acc": [], + "train_auroc": [], + "train_aupr": [], + "val_f1": [], + "val_acc": [], + "val_auroc": [], + "val_aupr": [], + } + + global_step = 0 + best_val_f1 = 0.0 + best_model_state: Optional[dict] = None + start_time = time.time() + + for epoch in range(total_epochs): + model.train() + epoch_losses = [] + + for batch in train_loader: + if global_step >= total_optimizer_steps: + break + + gene_indices = batch["gene_indices"].to(device) + labels = batch["labels"].to(device) + mask = batch["mask"].to(device) + + # OneCycle LR/momentum update + lr, mom = scheduler.step() + for pg in optimizer.param_groups: + pg["lr"] = lr + pg["momentum"] = mom + + optimizer.zero_grad() + result = model(gene_indices, labels=labels, mask=mask) + loss = result["loss"] + loss.backward() + optimizer.step() + + epoch_losses.append(loss.item()) + global_step += 1 + + # Evaluate every eval_every epochs (or at the last step) + if ( + (epoch + 1) % args.eval_every == 0 + or global_step >= total_optimizer_steps + ): + train_metrics = evaluate(model, train_loader, device) + val_metrics = evaluate(model, val_loader, device) + + avg_loss = np.mean(epoch_losses) if epoch_losses else 0.0 + elapsed = time.time() - start_time + + print( + f" [Epoch {epoch + 1:3d} | Step {global_step:5d} | {elapsed:.0f}s] " + f"loss={avg_loss:.4f} | " + f"trn F1={100 * train_metrics['f1']:.1f} " + f"AUC={100 * train_metrics['auroc']:.1f} | " + f"val F1={100 * val_metrics['f1']:.1f} " + f"AUC={100 * val_metrics['auroc']:.1f} " + f"AUPR={100 * val_metrics['aupr']:.1f} " + f"Acc={100 * val_metrics['accuracy']:.1f}" + ) + + logs["epoch"].append(epoch + 1) + logs["step"].append(global_step) + logs["train_loss"].append(avg_loss) + logs["train_f1"].append(train_metrics["f1"]) + logs["train_acc"].append(train_metrics["accuracy"]) + logs["train_auroc"].append(train_metrics["auroc"]) + logs["train_aupr"].append(train_metrics["aupr"]) + logs["val_f1"].append(val_metrics["f1"]) + logs["val_acc"].append(val_metrics["accuracy"]) + logs["val_auroc"].append(val_metrics["auroc"]) + logs["val_aupr"].append(val_metrics["aupr"]) + + # Save best model by val F1 + if val_metrics["f1"] > best_val_f1: + best_val_f1 = val_metrics["f1"] + best_model_state = { + k: v.cpu().clone() for k, v in model.state_dict().items() + } + + if global_step >= total_optimizer_steps: + break + + # Final evaluation on test set using best-val-F1 model + print(f"\n=== Final Evaluation [{label}] (best val F1 checkpoint) ===") + if best_model_state is not None: + model.load_state_dict(best_model_state) + model.to(device) + + test_metrics = evaluate(model, test_loader, device) + train_metrics_final = evaluate(model, train_loader, device) + val_metrics_final = evaluate(model, val_loader, device) + + print( + f"Train: F1={100 * train_metrics_final['f1']:.1f} " + f"Acc={100 * train_metrics_final['accuracy']:.1f} " + f"AUROC={100 * train_metrics_final['auroc']:.1f} " + f"AUPR={100 * train_metrics_final['aupr']:.1f}" + ) + print( + f"Val: F1={100 * val_metrics_final['f1']:.1f} " + f"Acc={100 * val_metrics_final['accuracy']:.1f} " + f"AUROC={100 * val_metrics_final['auroc']:.1f} " + f"AUPR={100 * val_metrics_final['aupr']:.1f}" + ) + print( + f"Test: F1={100 * test_metrics['f1']:.1f} " + f"Acc={100 * test_metrics['accuracy']:.1f} " + f"AUROC={100 * test_metrics['auroc']:.1f} " + f"AUPR={100 * test_metrics['aupr']:.1f}" + ) + + # Bundle extra info for callers that want to save outputs + test_metrics["_logs"] = logs + test_metrics["_train_final"] = train_metrics_final + test_metrics["_val_final"] = val_metrics_final + test_metrics["_pw_info"] = pw_info + test_metrics["_elapsed"] = time.time() - start_time + return test_metrics + + +# --------------------------------------------------------------------------- +# Step 6: Ablation study +# --------------------------------------------------------------------------- + +# Ablation configurations (paper Table 2): +# +# Name use_attention use_cntx_attn embedding_dim freeze_gene_emb +# CADRE True True 200 True ← full model +# SADRE True False 200 True ← no pathway ctx +# ADRE False False 200 True ← mean pooling +# CADRE-100 True True 100 True ← smaller emb +# CADRE-free True True 200 False ← trainable emb +ABLATION_CONFIGS: List[Dict] = [ + dict(label="CADRE", + use_attention=True, use_cntx_attn=True, + embedding_dim=200, freeze_gene_emb=True), + dict(label="SADRE (no pathway ctx)", + use_attention=True, use_cntx_attn=False, + embedding_dim=200, freeze_gene_emb=True), + dict(label="ADRE (mean pooling)", + use_attention=False, use_cntx_attn=False, + embedding_dim=200, freeze_gene_emb=True), + dict(label="CADRE-100 (emb=100)", + use_attention=True, use_cntx_attn=True, + embedding_dim=100, freeze_gene_emb=True), + dict(label="CADRE-free (trainable emb)", + use_attention=True, use_cntx_attn=True, + embedding_dim=200, freeze_gene_emb=False), +] + + +def ablation_study(args: argparse.Namespace) -> None: + """Run all ablation configurations and print a comparison table. + + Experimental setup: + - Dataset: GDSC (60/20/20 cell-line split, seed=2019) + - Optimiser: SGD with 1-Cycle LR (max_lr=0.3, wd=3e-4, 48k steps) + - Evaluation: masked F1, AUROC, AUPR on held-out test cell lines + - Metric reported: best-val-F1 checkpoint on test set + + Args: + args: Parsed CLI arguments. + """ + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + device = _pick_device(args) + print(f"Device: {device}") + + dataset, train_ds, val_ds, test_ds = load_data(args.data_dir, args.seed) + + if not args.no_fill_mask: + fill_mask_training(train_ds) + print("Applied fill_mask to training set") + else: + print("Skipped fill_mask (--no_fill_mask)") + + results = [] + for cfg in ABLATION_CONFIGS: + # Reset seeds before each run for fair comparison + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + m = run_one( + dataset=dataset, + train_ds=train_ds, + val_ds=val_ds, + test_ds=test_ds, + device=device, + args=args, + label=cfg["label"], + embedding_dim=cfg.get("embedding_dim"), + use_attention=cfg.get("use_attention"), + use_cntx_attn=cfg.get("use_cntx_attn"), + freeze_gene_emb=cfg.get("freeze_gene_emb"), + ) + results.append((cfg["label"], m)) + + # Print comparison table + print("\n" + "=" * 70) + print("Ablation Study — GDSC Test Set") + print("=" * 70) + print(f"{'Model':<28} {'F1':>6} {'AUROC':>7} {'AUPR':>6} {'Acc':>6}") + print("-" * 70) + for name, m in results: + print( + f"{name:<28} " + f"{100 * m['f1']:6.1f} " + f"{100 * m['auroc']:7.1f} " + f"{100 * m['aupr']:6.1f} " + f"{100 * m['accuracy']:6.1f}" + ) + print("=" * 70) + print("Paper reference (CADRE): 64.3 83.4 70.6 78.6") + + +# --------------------------------------------------------------------------- +# Step 7: Demo mode (synthetic data, no GDSC files required) +# --------------------------------------------------------------------------- + + +def _write_synthetic_gdsc(tmp_dir: str) -> None: + """Write minimal synthetic GDSC CSV files for smoke-testing. + + Creates a tiny GDSC-compatible dataset: + - 10 cell lines, 20 genes, 5 drugs, 3 pathways, ~30 % missing labels. + + Args: + tmp_dir: Directory to write CSV files into. + """ + rng = np.random.RandomState(0) + n_cells, n_genes, n_drugs = 10, 20, 5 + + # Binary gene expression (cell lines × genes) + exp = pd.DataFrame( + rng.randint(0, 2, (n_cells, n_genes)), + index=[f"CL{i}" for i in range(n_cells)], + columns=[str(g) for g in range(1, n_genes + 1)], + ) + exp.to_csv(os.path.join(tmp_dir, "exp_gdsc.csv")) + + # Binary sensitivity matrix; ~30 % missing (NaN) + sens = rng.randint(0, 2, (n_cells, n_drugs)).astype(float) + sens[rng.rand(n_cells, n_drugs) < 0.3] = np.nan + drug_ids = list(range(1001, 1001 + n_drugs)) + tgt = pd.DataFrame( + sens, + index=[f"CL{i}" for i in range(n_cells)], + columns=[str(d) for d in drug_ids], + ) + tgt.to_csv(os.path.join(tmp_dir, "gdsc.csv")) + + # Drug metadata (Name + Target pathway) + pathways = ["PI3K/MTOR", "ERK MAPK", "WNT"] + drug_info = pd.DataFrame( + { + "Name": [f"Drug{i}" for i in range(n_drugs)], + "Target pathway": [pathways[i % len(pathways)] for i in range(n_drugs)], + }, + index=drug_ids, + ) + drug_info.to_csv(os.path.join(tmp_dir, "drug_info_gdsc.csv")) + + # Gene2Vec embeddings (n_genes+1 rows; row 0 = padding vector) + emb = rng.randn(n_genes + 1, 8).astype(np.float32) + emb[0] = 0.0 + np.savetxt(os.path.join(tmp_dir, "exp_emb_gdsc.csv"), emb, delimiter=",") + + +def demo(args: argparse.Namespace) -> None: + """Smoke-test the full pipeline with synthetic data. + + Trains for a tiny number of steps to verify that data loading, model + forward pass, and evaluation all run without errors. Results are + meaningless — only absence of exceptions matters. + + Args: + args: Parsed CLI arguments (``seed`` and ``batch_size`` are used). + """ + with tempfile.TemporaryDirectory() as tmp_dir: + _write_synthetic_gdsc(tmp_dir) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + device = torch.device("cpu") + dataset, train_ds, val_ds, test_ds = load_data(tmp_dir, seed=args.seed) + fill_mask_training(train_ds) + + # Override to run just 10 mini-batches + demo_args = argparse.Namespace(**vars(args)) + demo_args.max_iter = args.batch_size * 10 + demo_args.eval_every = 1 + + m = run_one( + dataset=dataset, + train_ds=train_ds, + val_ds=val_ds, + test_ds=test_ds, + device=device, + args=demo_args, + label="CADRE (demo)", + ) + print("\n=== Demo Results (synthetic data — values are meaningless) ===") + print(f"F1: {100 * m['f1']:.1f}") + print(f"AUROC: {100 * m['auroc']:.1f}") + print("Pipeline smoke-test passed.") + + +# --------------------------------------------------------------------------- +# Step 8: Full paper replication (single run) +# --------------------------------------------------------------------------- + + +def train(args: argparse.Namespace) -> None: + """Run the full paper-replication training on real GDSC data. + + Args: + args: Parsed CLI arguments. + """ + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + device = _pick_device(args) + print(f"Device: {device}") + + dataset, train_ds, val_ds, test_ds = load_data(args.data_dir, args.seed) + + if not args.no_fill_mask: + fill_mask_training(train_ds) + print("Applied fill_mask to training set") + else: + print("Skipped fill_mask (--no_fill_mask)") + + test_m = run_one( + dataset=dataset, + train_ds=train_ds, + val_ds=val_ds, + test_ds=test_ds, + device=device, + args=args, + label="CADRE", + ) + + # Save outputs + os.makedirs(args.output_dir, exist_ok=True) + + logs = test_m["_logs"] + pw_info = test_m["_pw_info"] + val_metrics_final = test_m["_val_final"] + elapsed = test_m["_elapsed"] + + logs["test_f1"] = test_m["f1"] + logs["test_acc"] = test_m["accuracy"] + logs["test_auroc"] = test_m["auroc"] + logs["test_aupr"] = test_m["aupr"] + logs["test_precision"] = test_m["precision"] + logs["test_recall"] = test_m["recall"] + logs["test_probs"] = test_m["probs"] + logs["test_labels"] = test_m["labels"] + logs["test_masks"] = test_m["masks"] + logs["train_time_seconds"] = elapsed + + logs_path = os.path.join(args.output_dir, "logs.pkl") + with open(logs_path, "wb") as f: + pickle.dump(logs, f, protocol=2) + print(f"\nLogs saved to {logs_path}") + + model_path = os.path.join(args.output_dir, "model.pt") + torch.save( + {"args": vars(args), "pathway_info": pw_info}, + model_path, + ) + print(f"Model saved to {model_path}") + + summary_path = os.path.join(args.output_dir, "results.txt") + with open(summary_path, "w") as f: + f.write("reCADRE Training Results\n") + f.write("=" * 50 + "\n\n") + f.write("Hyperparameters:\n") + for k, v in vars(args).items(): + f.write(f" {k}: {v}\n") + f.write("\nDataset:\n") + f.write(f" Cell lines: 846\n") + f.write(f" Drugs: 260\n") + f.write(f" Genes: 3000 (1500 active)\n") + f.write(f" Pathways: {pw_info['num_pathways']}\n") + f.write(f" Split: 60/20/20\n") + f.write("\nResults (Test Set):\n") + f.write(f" F1 Score: {100 * test_m['f1']:.2f}\n") + f.write(f" Accuracy: {100 * test_m['accuracy']:.2f}\n") + f.write(f" AUROC: {100 * test_m['auroc']:.2f}\n") + f.write(f" AUPR: {100 * test_m['aupr']:.2f}\n") + f.write(f" Precision: {100 * test_m['precision']:.2f}\n") + f.write(f" Recall: {100 * test_m['recall']:.2f}\n") + f.write("\nResults (Validation Set):\n") + f.write(f" F1 Score: {100 * val_metrics_final['f1']:.2f}\n") + f.write(f" Accuracy: {100 * val_metrics_final['accuracy']:.2f}\n") + f.write(f" AUROC: {100 * val_metrics_final['auroc']:.2f}\n") + f.write(f" AUPR: {100 * val_metrics_final['aupr']:.2f}\n") + f.write("\nPaper Reference (CADRE on GDSC, Table 1):\n") + f.write(f" F1 Score: 64.3 ± 0.22\n") + f.write(f" Accuracy: 78.6 ± 0.34\n") + f.write(f" AUROC: 83.4 ± 0.19\n") + f.write(f" AUPR: 70.6 ± 1.30\n") + f.write(f"\nTraining time: {elapsed:.1f}s\n") + print(f"Summary saved to {summary_path}") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _pick_device(args: argparse.Namespace) -> torch.device: + """Select compute device based on CLI flags and availability. + + Args: + args: Parsed arguments; reads ``cpu`` and ``use_cuda``. + + Returns: + Selected ``torch.device``. + """ + if args.cpu: + return torch.device("cpu") + if args.use_cuda and torch.cuda.is_available(): + return torch.device("cuda") + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return torch.device("mps") + return torch.device("cpu") + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments. + + Returns: + Populated ``argparse.Namespace``. + """ + parser = argparse.ArgumentParser( + description="Train reCADRE model on GDSC drug sensitivity data" + ) + + # Resolve default paths relative to this script's directory + _script_dir = os.path.dirname(os.path.abspath(__file__)) + + # Data + parser.add_argument( + "--data_dir", + type=str, + default=os.path.join(_script_dir, "..", "originalData"), + help="Path to GDSC CSV files", + ) + parser.add_argument( + "--output_dir", + type=str, + default=os.path.join(_script_dir, "..", "outputs"), + help="Directory to save checkpoint, logs, and results", + ) + + # Model architecture (Table A2) + parser.add_argument("--embedding_dim", type=int, default=200) + parser.add_argument("--attention_size", type=int, default=128) + parser.add_argument("--attention_head", type=int, default=8) + parser.add_argument("--dropout_rate", type=float, default=0.6) + parser.add_argument( + "--use_attention", action="store_true", default=True, + help="Enable multi-head attention in the encoder", + ) + parser.add_argument( + "--no_attention", action="store_true", default=False, + help="Disable attention (ADRE ablation — mean pooling)", + ) + parser.add_argument( + "--use_cntx_attn", action="store_true", default=True, + help="Enable contextual (pathway) conditioning in attention", + ) + parser.add_argument( + "--no_cntx_attn", action="store_true", default=False, + help="Disable contextual attention (SADRE ablation)", + ) + parser.add_argument( + "--train_gene_emb", action="store_true", default=False, + help="Unfreeze gene embeddings (CADRE∆pretrain variant)", + ) + parser.add_argument( + "--no_fill_mask", action="store_true", default=False, + help="Skip per-drug-mode imputation of missing labels", + ) + + # Training (Table A2) + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument( + "--max_iter", type=int, default=48000, + help="Total training iterations (paper: 48k for GDSC)", + ) + parser.add_argument("--learning_rate", type=float, default=0.3) + parser.add_argument("--weight_decay", type=float, default=3e-4) + + # Misc + parser.add_argument( + "--eval_every", type=int, default=10, + help="Evaluate every N epochs", + ) + parser.add_argument("--seed", type=int, default=2019) + parser.add_argument( + "--use_cuda", action="store_true", default=True, + help="Use CUDA if available", + ) + parser.add_argument( + "--cpu", action="store_true", default=False, + help="Force CPU even if GPU/MPS available", + ) + + # Modes + parser.add_argument( + "--ablation", action="store_true", + help="Run all 5 ablation variants and print comparison table", + ) + parser.add_argument( + "--demo", action="store_true", + help="Smoke-test the pipeline with synthetic data (no real data needed)", + ) + + args = parser.parse_args() + + # Handle negation flags (match reCADRE train.py convention) + if args.no_attention: + args.use_attention = False + if args.no_cntx_attn: + args.use_cntx_attn = False + + return args + + +if __name__ == "__main__": + args = parse_args() + + if args.demo: + demo(args) + elif args.ablation: + ablation_study(args) + else: + train(args) diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 50b1b3887..42338c94c 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -91,3 +91,6 @@ def __init__(self, *args, **kwargs): save_processors, ) from .collate import collate_temporal + +from .gdsc import GDSCDataset +from .ccle import CCLEDataset diff --git a/pyhealth/datasets/ccle.py b/pyhealth/datasets/ccle.py new file mode 100644 index 000000000..61e380270 --- /dev/null +++ b/pyhealth/datasets/ccle.py @@ -0,0 +1,286 @@ +"""CCLE (Cancer Cell Line Encyclopedia) Dataset. + +Paper: + Barretina, J. et al. (2012). The Cancer Cell Line Encyclopedia enables + predictive modelling of anticancer drug sensitivity. + Nature, 483(7391), 603-607. + +This module wraps pre-processed CCLE data as a PyHealth dataset for cancer +drug sensitivity prediction. Follows the same conventions as +:class:`~pyhealth.datasets.GDSCDataset` so both datasets can be used +interchangeably (and jointly for cross-dataset evaluation). + +Unlike GDSC — where drug column headers are numeric COSMIC IDs — CCLE uses +drug *names* directly as column headers. Cross-dataset drug matching is +therefore done by name via :meth:`get_overlap_drugs`. +""" + +import os +from typing import Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +from pyhealth.datasets.sample_dataset import create_sample_dataset + + +class CCLEDataset: + """CCLE dataset for cancer drug sensitivity prediction. + + Loads pre-processed CCLE data from CSV files and converts it into a + PyHealth dataset via :meth:`set_task`. Follows the PyHealth + ``dataset.set_task()`` convention. + + Each *patient* corresponds to one cancer cell line. The single *visit* + represents its genomic measurement (gene expression + drug sensitivity). + + **Source data layout** (``data_dir``): + + .. code-block:: text + + exp_ccle.csv Binary gene expression (cell lines x genes) + ccle.csv Binary drug sensitivity (cell lines x drugs) + drug_info_ccle.csv Drug metadata with target pathways + exp_emb_ccle.csv Gene2Vec embeddings (n_genes+1 x emb_dim) + + Unlike GDSC, CCLE drug columns are identified by drug **name** (not + numeric ID). :meth:`get_overlap_drugs` handles name-based cross-dataset + matching automatically. + + Args: + data_dir (str): Path to the directory containing the CSV files. + Defaults to ``"ccleData"``. + seed (int): Random seed (used by downstream splitters). Defaults + to ``2019``. + + Attributes: + gene_names (List[str]): Ordered gene column names. + drug_ids (List[str]): Ordered drug column names (drug names). + drug_names (List[str]): Same as ``drug_ids`` for CCLE (drug names + are used directly as column headers). + drug_pathway_ids (List[int]): Integer pathway ID per drug. + gene_embeddings (np.ndarray): Pre-trained Gene2Vec matrix; row 0 is + the zero-padding vector. + + Examples: + >>> from pyhealth.datasets import CCLEDataset + >>> dataset = CCLEDataset(data_dir="ccleData") + >>> sample_ds = dataset.set_task() + >>> "gene_indices" in sample_ds[0] + True + """ + + dataset_name: str = "CCLE" + + def __init__(self, data_dir: str = "ccleData", seed: int = 2019) -> None: + self.data_dir = data_dir + self.seed = seed + self._load_data() + + # ------------------------------------------------------------------ + # Internal data loading + # ------------------------------------------------------------------ + + def _load_data(self) -> None: + """Load and align all CSV data files.""" + try: + self.exp = pd.read_csv( + os.path.join(self.data_dir, "exp_ccle.csv"), index_col=0 + ) + self.tgt = pd.read_csv( + os.path.join(self.data_dir, "ccle.csv"), index_col=0 + ) + self.drug_info = pd.read_csv( + os.path.join(self.data_dir, "drug_info_ccle.csv"), index_col=0 + ) + self.gene_embeddings = np.loadtxt( + os.path.join(self.data_dir, "exp_emb_ccle.csv"), delimiter="," + ) + except FileNotFoundError as exc: + raise FileNotFoundError( + f"CCLE data files not found in {self.data_dir}. " + "Ensure exp_ccle.csv, ccle.csv, drug_info_ccle.csv, and " + "exp_emb_ccle.csv exist." + ) from exc + + # Find common cell lines using case/punctuation-insensitive matching to + # handle CCLE naming inconsistencies between tables (e.g. "22Rv1" vs + # "22RV1", "42-MG-BA" vs "42MGBA"). + def _norm(s): + return str(s).upper().replace("-", "").replace(" ", "").replace(".", "") + + exp_norm = {_norm(i): i for i in self.exp.index} + tgt_norm = {_norm(i): i for i in self.tgt.index} + common_norm = sorted(set(exp_norm) & set(tgt_norm)) + # Use the expression-side label as the canonical cell-line ID. + self.common_samples = [exp_norm[k] for k in common_norm] + self.exp = self.exp.loc[self.common_samples] + self.tgt = self.tgt.loc[[tgt_norm[k] for k in common_norm]] + self.tgt.index = self.common_samples + + # Preprocessed CCLE labels are inverted (~75% "1") relative to the + # paper's 24.8% sensitive prior and GDSC's "1 = sensitive" convention. + # Flip so "1 = sensitive" is consistent across both datasets. + observed = self.tgt.notnull() + self.tgt = self.tgt.where(~observed, 1 - self.tgt) + + # Gene and drug IDs + self.gene_names: List[str] = list(self.exp.columns) + # For CCLE, column headers ARE drug names (not numeric IDs) + self.drug_ids: List[str] = list(self.tgt.columns) + self.drug_names: List[str] = self.drug_ids + + self._build_pathway_mapping() + + def _build_pathway_mapping(self) -> None: + """Map each drug column (drug name) to an integer pathway ID. + + Uses case-insensitive name matching against ``drug_info_ccle.csv`` + index to handle minor casing discrepancies. + """ + id2pw: Dict[str, str] = {} + for idx in self.drug_info.index: + id2pw[str(idx)] = self.drug_info.loc[idx, "Target pathway"] + + self.drug_pathways: List[str] = [] + for drug_name in self.tgt.columns: + name_str = str(drug_name) + if name_str in id2pw: + pw = id2pw[name_str] + else: + # Case-insensitive fallback + matches = [k for k in id2pw if k.lower() == name_str.lower()] + pw = id2pw[matches[0]] if matches else "Unknown" + self.drug_pathways.append(pw) + + unique_pathways = sorted(set(self.drug_pathways)) + self.pathway2id: Dict[str, int] = {pw: i for i, pw in enumerate(unique_pathways)} + self.drug_pathway_ids: List[int] = [ + self.pathway2id[pw] for pw in self.drug_pathways + ] + + # ------------------------------------------------------------------ + # PyHealth task interface + # ------------------------------------------------------------------ + + def set_task(self, task=None): + """Apply a task to produce a model-ready dataset. + + Args: + task: An object with ``task_name``, ``input_schema``, + ``output_schema`` attributes and a + ``__call__(patient) -> List[dict]`` method. If ``None``, + :class:`~pyhealth.tasks.DrugSensitivityPredictionCCLE` is + used with default settings. + + Returns: + Dataset ready for a standard PyTorch + :class:`~torch.utils.data.DataLoader`. + """ + if task is None: + from pyhealth.tasks.drug_sensitivity_ccle import ( + DrugSensitivityPredictionCCLE, + ) + task = DrugSensitivityPredictionCCLE() + + samples: List[Dict] = [] + for cell_line in self.common_samples: + patient = { + "patient_id": str(cell_line), + "gene_expression": self.exp.loc[cell_line].values, + "drug_sensitivity": self.tgt.loc[cell_line].values.astype(float), + "drug_pathway_ids": self.drug_pathway_ids, + } + samples.extend(task(patient)) + + return create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + dataset_name=self.dataset_name, + task_name=task.task_name, + in_memory=True, + ) + + # ------------------------------------------------------------------ + # Model-configuration accessors + # ------------------------------------------------------------------ + + def get_gene_embeddings(self) -> np.ndarray: + """Return the pre-trained Gene2Vec embedding matrix. + + Returns: + np.ndarray: Row 0 is the zero-padding vector; subsequent rows + correspond to 1-indexed gene positions used in ``gene_indices``. + """ + return self.gene_embeddings + + def get_pathway_info(self) -> Dict: + """Return drug pathway metadata for model initialisation. + + Returns: + dict: Keys are ``pathway2id``, ``id2pathway``, + ``num_pathways``, and ``drug_pathway_ids``. + """ + return { + "pathway2id": self.pathway2id, + "id2pathway": {v: k for k, v in self.pathway2id.items()}, + "num_pathways": len(self.pathway2id), + "drug_pathway_ids": self.drug_pathway_ids, + } + + def get_overlap_drugs(self, other_dataset) -> Tuple[List[int], List[int], List[str]]: + """Find drugs overlapping with another dataset by drug name. + + Args: + other_dataset: Another dataset instance with a ``drug_names`` + attribute (or falling back to ``drug_ids``). + + Returns: + tuple: ``(self_indices, other_indices, overlap_names)`` + """ + self_names = set(self.drug_names) + other_names = set( + other_dataset.drug_names + if hasattr(other_dataset, "drug_names") + else other_dataset.drug_ids + ) + overlap = sorted(self_names & other_names) + + self_indices = [self.drug_names.index(d) for d in overlap] + other_indices = [ + ( + other_dataset.drug_names.index(d) + if hasattr(other_dataset, "drug_names") + else other_dataset.drug_ids.index(d) + ) + for d in overlap + ] + return self_indices, other_indices, overlap + + # ------------------------------------------------------------------ + # Utilities + # ------------------------------------------------------------------ + + def summary(self) -> None: + """Print dataset summary statistics to stdout.""" + total_pairs = len(self.common_samples) * len(self.tgt.columns) + tested = self.tgt.notnull().sum().sum() + sensitive = (self.tgt == 1).sum().sum() + resistant = (self.tgt == 0).sum().sum() + + print("CCLE Dataset Summary") + print(f" Cell lines: {len(self.common_samples)}") + print(f" Drugs: {len(self.tgt.columns)}") + print(f" Genes: {len(self.exp.columns)}") + print(f" Active genes/cell:{int(self.exp.sum(axis=1).mean())}") + print(f" Total pairs: {total_pairs}") + print(f" Tested pairs: {int(tested)} ({tested / total_pairs:.1%})") + print( + f" Missing pairs: {int(total_pairs - tested)}" + f" ({(total_pairs - tested) / total_pairs:.1%})" + ) + print(f" Sensitive: {int(sensitive)} ({sensitive / tested:.1%})") + print(f" Resistant: {int(resistant)} ({resistant / tested:.1%})") + print(f" Pathways: {len(self.pathway2id)}") + print(f" Embedding shape: {self.gene_embeddings.shape}") diff --git a/pyhealth/datasets/configs/ccle.yaml b/pyhealth/datasets/configs/ccle.yaml new file mode 100644 index 000000000..ff90976e4 --- /dev/null +++ b/pyhealth/datasets/configs/ccle.yaml @@ -0,0 +1,52 @@ +version: "1.0" +# CCLE (Cancer Cell Line Encyclopedia) dataset configuration. +# +# Data format: flat matrix files (cell lines × genes, cell lines × drugs). +# This schema documents the expected file layout. CCLEDataset loads these +# files directly via pandas rather than the BaseDataset YAML-table pipeline, +# because the expression and sensitivity data are dense matrices rather than +# per-patient event tables. +# +# Required files in data_dir: +# exp_ccle.csv - Binary gene expression matrix (cell lines × genes) +# ccle.csv - Binary drug sensitivity matrix (cell lines × drugs) +# drug_info_ccle.csv - Drug metadata: drug name, Target pathway +# exp_emb_ccle.csv - Gene2Vec embedding matrix (n_genes+1 × emb_dim) +tables: + expression: + file_path: "exp_ccle.csv" + patient_id: "cell_line_id" + timestamp: null + attributes: [] + description: > + Binary gene expression matrix. Rows are CCLE cell-line names; columns + are 1-indexed gene positions. Value 1 = gene expressed, 0 = not expressed. + sensitivity: + file_path: "ccle.csv" + patient_id: "cell_line_id" + timestamp: null + attributes: [] + description: > + Binary drug sensitivity matrix. Rows are CCLE cell-line names; columns + are drug names. Value 1 = sensitive, 0 = resistant, NaN = untested. + Note: raw CCLE labels are inverted relative to the paper convention; + CCLEDataset automatically flips them so that 1 = sensitive consistently. + drug_info: + file_path: "drug_info_ccle.csv" + patient_id: null + timestamp: null + attributes: + - "Target pathway" + description: > + Drug metadata table. Indexed by drug name (matching ccle.csv columns). + Used to build the drug-to-pathway mapping required by the CADRE + contextual attention layer. + gene_embeddings: + file_path: "exp_emb_ccle.csv" + patient_id: null + timestamp: null + attributes: [] + description: > + Pre-trained Gene2Vec embedding matrix. Row 0 is the zero-padding vector; + subsequent rows correspond to 1-indexed gene positions matching the + columns of exp_ccle.csv. diff --git a/pyhealth/datasets/configs/gdsc.yaml b/pyhealth/datasets/configs/gdsc.yaml new file mode 100644 index 000000000..47c550bb4 --- /dev/null +++ b/pyhealth/datasets/configs/gdsc.yaml @@ -0,0 +1,50 @@ +version: "1.0" +# GDSC (Genomics of Drug Sensitivity in Cancer) dataset configuration. +# +# Data format: flat matrix files (cell lines × genes, cell lines × drugs). +# This schema documents the expected file layout. GDSCDataset loads these +# files directly via pandas rather than the BaseDataset YAML-table pipeline, +# because the expression and sensitivity data are dense matrices rather than +# per-patient event tables. +# +# Required files in data_dir: +# exp_gdsc.csv - Binary gene expression matrix (cell lines × genes) +# gdsc.csv - Binary drug sensitivity matrix (cell lines × drugs) +# drug_info_gdsc.csv - Drug metadata: Name, Target pathway, COSMIC drug ID +# exp_emb_gdsc.csv - Gene2Vec embedding matrix (n_genes+1 × emb_dim) +tables: + expression: + file_path: "exp_gdsc.csv" + patient_id: "cell_line_id" + timestamp: null + attributes: [] + description: > + Binary gene expression matrix. Rows are COSMIC cell-line IDs; columns + are 1-indexed gene positions. Value 1 = gene expressed, 0 = not expressed. + sensitivity: + file_path: "gdsc.csv" + patient_id: "cell_line_id" + timestamp: null + attributes: [] + description: > + Binary drug sensitivity matrix. Rows are COSMIC cell-line IDs; columns + are COSMIC drug IDs. Value 1 = sensitive, 0 = resistant, NaN = untested. + drug_info: + file_path: "drug_info_gdsc.csv" + patient_id: null + timestamp: null + attributes: + - "Name" + - "Target pathway" + description: > + Drug metadata table. Indexed by COSMIC drug ID. Used to build the + drug-to-pathway mapping required by the CADRE contextual attention layer. + gene_embeddings: + file_path: "exp_emb_gdsc.csv" + patient_id: null + timestamp: null + attributes: [] + description: > + Pre-trained Gene2Vec embedding matrix. Row 0 is the zero-padding vector; + subsequent rows correspond to 1-indexed gene positions matching the + columns of exp_gdsc.csv. diff --git a/pyhealth/datasets/gdsc.py b/pyhealth/datasets/gdsc.py new file mode 100644 index 000000000..4fd4f7c74 --- /dev/null +++ b/pyhealth/datasets/gdsc.py @@ -0,0 +1,269 @@ +"""GDSC (Genomics of Drug Sensitivity in Cancer) Dataset. + +Paper: + Yang, W. et al. (2013). Genomics of Drug Sensitivity in Cancer (GDSC): + a resource for therapeutic biomarker discovery in cancer cells. + Nucleic Acids Research, 41(D1), D955-D961. + + Tao, Y. et al. (2020). Predicting Drug Sensitivity of Cancer Cell Lines + via Collaborative Filtering with Contextual Attention. MLHC 2020. + +This module wraps pre-processed GDSC data as a PyHealth dataset for +cancer drug sensitivity prediction tasks. Each "patient" corresponds to +a cancer cell line characterised by its binary gene expression profile. +""" + +import os +from typing import Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +from pyhealth.datasets.sample_dataset import create_sample_dataset + + +class GDSCDataset: + """GDSC dataset for cancer drug sensitivity prediction. + + Loads pre-processed GDSC data from CSV files and converts it into + a PyHealth ``SampleDataset`` via :meth:`set_task`. Follows the PyHealth + ``dataset.set_task()`` convention. + + Each *patient* corresponds to one cancer cell line identified by its + COSMIC ID. The single *visit* for each patient represents its genomic + measurement (gene expression + drug sensitivity). + + **Source data layout** (``data_dir``): + + .. code-block:: text + + exp_gdsc.csv Binary gene expression (1014 cell lines x 3000 genes) + gdsc.csv Binary drug sensitivity ( 846 cell lines x 260 drugs) + drug_info_gdsc.csv Drug metadata with target pathways and drug names + exp_emb_gdsc.csv Gene2Vec embeddings (3001 x 200) + + Args: + data_dir (str): Path to the directory containing the CSV files. + Defaults to ``"originalData"``. + seed (int): Random seed (used by downstream splitters). Defaults + to ``2019``. + + Attributes: + gene_names (List[str]): Ordered gene column names (3000 genes). + drug_ids (List[str]): Ordered drug column names (260 drugs). + drug_names (List[str]): Human-readable drug names resolved from + ``drug_info_gdsc.csv`` ``Name`` column. + drug_pathway_ids (List[int]): Integer pathway ID per drug (length 260). + gene_embeddings (np.ndarray): Pre-trained Gene2Vec matrix of shape + ``(3001, 200)``; row 0 is the zero-padding vector. + + Examples: + >>> from pyhealth.datasets import GDSCDataset + >>> dataset = GDSCDataset(data_dir="originalData") + >>> dataset.summary() + GDSC Dataset Summary + Cell lines: 846 + ... + >>> sample_ds = dataset.set_task() + >>> len(sample_ds) + 846 + >>> sample = sample_ds[0] + >>> "gene_indices" in sample + True + """ + + dataset_name: str = "GDSC" + + def __init__(self, data_dir: str = "originalData", seed: int = 2019) -> None: + self.data_dir = data_dir + self.seed = seed + self._load_data() + + # ------------------------------------------------------------------ + # Internal data loading + # ------------------------------------------------------------------ + + def _load_data(self) -> None: + """Load and align all CSV data files.""" + self.exp = pd.read_csv( + os.path.join(self.data_dir, "exp_gdsc.csv"), index_col=0 + ) + self.tgt = pd.read_csv( + os.path.join(self.data_dir, "gdsc.csv"), index_col=0 + ) + self.drug_info = pd.read_csv( + os.path.join(self.data_dir, "drug_info_gdsc.csv"), index_col=0 + ) + self.gene_embeddings = np.loadtxt( + os.path.join(self.data_dir, "exp_emb_gdsc.csv"), delimiter="," + ) + + # Restrict to cell lines present in both expression and sensitivity data + self.common_samples = sorted(set(self.exp.index) & set(self.tgt.index)) + self.exp = self.exp.loc[self.common_samples] + self.tgt = self.tgt.loc[self.common_samples] + + # Gene and drug IDs (before mapping functions) + self.gene_names: List[str] = list(self.exp.columns) + self.drug_ids: List[str] = list(self.tgt.columns) + + self._build_pathway_mapping() + self._build_id_to_name_mapping() + + def _build_pathway_mapping(self) -> None: + """Map each drug column to an integer pathway ID.""" + id2pw = dict(zip(self.drug_info.index, self.drug_info["Target pathway"])) + self.drug_pathways = [id2pw.get(int(c), "Unknown") for c in self.tgt.columns] + unique_pathways = sorted(set(self.drug_pathways)) + self.pathway2id: Dict[str, int] = {pw: i for i, pw in enumerate(unique_pathways)} + self.drug_pathway_ids: List[int] = [self.pathway2id[pw] for pw in self.drug_pathways] + + def _build_id_to_name_mapping(self) -> None: + """Map numeric drug IDs to drug names via drug_info_gdsc.csv ``Name`` column.""" + self.id_to_name: Dict[str, str] = {} + for drug_id, row in self.drug_info.iterrows(): + self.id_to_name[str(drug_id)] = row["Name"] + + self.drug_names: List[str] = [ + self.id_to_name.get(str(int(drug_id)), f"UNKNOWN_{drug_id}") + for drug_id in self.drug_ids + ] + + # ------------------------------------------------------------------ + # PyHealth task interface + # ------------------------------------------------------------------ + + def set_task(self, task=None): + """Apply a task to produce a model-ready dataset. + + Follows the PyHealth ``dataset.set_task()`` convention: iterates + over every cell line, calls ``task(patient)``, and collects the + returned sample dicts into a dataset. + + Args: + task: An object with ``task_name``, ``input_schema``, + ``output_schema`` attributes and a + ``__call__(patient) -> List[dict]`` method conforming to the + :class:`~pyhealth.tasks.BaseTask` interface. If ``None``, + :class:`~pyhealth.tasks.DrugSensitivityPredictionGDSC` is + used with default settings. + + Returns: + Dataset ready for a standard PyTorch + :class:`~torch.utils.data.DataLoader`. + """ + if task is None: + from pyhealth.tasks.drug_sensitivity_gdsc import ( + DrugSensitivityPredictionGDSC, + ) + task = DrugSensitivityPredictionGDSC() + + samples: List[Dict] = [] + for cell_line in self.common_samples: + patient = { + "patient_id": str(cell_line), + "gene_expression": self.exp.loc[cell_line].values, + "drug_sensitivity": self.tgt.loc[cell_line].values.astype(float), + "drug_pathway_ids": self.drug_pathway_ids, + } + samples.extend(task(patient)) + + return create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + dataset_name=self.dataset_name, + task_name=task.task_name, + in_memory=True, + ) + + # ------------------------------------------------------------------ + # Model-configuration accessors + # ------------------------------------------------------------------ + + def get_gene_embeddings(self) -> np.ndarray: + """Return the pre-trained Gene2Vec embedding matrix. + + Returns: + np.ndarray: Shape ``(3001, 200)``. Row 0 is the zero-padding + vector; rows 1-3000 correspond to 1-indexed gene positions used + in the ``gene_indices`` sample field. + """ + return self.gene_embeddings + + def get_pathway_info(self) -> Dict: + """Return drug pathway metadata for model initialisation. + + Returns: + dict: Keys are ``pathway2id``, ``id2pathway``, + ``num_pathways``, and ``drug_pathway_ids``. + """ + return { + "pathway2id": self.pathway2id, + "id2pathway": {v: k for k, v in self.pathway2id.items()}, + "num_pathways": len(self.pathway2id), + "drug_pathway_ids": self.drug_pathway_ids, + } + + def get_overlap_drugs(self, other_dataset) -> Tuple[List[int], List[int], List[str]]: + """Find drugs overlapping with another dataset by drug name. + + Compares by drug *name* rather than numeric ID so that GDSC (which + uses numeric COSMIC drug IDs as column headers) and CCLE (which uses + drug names) can be matched for cross-dataset evaluation. + + Args: + other_dataset: Another dataset instance with a ``drug_names`` + attribute (or falling back to ``drug_ids``). + + Returns: + tuple: ``(self_indices, other_indices, overlap_names)`` where each + element of ``self_indices`` / ``other_indices`` is the integer + column position of the shared drug in the respective dataset, and + ``overlap_names`` is the sorted list of shared drug name strings. + """ + self_names = set(self.drug_names) + other_names = set( + other_dataset.drug_names + if hasattr(other_dataset, "drug_names") + else other_dataset.drug_ids + ) + overlap = sorted(self_names & other_names) + + self_indices = [self.drug_names.index(d) for d in overlap] + other_indices = [ + ( + other_dataset.drug_names.index(d) + if hasattr(other_dataset, "drug_names") + else other_dataset.drug_ids.index(d) + ) + for d in overlap + ] + return self_indices, other_indices, overlap + + # ------------------------------------------------------------------ + # Utilities + # ------------------------------------------------------------------ + + def summary(self) -> None: + """Print dataset summary statistics to stdout.""" + total_pairs = len(self.common_samples) * len(self.tgt.columns) + tested = self.tgt.notnull().sum().sum() + sensitive = (self.tgt == 1).sum().sum() + resistant = (self.tgt == 0).sum().sum() + + print("GDSC Dataset Summary") + print(f" Cell lines: {len(self.common_samples)}") + print(f" Drugs: {len(self.tgt.columns)}") + print(f" Genes: {len(self.exp.columns)}") + print(f" Active genes/cell:{int(self.exp.sum(axis=1).mean())}") + print(f" Total pairs: {total_pairs}") + print(f" Tested pairs: {int(tested)} ({tested / total_pairs:.1%})") + print( + f" Missing pairs: {int(total_pairs - tested)}" + f" ({(total_pairs - tested) / total_pairs:.1%})" + ) + print(f" Sensitive: {int(sensitive)} ({sensitive / tested:.1%})") + print(f" Resistant: {int(resistant)} ({resistant / tested:.1%})") + print(f" Pathways: {len(self.pathway2id)}") + print(f" Embedding shape: {self.gene_embeddings.shape}") diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..4a1665b97 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -45,4 +45,7 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding -from .califorest import CaliForest \ No newline at end of file + +from .cadre import CADRE, ExpEncoder, DrugDecoder, collate_fn as cadre_collate_fn +from .cadre_dot_attn import CADREDotAttn, DotProductExpEncoder +from .califorest import CaliForest diff --git a/pyhealth/models/cadre.py b/pyhealth/models/cadre.py new file mode 100644 index 000000000..370f3e177 --- /dev/null +++ b/pyhealth/models/cadre.py @@ -0,0 +1,337 @@ +"""CADRE: Contextual Attention-based Drug REsponse prediction. + +Re-implementation of: + Tao, Y. et al. (2020). Predicting Drug Sensitivity of Cancer Cell Lines + via Collaborative Filtering with Contextual Attention. + Proceedings of Machine Learning Research, 126, 456-477. PMLR (MLHC 2020). + +Original code: https://github.com/yifengtao/CADRE + +Architecture: + 1. Gene Embedding Layer pretrained Gene2Vec (3001 x 200), frozen + 2. Contextual Attention drug pathway conditions gene importance + 3. Collaborative Filtering learned drug embeddings + dot-product decoder + 4. Prediction Head logit per (cell line, drug) pair + 5. Masked BCE Loss only scored on tested (cell line, drug) pairs +""" + +from typing import Dict, List, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel + + +class ExpEncoder(nn.Module): + """Gene expression encoder with contextual attention. + + Produces drug-specific cell-line representations by weighting gene + embeddings according to the drug's target pathway context. + + Tensor shapes through the forward pass:: + + gene_indices : (B, G) B = batch, G = max active genes + E : (B, G, emb) gene embeddings + E_exp : (B, D, G, emb) expanded across D drugs + Ep : (1, D, 1, q) pathway embeddings, broadcast + H : (B, D, G, q) tanh(W·e_gene + e_pathway) + A : (B, D, G, h) multi-head scores, softmax over G + A_sum : (B, D, G, 1) summed across heads + out : (B, D, emb) weighted gene representation per drug + + Args: + gene_embeddings (np.ndarray): Pre-trained Gene2Vec matrix ``(3001, 200)``. + num_pathways (int): Number of unique drug target pathways. + embedding_dim (int): Gene embedding dimension. Default: ``200``. + attention_size (int): Attention hidden dimension (q). Default: ``128``. + attention_head (int): Number of attention heads (h). Default: ``8``. + dropout_rate (float): Dropout probability. Default: ``0.6``. + use_attention (bool): Enable attention mechanism. Default: ``True``. + use_cntx_attn (bool): Enable contextual (pathway) conditioning. + Default: ``True``. If ``False``, runs self-attention only + (reproduces the SADRE ablation from the paper). + """ + + def __init__( + self, + gene_embeddings: np.ndarray, + num_pathways: int, + embedding_dim: int = 200, + attention_size: int = 128, + attention_head: int = 8, + dropout_rate: float = 0.6, + use_attention: bool = True, + use_cntx_attn: bool = True, + freeze_gene_emb: bool = True, + ) -> None: + super().__init__() + + self.use_attention = use_attention + self.use_cntx_attn = use_cntx_attn + + # Pretrained gene embeddings; frozen by default. + # Set freeze_gene_emb=False to fine-tune them (CADRE∆pretrain variant). + self.layer_emb = nn.Embedding.from_pretrained( + torch.FloatTensor(gene_embeddings), freeze=freeze_gene_emb, padding_idx=0 + ) + self.layer_dropout = nn.Dropout(p=dropout_rate) + + if self.use_attention: + self.layer_w_0 = nn.Linear(embedding_dim, attention_size, bias=True) + self.layer_beta = nn.Linear(attention_size, attention_head, bias=True) + if self.use_cntx_attn: + self.layer_emb_ptw = nn.Embedding( + num_embeddings=num_pathways, + embedding_dim=attention_size, + ) + + self.attention_weights: Optional[torch.Tensor] = None + + def forward( + self, gene_indices: torch.Tensor, ptw_ids: torch.Tensor + ) -> torch.Tensor: + """Compute drug-specific cell representations. + + Args: + gene_indices (torch.Tensor): Shape ``(B, G)`` — active gene + indices (1-indexed; 0 = padding). + ptw_ids (torch.Tensor): Shape ``(1, D)`` — pathway ID per drug. + + Returns: + torch.Tensor: Shape ``(B, D, embedding_dim)``. + """ + num_drugs = ptw_ids.shape[1] + E = self.layer_emb(gene_indices) # (B, G, emb) + + if self.use_attention: + E_exp = E.unsqueeze(1).expand(-1, num_drugs, -1, -1) # (B, D, G, emb) + + if self.use_cntx_attn: + Ep = self.layer_emb_ptw(ptw_ids).unsqueeze(2) # (1, D, 1, q) + H = torch.tanh(self.layer_w_0(E_exp) + Ep) + else: + H = torch.tanh(self.layer_w_0(E_exp)) + + A = F.softmax(self.layer_beta(H), dim=2) # (B, D, G, h) + A = A.sum(dim=3, keepdim=True) # (B, D, G, 1) + self.attention_weights = A.squeeze(3) # (B, D, G) for interpretability + + out = torch.matmul(A.permute(0, 1, 3, 2), E_exp).squeeze(2) # (B, D, emb) + else: + out = E.mean(dim=1).unsqueeze(1).expand(-1, num_drugs, -1) # (B, D, emb) + + return self.layer_dropout(out) + + +class DrugDecoder(nn.Module): + """Collaborative filtering decoder with learned drug embeddings. + + Computes a dot product between the encoder output and learned drug + embeddings to produce a sensitivity logit for each (cell line, drug) pair. + + Args: + num_drugs (int): Number of drugs (260 for GDSC). + embedding_dim (int): Embedding dimension. Default: ``200``. + """ + + def __init__(self, num_drugs: int, embedding_dim: int = 200) -> None: + super().__init__() + self.layer_emb_drg = nn.Embedding( + num_embeddings=num_drugs, embedding_dim=embedding_dim + ) + self.drg_bias = nn.Parameter(torch.zeros(num_drugs)) + + def forward( + self, cell_repr: torch.Tensor, drg_ids: torch.Tensor + ) -> torch.Tensor: + """Compute per-drug logits via dot product. + + Args: + cell_repr (torch.Tensor): Shape ``(B, D, emb)`` from encoder. + drg_ids (torch.Tensor): Shape ``(1, D)`` — drug index range. + + Returns: + torch.Tensor: Shape ``(B, D)`` — raw logits. + """ + D = self.layer_emb_drg(drg_ids).expand(cell_repr.shape[0], -1, -1) + logits = (cell_repr * D).sum(dim=2) + self.drg_bias.unsqueeze(0) + return logits + + +class CADRE(BaseModel): + """CADRE: Contextual Attention-based Drug REsponse prediction model. + + Combines :class:`ExpEncoder` (frozen Gene2Vec embeddings + multi-head + contextual attention) with :class:`DrugDecoder` (collaborative filtering) + for multi-task binary drug sensitivity prediction. + + Inherits from :class:`~pyhealth.models.BaseModel`. Because CADRE requires + dataset-specific inputs (Gene2Vec matrix, drug/pathway metadata) rather than + a generic ``SampleDataset``, ``dataset`` is left as ``None`` and the GDSC- + specific arguments are passed directly. + + Integrates with PyHealth via :class:`~pyhealth.datasets.GDSCDataset` + and :class:`~pyhealth.tasks.DrugSensitivityPredictionGDSC`. + + Args: + gene_embeddings (np.ndarray): Pre-trained Gene2Vec matrix ``(3001, 200)``. + num_drugs (int): Number of drugs. ``260`` for GDSC. + num_pathways (int): Number of unique drug target pathways. ``25`` for GDSC. + drug_pathway_ids (List[int]): Pathway ID for each drug, length ``num_drugs``. + embedding_dim (int): Gene/drug embedding dimension. Default: ``200``. + attention_size (int): Attention hidden dimension. Default: ``128``. + attention_head (int): Number of attention heads. Default: ``8``. + dropout_rate (float): Dropout probability. Default: ``0.6``. + use_attention (bool): Enable attention. Default: ``True``. + use_cntx_attn (bool): Enable contextual (pathway) conditioning. + Default: ``True``. + freeze_gene_emb (bool): Freeze Gene2Vec weights during training. + Set ``False`` for the CADRE∆pretrain ablation. Default: ``True``. + + Examples: + >>> import numpy as np + >>> from pyhealth.models import CADRE + >>> gene_emb = np.zeros((3001, 200)) + >>> model = CADRE( + ... gene_embeddings=gene_emb, + ... num_drugs=260, + ... num_pathways=25, + ... drug_pathway_ids=list(range(260)), + ... ) + >>> gene_indices = torch.randint(1, 3001, (4, 1500)) # batch=4 + >>> out = model(gene_indices) + >>> out["probs"].shape + torch.Size([4, 260]) + """ + + def __init__( + self, + gene_embeddings: np.ndarray, + num_drugs: int, + num_pathways: int, + drug_pathway_ids: List[int], + embedding_dim: int = 200, + attention_size: int = 128, + attention_head: int = 8, + dropout_rate: float = 0.6, + use_attention: bool = True, + use_cntx_attn: bool = True, + freeze_gene_emb: bool = True, + ) -> None: + # CADRE takes GDSC-specific constructor args rather than a generic + # SampleDataset, so we pass dataset=None to BaseModel. + super().__init__(dataset=None) + + self.num_drugs = num_drugs + self.embedding_dim = embedding_dim + + self.register_buffer("ptw_ids", torch.LongTensor([drug_pathway_ids])) + self.register_buffer("drg_ids", torch.arange(num_drugs).unsqueeze(0)) + + self.encoder = ExpEncoder( + gene_embeddings=gene_embeddings, + num_pathways=num_pathways, + embedding_dim=embedding_dim, + attention_size=attention_size, + attention_head=attention_head, + dropout_rate=dropout_rate, + use_attention=use_attention, + use_cntx_attn=use_cntx_attn, + freeze_gene_emb=freeze_gene_emb, + ) + self.decoder = DrugDecoder(num_drugs=num_drugs, embedding_dim=embedding_dim) + self.loss_fn = nn.BCEWithLogitsLoss(reduction="none") + + def forward( + self, + gene_indices: torch.Tensor, + labels: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + """Forward pass. + + Args: + gene_indices (torch.Tensor): Shape ``(B, G)`` — active gene + indices padded to the batch maximum. + labels (torch.Tensor, optional): Shape ``(B, D)`` — binary + drug sensitivity labels. Required to compute ``loss``. + mask (torch.Tensor, optional): Shape ``(B, D)`` — 1 if the + drug was tested for that cell line, 0 otherwise. Required + to compute ``loss``. + + Returns: + dict: Always contains: + + * ``"logits"`` — raw logits ``(B, D)`` + * ``"probs"`` — sigmoid probabilities ``(B, D)`` + + When ``labels`` and ``mask`` are provided, also contains: + + * ``"loss"`` — masked BCE loss (scalar) + * ``"y_true"`` — same as ``labels`` + """ + cell_repr = self.encoder(gene_indices, self.ptw_ids) + logits = self.decoder(cell_repr, self.drg_ids) + probs = torch.sigmoid(logits) + + # "logit" / "y_prob" aliases satisfy the BaseModel forward contract. + result: Dict[str, torch.Tensor] = { + "logits": logits, + "probs": probs, + "logit": logits, + "y_prob": probs, + } + + if labels is not None and mask is not None: + per_element = self.loss_fn(logits, labels.float()) + result["loss"] = (per_element * mask).sum() / (mask.sum() + 1e-5) + result["y_true"] = labels + + if self.encoder.attention_weights is not None: + result["attention"] = self.encoder.attention_weights + + return result + + def get_attention_weights(self) -> Optional[torch.Tensor]: + """Return the last batch's attention weights for interpretability. + + Returns: + torch.Tensor or None: Shape ``(B, D, G)`` — attention weight + per (drug, gene) pair in the last forward call. + """ + return self.encoder.attention_weights + + +def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]: + """Collate a list of sample dicts into a batched tensor dict. + + Pads ``gene_indices`` to the maximum length within the batch using + the padding index (0). All other fields are stacked directly. + + Args: + batch (List[dict]): Sample dicts from + :class:`~pyhealth.datasets.SampleBaseDataset`. + + Returns: + dict: Keys ``gene_indices`` (LongTensor), ``labels`` (FloatTensor), + ``mask`` (FloatTensor), ``patient_ids`` (List[str]). + """ + max_genes = max(len(s["gene_indices"]) for s in batch) + gene_indices, labels, masks, patient_ids = [], [], [], [] + + for s in batch: + gi = s["gene_indices"] + gene_indices.append(gi + [0] * (max_genes - len(gi))) + labels.append(s["labels"]) + masks.append(s["mask"]) + patient_ids.append(s["patient_id"]) + + return { + "gene_indices": torch.LongTensor(gene_indices), + "labels": torch.FloatTensor(labels), + "mask": torch.FloatTensor(masks), + "patient_ids": patient_ids, + } diff --git a/pyhealth/models/cadre_dot_attn.py b/pyhealth/models/cadre_dot_attn.py new file mode 100644 index 000000000..ef761ee8e --- /dev/null +++ b/pyhealth/models/cadre_dot_attn.py @@ -0,0 +1,222 @@ +"""CADREDotAttn: CADRE with scaled dot-product attention (Extension). + +Replaces CADRE's additive contextual attention:: + + score = W_beta · tanh(W_alpha · e_gene + e_pathway) [CADRE] + +with transformer-style scaled dot-product attention:: + + score = (W_Q · e_drug) · (W_K · e_gene)^T / sqrt(d_k) [this module] + +Drug embeddings from :class:`~pyhealth.models.DrugDecoder` serve as +attention queries, so gradients flow through both the prediction path +(decoder dot-product) and the attention path (encoder alignment), jointly +shaping drug representations to predict sensitivity *and* attend to +relevant genes. + +Reference extension of: + Tao, Y. et al. (2020). Predicting Drug Sensitivity of Cancer Cell Lines + via Collaborative Filtering with Contextual Attention. MLHC 2020. +""" + +from typing import Dict, List, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pyhealth.models.cadre import DrugDecoder + + +class DotProductExpEncoder(nn.Module): + """Gene expression encoder using scaled dot-product (transformer) attention. + + Tensor shapes through the forward pass:: + + gene_indices : (B, G) B = batch, G = max genes + G_emb : (B, G, emb) gene embeddings + K : (B, H, G, d_k) gene key vectors + Q : (B, H, D, d_k) drug query vectors + scores : (B, H, D, G) scaled dot-product scores + attn : (B, H, D, G) softmax (padding masked to -inf) + context : (B, H, D, d_k) + output : (B, D, emb) + + Args: + gene_embeddings (np.ndarray): Pre-trained Gene2Vec matrix ``(3001, 200)``. + num_drugs (int): Number of drugs. + embedding_dim (int): Gene/drug embedding dimension. Default: ``200``. + num_heads (int): Number of attention heads. Default: ``8``. + d_k (int): Key/query dimension per head. Default: ``64``. + dropout_rate (float): Dropout probability. Default: ``0.6``. + """ + + def __init__( + self, + gene_embeddings: np.ndarray, + num_drugs: int, + embedding_dim: int = 200, + num_heads: int = 8, + d_k: int = 64, + dropout_rate: float = 0.6, + ) -> None: + super().__init__() + + self.num_drugs = num_drugs + self.num_heads = num_heads + self.d_k = d_k + self.embedding_dim = embedding_dim + + self.layer_emb = nn.Embedding.from_pretrained( + torch.FloatTensor(gene_embeddings), freeze=True, padding_idx=0 + ) + self.key_proj = nn.Linear(embedding_dim, num_heads * d_k, bias=False) + self.query_proj = nn.Linear(embedding_dim, num_heads * d_k, bias=False) + self.W_O = nn.Linear(num_heads * d_k, embedding_dim, bias=False) + self.layer_dropout = nn.Dropout(p=dropout_rate) + + self.attention_weights: Optional[torch.Tensor] = None + + def forward( + self, gene_indices: torch.Tensor, drug_embeddings: torch.Tensor + ) -> torch.Tensor: + """Compute drug-specific cell representations via dot-product attention. + + Args: + gene_indices (torch.Tensor): Shape ``(B, G)``; 0 = padding. + drug_embeddings (torch.Tensor): Shape ``(D, emb)`` from + :class:`~pyhealth.models.DrugDecoder`. + + Returns: + torch.Tensor: Shape ``(B, D, embedding_dim)``. + """ + B, G = gene_indices.shape + D = drug_embeddings.shape[0] + H, dk = self.num_heads, self.d_k + + G_emb = self.layer_emb(gene_indices) # (B, G, emb) + + # Padding mask: positions where gene_index == 0 + pad_mask = (gene_indices == 0).unsqueeze(1).unsqueeze(2) # (B, 1, 1, G) + + K = self.key_proj(G_emb).view(B, G, H, dk).transpose(1, 2) # (B, H, G, dk) + Q = ( + self.query_proj(drug_embeddings) + .view(D, H, dk) + .permute(1, 0, 2) + .unsqueeze(0) + .expand(B, -1, -1, -1) + ) # (B, H, D, dk) + + scores = torch.matmul(Q, K.transpose(-2, -1)) / (dk ** 0.5) # (B, H, D, G) + scores = scores.masked_fill(pad_mask, float("-inf")) + attn = F.softmax(scores, dim=-1).nan_to_num(0.0) + + self.attention_weights = attn.mean(dim=1).detach() # (B, D, G) + + context = torch.matmul(attn, K) # (B, H, D, dk) + context = context.permute(0, 2, 1, 3).contiguous() # (B, D, H, dk) + context = context.view(B, D, H * dk) + out = self.W_O(context) # (B, D, emb) + + return self.layer_dropout(out) + + +class CADREDotAttn(nn.Module): + """CADRE variant using scaled dot-product attention (transformer-style). + + Replaces CADRE's additive contextual attention with multi-head + scaled dot-product attention. Drug embeddings act as queries so they + receive gradients from both the prediction dot-product (decoder) and + the attention alignment (encoder), shaping them jointly. + + See :class:`~pyhealth.models.CADRE` for the baseline model. + + Args: + gene_embeddings (np.ndarray): Pre-trained Gene2Vec matrix ``(3001, 200)``. + num_drugs (int): Number of drugs. ``260`` for GDSC. + embedding_dim (int): Gene/drug embedding dimension. Default: ``200``. + num_heads (int): Number of attention heads. Default: ``8``. + d_k (int): Key/query dimension per head. Default: ``64``. + dropout_rate (float): Dropout probability. Default: ``0.6``. + + Examples: + >>> import numpy as np, torch + >>> from pyhealth.models import CADREDotAttn + >>> gene_emb = np.zeros((3001, 200)) + >>> model = CADREDotAttn(gene_embeddings=gene_emb, num_drugs=260) + >>> gene_indices = torch.randint(1, 3001, (4, 1500)) + >>> out = model(gene_indices) + >>> out["probs"].shape + torch.Size([4, 260]) + """ + + def __init__( + self, + gene_embeddings: np.ndarray, + num_drugs: int, + embedding_dim: int = 200, + num_heads: int = 8, + d_k: int = 64, + dropout_rate: float = 0.6, + ) -> None: + super().__init__() + + self.num_drugs = num_drugs + self.embedding_dim = embedding_dim + + self.register_buffer("drg_ids", torch.arange(num_drugs).unsqueeze(0)) + + self.encoder = DotProductExpEncoder( + gene_embeddings=gene_embeddings, + num_drugs=num_drugs, + embedding_dim=embedding_dim, + num_heads=num_heads, + d_k=d_k, + dropout_rate=dropout_rate, + ) + self.decoder = DrugDecoder(num_drugs=num_drugs, embedding_dim=embedding_dim) + self.loss_fn = nn.BCEWithLogitsLoss(reduction="none") + + def forward( + self, + gene_indices: torch.Tensor, + labels: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + """Forward pass. + + Args: + gene_indices (torch.Tensor): Shape ``(B, G)``. + labels (torch.Tensor, optional): Shape ``(B, D)`` binary labels. + mask (torch.Tensor, optional): Shape ``(B, D)`` tested mask. + + Returns: + dict: ``"logits"``, ``"probs"``, optionally ``"loss"``, + ``"y_true"``, and ``"attention"``. + """ + drug_emb = self.decoder.layer_emb_drg(self.drg_ids).squeeze(0) # (D, emb) + cell_repr = self.encoder(gene_indices, drug_emb) # (B, D, emb) + logits = self.decoder(cell_repr, self.drg_ids) # (B, D) + probs = torch.sigmoid(logits) + + result: Dict[str, torch.Tensor] = {"logits": logits, "probs": probs} + + if labels is not None and mask is not None: + per_element = self.loss_fn(logits, labels.float()) + result["loss"] = (per_element * mask).sum() / (mask.sum() + 1e-5) + result["y_true"] = labels + + if self.encoder.attention_weights is not None: + result["attention"] = self.encoder.attention_weights + + return result + + def get_attention_weights(self) -> Optional[torch.Tensor]: + """Return last batch attention weights for interpretability. + + Returns: + torch.Tensor or None: Shape ``(B, D, G)``. + """ + return self.encoder.attention_weights diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..fd8d8e1db 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -67,3 +67,6 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task + +from .drug_sensitivity_gdsc import DrugSensitivityPredictionGDSC +from .drug_sensitivity_ccle import DrugSensitivityPredictionCCLE diff --git a/pyhealth/tasks/drug_sensitivity_ccle.py b/pyhealth/tasks/drug_sensitivity_ccle.py new file mode 100644 index 000000000..4939b4abc --- /dev/null +++ b/pyhealth/tasks/drug_sensitivity_ccle.py @@ -0,0 +1,111 @@ +"""Drug sensitivity prediction task for the CCLE dataset. + +Converts raw CCLE cell-line records into sample dicts compatible with +PyHealth's dataset pipeline and the CADRE model family. + +This task is the recommended default when calling +``CCLEDataset.set_task()``. +""" + +from typing import Dict, List + +import numpy as np + +from pyhealth.tasks.base_task import BaseTask + + +class DrugSensitivityPredictionCCLE(BaseTask): + """Multi-label drug sensitivity prediction from CCLE gene expression profiles. + + Transforms a single CCLE *patient* record (one cancer cell line) into a + sample dict containing the active-gene index sequence, binary drug + sensitivity labels, and a tested-drug mask. + + Follows the PyHealth :class:`~pyhealth.tasks.BaseTask` interface: + ``task_name``, ``input_schema``, ``output_schema``, and a callable + ``__call__(patient) -> List[dict]``. + + **Input** patient dict (produced by + :class:`~pyhealth.datasets.CCLEDataset`): + + .. code-block:: text + + patient_id str Cell-line identifier + gene_expression np.ndarray[int] Binary indicator vector + drug_sensitivity np.ndarray[float] Binary labels; NaN = untested + drug_pathway_ids List[int] Integer pathway ID per drug + + **Output** sample dict (one per cell line): + + .. code-block:: text + + patient_id str Cell-line identifier + visit_id str Same as patient_id (one record per line) + gene_indices List[int] 1-indexed active gene positions + labels List[int] Binary drug sensitivity labels + mask List[int] 1 = drug was tested, 0 = missing + drug_pathway_ids List[int] Integer pathway ID per drug + + Examples: + >>> import numpy as np + >>> from pyhealth.tasks import DrugSensitivityPredictionCCLE + >>> task = DrugSensitivityPredictionCCLE() + >>> gene_expr = np.zeros(3000, dtype=int) + >>> gene_expr[[10, 42]] = 1 + >>> drug_sens = np.array([1.0, np.nan, 0.0]) + >>> patient = { + ... "patient_id": "MCF7", + ... "gene_expression": gene_expr, + ... "drug_sensitivity": drug_sens, + ... "drug_pathway_ids": [0, 1, 2], + ... } + >>> samples = task(patient) + >>> len(samples) + 1 + >>> samples[0]["gene_indices"] # 1-indexed + [11, 43] + >>> samples[0]["mask"] + [1, 0, 1] + """ + + task_name: str = "drug_sensitivity_prediction" + + input_schema: Dict = { + "gene_indices": "sequence", + "drug_pathway_ids": "sequence", + } + output_schema: Dict = { + "labels": "raw", + "mask": "raw", + } + + def __init__(self) -> None: + super().__init__() + + def __call__(self, patient: Dict) -> List[Dict]: + """Extract one sample dict from a CCLE cell-line patient record. + + Args: + patient (dict): Must contain ``patient_id``, ``gene_expression``, + ``drug_sensitivity``, and ``drug_pathway_ids``. + + Returns: + List[dict]: Single-element list. + """ + gene_vec = np.asarray(patient["gene_expression"]) + gene_indices = (np.where(gene_vec == 1)[0] + 1).tolist() + + sensitivity = np.asarray(patient["drug_sensitivity"], dtype=float) + mask = (~np.isnan(sensitivity)).astype(int).tolist() + labels = np.nan_to_num(sensitivity, nan=0.0).astype(int).tolist() + + return [ + { + "patient_id": patient["patient_id"], + "visit_id": patient["patient_id"], + "gene_indices": gene_indices, + "labels": labels, + "mask": mask, + "drug_pathway_ids": list(patient["drug_pathway_ids"]), + } + ] diff --git a/pyhealth/tasks/drug_sensitivity_gdsc.py b/pyhealth/tasks/drug_sensitivity_gdsc.py new file mode 100644 index 000000000..9ba8d1510 --- /dev/null +++ b/pyhealth/tasks/drug_sensitivity_gdsc.py @@ -0,0 +1,121 @@ +"""Drug sensitivity prediction task for the GDSC dataset. + +Converts raw GDSC cell-line records into sample dicts compatible with +PyHealth's dataset pipeline and the CADRE model family. + +This task is the recommended default when calling +``GDSCDataset.set_task()``. +""" + +from typing import Dict, List + +import numpy as np + +from pyhealth.tasks.base_task import BaseTask + + +class DrugSensitivityPredictionGDSC(BaseTask): + """Multi-label drug sensitivity prediction from gene expression profiles. + + Transforms a single GDSC *patient* record (one cancer cell line) into a + sample dict containing the active-gene index sequence, binary drug + sensitivity labels, and a tested-drug mask. + + Follows the PyHealth :class:`~pyhealth.tasks.BaseTask` interface: + ``task_name``, ``input_schema``, ``output_schema``, and a callable + ``__call__(patient) -> List[dict]``. + + **Input** patient dict (produced by + :class:`~pyhealth.datasets.GDSCDataset`): + + .. code-block:: text + + patient_id str COSMIC cell-line identifier + gene_expression np.ndarray[int] Binary indicator vector, shape (3000,) + drug_sensitivity np.ndarray[float] Binary labels; NaN = untested, shape (260,) + drug_pathway_ids List[int] Integer pathway ID per drug, length 260 + + **Output** sample dict (one per cell line): + + .. code-block:: text + + patient_id str COSMIC cell-line identifier + visit_id str Same as patient_id (one record per line) + gene_indices List[int] 1-indexed active gene positions (~1500 entries) + labels List[int] Binary drug sensitivity labels (260,) + mask List[int] 1 = drug was tested, 0 = missing (260,) + drug_pathway_ids List[int] Integer pathway ID per drug (260,) + + The ``gene_indices`` field uses 1-based indexing so that index 0 is + reserved as the padding token in the Gene2Vec embedding table. + + Examples: + >>> import numpy as np + >>> from pyhealth.tasks import DrugSensitivityPredictionGDSC + >>> task = DrugSensitivityPredictionGDSC() + >>> gene_expr = np.zeros(3000, dtype=int) + >>> gene_expr[[1, 5, 99]] = 1 # 3 active genes + >>> drug_sens = np.array([1.0, 0.0, np.nan]) # 2 tested, 1 missing + >>> patient = { + ... "patient_id": "COSMIC.906826", + ... "gene_expression": gene_expr, + ... "drug_sensitivity": drug_sens, + ... "drug_pathway_ids": [0, 1, 2], + ... } + >>> samples = task(patient) + >>> len(samples) + 1 + >>> samples[0]["gene_indices"] # 1-indexed + [2, 6, 100] + >>> samples[0]["labels"] + [1, 0, 0] + >>> samples[0]["mask"] + [1, 1, 0] + """ + + task_name: str = "drug_sensitivity_prediction" + + # Processor type strings tell PyHealth processors how to handle each field. + # "sequence" = variable-length integer list; "multilabel" = fixed-length binary vector. + input_schema: Dict = { + "gene_indices": "sequence", + "drug_pathway_ids": "sequence", + } + output_schema: Dict = { + "labels": "raw", + "mask": "raw", + } + + def __init__(self) -> None: + super().__init__() + + def __call__(self, patient: Dict) -> List[Dict]: + """Extract one sample dict from a GDSC cell-line patient record. + + Args: + patient (dict): Must contain ``patient_id``, ``gene_expression`` + (shape ``(3000,)``), ``drug_sensitivity`` (shape ``(260,)``), + and ``drug_pathway_ids`` (length ``260``). + + Returns: + List[dict]: Single-element list; the GDSC data model has one + record per cell line. + """ + gene_vec = np.asarray(patient["gene_expression"]) + # 1-indexed: embedding row 0 is reserved for padding + gene_indices = (np.where(gene_vec == 1)[0] + 1).tolist() + + sensitivity = np.asarray(patient["drug_sensitivity"], dtype=float) + mask = (~np.isnan(sensitivity)).astype(int).tolist() + labels = np.nan_to_num(sensitivity, nan=0.0).astype(int).tolist() + + return [ + { + "patient_id": patient["patient_id"], + "visit_id": patient["patient_id"], + "gene_indices": gene_indices, + "labels": labels, + "mask": mask, + "drug_pathway_ids": list(patient["drug_pathway_ids"]), + } + ] diff --git a/tests/core/test_cadre_model.py b/tests/core/test_cadre_model.py new file mode 100644 index 000000000..fc465e7fb --- /dev/null +++ b/tests/core/test_cadre_model.py @@ -0,0 +1,100 @@ +"""Unit tests for CADRE and CADREDotAttn models. + +Uses tiny synthetic inputs — no real data files, no network access. +""" + +import numpy as np +import pytest +import torch + +from pyhealth.models.cadre import CADRE, collate_fn +from pyhealth.models.cadre_dot_attn import CADREDotAttn + +N_GENES_EMB = 21 # 1 padding row + 20 genes +N_DRUGS = 10 +N_PATHWAYS = 3 +EMB_DIM = 16 +BATCH = 3 +MAX_GENES = 6 + + +def _gene_emb(): + return np.random.RandomState(0).randn(N_GENES_EMB, EMB_DIM).astype(np.float32) + + +def _pathway_ids(): + return [i % N_PATHWAYS for i in range(N_DRUGS)] + + +def _gene_indices(): + return torch.randint(1, N_GENES_EMB, (BATCH, MAX_GENES)) + + +@pytest.fixture +def cadre(): + return CADRE( + gene_embeddings=_gene_emb(), num_drugs=N_DRUGS, num_pathways=N_PATHWAYS, + drug_pathway_ids=_pathway_ids(), embedding_dim=EMB_DIM, + attention_size=8, attention_head=2, dropout_rate=0.0, + ).eval() + + +@pytest.fixture +def dot_attn(): + return CADREDotAttn( + gene_embeddings=_gene_emb(), num_drugs=N_DRUGS, + embedding_dim=EMB_DIM, num_heads=2, d_k=4, dropout_rate=0.0, + ).eval() + + +def test_cadre_forward(cadre): + gi = _gene_indices() + labels = torch.randint(0, 2, (BATCH, N_DRUGS)).float() + mask = torch.ones(BATCH, N_DRUGS) + + out = cadre(gi, labels=labels, mask=mask) + assert out["probs"].shape == (BATCH, N_DRUGS) + assert out["logits"].shape == (BATCH, N_DRUGS) + assert 0.0 <= out["probs"].min() and out["probs"].max() <= 1.0 + assert out["loss"].shape == () + assert out["loss"].item() > 0.0 + assert torch.equal(out["y_true"], labels) + assert cadre.get_attention_weights().shape == (BATCH, N_DRUGS, MAX_GENES) + + +def test_cadre_no_loss_without_labels(cadre): + out = cadre(_gene_indices()) + assert "loss" not in out + + +def test_cadre_ablations(): + """Both use_attention=False and use_cntx_attn=False should still produce valid probs.""" + gi = _gene_indices() + for kwargs in [{"use_attention": False}, {"use_cntx_attn": False}]: + model = CADRE( + gene_embeddings=_gene_emb(), num_drugs=N_DRUGS, num_pathways=N_PATHWAYS, + drug_pathway_ids=_pathway_ids(), embedding_dim=EMB_DIM, + attention_size=8, attention_head=2, dropout_rate=0.0, **kwargs, + ).eval() + assert model(gi)["probs"].shape == (BATCH, N_DRUGS) + + +def test_dot_attn_forward(dot_attn): + gi = _gene_indices() + out = dot_attn(gi) + assert out["probs"].shape == (BATCH, N_DRUGS) + assert 0.0 <= out["probs"].min() and out["probs"].max() <= 1.0 + attn = dot_attn.get_attention_weights() + assert attn.shape == (BATCH, N_DRUGS, MAX_GENES) + assert attn.min() >= 0.0 + + +def test_collate_fn(): + batch = [ + {"gene_indices": [1, 2, 3], "labels": [0] * N_DRUGS, "mask": [1] * N_DRUGS, "patient_id": "A"}, + {"gene_indices": [4, 5], "labels": [1] * N_DRUGS, "mask": [1] * N_DRUGS, "patient_id": "B"}, + ] + out = collate_fn(batch) + assert out["gene_indices"].shape == (2, 3) # padded to longest + assert out["gene_indices"][1, 2].item() == 0 # padding index + assert out["labels"].shape == (2, N_DRUGS) diff --git a/tests/core/test_ccle_dataset.py b/tests/core/test_ccle_dataset.py new file mode 100644 index 000000000..686a05e6c --- /dev/null +++ b/tests/core/test_ccle_dataset.py @@ -0,0 +1,146 @@ +"""Unit tests for CCLEDataset and cross-dataset overlap with GDSCDataset. + +Uses synthetic in-memory data — no real CSV files, no network access. +""" + +import os +import tempfile + +import numpy as np +import pandas as pd +import pytest + +from pyhealth.datasets.ccle import CCLEDataset +from pyhealth.datasets.gdsc import GDSCDataset + +N_CELL_LINES = 4 +N_GENES = 20 +N_PATHWAYS = 3 +EMB_DIM = 8 +RNG = np.random.RandomState(42) + +SHARED_DRUG_NAMES = ["DrugA", "DrugB", "DrugC"] +CCLE_ONLY_NAMES = ["DrugX", "DrugY"] +GDSC_ONLY_NAMES = ["DrugM", "DrugN", "DrugO", "DrugP", "DrugQ", "DrugR", "DrugS"] + + +def _make_ccle_data(tmp_dir: str) -> None: + cell_ids = [f"CCLE.{i}" for i in range(N_CELL_LINES)] + gene_cols = [str(g) for g in range(1, N_GENES + 1)] + drug_names = SHARED_DRUG_NAMES + CCLE_ONLY_NAMES + + exp = pd.DataFrame( + RNG.randint(0, 2, size=(N_CELL_LINES, N_GENES)), + index=cell_ids, columns=gene_cols, + ) + exp.to_csv(os.path.join(tmp_dir, "exp_ccle.csv")) + + tgt_data = RNG.randint(0, 2, size=(N_CELL_LINES, len(drug_names))).astype(float) + tgt_data[RNG.rand(N_CELL_LINES, len(drug_names)) < 0.2] = np.nan + pd.DataFrame(tgt_data, index=cell_ids, columns=drug_names).to_csv( + os.path.join(tmp_dir, "ccle.csv") + ) + + pathway_names = ["PathwayA", "PathwayB", "PathwayC"] + pd.DataFrame( + {"Target pathway": [pathway_names[i % N_PATHWAYS] for i in range(len(drug_names))]}, + index=drug_names, + ).to_csv(os.path.join(tmp_dir, "drug_info_ccle.csv")) + + np.savetxt( + os.path.join(tmp_dir, "exp_emb_ccle.csv"), + RNG.randn(N_GENES + 1, EMB_DIM), delimiter="," + ) + + +def _make_gdsc_data(tmp_dir: str) -> None: + cell_ids = [f"COSMIC.{i}" for i in range(N_CELL_LINES)] + gene_cols = [str(g) for g in range(1, N_GENES + 1)] + all_names = SHARED_DRUG_NAMES + GDSC_ONLY_NAMES + drug_ids = list(range(1, len(all_names) + 1)) + + exp = pd.DataFrame( + RNG.randint(0, 2, size=(N_CELL_LINES, N_GENES)), + index=cell_ids, columns=gene_cols, + ) + exp.to_csv(os.path.join(tmp_dir, "exp_gdsc.csv")) + + tgt_data = RNG.randint(0, 2, size=(N_CELL_LINES, len(drug_ids))).astype(float) + tgt_data[RNG.rand(N_CELL_LINES, len(drug_ids)) < 0.2] = np.nan + pd.DataFrame(tgt_data, index=cell_ids, columns=[str(d) for d in drug_ids]).to_csv( + os.path.join(tmp_dir, "gdsc.csv") + ) + + pathway_names = ["PathwayA", "PathwayB", "PathwayC"] + pd.DataFrame( + { + "Name": all_names, + "Target pathway": [pathway_names[i % N_PATHWAYS] for i in range(len(all_names))], + }, + index=drug_ids, + ).to_csv(os.path.join(tmp_dir, "drug_info_gdsc.csv")) + + np.savetxt( + os.path.join(tmp_dir, "exp_emb_gdsc.csv"), + RNG.randn(N_GENES + 1, EMB_DIM), delimiter="," + ) + + +@pytest.fixture(scope="module") +def ccle(): + with tempfile.TemporaryDirectory() as tmp_dir: + _make_ccle_data(tmp_dir) + yield CCLEDataset(data_dir=tmp_dir) + + +@pytest.fixture(scope="module") +def gdsc(): + with tempfile.TemporaryDirectory() as tmp_dir: + _make_gdsc_data(tmp_dir) + yield GDSCDataset(data_dir=tmp_dir) + + +def test_ccle_loads_and_shapes(ccle): + assert len(ccle.common_samples) == N_CELL_LINES + assert len(ccle.gene_names) == N_GENES + assert len(ccle.drug_ids) == len(SHARED_DRUG_NAMES) + len(CCLE_ONLY_NAMES) + assert ccle.drug_names == ccle.drug_ids # CCLE uses names as column headers + assert ccle.get_gene_embeddings().shape == (N_GENES + 1, EMB_DIM) + assert ccle.dataset_name == "CCLE" + + +def test_ccle_set_task(ccle): + sample_ds = ccle.set_task() + assert len(sample_ds) == N_CELL_LINES + sample = sample_ds[0] + assert set(sample.keys()) == {"patient_id", "visit_id", "gene_indices", "labels", "mask", "drug_pathway_ids"} + assert 0 not in sample["gene_indices"] + assert len(sample["labels"]) == len(ccle.drug_ids) + assert all(v in (0, 1) for v in sample["mask"]) + + +def test_ccle_missing_data_raises(tmp_path): + with pytest.raises(FileNotFoundError): + CCLEDataset(data_dir=str(tmp_path)) + + +def test_cross_dataset_overlap(gdsc, ccle): + gdsc_idx, ccle_idx, names = gdsc.get_overlap_drugs(ccle) + + # Correct drugs found + assert set(names) == set(SHARED_DRUG_NAMES) + assert names == sorted(names) + + # Indices point to the right names in each dataset + for i, (gi, ci) in enumerate(zip(gdsc_idx, ccle_idx)): + assert gdsc.drug_names[gi] == names[i] + assert ccle.drug_names[ci] == names[i] + + # Symmetric: same result from either side + _, _, names_rev = ccle.get_overlap_drugs(gdsc) + assert set(names) == set(names_rev) + + +def test_ccle_summary_runs(ccle, capsys): + ccle.summary() + assert "CCLE Dataset Summary" in capsys.readouterr().out diff --git a/tests/core/test_drug_sensitivity_gdsc_task.py b/tests/core/test_drug_sensitivity_gdsc_task.py new file mode 100644 index 000000000..6c3c21022 --- /dev/null +++ b/tests/core/test_drug_sensitivity_gdsc_task.py @@ -0,0 +1,66 @@ +"""Unit tests for DrugSensitivityPredictionGDSC task. + +Uses only in-memory synthetic data — no file I/O, no network access. +""" + +import numpy as np +import pytest + +from pyhealth.tasks.drug_sensitivity_gdsc import DrugSensitivityPredictionGDSC + +N_GENES = 20 +N_DRUGS = 10 + + +@pytest.fixture +def task(): + return DrugSensitivityPredictionGDSC() + + +@pytest.fixture +def patient(): + rng = np.random.RandomState(42) + drug_sens = rng.randint(0, 2, size=N_DRUGS).astype(float) + drug_sens[rng.rand(N_DRUGS) < 0.2] = np.nan + return { + "patient_id": "COSMIC.42", + "gene_expression": rng.randint(0, 2, size=N_GENES), + "drug_sensitivity": drug_sens, + "drug_pathway_ids": list(range(N_DRUGS)), + } + + +def test_schema(task): + assert task.task_name == "drug_sensitivity_prediction" + assert "gene_indices" in task.input_schema + assert "drug_pathway_ids" in task.input_schema + assert "labels" in task.output_schema + assert "mask" in task.output_schema + + +def test_call_output(task, patient): + result = task(patient) + assert len(result) == 1 + sample = result[0] + assert set(sample.keys()) == {"patient_id", "visit_id", "gene_indices", "labels", "mask", "drug_pathway_ids"} + assert sample["patient_id"] == patient["patient_id"] + assert sample["visit_id"] == sample["patient_id"] + assert 0 not in sample["gene_indices"] # 1-indexed + assert len(sample["gene_indices"]) == int(patient["gene_expression"].sum()) + assert len(sample["labels"]) == N_DRUGS + assert len(sample["mask"]) == N_DRUGS + assert all(v in (0, 1) for v in sample["labels"]) + assert all(v in (0, 1) for v in sample["mask"]) + + +def test_nan_handling(task): + """NaN sensitivity → mask=0, label=0. Observed values → mask=1.""" + patient = { + "patient_id": "COSMIC.0", + "gene_expression": np.array([1, 0, 1, 0]), + "drug_sensitivity": np.array([1.0, np.nan, 0.0]), + "drug_pathway_ids": [0, 1, 2], + } + sample = task(patient)[0] + assert sample["mask"] == [1, 0, 1] + assert sample["labels"] == [1, 0, 0] diff --git a/tests/core/test_gdsc_dataset.py b/tests/core/test_gdsc_dataset.py new file mode 100644 index 000000000..5f2996027 --- /dev/null +++ b/tests/core/test_gdsc_dataset.py @@ -0,0 +1,111 @@ +"""Unit tests for GDSCDataset. + +Uses synthetic in-memory data — no real CSV files, no network access. +""" + +import os +import tempfile + +import numpy as np +import pandas as pd +import pytest + +from pyhealth.datasets.gdsc import GDSCDataset + +N_CELL_LINES = 5 +N_GENES = 20 +N_DRUGS = 10 +N_PATHWAYS = 3 +EMB_DIM = 8 +RNG = np.random.RandomState(0) + + +def _make_synthetic_data(tmp_dir: str) -> None: + """Write minimal CSV files mimicking the GDSC originalData layout.""" + cell_ids = [f"COSMIC.{i}" for i in range(N_CELL_LINES)] + gene_cols = [str(g) for g in range(1, N_GENES + 1)] + drug_cols = [str(d) for d in range(1, N_DRUGS + 1)] + + # Binary gene expression (N_CELL_LINES x N_GENES) + exp = pd.DataFrame( + RNG.randint(0, 2, size=(N_CELL_LINES, N_GENES)), + index=cell_ids, + columns=gene_cols, + ) + exp.to_csv(os.path.join(tmp_dir, "exp_gdsc.csv")) + + # Drug sensitivity with some NaN entries + tgt_data = RNG.randint(0, 2, size=(N_CELL_LINES, N_DRUGS)).astype(float) + tgt_data[RNG.rand(N_CELL_LINES, N_DRUGS) < 0.2] = np.nan + tgt = pd.DataFrame(tgt_data, index=cell_ids, columns=drug_cols) + tgt.to_csv(os.path.join(tmp_dir, "gdsc.csv")) + + # Drug info with pathway metadata and Name column for id-to-name mapping + pathway_names = ["PathwayA", "PathwayB", "PathwayC"] + drug_info = pd.DataFrame( + { + "Name": [f"Drug{d}" for d in range(1, N_DRUGS + 1)], + "Target pathway": [ + pathway_names[i % N_PATHWAYS] for i in range(N_DRUGS) + ], + }, + index=list(range(1, N_DRUGS + 1)), + ) + drug_info.to_csv(os.path.join(tmp_dir, "drug_info_gdsc.csv")) + + # Gene2Vec embeddings (N_GENES + 1 rows: row 0 is padding) + emb = RNG.randn(N_GENES + 1, EMB_DIM) + np.savetxt(os.path.join(tmp_dir, "exp_emb_gdsc.csv"), emb, delimiter=",") + + +@pytest.fixture(scope="module") +def dataset(): + with tempfile.TemporaryDirectory() as tmp_dir: + _make_synthetic_data(tmp_dir) + yield GDSCDataset(data_dir=tmp_dir) + + +def test_loads_and_shapes(dataset): + assert len(dataset.common_samples) == N_CELL_LINES + assert len(dataset.gene_names) == N_GENES + assert len(dataset.drug_ids) == N_DRUGS + assert len(dataset.drug_pathway_ids) == N_DRUGS + assert len(dataset.pathway2id) == N_PATHWAYS + assert dataset.get_gene_embeddings().shape == (N_GENES + 1, EMB_DIM) + assert dataset.dataset_name == "GDSC" + + +def test_drug_names(dataset): + assert len(dataset.drug_names) == N_DRUGS + assert all(isinstance(n, str) and len(n) > 0 for n in dataset.drug_names) + assert all(str(int(d)) in dataset.id_to_name for d in dataset.drug_ids) + + +def test_pathway_info(dataset): + info = dataset.get_pathway_info() + assert set(info.keys()) == {"pathway2id", "id2pathway", "num_pathways", "drug_pathway_ids"} + assert info["num_pathways"] == N_PATHWAYS + + +def test_set_task(dataset): + sample_ds = dataset.set_task() + assert len(sample_ds) == N_CELL_LINES + sample = sample_ds[0] + assert set(sample.keys()) == {"patient_id", "visit_id", "gene_indices", "labels", "mask", "drug_pathway_ids"} + assert 0 not in sample["gene_indices"] # 1-indexed + assert len(sample["labels"]) == N_DRUGS + assert len(sample["mask"]) == N_DRUGS + assert all(v in (0, 1) for v in sample["mask"]) + + +def test_get_overlap_drugs(dataset): + self_idx, other_idx, names = dataset.get_overlap_drugs(dataset) + assert len(names) == N_DRUGS + assert names == sorted(names) + assert self_idx == other_idx + assert all(0 <= i < N_DRUGS for i in self_idx) + + +def test_summary_runs(dataset, capsys): + dataset.summary() + assert "GDSC Dataset Summary" in capsys.readouterr().out