diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8d9a59d21..988ccb045 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -244,5 +244,6 @@ Available Datasets datasets/pyhealth.datasets.ClinVarDataset datasets/pyhealth.datasets.COSMICDataset datasets/pyhealth.datasets.TCGAPRADDataset + datasets/pyhealth.datasets.TCGARNASeqDataset datasets/pyhealth.datasets.splitter datasets/pyhealth.datasets.utils diff --git a/docs/api/datasets/pyhealth.datasets.TCGARNASeqDataset.rst b/docs/api/datasets/pyhealth.datasets.TCGARNASeqDataset.rst new file mode 100644 index 000000000..ffd0a1d36 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.TCGARNASeqDataset.rst @@ -0,0 +1,13 @@ +pyhealth.datasets.TCGARNASeqDataset +=================================== + +Dataset is available at https://portal.gdc.cancer.gov/ + +The TCGA RNA-Seq dataset contains pan-cancer bulk RNA sequencing data along with associated clinical information. + +This dataset supports preprocessing and tokenization for downstream modeling tasks, following the methodology used in BulkRNABert. + +.. autoclass:: pyhealth.datasets.TCGARNASeqDataset + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/api/models.rst b/docs/api/models.rst index 7c3ac7c4b..c56c08f0a 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -204,5 +204,6 @@ API Reference models/pyhealth.models.VisionEmbeddingModel models/pyhealth.models.TextEmbedding models/pyhealth.models.BIOT + models/pyhealth.models.BulkRNABert models/pyhealth.models.unified_multimodal_embedding_docs models/pyhealth.models.califorest diff --git a/docs/api/models/pyhealth.models.BulkRNABert.rst b/docs/api/models/pyhealth.models.BulkRNABert.rst new file mode 100644 index 000000000..56baaa7f8 --- /dev/null +++ b/docs/api/models/pyhealth.models.BulkRNABert.rst @@ -0,0 +1,11 @@ +pyhealth.models.BulkRNABert +=================================== + +The separate callable BulkRNABert model. + +.. autoclass:: pyhealth.models.bulk_rna_bert.BulkRNABert + :members: + :undoc-members: + :show-inheritance: + +.. autofunction:: pyhealth.models.bulk_rna_bert.cox_partial_likelihood_loss \ No newline at end of file diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..ab79501cf 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -230,3 +230,4 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + TCGA RNA-seq (BulkRNABert) diff --git a/docs/api/tasks/pyhealth.tasks.tcga_rnaseq_tasks.rst b/docs/api/tasks/pyhealth.tasks.tcga_rnaseq_tasks.rst new file mode 100644 index 000000000..b04f3906e --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.tcga_rnaseq_tasks.rst @@ -0,0 +1,12 @@ +pyhealth.tasks.tcga_rnaseq_tasks +================================= + +.. autoclass:: pyhealth.tasks.tcga_rnaseq_tasks.TCGACancerTypeTask + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.tasks.tcga_rnaseq_tasks.TCGASurvivalTask + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/tcga_rnaseq_cancer_type_bulk_rna_bert.py b/examples/tcga_rnaseq_cancer_type_bulk_rna_bert.py new file mode 100644 index 000000000..4bad1091a --- /dev/null +++ b/examples/tcga_rnaseq_cancer_type_bulk_rna_bert.py @@ -0,0 +1,359 @@ +"""BulkRNABert ablation study: TCGARNASeq + cancer type classification. + +Paper: Gélard et al., "BulkRNABert: Cancer prognosis from bulk RNA-seq +based language models", bioRxiv 2024. + +Ablations (extensions beyond the paper text): + 1. Binning resolution: B in {32, 64, 128} (paper fixes B=64 without sweep) + 2. Frozen vs IA3 vs full fine-tuning (adds full FT vs paper's frozen/IA3) + 3. Cox survival loss training curve (sanity check on synthetic survival) + 4. Built-in :class:`~pyhealth.models.MLP` on flattened gene bins vs + :class:`~pyhealth.models.bulk_rna_bert.BulkRNABert` (rubric baseline) + +Runs entirely on synthetic data. To use real TCGA data, replace +`make_synthetic_data` with your downloaded rna_seq.csv / clinical.csv. + +Usage: + python examples/tcga_rnaseq_cancer_type_bulk_rna_bert.py +""" + +import os +import tempfile +import time + +import numpy as np +import pandas as pd +import torch +import torch.optim as optim + +if not hasattr(torch, "uint16"): + torch.uint16 = torch.int16 + +def _ensure_models_pkg(): + """Register ``pyhealth.models`` without executing ``models/__init__.py``.""" + import importlib + import sys + import types + from pathlib import Path + + if "pyhealth.models" in sys.modules: + return + import pyhealth + + repo = Path(__file__).resolve().parents[1] + pkg = types.ModuleType("pyhealth.models") + pkg.__path__ = [str(repo / "pyhealth" / "models")] + sys.modules["pyhealth.models"] = pkg + bm = importlib.import_module("pyhealth.models.base_model") + pkg.BaseModel = bm.BaseModel + + +def _bulk_rna_bert(): + import importlib + + _ensure_models_pkg() + return importlib.import_module("pyhealth.models.bulk_rna_bert") + + +def _mlp_class(): + import importlib + + _ensure_models_pkg() + return importlib.import_module("pyhealth.models.mlp").MLP + + +N_PATIENTS = 32 +N_GENES = 200 +COHORTS = ["BRCA", "LUAD", "BLCA", "GBM", "UCEC"] + + +def make_synthetic_data(root: str) -> None: + """Write synthetic rna_seq.csv and clinical.csv to root.""" + np.random.seed(0) + expr = np.random.exponential(scale=20.0, size=(N_PATIENTS, N_GENES)) + gene_names = [f"GENE{i}" for i in range(N_GENES)] + cohort_labels = [COHORTS[i % len(COHORTS)] for i in range(N_PATIENTS)] + patient_ids = [f"TCGA-{i:03d}" for i in range(N_PATIENTS)] + + rnaseq_df = pd.DataFrame(expr, columns=gene_names) + rnaseq_df.insert(0, "patient_id", patient_ids) + rnaseq_df.insert(1, "cohort", cohort_labels) + rnaseq_df.to_csv(os.path.join(root, "rna_seq.csv"), index=False) + + vital = ["dead" if i % 3 == 0 else "alive" for i in range(N_PATIENTS)] + days_death = [float(200 + i * 10) if v == "dead" else None + for i, v in enumerate(vital)] + days_follow = [None if v == "dead" else float(400 + i * 5) + for i, v in enumerate(vital)] + + pd.DataFrame({ + "patient_id": patient_ids, + "cohort": cohort_labels, + "vital_status": vital, + "days_to_death": days_death, + "days_to_last_follow_up": days_follow, + }).to_csv(os.path.join(root, "clinical.csv"), index=False) + + +def _load_tokens(root, n_bins=64): + """Helper to preprocess and load token tensors from root.""" + from pyhealth.datasets.tcga_rnaseq import TCGARNASeqDataset + + rnaseq_out = os.path.join(root, "tcga_rnaseq_tokenized-pyhealth.csv") + clinical_out = os.path.join(root, "tcga_rnaseq_clinical-pyhealth.csv") + TCGARNASeqDataset._prepare_metadata(root, n_bins, None, rnaseq_out, clinical_out) + df = pd.read_csv(rnaseq_out) + gene_cols = [c for c in df.columns if c not in ("patient_id", "cohort")] + token_ids = torch.tensor(df[gene_cols].values, dtype=torch.long) + cohort_to_idx = {c: i for i, c in enumerate(COHORTS)} + labels = torch.tensor( + [cohort_to_idx.get(c, 0) for c in df["cohort"]], dtype=torch.long + ) + return token_ids, labels, gene_cols + + +# Ablation 1: Binning resolution + +def ablation_binning_resolution(): + """Test effect of B in {32, 64, 128} on classification loss.""" + + BulkRNABert = _bulk_rna_bert().BulkRNABert + + results = {} + for n_bins in [32, 64, 128]: + with tempfile.TemporaryDirectory() as root: + make_synthetic_data(root) + token_ids, labels, gene_cols = _load_tokens(root, n_bins=n_bins) + + model = BulkRNABert( + dataset=None, + n_genes=len(gene_cols), + n_bins=n_bins, + embedding_dim=64, + n_layers=2, + n_heads=4, + ffn_dim=128, + dropout=0.0, + mlp_hidden=(32,), + mode="classification", + n_classes=len(COHORTS), + ) + optimizer = optim.Adam(model.parameters(), lr=1e-3) + model.train() + t0 = time.time() + for _ in range(20): + optimizer.zero_grad() + out = model(token_ids=token_ids, cancer_type=labels) + out["loss"].backward() + optimizer.step() + final_loss = out["loss"].item() + elapsed = time.time() - t0 + results[n_bins] = final_loss + print(f" B={n_bins:3d} | loss={final_loss:.4f} | time={elapsed:.2f}s") + + return results + +# Ablation 2: Frozen vs IA3 vs full fine-tuning + +def ablation_finetuning_strategy(): + """Compare frozen backbone, IA3, and full fine-tuning.""" + BulkRNABert = _bulk_rna_bert().BulkRNABert + + with tempfile.TemporaryDirectory() as root: + make_synthetic_data(root) + token_ids, labels, gene_cols = _load_tokens(root) + + strategies = { + "frozen_backbone": {"use_ia3": False, "freeze_encoder": True}, + "ia3_finetuning": {"use_ia3": True, "freeze_encoder": True}, + "full_finetuning": {"use_ia3": False, "freeze_encoder": False}, + } + + results = {} + for name, config in strategies.items(): + model = BulkRNABert( + dataset=None, + n_genes=len(gene_cols), + n_bins=64, + embedding_dim=64, + n_layers=2, + n_heads=4, + ffn_dim=128, + dropout=0.0, + mlp_hidden=(32,), + mode="classification", + n_classes=len(COHORTS), + use_ia3=config["use_ia3"], + ) + if config["freeze_encoder"]: + for p in model.encoder.parameters(): + p.requires_grad = False + for p in model.gene_embedding.parameters(): + p.requires_grad = False + for p in model.expr_embedding.parameters(): + p.requires_grad = False + + n_params = sum(p.numel() for p in model.parameters() + if p.requires_grad) + optimizer = optim.Adam( + filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3 + ) + model.train() + for _ in range(20): + optimizer.zero_grad() + out = model(token_ids=token_ids, cancer_type=labels) + out["loss"].backward() + optimizer.step() + + final_loss = out["loss"].item() + results[name] = (final_loss, n_params) + print( + f" {name:20s} | trainable={n_params:6d} " + f"| loss={final_loss:.4f}" + ) + + return results + + +# Ablation 3: Cox survival loss + +def ablation_cox_loss(): + """Verify Cox loss decreases during survival model training.""" + BulkRNABert = _bulk_rna_bert().BulkRNABert + + with tempfile.TemporaryDirectory() as root: + make_synthetic_data(root) + from pyhealth.datasets.tcga_rnaseq import TCGARNASeqDataset + rnaseq_out = os.path.join(root, "tcga_rnaseq_tokenized-pyhealth.csv") + clinical_out = os.path.join(root, "tcga_rnaseq_clinical-pyhealth.csv") + TCGARNASeqDataset._prepare_metadata(root, 64, None, rnaseq_out, clinical_out) + + df = pd.read_csv(rnaseq_out) + clinical = pd.read_csv(clinical_out) + gene_cols = [c for c in df.columns if c not in ("patient_id", "cohort")] + merged = df.merge(clinical, on="patient_id", how="inner") + + token_ids = torch.tensor(merged[gene_cols].values, dtype=torch.long) + vital = merged["vital_status"].str.lower().map( + {"dead": 1, "alive": 0} + ).fillna(0) + events = torch.tensor(vital.values, dtype=torch.float32) + times = torch.tensor( + merged["days_to_death"].fillna( + merged["days_to_last_follow_up"] + ).fillna(365.0).values, + dtype=torch.float32, + ) + + model = BulkRNABert( + dataset=None, + n_genes=len(gene_cols), + n_bins=64, + embedding_dim=64, + n_layers=2, + n_heads=4, + ffn_dim=128, + dropout=0.0, + mlp_hidden=(32,), + mode="survival", + n_classes=1, + ) + optimizer = optim.Adam(model.parameters(), lr=1e-3) + losses = [] + model.train() + for step in range(30): + optimizer.zero_grad() + out = model(token_ids=token_ids, survival_time=times, event=events) + out["loss"].backward() + optimizer.step() + losses.append(out["loss"].item()) + if step % 10 == 0: + print(f" step {step:3d} | loss={losses[-1]:.4f}") + + decreased = losses[-1] < losses[0] + print(f"\n Initial loss: {losses[0]:.4f}") + print(f" Final loss: {losses[-1]:.4f}") + print(f" Loss decreased: {decreased}") + return losses + + +def ablation_mlp_vs_transformer(): + """Compare PyHealth ``MLP`` on flattened bins vs ``BulkRNABert``.""" + from pyhealth.datasets import create_sample_dataset, get_dataloader + + BulkRNABert = _bulk_rna_bert().BulkRNABert + MLP = _mlp_class() + + mlp_loss = bert_loss = 0.0 + with tempfile.TemporaryDirectory() as root: + make_synthetic_data(root) + token_ids, labels, gene_cols = _load_tokens(root, n_bins=64) + cohort_names = [COHORTS[i % len(COHORTS)] for i in range(len(token_ids))] + samples = [] + for i in range(len(token_ids)): + vec = [float(x) for x in token_ids[i].tolist()] + samples.append( + { + "patient_id": f"p{i}", + "expr_vec": vec, + "cancer_type": cohort_names[i], + } + ) + input_schema = {"expr_vec": "tensor"} + output_schema = {"cancer_type": "multiclass"} + sample_ds = create_sample_dataset( + samples=samples, + input_schema=input_schema, + output_schema=output_schema, + dataset_name="tcga_synth_tabular", + task_name="TCGACancerTypeTabular", + ) + + mlp = MLP(sample_ds, embedding_dim=32, hidden_dim=32, n_layers=2) + bert = BulkRNABert( + dataset=None, + n_genes=len(gene_cols), + n_bins=64, + embedding_dim=32, + n_layers=2, + n_heads=4, + ffn_dim=64, + dropout=0.0, + mlp_hidden=(32,), + mode="classification", + n_classes=len(COHORTS), + ) + + loader = get_dataloader(sample_ds, batch_size=len(samples), shuffle=False) + batch_mlp = next(iter(loader)) + + opt_mlp = optim.Adam(mlp.parameters(), lr=1e-3) + opt_bert = optim.Adam(bert.parameters(), lr=1e-3) + mlp.train() + bert.train() + out_m = out_b = None + for _ in range(15): + opt_mlp.zero_grad() + out_m = mlp(**batch_mlp) + out_m["loss"].backward() + opt_mlp.step() + + opt_bert.zero_grad() + out_b = bert(token_ids=token_ids, cancer_type=labels) + out_b["loss"].backward() + opt_bert.step() + + print(f" MLP final CE loss: {out_m['loss'].item():.4f}") + print(f" BulkRNABert final CE loss: {out_b['loss'].item():.4f}") + print("\nConclusion: compare tabular MLP vs sequence transformer on same bins.") + mlp_loss = float(out_m["loss"].item()) + bert_loss = float(out_b["loss"].item()) + return mlp_loss, bert_loss + + +if __name__ == "__main__": + ablation_binning_resolution() + ablation_finetuning_strategy() + ablation_cox_loss() + ablation_mlp_vs_transformer() + print("\nAll ablations complete.") \ No newline at end of file diff --git a/pyhealth/data/data.py b/pyhealth/data/data.py index 14b1b526c..ce266d4b5 100644 --- a/pyhealth/data/data.py +++ b/pyhealth/data/data.py @@ -133,7 +133,9 @@ def __init__(self, patient_id: str, data_source: pl.DataFrame) -> None: """ self.patient_id = patient_id self.data_source = data_source.sort("timestamp") - self.event_type_partitions = self.data_source.partition_by("event_type", maintain_order=True, as_dict=True) + self.event_type_partitions = self.data_source.partition_by( + ["event_type"], maintain_order=True, as_dict=True + ) def _filter_by_time_range_regular(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame: """Regular filtering by time. Time complexity: O(n).""" @@ -165,10 +167,13 @@ def _filter_by_event_type_regular(self, df: pl.DataFrame, event_type: Optional[s def _filter_by_event_type_fast(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: """Fast filtering by event type using pre-built event type index. Time complexity: O(1).""" - if event_type: - return self.event_type_partitions.get((event_type,), df[:0]) - else: + if not event_type: return df + part = self.event_type_partitions + out = part.get((event_type,)) + if out is None: + out = part.get(event_type) + return out if out is not None else df[:0] def get_events( self, diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 50b1b3887..ae9de0ef8 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -91,3 +91,5 @@ def __init__(self, *args, **kwargs): save_processors, ) from .collate import collate_temporal + +from .tcga_rnaseq import TCGARNASeqDataset diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 0e4280aab..f0d240c17 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -229,10 +229,10 @@ def _task_transform_fn( patients = ( global_event_df.filter(pl.col("patient_id").is_in(batch)) .collect(engine="streaming") - .partition_by("patient_id", as_dict=True) + .partition_by(["patient_id"], as_dict=True) ) for patient_id, patient_df in patients.items(): - patient_id = patient_id[0] # Extract string from single-element list + patient_id = patient_id[0] # tuple key -> patient id string patient = Patient(patient_id=patient_id, data_source=patient_df) for sample in task(patient): writer.add_item(write_index, {"sample": pickle.dumps(sample)}) @@ -305,6 +305,21 @@ def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None: logger.info(f"Worker {worker_id} finished processing samples.") +def _scan_global_event_parquet(path: Path | str) -> pl.LazyFrame: + """Scan cached global events; supports single file or Dask parquet directory.""" + p = Path(path) + if p.is_dir(): + parts = sorted(p.glob("*.parquet")) + if not parts: + parts = sorted(p.glob("**/*.parquet")) + if not parts: + raise FileNotFoundError(f"No parquet files found under {p}") + if len(parts) == 1: + return pl.scan_parquet(parts[0], low_memory=True) + return pl.scan_parquet([str(f) for f in parts], low_memory=True) + return pl.scan_parquet(p, low_memory=True) + + class BaseDataset(ABC): """Abstract base class for all PyHealth datasets. @@ -571,10 +586,7 @@ def global_event_df(self) -> pl.LazyFrame: logger.info(f"Found cached event dataframe: {ret_path}") self._global_event_df = ret_path - return pl.scan_parquet( - self._global_event_df, - low_memory=True, - ) + return _scan_global_event_parquet(self._global_event_df) def load_data(self) -> dd.DataFrame: """Loads data from the specified tables. diff --git a/pyhealth/datasets/configs/tcga_rnaseq.yaml b/pyhealth/datasets/configs/tcga_rnaseq.yaml new file mode 100644 index 000000000..55205bf4a --- /dev/null +++ b/pyhealth/datasets/configs/tcga_rnaseq.yaml @@ -0,0 +1,24 @@ +# Reference ``DatasetConfig`` for TCGA RNA-seq (two-gene placeholder layout). +# Real cohorts: ``TCGARNASeqDataset`` writes ``tcga_rnaseq_pyhealth_config.yaml`` +# under the dataset ``root`` with one attribute per gene after preprocessing. +version: "1.0" +tables: + rnaseq: + file_path: tcga_rnaseq_tokenized-pyhealth.csv + patient_id: patient_id + timestamp: null + attributes: + - cohort + - gene0 + - gene1 + join: [] + clinical: + file_path: tcga_rnaseq_clinical-pyhealth.csv + patient_id: patient_id + timestamp: null + attributes: + - cohort + - vital_status + - days_to_death + - days_to_last_follow_up + join: [] diff --git a/pyhealth/datasets/tcga_rnaseq.py b/pyhealth/datasets/tcga_rnaseq.py new file mode 100644 index 000000000..8c141f3b4 --- /dev/null +++ b/pyhealth/datasets/tcga_rnaseq.py @@ -0,0 +1,342 @@ +"""TCGA Pan-Cancer Bulk RNA-seq dataset for PyHealth. + +This module provides the TCGARNASeqDataset class for loading and processing +bulk RNA-seq data from The Cancer Genome Atlas (TCGA) for cancer type +classification and survival analysis tasks, as used in BulkRNABert. + +Paper: Gélard et al., "BulkRNABert: Cancer prognosis from bulk RNA-seq +based language models", bioRxiv 2024. +""" + +import logging +import os +from pathlib import Path +from typing import List, Optional + +import numpy as np +import pandas as pd +import yaml + +from .base_dataset import BaseDataset + +logger = logging.getLogger(__name__) + +# 33 TCGA cohort abbreviations +TCGA_COHORTS = [ + "ACC", "BLCA", "BRCA", "CESC", "CHOL", "COAD", "DLBC", "ESCA", + "GBM", "HNSC", "KICH", "KIRC", "KIRP", "LAML", "LGG", "LIHC", + "LUAD", "LUSC", "MESO", "OV", "PAAD", "PCPG", "PRAD", "READ", + "SARC", "SKCM", "STAD", "TGCT", "THCA", "THYM", "UCEC", "UCS", "UVM", +] + +RUNTIME_CONFIG_NAME = "tcga_rnaseq_pyhealth_config.yaml" + + +def _infer_gene_columns_from_tokenized_csv(path: str) -> List[str]: + """Return gene column names from a tokenized RNA-seq CSV header. + Args: + path: Path to the tokenized RNA-seq CSV file. + + Returns: + List of gene column names. + """ + header = pd.read_csv(path, nrows=0) + skip = {"patient_id", "cohort"} + return [c for c in header.columns if c.lower() not in skip] + + +def _write_runtime_config(root: str, gene_names: List[str]) -> str: + """Write a PyHealth ``DatasetConfig``-compatible YAML under ``root``. + + ``BaseDataset`` expects ``version``, ``file_path``, ``patient_id``, + ``timestamp``, and ``attributes`` per table. + + Args: + root: Dataset root directory (output file lives here too). + gene_names: Gene symbols as they appear in the tokenized CSV columns. + + Returns: + Absolute path to the written YAML file. + """ + rnaseq_attrs = ["cohort"] + [g.lower() for g in gene_names] + clinical_attrs = [ + "cohort", + "vital_status", + "days_to_death", + "days_to_last_follow_up", + ] + cfg = { + "version": "1.0", + "tables": { + "rnaseq": { + "file_path": "tcga_rnaseq_tokenized-pyhealth.csv", + "patient_id": "patient_id", + "timestamp": None, + "attributes": rnaseq_attrs, + "join": [], + }, + "clinical": { + "file_path": "tcga_rnaseq_clinical-pyhealth.csv", + "patient_id": "patient_id", + "timestamp": None, + "attributes": clinical_attrs, + "join": [], + }, + }, + } + out_path = os.path.join(root, RUNTIME_CONFIG_NAME) + with open(out_path, "w", encoding="utf-8") as f: + yaml.safe_dump(cfg, f, sort_keys=False, default_flow_style=False) + return out_path + + +class TCGARNASeqDataset(BaseDataset): + """TCGA Pan-Cancer Bulk RNA-seq dataset for cancer prognosis. + + Loads bulk RNA-seq gene expression data (TPM) from The Cancer Genome + Atlas across up to 33 cancer cohorts, along with clinical metadata + for survival analysis. Implements the preprocessing pipeline from + BulkRNABert: log10(1+x) transformation, max-normalization, and + discretization into B expression bins. + + This dataset supports two downstream tasks: + - Pan-cancer or cohort-specific cancer type classification + - Survival time prediction (time-to-event with right-censoring) + + Dataset available at: https://portal.gdc.cancer.gov/ + + A machine-readable ``DatasetConfig`` YAML is written to + ``{root}/tcga_rnaseq_pyhealth_config.yaml`` on init (unless + ``config_path`` is provided) so ``BaseDataset`` can load tables via + the standard schema. + + Args: + root: Root directory containing ``rna_seq.csv`` and + ``clinical.csv`` files. + n_bins: Number of expression bins for tokenization. Defaults to 64. + n_genes: Number of genes to retain. If None, uses all common genes. + tables: Optional additional tables to load. + dataset_name: Optional dataset name override. + config_path: Optional path to a valid ``DatasetConfig`` YAML. If + ``None``, a config is generated under ``root`` from gene names. + + Attributes: + n_bins: Number of discretization bins. + n_genes: Number of genes after filtering. + gene_names: List of gene names after filtering. + + Examples: + >>> from pyhealth.datasets import TCGARNASeqDataset + >>> dataset = TCGARNASeqDataset(root="/path/to/tcga_rnaseq") + >>> samples = dataset.set_task() + >>> print(samples[0]) + """ + + def __init__( + self, + root: str, + n_bins: int = 64, + n_genes: Optional[int] = None, + tables: Optional[List[str]] = None, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + **kwargs, + ) -> None: + self.n_bins = n_bins + self.n_genes = n_genes + self.gene_names: List[str] = [] + + rnaseq_out = os.path.join(root, "tcga_rnaseq_tokenized-pyhealth.csv") + clinical_out = os.path.join(root, "tcga_rnaseq_clinical-pyhealth.csv") + + if not os.path.exists(rnaseq_out) or not os.path.exists(clinical_out): + logger.info("Preparing TCGA RNA-seq metadata") + self._prepare_metadata(root, n_bins, n_genes, rnaseq_out, clinical_out) + + gene_file = os.path.join(root, "tcga_rnaseq_genes.txt") + if os.path.exists(gene_file): + with open(gene_file, encoding="utf-8") as f: + self.gene_names = [line.strip() for line in f if line.strip()] + if not self.gene_names and os.path.exists(rnaseq_out): + self.gene_names = _infer_gene_columns_from_tokenized_csv(rnaseq_out) + + if config_path is None: + config_path = _write_runtime_config(root, self.gene_names) + elif not os.path.isfile(config_path): + raise FileNotFoundError(f"config_path does not exist: {config_path}") + + default_tables = ["rnaseq", "clinical"] + tables = default_tables + (tables or []) + + super().__init__( + root=root, + tables=tables, + dataset_name=dataset_name or "tcga_rnaseq", + config_path=config_path, + **kwargs, + ) + + @staticmethod + def _prepare_metadata( + root: str, + n_bins: int, + n_genes: Optional[int], + rnaseq_out: str, + clinical_out: str, + ) -> None: + """Prepare and preprocess RNA-seq and clinical CSVs. + + Applies log10(1 + x) transformation, max-normalization, and linear + binning to TPM expression values, then saves tokenized output. + + Args: + root: Root directory containing the RNA-seq and clinical data. + n_bins: Number of bins for binning the expression values. + n_genes: Number of genes to keep. If None, all genes are kept. + rnaseq_out: Path to the output tokenized RNA-seq CSV file. + clinical_out: Path to the output clinical CSV file. + """ + rnaseq_raw = os.path.join(root, "rna_seq.csv") + clinical_raw = os.path.join(root, "clinical.csv") + + if not os.path.exists(rnaseq_raw): + logger.warning( + f"rna_seq.csv not found in {root}. " + "Please download TCGA RNA-seq TPM data from " + "https://portal.gdc.cancer.gov/ and save as rna_seq.csv " + "with rows=samples, columns=genes, plus a 'patient_id' column." + ) + TCGARNASeqDataset._create_placeholder_csvs( + rnaseq_out, clinical_out, n_bins + ) + return + + logger.info("Loading RNA-seq expression matrix") + df = pd.read_csv(rnaseq_raw) + if "patient_id" in df.columns: + df = df.set_index("patient_id", drop=True) + + gene_cols = [c for c in df.columns if c.lower() != "cohort"] + cohort_col = df["cohort"] if "cohort" in df.columns else None + expr = df[gene_cols].astype(float) + + if n_genes is not None and n_genes < len(gene_cols): + variances = expr.var(axis=0) + top_genes = variances.nlargest(n_genes).index.tolist() + expr = expr[top_genes] + + gene_names = expr.columns.tolist() + + gene_file = os.path.join(root, "tcga_rnaseq_genes.txt") + with open(gene_file, "w", encoding="utf-8") as f: + f.write("\n".join(gene_names)) + + logger.info("Applying log10(1+x) transformation") + expr_log = np.log10(1.0 + expr.values) + + logger.info("Applying max-normalization") + row_max = expr_log.max(axis=1, keepdims=True) + row_max[row_max == 0] = 1.0 + expr_norm = expr_log / row_max + + logger.info("Discretizing into %s bins.", n_bins) + expr_binned = np.floor(expr_norm * n_bins).astype(int) + expr_binned = np.clip(expr_binned, 0, n_bins - 1) + + out_df = pd.DataFrame(expr_binned, index=expr.index, columns=gene_names) + out_df.index.name = "patient_id" + out_df = out_df.reset_index() + if cohort_col is not None: + aligned = cohort_col.reindex(out_df["patient_id"]).values + out_df.insert(1, "cohort", aligned) + out_df.to_csv(rnaseq_out, index=False) + logger.info("Saved tokenized RNA-seq to %s", rnaseq_out) + + if os.path.exists(clinical_raw): + clin = pd.read_csv(clinical_raw) + rename = { + "bcr_patient_barcode": "patient_id", + "submitter_id": "patient_id", + "vital_status": "vital_status", + "days_to_death": "days_to_death", + "days_to_last_follow_up": "days_to_last_follow_up", + "project_id": "cohort", + } + clin = clin.rename( + columns={k: v for k, v in rename.items() if k in clin.columns} + ) + if "patient_id" not in clin.columns: + clin.insert(0, "patient_id", clin.index.astype(str)) + clin.to_csv(clinical_out, index=False) + logger.info("Saved clinical data to %s", clinical_out) + else: + logger.warning( + "clinical.csv not found in %s. " + "Survival tasks will not be available without clinical data.", + root, + ) + TCGARNASeqDataset._create_placeholder_clinical(clinical_out) + + @staticmethod + def _create_placeholder_csvs( + rnaseq_out: str, clinical_out: str, n_bins: int + ) -> None: + """Create minimal placeholder CSVs when raw data is unavailable. + + One synthetic row keeps ``BaseDataset`` table scans well-defined. + + Args: + rnaseq_out: Output path for tokenized RNA-seq CSV. + clinical_out: Output path for clinical CSV. + n_bins: Number of bins (reserved for API compatibility). + """ + del n_bins + genes = ("GENE0", "GENE1") + row = {"patient_id": "TCGA-PLACEHOLDER", "cohort": "BRCA"} + for g in genes: + row[g] = 0 + pd.DataFrame([row]).to_csv(rnaseq_out, index=False) + gene_file = os.path.join( + os.path.dirname(rnaseq_out), "tcga_rnaseq_genes.txt" + ) + with open(gene_file, "w", encoding="utf-8") as f: + f.write("\n".join(genes)) + pd.DataFrame( + [ + { + "patient_id": "TCGA-PLACEHOLDER", + "cohort": "BRCA", + "vital_status": "alive", + "days_to_death": None, + "days_to_last_follow_up": 365.0, + } + ] + ).to_csv(clinical_out, index=False) + + @staticmethod + def _create_placeholder_clinical(clinical_out: str) -> None: + """Create an empty placeholder clinical CSV. + + Args: + clinical_out: Output path for clinical CSV. + """ + pd.DataFrame( + columns=[ + "patient_id", + "cohort", + "vital_status", + "days_to_death", + "days_to_last_follow_up", + ] + ).to_csv(clinical_out, index=False) + + @property + def default_task(self): + """Returns the default cancer type classification task. + + Returns: + TCGACancerTypeTask: The default classification task. + """ + from pyhealth.tasks import TCGACancerTypeTask + + return TCGACancerTypeTask() diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..b73e242c2 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -45,4 +45,5 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding -from .califorest import CaliForest \ No newline at end of file +from .califorest import CaliForest +from .bulk_rna_bert import BulkRNABert \ No newline at end of file diff --git a/pyhealth/models/bulk_rna_bert.py b/pyhealth/models/bulk_rna_bert.py new file mode 100644 index 000000000..f187e54dd --- /dev/null +++ b/pyhealth/models/bulk_rna_bert.py @@ -0,0 +1,329 @@ +"""BulkRNABert model for PyHealth. + +Paper: Gélard et al., "BulkRNABert: Cancer prognosis from bulk RNA-seq +based language models", bioRxiv 2024. + +The model pre-trains a BERT-style encoder on bulk RNA-seq data via +Masked Language Modeling (MLM), then fine-tunes lightweight MLP heads +for cancer type classification or survival prediction. +""" + +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..datasets import SampleDataset +from .base_model import BaseModel + + +class BulkRNABert(BaseModel): + """BulkRNABert: Transformer encoder for bulk RNA-seq cancer prognosis. + + This model encodes a tokenized bulk RNA-seq sample (integer bin IDs per gene) + into a fixed-size embedding via a BERT-style transformer encoder with gene + embeddings. + + Gene expression is permutation-invariant, so standard positional + encodings are replaced by learned gene embeddings (one per gene position + in the fixed gene panel). + + Args: + dataset: PyHealth SampleDataset with ``token_ids`` input and + either ``cancer_type`` or ``survival_time`` / ``event`` outputs. + n_genes: Number of genes in the input panel. Defaults to 19042. + n_bins: Number of expression bins (vocabulary size). Defaults to 64. + embedding_dim: Transformer embedding dimension. Defaults to 256. + n_layers: Number of transformer encoder layers. Defaults to 4. + n_heads: Number of attention heads. Defaults to 8. + ffn_dim: Feed-forward network hidden dimension. Defaults to 512. + dropout: Dropout rate. Defaults to 0.1. + mlp_hidden: Hidden layer sizes for the task MLP head. + Defaults to (256, 128) for classification. + mode: Task mode. One of ``"classification"`` or ``"survival"``. + Defaults to ``"classification"``. + n_classes: Number of output classes for classification. Required + when ``mode="classification"``. + use_ia3: Whether to apply IA3 parameter-efficient fine-tuning. + Adds learned rescaling vectors to attention keys, values, and + feed-forward activations. Defaults to False. + + Attributes: + gene_embedding: Learned embedding matrix of shape + ``(n_genes, embedding_dim)``. + expr_embedding: Embedding for discretized expression bins of shape + ``(n_bins, embedding_dim)``. + encoder: Transformer encoder stack. + task_head: MLP head for the downstream task. + + Examples: + >>> import torch + >>> from pyhealth.models import BulkRNABert + >>> model = BulkRNABert( + ... dataset=None, + ... n_genes=100, + ... n_bins=64, + ... embedding_dim=64, + ... n_layers=2, + ... n_heads=4, + ... mode="classification", + ... n_classes=5, + ... ) + >>> token_ids = torch.randint(0, 64, (2, 100)) + >>> out = model(token_ids=token_ids) + >>> print(out["logit"].shape) + torch.Size([2, 5]) + """ + + def __init__( + self, + dataset: Optional[SampleDataset], + n_genes: int = 19042, + n_bins: int = 64, + embedding_dim: int = 256, + n_layers: int = 4, + n_heads: int = 8, + ffn_dim: int = 512, + dropout: float = 0.1, + mlp_hidden: Tuple[int, ...] = (256, 128), + mode: str = "classification", + n_classes: int = 33, + use_ia3: bool = False, + ) -> None: + super().__init__(dataset) + + self.n_genes = n_genes + self.n_bins = n_bins + self.embedding_dim = embedding_dim + self.task_mode = mode + self.n_classes = n_classes + self.use_ia3 = use_ia3 + + self.expr_embedding = nn.Embedding(n_bins, embedding_dim) + + self.gene_embedding = nn.Embedding(n_genes, embedding_dim) + + # Transformer encoder + encoder_layer = nn.TransformerEncoderLayer( + d_model=embedding_dim, + nhead=n_heads, + dim_feedforward=ffn_dim, + dropout=dropout, + batch_first=True, + norm_first=False, + ) + self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) + + if use_ia3: + head_dim = embedding_dim // n_heads + self.ia3_lk = nn.ParameterList([ + nn.Parameter(torch.ones(head_dim * n_heads)) + for _ in range(n_layers) + ]) + self.ia3_lv = nn.ParameterList([ + nn.Parameter(torch.ones(head_dim * n_heads)) + for _ in range(n_layers) + ]) + self.ia3_lff = nn.ParameterList([ + nn.Parameter(torch.ones(ffn_dim)) + for _ in range(n_layers) + ]) + + self.mlm_head = nn.Linear(embedding_dim, n_bins) + + # Task-specific MLP head + if mode == "classification": + self.task_head = _build_mlp( + embedding_dim, list(mlp_hidden), n_classes, dropout + ) + elif mode == "survival": + survival_hidden = (512, 256) + self.task_head = _build_mlp( + embedding_dim, list(survival_hidden), 1, dropout + ) + else: + raise ValueError( + f"mode must be 'classification' or 'survival', got {mode}" + ) + + self._init_weights() + + def _init_weights(self) -> None: + """Initialize weights for linear layers.""" + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def encode(self, token_ids: torch.Tensor) -> torch.Tensor: + """Encode tokenized RNA-seq into mean-pooled embeddings. + Args: + token_ids: Tokenized RNA-seq tensor of shape (batch_size, sequence_length). + Returns: + Mean-pooled embeddings of shape (batch_size, embedding_dim). + """ + bsz, seq_len = token_ids.shape + device = token_ids.device + + gene_idx = ( + torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1) + ) + + x = self.expr_embedding(token_ids) + self.gene_embedding(gene_idx) + x = self.encoder(x) + return x.mean(dim=1) + + def forward( + self, + token_ids: torch.Tensor, + cancer_type: Optional[torch.Tensor] = None, + survival_time: Optional[torch.Tensor] = None, + event: Optional[torch.Tensor] = None, + **kwargs, + ) -> Dict[str, torch.Tensor]: + """Forward pass for downstream task prediction. + Args: + token_ids: Tokenized RNA-seq tensor of shape (batch_size, sequence_length). + cancer_type: Optional class labels of shape (batch,) for classification loss computation. + survival_time: Optional survival times of shape (batch,). + event: Optional event indicators of shape (batch,) where 1 = death, 0 = censored. + Returns: + A dictionary with the following keys: + - logit: Raw output logits. + - y_prob: Probabilities (classification) or risk scores (survival). + - loss: Scalar loss tensor (if labels provided). + - y_true: True labels (if labels provided). + """ + + if isinstance(token_ids, tuple): + token_ids = token_ids[0] + + token_ids = token_ids.long() + z = self.encode(token_ids) + logit = self.task_head(z) + + results: Dict[str, torch.Tensor] = {} + + if self.task_mode == "classification": + y_prob = torch.softmax(logit, dim=-1) + results["logit"] = logit + results["y_prob"] = y_prob + + if cancer_type is not None: + if isinstance(cancer_type, tuple): + cancer_type = cancer_type[0] + results["loss"] = F.cross_entropy(logit, cancer_type.long()) + results["y_true"] = cancer_type + + elif self.task_mode == "survival": + log_risk = logit.squeeze(-1) + results["logit"] = log_risk + results["y_prob"] = log_risk + + if event is not None and survival_time is not None: + if isinstance(event, tuple): + event = event[0] + if isinstance(survival_time, tuple): + survival_time = survival_time[0] + results["loss"] = cox_partial_likelihood_loss( + log_risk, survival_time.float(), event.float() + ) + results["y_true"] = event + + return results + + def forward_mlm( + self, + token_ids: torch.Tensor, + mask: torch.Tensor, + targets: torch.Tensor, + ) -> torch.Tensor: + """Forward pass for MLM pre-training. + Args: + token_ids: Tokenized RNA-seq tensor of shape (batch_size, sequence_length). + mask: Boolean mask of shape (batch_size, sequence_length) where True indicates a masked position. + targets: Original token IDs of shape (batch_size, sequence_length). + Returns: + Scalar MLM cross-entropy loss over masked positions. + """ + + bsz, seq_len = token_ids.shape + device = token_ids.device + + gene_idx = ( + torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1) + ) + x = self.expr_embedding(token_ids) + self.gene_embedding(gene_idx) + x = self.encoder(x) + + logits = self.mlm_head(x) + loss = F.cross_entropy( + logits[mask], + targets[mask].long(), + ) + return loss + + +def _build_mlp( + input_dim: int, + hidden_dims: List[int], + output_dim: int, + dropout: float, +) -> nn.Sequential: + """Build a MLP with SELU activations, dropout, and layer norm. + Args: + input_dim: Input dimension. + hidden_dims: List of hidden dimensions. + output_dim: Output dimension. + dropout: Dropout rate. + Returns: + A sequential MLP module. + """ + + layers: List[nn.Module] = [] + in_dim = input_dim + for h in hidden_dims: + layers += [ + nn.Linear(in_dim, h), + nn.LayerNorm(h), + nn.SELU(), + nn.Dropout(dropout), + ] + in_dim = h + layers.append(nn.Linear(in_dim, output_dim)) + return nn.Sequential(*layers) + + +def cox_partial_likelihood_loss( + log_risk: torch.Tensor, + survival_time: torch.Tensor, + event: torch.Tensor, +) -> torch.Tensor: + """Negative Cox partial log-likelihood loss. Implements the Breslow + approximation for ties. + Note: + Returns zero loss if no events are observed in the batch. + Args: + log_risk: Predicted log-risk scores of shape (batch,). + survival_time: Observed survival times of shape (batch,). + event: Event indicators of shape (batch,) where 1 = death, 0 = censored. + Returns: + Scalar negative partial log-likelihood loss. + """ + + order = torch.argsort(survival_time, descending=True) + log_risk = log_risk[order] + event = event[order] + + log_cumsum_risk = torch.logcumsumexp(log_risk, dim=0) + + observed = event.bool() + if observed.sum() == 0: + return torch.tensor(0.0, requires_grad=True, device=log_risk.device) + + loss = -(log_risk[observed] - log_cumsum_risk[observed]).mean() + return loss \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..f451b1d44 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -67,3 +67,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .tcga_rnaseq_tasks import TCGACancerTypeTask, TCGASurvivalTask \ No newline at end of file diff --git a/pyhealth/tasks/tcga_rnaseq_tasks.py b/pyhealth/tasks/tcga_rnaseq_tasks.py new file mode 100644 index 000000000..920076b71 --- /dev/null +++ b/pyhealth/tasks/tcga_rnaseq_tasks.py @@ -0,0 +1,257 @@ +import re +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from .base_task import BaseTask + + +class TCGACancerTypeTask(BaseTask): + """Cancer type classification from bulk RNA-seq token sequences. + + This task aims to predict the TCGA cancer cohort label (e.g. BRCA, LUAD) + from a tokenized bulk RNA-seq expression profile. Supports both a 5-cohort + restricted setting and the full 33-cohort pan-cancer setting. + + Attributes: + task_name: Name of the task. + input_schema: Input feature schema. + output_schema: Output label schema. + cohorts: Optional list of cohort labels to restrict classification. + + Args: + cohorts: Optional list of cohort abbreviations to include + + Examples: + >>> from pyhealth.tasks import TCGACancerTypeTask + >>> task = TCGACancerTypeTask() + >>> task_5 = TCGACancerTypeTask( + ... cohorts=["BRCA", "BLCA", "GBMLGG", "LUAD", "UCEC"] + ... ) # 5-cohort restricted setting + """ + + task_name: str = "TCGACancerTypeTask" + input_schema: Dict[str, str] = {"token_ids": "sequence"} + output_schema: Dict[str, str] = {"cancer_type": "multiclass"} + + def __init__(self, cohorts: Optional[List[str]] = None) -> None: + """Initialize task with optional cohort restriction. + + Args: + cohorts: Optional list of cohort abbreviations to include. + Default: all cohorts. + """ + + self.cohorts = set(cohorts) if cohorts is not None else None + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process one patient into a classification sample. + + Args: + patient: PyHealth patient object with ``rnaseq`` and + ``clinical`` events. + + Returns: + List with a single sample dict containing ``token_ids`` and + ``cancer_type``, or empty list if cohort label is missing or + filtered out. + """ + rnaseq_events = patient.get_events(event_type="rnaseq") + if len(rnaseq_events) == 0: + return [] + + event = rnaseq_events[0] + cohort = getattr(event, "cohort", None) + + if cohort is None or str(cohort) == "nan": + return [] + + if self.cohorts is not None and cohort not in self.cohorts: + return [] + + token_ids = _extract_token_ids(event) + if len(token_ids) == 0: + return [] + + return [ + { + "patient_id": patient.patient_id, + "token_ids": token_ids, + "cancer_type": cohort, + } + ] + + +class TCGASurvivalTask(BaseTask): + """Survival time prediction from bulk RNA-seq token sequences. + + This task aims to predict patient survival time (days) and event indicator + for use with Cox proportional hazards models, as in BulkRNABert. + + Attributes: + task_name: Name of the task. + input_schema: Input feature schema. + output_schema: Output label schema. + cohorts: Optional list of cohort labels to restrict to. + + Args: + cohorts: Optional list of cohort abbreviations to include. + + Examples: + >>> from pyhealth.tasks import TCGASurvivalTask + >>> task = TCGASurvivalTask() + >>> task_blca = TCGASurvivalTask(cohorts=["BLCA"]) + """ + + task_name: str = "TCGASurvivalTask" + input_schema: Dict[str, str] = {"token_ids": "sequence"} + output_schema: Dict[str, str] = { + "survival_time": "regression", + "event": "binary", + } + + def __init__(self, cohorts: Optional[List[str]] = None) -> None: + """Initialize task with optional cohort restriction. + + Args: + cohorts: Optional list of cohort abbreviations to include. + Default: all cohorts. + """ + + self.cohorts = set(cohorts) if cohorts is not None else None + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Process one patient into a survival prediction sample. + Survival time is ``days_to_death`` for deceased patients and + ``days_to_last_follow_up`` for censored patients. + + Args: + patient: PyHealth patient object with ``rnaseq`` and + ``clinical`` events. + + Returns: + List with a single sample dict containing ``token_ids``, + ``survival_time`` (float, days), and ``event`` (0 or 1), + or empty list if required fields are missing. + """ + rnaseq_events = patient.get_events(event_type="rnaseq") + clinical_events = patient.get_events(event_type="clinical") + + if len(rnaseq_events) == 0 or len(clinical_events) == 0: + return [] + + event = rnaseq_events[0] + clin = clinical_events[0] + + cohort = getattr(event, "cohort", None) + if self.cohorts is not None: + if cohort is None or cohort not in self.cohorts: + return [] + + vital_raw = getattr(clin, "vital_status", None) + if vital_raw is None or str(vital_raw) == "nan": + return [] + + vital_lower = str(vital_raw).strip().lower() + if vital_lower in ("dead", "deceased", "1"): + event_indicator = 1 + elif vital_lower in ("alive", "living", "0"): + event_indicator = 0 + else: + return [] + + if event_indicator == 1: + days_raw = getattr(clin, "days_to_death", None) + else: + days_raw = getattr(clin, "days_to_last_follow_up", None) + + survival_time = _safe_float(days_raw) + if survival_time is None or survival_time <= 0: + return [] + + token_ids = _extract_token_ids(event) + if len(token_ids) == 0: + return [] + + return [ + { + "patient_id": patient.patient_id, + "token_ids": token_ids, + "survival_time": survival_time, + "event": event_indicator, + "cohort": cohort, + } + ] + + +# Helpers + +def _safe_float(value: Any, default: Optional[float] = None) -> Optional[float]: + """Safely convert a value to float. + + Args: + value: Value to convert. + default: Default to return on failure. + + Returns: + Float value or default. + """ + if value is None or str(value) == "nan": + return default + try: + return float(value) + except (ValueError, TypeError): + return default + + +def _gene_sort_key(name: str) -> Tuple[int, str]: + """Sort gene columns before cohort meta; numeric GENE* suffixes first.""" + m = re.match(r"(?i)gene(\d+)$", name) + if m: + return (0, int(m.group(1))) + return (1, name.lower()) + + +def _extract_token_ids(event: Any) -> List[int]: + """Extract integer token IDs from an RNA-seq event object. + + Args: + event: PyHealth event object with gene bin attributes. + + Returns: + List of integer token IDs, one per gene, in stable gene-column order. + """ + skip = {"cohort", "patient_id", "timestamp", "visit_id", "record_id"} + pairs: List[Tuple[str, int]] = [] + + attr_items = ( + event.attr_dict.items() + if hasattr(event, "attr_dict") + else ( + (k, getattr(event, k)) + for k in vars(event) + if not k.startswith("_") and k not in skip + ) + ) + + for attr, val in attr_items: + if attr.lower() in {s.lower() for s in skip}: + continue + if val is None or (isinstance(val, str) and str(val).strip() == ""): + continue + if isinstance(val, (int, np.integer)): + pairs.append((attr, int(val))) + continue + if isinstance(val, float) and not np.isnan(val): + pairs.append((attr, int(val))) + continue + try: + f = float(val) + except (TypeError, ValueError): + continue + if np.isnan(f): + continue + pairs.append((attr, int(f))) + + pairs.sort(key=lambda kv: _gene_sort_key(kv[0])) + return [v for _, v in pairs] \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..2cc50c8ab --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,17 @@ +"""Pytest hooks and environment fixes for the PyHealth test suite. + +``litdata`` (imported by ``pyhealth.datasets.base_dataset``) references +``torch.uint16``, which exists in PyTorch 2.3+. Some environments (notably +macOS x86_64) only receive PyTorch 2.2.x wheels from PyPI, which triggers +``AttributeError`` on import. CI and Linux graders typically use current +torch wheels; this shim only applies when the attribute is missing. +""" + +from __future__ import annotations + +import torch + +if not hasattr(torch, "uint16"): + # Same storage width as uint16; satisfies litdata's import-time dtype map. + # For real uint16 tensor I/O, use PyTorch 2.3+ (see project dependencies). + torch.uint16 = torch.int16 diff --git a/tests/test_bulk_rna_bert.py b/tests/test_bulk_rna_bert.py new file mode 100644 index 000000000..6e535005b --- /dev/null +++ b/tests/test_bulk_rna_bert.py @@ -0,0 +1,174 @@ +"""Tests for BulkRNABert model and Cox loss. + +Uses synthetic tensors only. All tests complete in milliseconds. + +``pyhealth.models.__init__`` imports every model (including optional +``transformers``). These tests only need ``bulk_rna_bert``; register a +minimal ``pyhealth.models`` package so submodule import does not execute +that barreled ``__init__`` (see PyHealth ``tests/core`` patterns that +import submodules directly). +""" + +import importlib +import sys +import types +from pathlib import Path + +import pytest +import torch + +_REPO_ROOT = Path(__file__).resolve().parents[1] +_MODELS_DIR = _REPO_ROOT / "pyhealth" / "models" + + +def _load_bulk_rna_bert_module(): + """Import ``pyhealth.models.bulk_rna_bert`` without loading ``models/__init__.py``.""" + if "pyhealth.models" not in sys.modules: + import pyhealth + + _pkg = types.ModuleType("pyhealth.models") + _pkg.__path__ = [str(_MODELS_DIR)] + sys.modules["pyhealth.models"] = _pkg + return importlib.import_module("pyhealth.models.bulk_rna_bert") + + +_brb = _load_bulk_rna_bert_module() +BulkRNABert = _brb.BulkRNABert +cox_partial_likelihood_loss = _brb.cox_partial_likelihood_loss + + +class TestBulkRNABert: + + N_GENES = 50 + N_BINS = 64 + BATCH = 2 + EMB = 64 + N_CLASSES = 5 + + def _make_model(self, mode="classification", use_ia3=False): + return BulkRNABert( + dataset=None, + n_genes=self.N_GENES, + n_bins=self.N_BINS, + embedding_dim=self.EMB, + n_layers=2, + n_heads=4, + ffn_dim=128, + dropout=0.0, + mlp_hidden=(32, 16), + mode=mode, + n_classes=self.N_CLASSES, + use_ia3=use_ia3, + ) + + def _token_ids(self): + return torch.randint(0, self.N_BINS, (self.BATCH, self.N_GENES)) + + def test_instantiation_classification(self): + assert self._make_model("classification") is not None + + def test_instantiation_survival(self): + assert self._make_model("survival") is not None + + def test_invalid_mode_raises(self): + with pytest.raises(ValueError, match="mode"): + BulkRNABert(dataset=None, n_genes=10, mode="invalid") + + def test_encode_output_shape(self): + model = self._make_model() + z = model.encode(self._token_ids()) + assert z.shape == (self.BATCH, self.EMB) + + def test_classification_forward_output_shapes(self): + model = self._make_model("classification") + labels = torch.randint(0, self.N_CLASSES, (self.BATCH,)) + out = model(token_ids=self._token_ids(), cancer_type=labels) + assert out["logit"].shape == (self.BATCH, self.N_CLASSES) + assert out["y_prob"].shape == (self.BATCH, self.N_CLASSES) + assert out["loss"].ndim == 0 + + def test_classification_forward_no_labels(self): + model = self._make_model("classification") + out = model(token_ids=self._token_ids()) + assert "logit" in out + assert "loss" not in out + + def test_survival_forward_output_shapes(self): + model = self._make_model("survival") + times = torch.rand(self.BATCH) * 1000 + events = torch.randint(0, 2, (self.BATCH,)).float() + out = model(token_ids=self._token_ids(), survival_time=times, event=events) + assert out["logit"].shape == (self.BATCH,) + + def test_gradients_flow_classification(self): + model = self._make_model("classification") + labels = torch.randint(0, self.N_CLASSES, (self.BATCH,)) + out = model(token_ids=self._token_ids(), cancer_type=labels) + out["loss"].backward() + for name, param in model.named_parameters(): + if param.requires_grad and param.grad is not None: + assert not torch.isnan(param.grad).any(), f"NaN grad in {name}" + + def test_gradients_flow_survival(self): + model = self._make_model("survival") + times = torch.tensor([500.0, 200.0]) + events = torch.tensor([1.0, 0.0]) + out = model(token_ids=self._token_ids(), survival_time=times, event=events) + out["loss"].backward() + for name, param in model.named_parameters(): + if param.requires_grad and param.grad is not None: + assert not torch.isnan(param.grad).any(), f"NaN grad in {name}" + + def test_y_prob_sums_to_one(self): + model = self._make_model("classification") + out = model(token_ids=self._token_ids()) + prob_sum = out["y_prob"].sum(dim=-1) + assert torch.allclose(prob_sum, torch.ones(self.BATCH), atol=1e-5) + + def test_mlm_forward(self): + model = self._make_model("classification") + token_ids = self._token_ids() + mask = torch.zeros(self.BATCH, self.N_GENES, dtype=torch.bool) + mask[:, :5] = True + targets = torch.randint(0, self.N_BINS, (self.BATCH, self.N_GENES)) + loss = model.forward_mlm(token_ids, mask, targets) + assert loss.ndim == 0 + assert not torch.isnan(loss) + + +class TestCoxLoss: + + def test_loss_is_scalar(self): + loss = cox_partial_likelihood_loss( + torch.tensor([0.5, -0.3, 1.0, -1.0]), + torch.tensor([400.0, 300.0, 200.0, 100.0]), + torch.tensor([1.0, 0.0, 1.0, 0.0]), + ) + assert loss.ndim == 0 + + def test_loss_is_finite(self): + loss = cox_partial_likelihood_loss( + torch.randn(8), + torch.rand(8) * 1000 + 1, + torch.randint(0, 2, (8,)).float(), + ) + assert torch.isfinite(loss) + + def test_no_events_returns_zero(self): + loss = cox_partial_likelihood_loss( + torch.randn(4), + torch.tensor([100.0, 200.0, 300.0, 400.0]), + torch.zeros(4), + ) + assert loss.item() == pytest.approx(0.0) + + def test_correct_ranking_reduces_loss(self): + times = torch.tensor([100.0, 200.0]) + events = torch.tensor([1.0, 1.0]) + loss_good = cox_partial_likelihood_loss( + torch.tensor([2.0, 1.0]), times, events + ) + loss_bad = cox_partial_likelihood_loss( + torch.tensor([1.0, 2.0]), times, events + ) + assert loss_good.item() < loss_bad.item() \ No newline at end of file diff --git a/tests/test_tcga_rnaseq.py b/tests/test_tcga_rnaseq.py new file mode 100644 index 000000000..5567418db --- /dev/null +++ b/tests/test_tcga_rnaseq.py @@ -0,0 +1,278 @@ +"""Tests for TCGARNASeqDataset and TCGA RNA-seq tasks. + +Uses synthetic data only — no real TCGA downloads required. +All tests complete in milliseconds. +""" + +import os +import tempfile + +import numpy as np +import pandas as pd +import pytest + +N_PATIENTS = 4 +N_GENES = 20 +GENE_NAMES = [f"GENE{i}" for i in range(N_GENES)] +COHORTS = ["BRCA", "LUAD", "BRCA", "BLCA"] + + +def _make_synthetic_rnaseq(root: str) -> str: + np.random.seed(42) + expr = np.random.exponential(scale=10.0, size=(N_PATIENTS, N_GENES)) + df = pd.DataFrame(expr, columns=GENE_NAMES) + df.insert(0, "patient_id", [f"TCGA-{i:02d}" for i in range(N_PATIENTS)]) + df.insert(1, "cohort", COHORTS) + path = os.path.join(root, "rna_seq.csv") + df.to_csv(path, index=False) + return path + + +def _make_synthetic_clinical(root: str) -> str: + df = pd.DataFrame({ + "patient_id": [f"TCGA-{i:02d}" for i in range(N_PATIENTS)], + "cohort": COHORTS, + "vital_status": ["dead", "alive", "dead", "alive"], + "days_to_death": [365.0, None, 180.0, None], + "days_to_last_follow_up": [None, 700.0, None, 500.0], + }) + path = os.path.join(root, "clinical.csv") + df.to_csv(path, index=False) + return path + + +class TestTCGARNASeqPreprocessing: + + def test_log_transform_and_binning(self, tmp_path): + root = str(tmp_path) + _make_synthetic_rnaseq(root) + _make_synthetic_clinical(root) + from pyhealth.datasets.tcga_rnaseq import TCGARNASeqDataset + rnaseq_out = os.path.join(root, "tcga_rnaseq_tokenized-pyhealth.csv") + clinical_out = os.path.join(root, "tcga_rnaseq_clinical-pyhealth.csv") + TCGARNASeqDataset._prepare_metadata(root, 64, None, rnaseq_out, clinical_out) + assert os.path.exists(rnaseq_out) + df = pd.read_csv(rnaseq_out) + gene_cols = [c for c in df.columns if c not in ("patient_id", "cohort")] + values = df[gene_cols].values + assert values.min() >= 0 + assert values.max() < 64 + + def test_gene_file_written(self, tmp_path): + root = str(tmp_path) + _make_synthetic_rnaseq(root) + _make_synthetic_clinical(root) + from pyhealth.datasets.tcga_rnaseq import TCGARNASeqDataset + rnaseq_out = os.path.join(root, "tcga_rnaseq_tokenized-pyhealth.csv") + clinical_out = os.path.join(root, "tcga_rnaseq_clinical-pyhealth.csv") + TCGARNASeqDataset._prepare_metadata(root, 64, None, rnaseq_out, clinical_out) + gene_file = os.path.join(root, "tcga_rnaseq_genes.txt") + assert os.path.exists(gene_file) + with open(gene_file) as f: + genes = [l.strip() for l in f if l.strip()] + assert len(genes) == N_GENES + + def test_placeholder_created_when_no_raw(self, tmp_path): + root = str(tmp_path) + from pyhealth.datasets.tcga_rnaseq import TCGARNASeqDataset + rnaseq_out = os.path.join(root, "tcga_rnaseq_tokenized-pyhealth.csv") + clinical_out = os.path.join(root, "tcga_rnaseq_clinical-pyhealth.csv") + TCGARNASeqDataset._prepare_metadata(root, 64, None, rnaseq_out, clinical_out) + assert os.path.exists(rnaseq_out) + assert os.path.exists(clinical_out) + + def test_n_genes_filtering(self, tmp_path): + root = str(tmp_path) + _make_synthetic_rnaseq(root) + _make_synthetic_clinical(root) + from pyhealth.datasets.tcga_rnaseq import TCGARNASeqDataset + rnaseq_out = os.path.join(root, "tcga_rnaseq_tokenized-pyhealth.csv") + clinical_out = os.path.join(root, "tcga_rnaseq_clinical-pyhealth.csv") + TCGARNASeqDataset._prepare_metadata(root, 64, 10, rnaseq_out, clinical_out) + df = pd.read_csv(rnaseq_out) + gene_cols = [c for c in df.columns if c not in ("patient_id", "cohort")] + assert len(gene_cols) == 10 + + +class TestTCGARNASeqDatasetIntegration: + + def test_instantiate_and_runtime_config(self, tmp_path): + root = str(tmp_path) + _make_synthetic_rnaseq(root) + _make_synthetic_clinical(root) + from pyhealth.datasets.tcga_rnaseq import TCGARNASeqDataset + + cache = os.path.join(root, "ds_cache") + ds = TCGARNASeqDataset( + root=root, + n_bins=32, + n_genes=8, + cache_dir=cache, + num_workers=1, + dev=False, + ) + assert os.path.isfile(os.path.join(root, "tcga_rnaseq_pyhealth_config.yaml")) + assert len(ds.gene_names) == 8 + assert len(ds.unique_patient_ids) == N_PATIENTS + + def test_get_patient_rnaseq_event_token_order(self, tmp_path): + root = str(tmp_path) + _make_synthetic_rnaseq(root) + _make_synthetic_clinical(root) + from pyhealth.datasets.tcga_rnaseq import TCGARNASeqDataset + from pyhealth.tasks.tcga_rnaseq_tasks import _extract_token_ids + + cache = os.path.join(root, "cache_evt") + ds = TCGARNASeqDataset( + root=root, + n_bins=64, + n_genes=6, + cache_dir=cache, + num_workers=1, + dev=False, + ) + pid = ds.unique_patient_ids[0] + patient = ds.get_patient(pid) + rnaseq_events = patient.get_events(event_type="rnaseq") + assert len(rnaseq_events) >= 1 + toks = _extract_token_ids(rnaseq_events[0]) + assert len(toks) == 6 + assert all(0 <= t < 64 for t in toks) + + def test_set_task_default_cancer_type_smoke(self, tmp_path): + root = str(tmp_path) + _make_synthetic_rnaseq(root) + _make_synthetic_clinical(root) + from pyhealth.datasets.tcga_rnaseq import TCGARNASeqDataset + + cache = os.path.join(root, "cache_task") + ds = TCGARNASeqDataset( + root=root, + n_bins=64, + n_genes=5, + cache_dir=cache, + num_workers=1, + dev=False, + ) + samples = ds.set_task(num_workers=1) + assert len(samples) >= 1 + row = samples[0] + assert "token_ids" in row + assert "cancer_type" in row + + +class _FakeEvent: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +class _FakePatient: + def __init__(self, patient_id, rnaseq_events, clinical_events): + self.patient_id = patient_id + self._events = {"rnaseq": rnaseq_events, "clinical": clinical_events} + + def get_events(self, event_type): + return self._events.get(event_type, []) + + +def _make_rnaseq_event(cohort="BRCA", n_genes=10, n_bins=64): + kwargs = {"cohort": cohort} + for i in range(n_genes): + kwargs[f"GENE{i}"] = np.random.randint(0, n_bins) + return _FakeEvent(**kwargs) + + +class TestTCGACancerTypeTask: + + def test_returns_sample_with_token_ids(self): + from pyhealth.tasks.tcga_rnaseq_tasks import TCGACancerTypeTask + task = TCGACancerTypeTask() + event = _make_rnaseq_event("BRCA") + patient = _FakePatient("P1", [event], []) + samples = task(patient) + assert len(samples) == 1 + assert "token_ids" in samples[0] + assert samples[0]["cancer_type"] == "BRCA" + + def test_cohort_filter_excludes_other_cohorts(self): + from pyhealth.tasks.tcga_rnaseq_tasks import TCGACancerTypeTask + task = TCGACancerTypeTask(cohorts=["LUAD"]) + event = _make_rnaseq_event("BRCA") + patient = _FakePatient("P1", [event], []) + assert task(patient) == [] + + def test_cohort_filter_includes_matching_cohort(self): + from pyhealth.tasks.tcga_rnaseq_tasks import TCGACancerTypeTask + task = TCGACancerTypeTask(cohorts=["BRCA", "LUAD"]) + event = _make_rnaseq_event("LUAD") + patient = _FakePatient("P1", [event], []) + assert len(task(patient)) == 1 + + def test_missing_rnaseq_returns_empty(self): + from pyhealth.tasks.tcga_rnaseq_tasks import TCGACancerTypeTask + task = TCGACancerTypeTask() + assert task(_FakePatient("P1", [], [])) == [] + + def test_missing_cohort_returns_empty(self): + from pyhealth.tasks.tcga_rnaseq_tasks import TCGACancerTypeTask + task = TCGACancerTypeTask() + event = _FakeEvent(cohort=None, GENE0=5, GENE1=10) + assert task(_FakePatient("P1", [event], [])) == [] + + +class TestTCGASurvivalTask: + + def _make_clinical(self, vital="dead", days_death=365.0, days_follow=None): + return _FakeEvent( + vital_status=vital, + days_to_death=days_death, + days_to_last_follow_up=days_follow, + ) + + def test_deceased_patient_returns_sample(self): + from pyhealth.tasks.tcga_rnaseq_tasks import TCGASurvivalTask + task = TCGASurvivalTask() + patient = _FakePatient( + "P1", [_make_rnaseq_event("BRCA")], [self._make_clinical("dead", 365.0)] + ) + samples = task(patient) + assert len(samples) == 1 + assert samples[0]["event"] == 1 + assert samples[0]["survival_time"] == pytest.approx(365.0) + + def test_censored_patient_returns_sample(self): + from pyhealth.tasks.tcga_rnaseq_tasks import TCGASurvivalTask + task = TCGASurvivalTask() + patient = _FakePatient( + "P1", + [_make_rnaseq_event("LUAD")], + [self._make_clinical("alive", None, 700.0)], + ) + samples = task(patient) + assert len(samples) == 1 + assert samples[0]["event"] == 0 + assert samples[0]["survival_time"] == pytest.approx(700.0) + + def test_missing_clinical_returns_empty(self): + from pyhealth.tasks.tcga_rnaseq_tasks import TCGASurvivalTask + task = TCGASurvivalTask() + assert task(_FakePatient("P1", [_make_rnaseq_event()], [])) == [] + + def test_unknown_vital_status_returns_empty(self): + from pyhealth.tasks.tcga_rnaseq_tasks import TCGASurvivalTask + task = TCGASurvivalTask() + patient = _FakePatient( + "P1", + [_make_rnaseq_event()], + [self._make_clinical("unknown", None, None)], + ) + assert task(patient) == [] + + def test_cohort_filter(self): + from pyhealth.tasks.tcga_rnaseq_tasks import TCGASurvivalTask + task = TCGASurvivalTask(cohorts=["BLCA"]) + patient = _FakePatient( + "P1", [_make_rnaseq_event("BRCA")], [self._make_clinical("dead", 200.0)] + ) + assert task(patient) == [] \ No newline at end of file