Skip to content

model + task + dataset implementation, tested#1112

Open
sanyamahesh2 wants to merge 2 commits intosunlabuiuc:masterfrom
sanyamahesh2:bulkrnabert-sanyam2
Open

model + task + dataset implementation, tested#1112
sanyamahesh2 wants to merge 2 commits intosunlabuiuc:masterfrom
sanyamahesh2:bulkrnabert-sanyam2

Conversation

@sanyamahesh2
Copy link
Copy Markdown

Contributor: Sanya Mahesh (sanyam2@illinois.edu)

Contribution Type: Dataset + Task + Model (Full Pipeline)

Original Paper: Gélard et al., "BulkRNABert: Cancer prognosis from bulk RNA-seq based language models", bioRxiv 2024. https://doi.org/10.1101/2024.06.18.599483

Description
This PR implements a full PyHealth pipeline including dataset loading and preprocessing, two downstream tasks, and the BulkRNABert model.
The examples script investigates three design choices not ablated in the original paper: binning resolution (B ∈ {32, 64, 128}), frozen backbone vs IA3 vs full fine-tuning, and Cox loss behavior on censored cohorts.

Files to Review

Dataset

  1. pyhealth/datasets/tcga_rnaseq.py — TCGARNASeqDataset implementing the BulkRNABert preprocessing pipeline (log transform, max-norm, binning)
  2. pyhealth/datasets/configs/tcga_rnaseq.yaml — Dataset config defining rnaseq and clinical table loading
  3. pyhealth/datasets/init.py — Added import for TCGARNASeqDataset

Tasks

  1. pyhealth/tasks/tcga_rnaseq_tasks.py — Two tasks: TCGACancerTypeTask (pan-cancer or cohort-restricted classification) and TCGASurvivalTask (Cox survival prediction with right-censoring support)
  2. pyhealth/tasks/init.py — Added imports for TCGACancerTypeTask and TCGASurvivalTask

Model

  1. pyhealth/models/bulk_rna_bert.py — BulkRNABert transformer model with gene embeddings, MLM pre-training head, classification and survival MLP heads, IA3 fine-tuning support, and Cox partial likelihood loss
  2. pyhealth/models/init.py — Added import for BulkRNABert

Unit tests

  1. tests/test_tcga_rnaseq.py — 15 tests for dataset preprocessing and both tasks using synthetic data only
  2. tests/test_bulk_rna_bert.py — 17 tests for model forward pass, output shapes, gradient flow, and Cox loss using synthetic tensors only
  • 32 unit tests (all passing)

To run tests:
enter into python3.12 venv
source venv312/bin/activate
install dependencies
pip install -e .
run tests
pytest tests/test_bulk_rna_bert.py tests/test_tcga_rnaseq.py -v

Full pipeline/example

  1. examples/tcga_rnaseq_cancer_type_bulk_rna_bert.py — Full ablation study runnable on synthetic data covering binning resolution, fine-tuning strategy, and survival loss

Notes
All tests use synthetic data only, no real TCGA download required
The examples script runs entirely on synthetic data. Swap in real TCGA data by replacing the make_synthetic_data call with your downloaded rna_seq.csv and clinical.csv
conftest.py - If torch.uint16 is missing, set torch.uint16 = torch.int16 before any code imports litdata - fixes strict older stacks
Data available at https://portal.gdc.cancer.gov/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant