diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8d9a59d21..404d734ae 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -241,6 +241,7 @@ Available Datasets datasets/pyhealth.datasets.PhysioNetDeIDDataset datasets/pyhealth.datasets.TUABDataset datasets/pyhealth.datasets.TUEVDataset + datasets/pyhealth.datasets.CCEPECoGDataset datasets/pyhealth.datasets.ClinVarDataset datasets/pyhealth.datasets.COSMICDataset datasets/pyhealth.datasets.TCGAPRADDataset diff --git a/docs/api/datasets/pyhealth.datasets.CCEPECoGDataset.rst b/docs/api/datasets/pyhealth.datasets.CCEPECoGDataset.rst new file mode 100644 index 000000000..9c8cfcb19 --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.CCEPECoGDataset.rst @@ -0,0 +1,9 @@ +pyhealth.datasets.CCEPECoGDataset +================================= + +The open CCEP ECoG dataset of electrocorticography (ECoG) recordings from patients undergoing epilepsy surgery, refer to `OpenNeuro `_ for more information. + +.. autoclass:: pyhealth.datasets.CCEPECoGDataset + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/models.rst b/docs/api/models.rst index 7c3ac7c4b..bf586db4c 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -206,3 +206,5 @@ API Reference models/pyhealth.models.BIOT models/pyhealth.models.unified_multimodal_embedding_docs models/pyhealth.models.califorest + models/pyhealth.models.SPESResNet + models/pyhealth.models.SPESTransformer \ No newline at end of file diff --git a/docs/api/models/pyhealth.models.SPESResNet.rst b/docs/api/models/pyhealth.models.SPESResNet.rst new file mode 100644 index 000000000..0f157e6b8 --- /dev/null +++ b/docs/api/models/pyhealth.models.SPESResNet.rst @@ -0,0 +1,9 @@ +pyhealth.models.SPESResNet +========================== + +Multi-scale 1D ResNet for electrode-level seizure onset zone localization from CCEP ECoG data. + +.. autoclass:: pyhealth.models.SPESResNet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/models/pyhealth.models.SPESTransformer.rst b/docs/api/models/pyhealth.models.SPESTransformer.rst new file mode 100644 index 000000000..a29a0018a --- /dev/null +++ b/docs/api/models/pyhealth.models.SPESTransformer.rst @@ -0,0 +1,9 @@ +pyhealth.models.SPESTransformer +================================ + +CNN-Transformer hybrid for electrode-level seizure onset zone localization from CCEP ECoG data. + +.. autoclass:: pyhealth.models.SPESTransformer + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..f22294a43 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -213,6 +213,7 @@ Available Tasks DKA Prediction (MIMIC-IV) Drug Recommendation Length of Stay Prediction + Localize Seizure Onset Zone (SOZ) Medical Transcriptions Classification Mortality Prediction (Next Visit) Mortality Prediction (StageNet MIMIC-IV) diff --git a/docs/api/tasks/pyhealth.tasks.LocalizeSOZ.rst b/docs/api/tasks/pyhealth.tasks.LocalizeSOZ.rst new file mode 100644 index 000000000..858144a6c --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.LocalizeSOZ.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.LocalizeSOZ +========================== + +.. autoclass:: pyhealth.tasks.localize_soz.LocalizeSOZ + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/ccep_ecog_localize_soz_spes.py b/examples/ccep_ecog_localize_soz_spes.py new file mode 100644 index 000000000..be9a9b422 --- /dev/null +++ b/examples/ccep_ecog_localize_soz_spes.py @@ -0,0 +1,568 @@ +# ============================================================================= +# Experimental Setup +# +# Commands run (5-fold patient-level cross-validation, seed=42): +# +# python examples/ceep_ecog_localize_soz_spes.py --model cnn-divergent +# python examples/ceep_ecog_localize_soz_spes.py --model cnn-convergent +# python examples/ccep_ecog_localize_soz_spes.py --model cnn-transformer +# python examples/ccep_ecog_localize_soz_spes.py --model cnn-transformer --lr 1e-4 --dropout 0.5 +# python examples/ccep_ecog_localize_soz_spes.py --model cnn-transformer-ablation +# +# cnn-transformer is run twice: once with paper-tuned hyperparameters, and once with hyperparameters +# matched to the ablation (lr=1e-4, dropout=0.5) to enable a fair comparison. +# cnn-transformer-ablation uses the same matched hyperparameters but removes the std response mode +# and the MLP prefix from the convergent encoder, isolating their contribution. +# ============================================================================= +import argparse +from functools import partial + +import numpy as np +from sklearn.model_selection import KFold +from sklearn.metrics import roc_auc_score, roc_curve +import torch +from torch.utils.data import DataLoader + +from pyhealth.datasets import CCEPECoGDataset +from pyhealth.models import SPESResNet, SPESTransformer +from pyhealth.tasks.localize_soz import LocalizeSOZ +from pyhealth.trainer import Trainer + + +ROOT = "./data/ds004080" +CACHE_DIR = "./ccep_ecog" + + +# NOTE: MODEL_PRESETS hyperparameters for cnn-divergent, cnn-convergent, and cnn-transformer +# were optimized via Optuna in the original paper (https://proceedings.mlr.press/v252/norris24a.html) +# and pulled directly from the paper's codebase (https://github.com/norrisjamie23/Localising_SOZ_from_SPES). +MODEL_PRESETS = { + "cnn-divergent": { + "learning_rate": 0.003962831229235175, + "model_kwargs": { + "dropout_rate": 0.21763415739071962, + "input_channels": 49, + }, + }, + "cnn-convergent": { + "learning_rate": 0.001289042623854371, + "model_kwargs": { + "dropout_rate": 0.44374819954858546, + "input_channels": 37, + }, + }, + "cnn-transformer": { + "learning_rate": 0.003368045116199473, + "model_kwargs": { + "dropout_rate": 0.4391902174353594, + "embedding_dim": 2**4, + "num_layers": 2**1, + }, + }, + # Ablation: removes the std response mode and MLP prefix to isolate + # the contribution of trial variability and the hybrid embedding. + "cnn-transformer-ablation": { + "learning_rate": 1e-4, + "model_kwargs": { + "dropout_rate": 0.5, + "embedding_dim": 2**4, + "num_layers": 2**1, + }, + }, +} + + +def pad_tensor_to_shape(value, shape): + pad = [] + for current, target in zip(reversed(value.shape), reversed(shape)): + pad.extend([0, target - current]) + return torch.nn.functional.pad(value, pad) + + +def compute_norm_stats(dataset, keys=("X_stim", "X_recording")): + """Compute normalization statistics from the training dataset only. + + This must be called exclusively on the training split and the resulting + statistics applied to all splits (train, val, test). Computing stats on + the full dataset would leak val/test distribution information into + normalization, invalidating the evaluation. + + Distances are stored at position 0 of the last dim in cached tensors. + Stats are computed using the paper's per-sample averaging approach: mean + and std are averaged across samples rather than computed globally, to avoid + samples with more channels dominating the statistics. + """ + dist_values = [] + ts_sample_means = [] + ts_sample_stds = [] + + for sample in dataset: + for key in keys: + if key not in sample: + continue + x = sample[key] # (modes, chans, T+1), distance at position 0 + dist = x[0, :, 0] + valid = dist > 0 + if valid.any(): + dist_values.extend(dist[valid].tolist()) + ts = x[:, :, 1:] # (modes, chans, T) + ts_std = ts.std(dim=-1) # (modes, chans) + nonzero = ts_std > 0 + if nonzero.any(): + ts_sample_means.append(ts[nonzero.unsqueeze(-1).expand_as(ts)].mean().item()) + ts_sample_stds.append(ts_std[nonzero].mean().item()) + + return { + "mean_dist": float(np.mean(dist_values)) if dist_values else 0.0, + "std_dist": float(np.std(dist_values)) if len(dist_values) > 1 else 1.0, + "mean_ts": float(np.mean(ts_sample_means)) if ts_sample_means else 0.0, + "std_ts": float(np.mean(ts_sample_stds)) if ts_sample_stds else 1.0, + } + + +def normalize_spes_tensor(x, norm_stats): + """Apply z-score normalization to a single SPES tensor using pre-computed stats. + + Distances (position 0 of last dim) and time series (positions 1:) are + normalized separately, since they have different units and scales. Only + non-zero-padded entries are normalized; padded channels (zero distance, + zero-std time series) are left as zero so the model can distinguish real + data from padding. + """ + x = x.clone() + dist_mask = x[..., 0] > 0 + if dist_mask.any(): + x[..., 0][dist_mask] = ( + x[..., 0][dist_mask] - norm_stats["mean_dist"] + ) / norm_stats["std_dist"] + ts = x[..., 1:] + ts_std = ts.std(dim=-1) + ts_mask = (ts_std > 0).unsqueeze(-1).expand_as(ts) + if ts_mask.any(): + x[..., 1:][ts_mask] = ( + x[..., 1:][ts_mask] - norm_stats["mean_ts"] + ) / norm_stats["std_ts"] + return x + + +def collate_spes_batch(batch, norm_stats=None): + """Collate a batch of SPES samples, padding variable-length tensors and applying normalization. + + NOTE: PyHealth's built-in get_dataloader (pyhealth.datasets.utils) is not + used here for two reasons that are fundamental to this task: + + 1. Variable-shape inputs: X_stim and X_recording have a variable number + of trials (rows) per electrode, determined by how many stimulation + events were recorded for each channel. This varies across electrodes + and patients, so samples within a batch cannot be stacked without + padding. PyHealth's default collate_fn_dict_with_padding does not + handle this multi-dimensional, field-specific padding. + + 2. Per-fold normalization at collate time: z-score statistics are + computed from the training split only and must be injected into the + collate function via functools.partial. PyHealth's get_dataloader + accepts no such hook, so normalization would have to happen elsewhere + (e.g., in the model or dataset), breaking the clean separation between + preprocessing and model logic. + """ + collated = {} + for key in batch[0].keys(): + values = [sample[key] for sample in batch] + if key in {"X_stim", "X_recording"}: + max_shape = tuple(max(value.shape[dim] for value in values) for dim in range(values[0].dim())) + stacked = torch.stack( + [pad_tensor_to_shape(value, max_shape) for value in values] + ) + if norm_stats is not None: + stacked = normalize_spes_tensor(stacked, norm_stats) + collated[key] = stacked + elif isinstance(values[0], torch.Tensor): + if all(value.shape == values[0].shape for value in values): + collated[key] = torch.stack(values) + elif values[0].dim() == 0: + collated[key] = torch.stack(values) + else: + max_shape = tuple(max(value.shape[dim] for value in values) for dim in range(values[0].dim())) + collated[key] = torch.stack( + [pad_tensor_to_shape(value, max_shape) for value in values] + ) + else: + collated[key] = values + return collated + + +def get_spes_dataloader(dataset, batch_size, shuffle=False, norm_stats=None): + """Create a DataLoader with SPES-specific collation and normalization. + + norm_stats must be computed from the training split only (via + compute_norm_stats) and passed to all splits so that val and test are + normalized using training-set statistics, preventing data leakage. + """ + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=partial(collate_spes_batch, norm_stats=norm_stats), + ) + + +def build_model(model_name, sample_dataset, pos_weight=None, dropout_rate=None): + model_kwargs = dict(MODEL_PRESETS[model_name]["model_kwargs"]) + if dropout_rate is not None: + model_kwargs["dropout_rate"] = dropout_rate + if model_name == "cnn-divergent": + return SPESResNet( + dataset=sample_dataset, + input_type="divergent", + pos_weight=pos_weight, + **model_kwargs, + ) + if model_name == "cnn-convergent": + return SPESResNet( + dataset=sample_dataset, + input_type="convergent", + pos_weight=pos_weight, + **model_kwargs, + ) + if model_name == "cnn-transformer": + return SPESTransformer( + dataset=sample_dataset, + net_configs=[ + {"type": "convergent", "mean": True, "std": True}, + ], + pos_weight=pos_weight, + **model_kwargs, + ) + if model_name == "cnn-transformer-ablation": + return SPESTransformer( + dataset=sample_dataset, + net_configs=[ + {"type": "convergent", "mean": True, "std": False}, + ], + mlp_embedding=False, + pos_weight=pos_weight, + **model_kwargs, + ) + raise ValueError(f"Unknown model: {model_name}") + + +def compute_pos_weight(dataset): + labels = np.array([int(sample["soz"].item()) for sample in dataset]) + positives = int((labels == 1).sum()) + negatives = int((labels == 0).sum()) + if positives == 0: + return 1.0 + return negatives / positives + + +def split_by_patient_kfold(dataset, fold=0, n_splits=5, seed=0): + patient_ids = np.array(sorted(dataset.patient_to_index.keys())) + if len(patient_ids) < n_splits: + raise ValueError( + f"Need at least n_splits patients; got {len(patient_ids)} patients " + f"and n_splits={n_splits}." + ) + + kfold = KFold(n_splits=n_splits, shuffle=True, random_state=seed) + splits = list(kfold.split(patient_ids)) + fold = fold % n_splits + + test_patient_idx = splits[fold][1] + val_patient_idx = splits[(fold + 1) % n_splits][1] + train_patient_idx = np.array( + sorted(set(splits[fold][0]) - set(val_patient_idx)) + ) + + def patient_indices_to_sample_indices(patient_idx): + sample_indices = [] + for index in patient_idx: + sample_indices.extend(dataset.patient_to_index[patient_ids[index]]) + return sample_indices + + train_dataset = dataset.subset(patient_indices_to_sample_indices(train_patient_idx)) + val_dataset = dataset.subset(patient_indices_to_sample_indices(val_patient_idx)) + test_dataset = dataset.subset(patient_indices_to_sample_indices(test_patient_idx)) + return train_dataset, val_dataset, test_dataset + + +def youden_score(y_true, y_pred): + y_true = np.asarray(y_true).reshape(-1).astype(int) + y_pred = np.asarray(y_pred).reshape(-1).astype(int) + + true_positive = int(((y_true == 1) & (y_pred == 1)).sum()) + true_negative = int(((y_true == 0) & (y_pred == 0)).sum()) + false_positive = int(((y_true == 0) & (y_pred == 1)).sum()) + false_negative = int(((y_true == 1) & (y_pred == 0)).sum()) + + sensitivity = ( + true_positive / (true_positive + false_negative) + if true_positive + false_negative > 0 + else float("nan") + ) + specificity = ( + true_negative / (true_negative + false_positive) + if true_negative + false_positive > 0 + else float("nan") + ) + return sensitivity, specificity, sensitivity + specificity - 1 + + +def select_youden_threshold(y_true, y_prob): + y_true = np.asarray(y_true).reshape(-1).astype(int) + y_prob = np.asarray(y_prob).reshape(-1) + if len(np.unique(y_true)) < 2: + return 0.5 + + false_positive_rate, true_positive_rate, thresholds = roc_curve(y_true, y_prob) + youden_values = true_positive_rate - false_positive_rate + threshold = thresholds[int(np.argmax(youden_values))] + if np.isfinite(threshold): + return float(threshold) + return 0.5 + + +def safe_roc_auc(y_true, y_prob): + y_true = np.asarray(y_true).reshape(-1).astype(int) + y_prob = np.asarray(y_prob).reshape(-1) + if len(np.unique(y_true)) < 2: + return float("nan") + return roc_auc_score(y_true, y_prob) + + +def compute_soz_metrics(y_true, y_prob, patient_ids, threshold=0.5): + y_true = np.asarray(y_true).reshape(-1).astype(int) + y_prob = np.asarray(y_prob).reshape(-1) + patient_ids = np.asarray(patient_ids) + + patient_aucs = [] + baselines = [] + youdens = [] + specificities = [] + sensitivities = [] + + for patient_id in np.unique(patient_ids): + patient_mask = patient_ids == patient_id + patient_true = y_true[patient_mask] + patient_prob = y_prob[patient_mask] + patient_pred = patient_prob > threshold + + patient_aucs.append(safe_roc_auc(patient_true, patient_prob)) + baselines.append(np.mean(patient_true)) + + sensitivity, specificity, youden = youden_score(patient_true, patient_pred) + sensitivities.append(sensitivity) + specificities.append(specificity) + youdens.append(youden) + + return { + "Baseline": np.nanmean(baselines), + "AUROC": np.nanmean(patient_aucs), + "Specificity": np.nanmean(specificities), + "Sensitivity": np.nanmean(sensitivities), + "Youden": np.nanmean(youdens), + "Youden threshold": threshold, + } + + +def run_fold(args, sample_dataset, fold): + train_dataset, val_dataset, test_dataset = split_by_patient_kfold( + sample_dataset, + fold=fold, + n_splits=args.n_splits, + seed=args.seed, + ) + print("\nPatient-level k-fold split sizes:") + print(f" fold: {fold} / {args.n_splits}") + print(f" train: {len(train_dataset)}") + print(f" val: {len(val_dataset)}") + print(f" test: {len(test_dataset)}") + + if min(len(train_dataset), len(val_dataset), len(test_dataset)) == 0: + print("\nSkipping fold because at least one split is empty.") + return None + + pos_weight = compute_pos_weight(train_dataset) + model = build_model(args.model, sample_dataset, pos_weight=pos_weight, dropout_rate=args.dropout) + print(f"\nInitialized model: {model.__class__.__name__} ({args.model})") + learning_rate = ( + args.lr + if args.lr is not None + else MODEL_PRESETS[args.model]["learning_rate"] + ) + print(f"Using learning rate: {learning_rate}") + print(f"Using positive class weight: {pos_weight}") + + print("\nComputing normalization statistics from training set...") + norm_stats = compute_norm_stats(train_dataset) + print(f" mean_dist={norm_stats['mean_dist']:.4f}, std_dist={norm_stats['std_dist']:.4f}") + print(f" mean_ts={norm_stats['mean_ts']:.6f}, std_ts={norm_stats['std_ts']:.6f}") + + train_loader = get_spes_dataloader( + train_dataset, + batch_size=min(args.batch_size, len(train_dataset)), + shuffle=True, + norm_stats=norm_stats, + ) + val_loader = get_spes_dataloader( + val_dataset, + batch_size=min(args.batch_size, len(val_dataset)), + shuffle=False, + norm_stats=norm_stats, + ) + test_loader = get_spes_dataloader( + test_dataset, + batch_size=min(args.batch_size, len(test_dataset)), + shuffle=False, + norm_stats=norm_stats, + ) + + trainer = Trainer( + model=model, + device=args.device, + metrics=["roc_auc"], + ) + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=args.epochs, + optimizer_class=torch.optim.AdamW, + optimizer_params={"lr": learning_rate}, + monitor="roc_auc", + monitor_criterion="max", + patience=args.patience, + ) + + val_y_true, val_y_prob, _ = trainer.inference(val_loader) + threshold = select_youden_threshold(val_y_true, val_y_prob) + print(f"Using decision threshold: {threshold}") + + y_true, y_prob, test_loss, patient_ids = trainer.inference( + test_loader, + return_patient_ids=True, + ) + results = compute_soz_metrics( + y_true, + y_prob, + patient_ids, + threshold=threshold, + ) + results["loss"] = test_loss + print("\nTest results:") + for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + + results["model"] = args.model + results["seed"] = args.seed + results["fold"] = fold + return results + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + default="cnn-transformer", + choices=[ + "cnn-divergent", + "cnn-convergent", + "cnn-transformer", + "cnn-transformer-ablation", + ], + ) + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--patience", type=int, default=10) + parser.add_argument( + "--lr", + type=float, + default=None, + help="Override the selected model preset learning rate.", + ) + parser.add_argument( + "--dropout", + type=float, + default=None, + help="Override the selected model preset dropout rate.", + ) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--fold", + type=int, + default=None, + help="Run one fold only. Defaults to all folds.", + ) + parser.add_argument("--n-splits", type=int, default=5) + parser.add_argument("--device", type=str, default=None) + args = parser.parse_args() + + dataset = CCEPECoGDataset( + root=ROOT, + dev=False, + cache_dir=CACHE_DIR, + num_workers=1, + ) + + print("\nBuilding PyHealth LocalizeSOZ sample dataset...") + sample_dataset = dataset.set_task(LocalizeSOZ(), num_workers=1) + + print(f"PyHealth task name: {sample_dataset.task_name}") + print(f"Total electrode samples: {len(sample_dataset)}") + + positive_task_patients = set() + positive_electrodes = 0 + for sample in sample_dataset: + patient_id = sample["patient_id"] + if int(sample["soz"].item()) == 1: + positive_task_patients.add(patient_id) + positive_electrodes += 1 + + print(f"Patients with positive SOZ electrode samples: {len(positive_task_patients)}") + print(f"Positive SOZ electrode samples: {positive_electrodes}") + + if len(sample_dataset): + sample = sample_dataset[0] + print("\nFirst processed PyHealth sample:") + print(f" patient_id: {sample['patient_id']}") + print(f" record_id: {sample['record_id']}") + print(f" channel: {sample['channel']}") + print(f" soz shape/value: {tuple(sample['soz'].shape)} / {sample['soz'].tolist()}") + print(f" electrode_lobes shape: {tuple(sample['electrode_lobes'].shape)}") + print(f" electrode_coords shape: {tuple(sample['electrode_coords'].shape)}") + print(f" X_stim shape: {tuple(sample['X_stim'].shape)}") + print(f" X_recording shape: {tuple(sample['X_recording'].shape)}") + print(" X_stim mode axis: 0=mean, 1=std") + print(" X_recording mode axis: 0=mean, 1=std") + + if not len(sample_dataset): + return + + folds = [args.fold % args.n_splits] if args.fold is not None else range(args.n_splits) + fold_results = [] + for fold in folds: + print(f"\nSeed {args.seed}, fold {fold + 1}") + result = run_fold(args, sample_dataset, fold) + if result is not None: + fold_results.append(result) + + if not fold_results: + print("\nNo fold results to summarize.") + return + + metric_names = [ + "Baseline", + "AUROC", + "Specificity", + "Sensitivity", + "Youden", + "loss", + ] + if len(fold_results) > 1: + print("\nMean and standard deviation across folds:") + for metric in metric_names: + values = np.array([result[metric] for result in fold_results], dtype=float) + print(f" {metric}: {np.nanmean(values):.4f} +/- {np.nanstd(values):.4f}") + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 50b1b3887..c279873f6 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -48,6 +48,7 @@ def __init__(self, *args, **kwargs): from .base_dataset import BaseDataset from .cardiology import CardiologyDataset +from .ccep_ecog import CCEPECoGDataset from .chestxray14 import ChestXray14Dataset from .clinvar import ClinVarDataset from .cosmic import COSMICDataset diff --git a/pyhealth/datasets/ccep_ecog.py b/pyhealth/datasets/ccep_ecog.py new file mode 100644 index 000000000..5fa5b8ee7 --- /dev/null +++ b/pyhealth/datasets/ccep_ecog.py @@ -0,0 +1,211 @@ +""" +PyHealth dataset for the CCEP ECoG dataset. + +Dataset link: + https://openneuro.org/datasets/ds004080 +""" + +import logging +import os +import shutil +import tempfile +from pathlib import Path +from typing import Optional + +import mne_bids +import pandas as pd + +from pyhealth.datasets import BaseDataset +from pyhealth.tasks.localize_soz import LocalizeSOZ + +logger = logging.getLogger(__name__) + + +class CCEPECoGDataset(BaseDataset): + """Dataset class for the CCEP ECoG dataset. + + Dataset is organized in BIDS format. This class parses and labels subjects who have + all electrodes labeled, including at least one electrode in the Seizure Onset Zone (SOZ). + + The raw BIDS directory should contain patient folders like `sub-`. + + Attributes: + root (str): Root directory of the raw data. + dataset_name (str): Name of the dataset. + config_path (str): Path to the configuration file. + """ + + def __init__( + self, + root: str = ".", + config_path: Optional[str] = str(Path(__file__).parent / "configs" / "ccep_ecog.yaml"), + **kwargs, + ) -> None: + """Initializes the CCEP ECoG dataset. + + Args: + root (str): Root directory of the raw data. Defaults to the working directory. + config_path (Optional[str]): Path to the configuration file. Defaults to "configs/ccep_ecog.yaml". + + Raises: + FileNotFoundError: If the dataset path does not exist. + ValueError: If the dataset does not adhere to the expected BIDS structure. + + Example:: + >>> dataset = CCEPECoGDataset(root="./data/ds004080") + """ + self._verify_data(root) + self._tmp_dir = tempfile.mkdtemp(prefix="pyhealth_ccep_ecog_") + self._index_data(root, self._tmp_dir) + + super().__init__( + root=self._tmp_dir, + tables=["ecog"], + dataset_name="ccep_ecog", + config_path=config_path, + **kwargs, + ) + + def __del__(self) -> None: + if hasattr(self, "_tmp_dir"): + shutil.rmtree(self._tmp_dir, ignore_errors=True) + + def _verify_data(self, root: str) -> None: + """Verifies the presence and structure of the dataset directory. + + Ensures the root path exists, verifies the presence of subject directories as well as + at least one header file and electrode file. + + Args: + root (str): Root directory of the raw data. + + Raises: + FileNotFoundError: If the dataset path does not exist. + ValueError: If the dataset lacks subjects or core BIDS files. + """ + if not os.path.exists(root): + msg = f"Dataset path '{root}' does not exist" + logger.error(msg) + raise FileNotFoundError(msg) + + # Check for presence of subjects + subjects = list(Path(root).glob("sub-*")) + if not subjects: + msg = f"BIDS root '{root}' contains no 'sub-*' subject folders" + logger.error(msg) + raise ValueError(msg) + + # Check for at least one recording + if not any(Path(root).rglob("*.vhdr")): + msg = f"BIDS root '{root}' contains no '.vhdr' files" + logger.error(msg) + raise ValueError(msg) + + # Check for at least one electrode file + if not any(Path(root).rglob("*_electrodes.tsv")): + msg = f"BIDS root '{root}' contains no 'electrodes.tsv' file" + logger.error(msg) + raise ValueError(msg) + + # Check for at least one channels file + if not any(Path(root).rglob("*_channels.tsv")): + msg = f"BIDS root '{root}' contains no 'channels.tsv' files" + logger.error(msg) + raise ValueError(msg) + + # Check for at least one events file + if not any(Path(root).rglob("*_events.tsv")): + msg = f"BIDS root '{root}' contains no 'events.tsv' files" + logger.error(msg) + raise ValueError(msg) + + def _index_data(self, root: str, output_dir: str) -> pd.DataFrame: + """Parses and indexes metadata for all available patients in the dataset. + + Args: + root (str): Root directory of the raw data. + output_dir (str): Directory where the metadata CSV will be written. + + Returns: + pd.DataFrame: Table of patient ECoG signal metadata. + """ + try: + subjects = mne_bids.get_entity_vals(root, "subject") + except FileNotFoundError: + subjects = [] + + rows = [] + root_path = Path(root) + + for sub in subjects: + has_soz = False + patient_dir = root_path / f"sub-{sub}" + + for tsv_file in patient_dir.rglob("*electrodes.tsv"): + try: + df = pd.read_csv(tsv_file, sep="\t") + cols = [c.lower() for c in df.columns] + if "soz" in cols: + col_series = df["soz"].str.lower() + # Verify that there is at least one electrode in the SOZ and all electrodes are labeled + if (col_series == "yes").any() and col_series.isin(["yes", "no"]).all(): + has_soz = True + break + except Exception as e: + logger.warning( + f"Skipping metadata file {tsv_file} due to error: {e}" + ) + continue + + for header_file in patient_dir.rglob("*.vhdr"): + # Single electrodes.tsv file per session + elec_match = list(header_file.parent.glob("*electrodes.tsv")) + electrodes_file = str(elec_match[0]) if elec_match else "" + + # Multiple channels.tsv and events.tsv files per session + # header_file has the same base name as channels.tsv and events.tsv + base_name = header_file.name.replace("_ieeg.vhdr", "") + + chan_path = header_file.parent / f"{base_name}_channels.tsv" + channels_file = str(chan_path) if chan_path.exists() else "" + + evt_path = header_file.parent / f"{base_name}_events.tsv" + events_file = str(evt_path) if evt_path.exists() else "" + + entities = mne_bids.get_entities_from_fname(str(header_file)) + + rows.append( + { + "patient_id": sub, + "session_id": entities.get("session", ""), + "task_id": entities.get("task", ""), + "run_id": entities.get("run", ""), + "header_file": str(header_file), + "electrodes_file": electrodes_file, + "channels_file": channels_file, + "events_file": events_file, + "has_soz": has_soz, + } + ) + + if not rows: + logger.warning( + "No valid BIDS ECoG header files (.vhdr) were found for any subjects. " + "Ensure your root directory matches the BIDS structure (sub-*/ses-*/ieeg/*.vhdr)." + ) + + df = pd.DataFrame(rows) + if not df.empty: + df.sort_values(["patient_id"], inplace=True) + df.reset_index(drop=True, inplace=True) + + output_path = os.path.join(output_dir, "ccep_ecog-metadata-pyhealth.csv") + df.to_csv(output_path, index=False) + logger.info(f"Wrote metadata to {output_path}") + + return df + + @property + def default_task(self) -> LocalizeSOZ: + """Returns the default task for this dataset.""" + return LocalizeSOZ() diff --git a/pyhealth/datasets/configs/ccep_ecog.yaml b/pyhealth/datasets/configs/ccep_ecog.yaml new file mode 100644 index 000000000..fc35c0dbe --- /dev/null +++ b/pyhealth/datasets/configs/ccep_ecog.yaml @@ -0,0 +1,15 @@ +version: "1.0" +tables: + ecog: + file_path: "ccep_ecog-metadata-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "session_id" + - "task_id" + - "run_id" + - "header_file" + - "electrodes_file" + - "channels_file" + - "events_file" + - "has_soz" diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 4c168d3e3..09d5d6fbf 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -24,6 +24,7 @@ from .rnn import MultimodalRNN, RNN, RNNLayer from .safedrug import SafeDrug, SafeDrugLayer from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer +from .spes import SPESResNet, SPESTransformer from .stagenet import StageNet, StageNetLayer from .stagenet_mha import StageAttentionNet, StageNetAttentionLayer from .tcn import TCN, TCNLayer @@ -45,4 +46,4 @@ from .sdoh import SdohClassifier from .medlink import MedLink from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding -from .califorest import CaliForest \ No newline at end of file +from .califorest import CaliForest diff --git a/pyhealth/models/spes.py b/pyhealth/models/spes.py new file mode 100644 index 000000000..81f142f80 --- /dev/null +++ b/pyhealth/models/spes.py @@ -0,0 +1,716 @@ +import random +from typing import Dict, Iterable, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel + + +def _conv1d( + in_planes: int, + out_planes: int, + kernel_size: int, + stride: int = 1, + padding: Optional[int] = None, +) -> nn.Conv1d: + if padding is None: + padding = kernel_size // 2 + return nn.Conv1d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False, + ) + + +class _BasicBlock1D(nn.Module): + expansion = 1 + + def __init__( + self, + inplanes: int, + planes: int, + kernel_size: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + ): + super().__init__() + self.conv1 = _conv1d(inplanes, planes, kernel_size=kernel_size, stride=stride) + self.bn1 = nn.BatchNorm1d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = _conv1d(planes, planes, kernel_size=kernel_size) + self.bn2 = nn.BatchNorm1d(planes) + self.downsample = downsample + + @staticmethod + def _align_time(left: torch.Tensor, right: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + length = min(left.shape[-1], right.shape[-1]) + return left[..., :length], right[..., :length] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + # Trim to the same length before adding since strided convolutions can produce off-by-one mismatches + residual, out = self._align_time(residual, out) + out = out + residual + out = self.relu(out) + return out + + +class MultiScaleResNet1D(nn.Module): + """Multi-scale 1D ResNet backbone used by the SPES models. + + Three parallel residual branches (kernel sizes 3, 5, and 7) each + produce a 256-dimensional pooled feature; their concatenation yields a + 768-dimensional output embedding. + + Args: + input_channel: Number of input channels (signal modes, e.g. 1 or 2). + layers: Number of residual blocks per stage in each branch. + Defaults to ``[1, 1, 1, 1]``. + dropout_rate: Dropout probability applied after the final pooling. + Default is 0.2. + """ + + output_dim = 256 * 3 + + def __init__( + self, + input_channel: int, + layers: Optional[Iterable[int]] = None, + dropout_rate: float = 0.2, + ): + super().__init__() + layers = list(layers or [1, 1, 1, 1]) + + # Track inplanes per kernel size separately so each branch can build its own downsample projections + self.inplanes = {3: 64, 5: 64, 7: 64} + + self.conv1 = nn.Conv1d( + input_channel, 64, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = nn.BatchNorm1d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + self.layer3x3_1 = self._make_layer(3, 64, layers[0], stride=2) + self.layer3x3_2 = self._make_layer(3, 128, layers[1], stride=2) + self.layer3x3_3 = self._make_layer(3, 256, layers[2], stride=2) + + self.layer5x5_1 = self._make_layer(5, 64, layers[0], stride=2) + self.layer5x5_2 = self._make_layer(5, 128, layers[1], stride=2) + self.layer5x5_3 = self._make_layer(5, 256, layers[2], stride=2) + + self.layer7x7_1 = self._make_layer(7, 64, layers[0], stride=2) + self.layer7x7_2 = self._make_layer(7, 128, layers[1], stride=2) + self.layer7x7_3 = self._make_layer(7, 256, layers[2], stride=2) + + self.pool = nn.AdaptiveAvgPool1d(1) + self.drop = nn.Dropout(p=dropout_rate) + + def _make_layer( + self, + kernel_size: int, + planes: int, + blocks: int, + stride: int = 2, + ) -> nn.Sequential: + downsample = None + inplanes = self.inplanes[kernel_size] + if stride != 1 or inplanes != planes * _BasicBlock1D.expansion: + downsample = nn.Sequential( + nn.Conv1d( + inplanes, + planes * _BasicBlock1D.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm1d(planes * _BasicBlock1D.expansion), + ) + + layers: List[nn.Module] = [ + _BasicBlock1D(inplanes, planes, kernel_size, stride, downsample) + ] + self.inplanes[kernel_size] = planes * _BasicBlock1D.expansion + for _ in range(1, blocks): + layers.append( + _BasicBlock1D(self.inplanes[kernel_size], planes, kernel_size) + ) + + return nn.Sequential(*layers) + + def _forward_branch(self, x: torch.Tensor, layers: Iterable[nn.Module]) -> torch.Tensor: + for layer in layers: + x = layer(x) + return self.pool(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + out3 = self._forward_branch( + x, (self.layer3x3_1, self.layer3x3_2, self.layer3x3_3) + ) + out5 = self._forward_branch( + x, (self.layer5x5_1, self.layer5x5_2, self.layer5x5_3) + ) + out7 = self._forward_branch( + x, (self.layer7x7_1, self.layer7x7_2, self.layer7x7_3) + ) + + # Concatenate the three branch embeddings along the feature dimension + out = torch.cat([out3, out5, out7], dim=1) + # Squeeze the trailing dim from AdaptiveAvgPool1d output + out = out[:, :, 0] + return self.drop(out) + + +class SPESResponseEncoder(nn.Module): + """Channel-wise CCEP response encoder combining a ResNet and Transformer. + + Each channel's response (mean and/or std across stimulation trials) is + independently embedded by an optional multi-scale 1D ResNet and/or a + flattened MLP prefix, then aggregated by a Transformer encoder whose + class token serves as the output representation. + + Args: + mean: If ``True``, include the mean CCEP response as an input mode. + std: If ``True``, include the std CCEP response as an input mode. + At least one of ``mean`` or ``std`` must be ``True``. + conv_embedding: If ``True``, embed each channel via + :class:`MultiScaleResNet1D`. Default is ``True``. + mlp_embedding: If ``True`` (and ``conv_embedding`` is ``True``), + prepend a flattened MLP prefix to the ResNet embedding. + Default is ``True``. + dropout_rate: Dropout probability. Default is 0.5. + num_layers: Number of Transformer encoder layers. Default is 2. + embedding_dim: Dimension of the per-channel embedding passed to the + Transformer. Default is 64. + random_channels: If set, randomly sub-sample this many channels per + forward pass. ``None`` uses all channels. Default is ``None``. + noise_std: Std of Gaussian noise injected during training. Default is 0.1. + max_mlp_timesteps: Maximum number of timesteps kept for the MLP prefix. + Default is 155. + expected_timesteps: Expected signal length when ``conv_embedding=False``. + Default is 509. + """ + + def __init__( + self, + mean: bool, + std: bool, + conv_embedding: bool = True, + mlp_embedding: bool = True, + dropout_rate: float = 0.5, + num_layers: int = 2, + embedding_dim: int = 64, + random_channels: Optional[int] = None, + noise_std: float = 0.1, + max_mlp_timesteps: int = 155, + expected_timesteps: int = 509, + ): + super().__init__() + if not (mean or std): + raise ValueError("Either mean or std, or both, must be enabled.") + + self.mean = mean + self.std = std + self.conv_embedding = conv_embedding + self.mlp_embedding = mlp_embedding + self.random_channels = random_channels + self.noise_std = noise_std + self.max_mlp_timesteps = max_mlp_timesteps + self.expected_timesteps = expected_timesteps + + mode_count = int(self.mean) + int(self.std) + if conv_embedding: + self.msresnet = MultiScaleResNet1D( + input_channel=mode_count, dropout_rate=dropout_rate + ) + embedding_in = MultiScaleResNet1D.output_dim + if mlp_embedding: + embedding_in += mode_count * max_mlp_timesteps + else: + embedding_in = mode_count * expected_timesteps + + self.patch_to_embedding = nn.Linear(embedding_in, embedding_dim) + self.dropout = nn.Dropout(dropout_rate) + self.class_token = nn.Parameter( + nn.init.xavier_normal_(torch.empty(1, 1, embedding_dim)) + ) + + # Pick the largest nhead that evenly divides embedding_dim, up to embedding_dim // 8 + nhead = max( + head + for head in range(1, max(1, embedding_dim // 8) + 1) + if embedding_dim % head == 0 + ) + encoder_layer = nn.TransformerEncoderLayer( + d_model=embedding_dim, + nhead=nhead, + dim_feedforward=embedding_dim * 2, + dropout=dropout_rate, + batch_first=True, + ) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layer, num_layers=num_layers + ) + + def _add_noise(self, x: torch.Tensor) -> torch.Tensor: + if self.noise_std <= 0: + return x + return x + torch.randn_like(x) * self.noise_std + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.training: + x = self.apply_noise_and_zero_channels(x) + + if self.random_channels is not None: + x = self.select_random_channels(x, self.random_channels) + + # Zero distance means the channel slot is padding and should be masked out + distances = x[:, 0, :, 0] + key_padding_mask = self.create_key_padding_mask(distances) + + channel_features = self.prepare_channels(x) + x = self.dropout(self.patch_to_embedding(channel_features)) + + # Prepend the learnable class token and aggregate via the Transformer + class_token = self.class_token.repeat(x.shape[0], 1, 1) + x = torch.cat((class_token, x), dim=1) + x = self.transformer_encoder(x, src_key_padding_mask=key_padding_mask) + # Return only the class token output as the sequence-level embedding + return x[:, 0] + + def apply_noise_and_zero_channels(self, x: torch.Tensor) -> torch.Tensor: + x = x.clone() + # Identify non-padded channels by summing distances across the batch + valid_columns = torch.where(x[:, 0, :, 0].sum(dim=0) != 0)[0] + if len(valid_columns) > 0: + # Randomly zero out up to half the valid channels for regularization + sample_size = random.randint(0, len(valid_columns) // 2) + if sample_size > 0: + random_indices = torch.randperm( + len(valid_columns), device=x.device + )[:sample_size] + x[:, :, valid_columns[random_indices], :] = 0 + + # Add noise only to the time series (positions 1:), not the distance (position 0) + x[:, :, :, 1:] = self._add_noise(x[:, :, :, 1:]) + return x + + def select_random_channels( + self, x: torch.Tensor, num_channels: int + ) -> torch.Tensor: + all_x = [] + distances = x[:, 0, :, 0] + for single_sample, distance in zip(x, distances): + valid_rows = torch.where(distance != 0)[0] + if len(valid_rows) == 0: + raise ValueError("SPES input contains a sample with no valid channels.") + replacement = len(valid_rows) < num_channels + p = torch.ones(len(valid_rows), device=x.device) / len(valid_rows) + idx = p.multinomial(num_samples=num_channels, replacement=replacement) + channels = valid_rows[idx].sort()[0] + all_x.append(single_sample[:, channels]) + return torch.stack(all_x, dim=0) + + @staticmethod + def create_key_padding_mask(distances: torch.Tensor) -> torch.Tensor: + key_padding_mask = distances == 0 + # Prepend a False column for the class token, which is never masked + false_column = torch.zeros( + distances.size(0), 1, dtype=torch.bool, device=distances.device + ) + return torch.cat([false_column, key_padding_mask], dim=1) + + def _selected_modes(self, x: torch.Tensor) -> torch.Tensor: + if self.mean and self.std: + return x[:, :, :, 1:] + if self.mean: + return x[:, :1, :, 1:] + return x[:, 1:, :, 1:] + + def _selected_modes_with_distance(self, x: torch.Tensor) -> torch.Tensor: + if self.mean and self.std: + return x + if self.mean: + return x[:, :1] + return x[:, 1:] + + def prepare_channels(self, x: torch.Tensor) -> torch.Tensor: + mode_count = int(self.mean) + int(self.std) + if self.conv_embedding: + conv_input = self._selected_modes(x) + batch_size, modes, channels, timesteps = conv_input.shape + # Merge batch and channel dims so each channel is processed independently + conv_input = conv_input.reshape(-1, modes, timesteps) + + late_output = self.msresnet(conv_input) + late_output = late_output.reshape(batch_size, channels, -1) + + if not self.mlp_embedding: + return late_output + + # Prepend a short flattened prefix to the ResNet embedding + prefix = self._selected_modes_with_distance(x)[ + :, :, :, : self.max_mlp_timesteps + ] + if prefix.shape[-1] < self.max_mlp_timesteps: + pad = self.max_mlp_timesteps - prefix.shape[-1] + prefix = nn.functional.pad(prefix, (0, pad)) + prefix = prefix.swapaxes(1, 2).reshape(batch_size, channels, -1) + return torch.cat([prefix, late_output], dim=-1) + + selected = self._selected_modes(x) + batch_size, _, channels, timesteps = selected.shape + if timesteps < self.expected_timesteps: + selected = nn.functional.pad(selected, (0, self.expected_timesteps - timesteps)) + elif timesteps > self.expected_timesteps: + selected = selected[:, :, :, : self.expected_timesteps] + return selected.swapaxes(1, 2).reshape( + batch_size, channels, mode_count * self.expected_timesteps + ) + + +class SPESResNet(BaseModel): + """Multi-scale 1D CNN classifier for CCEP SPES SOZ localization. + + Randomly sub-samples a fixed number of channels from the input tensor and + passes them through a :class:`MultiScaleResNet1D` followed by a linear + classifier. ``input_type="divergent"`` uses the stimulation-channel view + (``X_stim``) and ``input_type="convergent"`` uses the recording-channel + view (``X_recording``). + + Args: + dataset: The dataset to train the model on. Used to infer label keys + and output size. + input_type: Either ``"divergent"`` (stimulation view) or + ``"convergent"`` (recording view). Default is ``"divergent"``. + input_channels: Number of channels randomly sampled per forward pass. + Default is 40. + stim_key: Key in the sample batch for the stimulation tensor. + Default is ``"X_stim"``. + recording_key: Key in the sample batch for the recording tensor. + Default is ``"X_recording"``. + noise_std: Std of Gaussian noise added to non-distance features during + training. Default is 0.1. + dropout_rate: Dropout probability inside the ResNet. Default is 0.2. + pos_weight: Optional positive-class weight for + ``binary_cross_entropy_with_logits``. Default is ``None``. + **kwargs: Additional keyword arguments forwarded to + :class:`MultiScaleResNet1D`. + + Examples: + >>> import numpy as np + >>> from pyhealth.datasets import create_sample_dataset + >>> n_ch, n_t = 30, 509 + >>> samples = [ + ... { + ... "patient_id": f"p{i}", + ... "visit_id": f"v{i}", + ... "X_stim": np.random.randn(2, n_ch, n_t).astype(np.float32), + ... "X_recording": np.random.randn(2, n_ch, n_t).astype(np.float32), + ... "electrode_lobes": np.array([i % 7], dtype=np.int64), + ... "electrode_coords": np.random.randn(3).astype(np.float32), + ... "soz": i % 2, + ... } + ... for i in range(4) + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"X_stim": "tensor", "X_recording": "tensor", + ... "electrode_lobes": "tensor", "electrode_coords": "tensor"}, + ... output_schema={"soz": "binary"}, + ... dataset_name="test", + ... ) + >>> from pyhealth.models import SPESResNet + >>> model = SPESResNet(dataset=dataset, input_channels=10) + >>> from pyhealth.datasets import get_dataloader + >>> loader = get_dataloader(dataset, batch_size=2, shuffle=False) + >>> batch = next(iter(loader)) + >>> ret = model(**batch) + """ + + def __init__( + self, + dataset: SampleDataset, + input_type: str = "divergent", + input_channels: int = 40, + stim_key: str = "X_stim", + recording_key: str = "X_recording", + noise_std: float = 0.1, + dropout_rate: float = 0.2, + pos_weight: Optional[float] = None, + **kwargs, + ): + super().__init__(dataset=dataset) + if input_type not in {"divergent", "convergent"}: + raise ValueError("input_type must be 'divergent' or 'convergent'.") + if len(self.label_keys) != 1: + raise ValueError("SPESResNet supports exactly one label key.") + + self.input_type = input_type + self.input_channels = input_channels + self.stim_key = stim_key + self.recording_key = recording_key + self.noise_std = noise_std + self.label_key = self.label_keys[0] + if pos_weight is None: + self.pos_weight = None + else: + self.register_buffer( + "pos_weight", + torch.tensor([pos_weight], dtype=torch.float32), + ) + + self.msresnet = MultiScaleResNet1D( + input_channel=input_channels, + dropout_rate=dropout_rate, + **kwargs, + ) + self.fc = nn.Linear(MultiScaleResNet1D.output_dim, self.get_output_size()) + + @property + def feature_key(self) -> str: + return self.stim_key if self.input_type == "divergent" else self.recording_key + + def _add_noise(self, x: torch.Tensor) -> torch.Tensor: + if self.noise_std <= 0: + return x + return x + torch.randn_like(x) * self.noise_std + + def _sample_channels(self, x: torch.Tensor) -> torch.Tensor: + # Distance is stored at position 0 of the last dim + distances = x[:, 0, :, 0] + all_x = [] + + for single_sample, distance in zip(x, distances): + valid_rows = torch.where(distance != 0)[0] + if len(valid_rows) == 0: + raise ValueError("SPES input contains a sample with no valid channels.") + # Sample with replacement when fewer valid channels exist than requested; + # preferable to discarding patients with low channel counts entirely. + replacement = len(valid_rows) < self.input_channels + p = torch.ones(len(valid_rows), device=x.device) / len(valid_rows) + idx = p.multinomial( + num_samples=self.input_channels, replacement=replacement + ) + channels = valid_rows[idx].sort()[0] + # Use only the mean mode (index 0) and skip the distance (last dim index 0) + all_x.append(single_sample[0, channels, 1:]) + + return torch.stack(all_x, dim=0) + + def _compute_loss( + self, + logits: torch.Tensor, + y_true: torch.Tensor, + ) -> torch.Tensor: + if self.pos_weight is None: + return self.get_loss_function()(logits, y_true) + return F.binary_cross_entropy_with_logits( + logits, + y_true, + pos_weight=self.pos_weight.to(logits.device), + ) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + x = kwargs[self.feature_key].to(self.device) + if self.training: + x = x.clone() + x[:, :, :, 1:] = self._add_noise(x[:, :, :, 1:]) + + x = self._sample_channels(x) + emb = self.msresnet(x) + logits = self.fc(emb) + + y_true = kwargs[self.label_key].to(self.device) + loss = self._compute_loss(logits, y_true) + y_prob = self.prepare_y_prob(logits) + + results = { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + "logit": logits, + } + if kwargs.get("embed", False): + results["embed"] = emb + return results + + +class SPESTransformer(BaseModel): + """Transformer-based classifier over SPES CCEP response-channel embeddings. + + One or more :class:`SPESResponseEncoder` networks process the stimulation + and/or recording tensors; their class-token outputs are concatenated and + passed through a linear classifier. + + Args: + dataset: The dataset to train the model on. Used to infer label keys + and output size. + net_configs: List of dicts, each specifying one encoder. Required keys: + + - ``"type"`` (``"divergent"`` or ``"convergent"``): selects + ``X_stim`` or ``X_recording`` as input. + - ``"mean"`` (bool): include the mean response mode. + - ``"std"`` (bool): include the std response mode. + + dropout_rate: Dropout probability applied in each encoder and before + the final linear layer. Default is 0.5. + stim_key: Key in the sample batch for the stimulation tensor. + Default is ``"X_stim"``. + recording_key: Key in the sample batch for the recording tensor. + Default is ``"X_recording"``. + pos_weight: Optional positive-class weight for + ``binary_cross_entropy_with_logits``. Default is ``None``. + **kwargs: Additional keyword arguments forwarded to each + :class:`SPESResponseEncoder` (e.g. ``embedding_dim``, + ``num_layers``, ``random_channels``). + + Examples: + >>> import numpy as np + >>> from pyhealth.datasets import create_sample_dataset + >>> n_ch, n_t = 30, 509 + >>> samples = [ + ... { + ... "patient_id": f"p{i}", + ... "visit_id": f"v{i}", + ... "X_stim": np.random.randn(2, n_ch, n_t).astype(np.float32), + ... "X_recording": np.random.randn(2, n_ch, n_t).astype(np.float32), + ... "electrode_lobes": np.array([i % 7], dtype=np.int64), + ... "electrode_coords": np.random.randn(3).astype(np.float32), + ... "soz": i % 2, + ... } + ... for i in range(4) + ... ] + >>> dataset = create_sample_dataset( + ... samples=samples, + ... input_schema={"X_stim": "tensor", "X_recording": "tensor", + ... "electrode_lobes": "tensor", "electrode_coords": "tensor"}, + ... output_schema={"soz": "binary"}, + ... dataset_name="test", + ... ) + >>> from pyhealth.models import SPESTransformer + >>> net_configs = [{"type": "divergent", "mean": True, "std": True}] + >>> model = SPESTransformer(dataset=dataset, net_configs=net_configs, + ... embedding_dim=32, random_channels=10) + >>> from pyhealth.datasets import get_dataloader + >>> loader = get_dataloader(dataset, batch_size=2, shuffle=False) + >>> batch = next(iter(loader)) + >>> ret = model(**batch) + """ + + def __init__( + self, + dataset: SampleDataset, + net_configs: list[dict], + dropout_rate: float = 0.5, + stim_key: str = "X_stim", + recording_key: str = "X_recording", + pos_weight: Optional[float] = None, + **kwargs, + ): + super().__init__(dataset=dataset) + if len(self.label_keys) != 1: + raise ValueError("SPESTransformer supports exactly one label key.") + + self.net_configs = net_configs + self.stim_key = stim_key + self.recording_key = recording_key + self.label_key = self.label_keys[0] + if pos_weight is None: + self.pos_weight = None + else: + self.register_buffer( + "pos_weight", + torch.tensor([pos_weight], dtype=torch.float32), + ) + + self.eegnets = nn.ModuleList( + [ + SPESResponseEncoder( + mean=net_config["mean"], + std=net_config["std"], + dropout_rate=dropout_rate, + **kwargs, + ) + for net_config in net_configs + ] + ) + + embedding_dim = kwargs.get("embedding_dim", 64) + total_feature_size = embedding_dim * len(net_configs) + self.fc = nn.Sequential( + nn.Dropout(dropout_rate), + nn.Linear(total_feature_size, self.get_output_size()), + ) + + # Explicitly initialize the classifier head + nn.init.xavier_uniform_(self.fc[1].weight) + if self.fc[1].bias is not None: + nn.init.zeros_(self.fc[1].bias) + + def _get_input(self, net_config: dict, kwargs: dict) -> torch.Tensor: + input_type = net_config["type"] + if input_type == "divergent": + return kwargs[self.stim_key].to(self.device) + if input_type == "convergent": + return kwargs[self.recording_key].to(self.device) + raise ValueError( + f"Invalid type '{input_type}' in net_configs; expected 'convergent' or 'divergent'." + ) + + def _compute_loss( + self, + logits: torch.Tensor, + y_true: torch.Tensor, + ) -> torch.Tensor: + if self.pos_weight is None: + return self.get_loss_function()(logits, y_true) + return F.binary_cross_entropy_with_logits( + logits, + y_true, + pos_weight=self.pos_weight.to(logits.device), + ) + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: + processed_inputs = [ + eegnet(self._get_input(net_config, kwargs)) + for net_config, eegnet in zip(self.net_configs, self.eegnets) + ] + emb = torch.cat(processed_inputs, dim=1) + logits = self.fc(emb) + + y_true = kwargs[self.label_key].to(self.device) + loss = self._compute_loss(logits, y_true) + y_prob = self.prepare_y_prob(logits) + + results = { + "loss": loss, + "y_prob": y_prob, + "y_true": y_true, + "logit": logits, + } + if kwargs.get("embed", False): + results["embed"] = emb + return results diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index a32618f9c..102e6c74e 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -30,6 +30,7 @@ LengthOfStayPredictionOMOP, ) from .length_of_stay_stagenet_mimic4 import LengthOfStayStageNetMIMIC4 +from .localize_soz import LocalizeSOZ from .medical_coding import MIMIC3ICD9Coding from .medical_transcriptions_classification import MedicalTranscriptionsClassification from .mortality_prediction import ( diff --git a/pyhealth/tasks/localize_soz.py b/pyhealth/tasks/localize_soz.py new file mode 100644 index 000000000..78b3820a8 --- /dev/null +++ b/pyhealth/tasks/localize_soz.py @@ -0,0 +1,701 @@ +import logging +import mne +import numpy as np +import pandas as pd +from typing import Any, Dict, List + +from pyhealth.tasks import BaseTask + +logger = logging.getLogger(__name__) + + +def _with_distance_first(df): + """Reorder DataFrame columns so 'distances' is first (position 0).""" + if 'distances' not in df.columns: + return df + return df[['distances'] + [c for c in df.columns if c != 'distances']] + + +def pad_and_stack(arrays, max_rows, pad_value=0): + """Pad each 2D array to max_rows and stack them into one array.""" + padded_arrays = [] + for array in arrays: + rows_to_add = max_rows - array.shape[0] + if rows_to_add > 0: + array = np.pad( + array, + ((0, rows_to_add), (0, 0)), + "constant", + constant_values=(pad_value,), + ) + padded_arrays.append(array) + + return np.stack(padded_arrays) + + +def process_stimulation_sites(site): + """Sort a bipolar stimulation site string so E2-E1 and E1-E2 match.""" + parts = site.split("-") + parts.sort() + return "-".join(parts) + + +def is_overlap(event, artifact): + """Return True if *event* and *artifact* overlap in sample space.""" + return ( + event["sample_start"] < artifact["sample_end"] + and artifact["sample_start"] < event["sample_end"] + ) + + +def combine_stats(group): + """Combine mean/std rows for duplicated stimulation-recording groups.""" + means = group[group["metric"] == "mean"].select_dtypes(include=[np.number]) + stds = group[group["metric"] == "std"].select_dtypes(include=[np.number]) + + n = 10 + if len(means) > 1 or len(stds) > 1: + total_samples = n * len(means) + combined_mean = means.mean() + combined_std = np.sqrt( + (stds**2).mean() + + ((means - combined_mean) ** 2).sum() / (total_samples - 1) + ) + else: + combined_mean = means.iloc[0] + combined_std = stds.iloc[0] + + mean_df = pd.DataFrame(combined_mean).transpose() + std_df = pd.DataFrame(combined_std).transpose() + + for col in ["subject", "recording", "stim_1", "stim_2"]: + mean_df[col] = group[col].iloc[0] + std_df[col] = group[col].iloc[0] + + mean_df["metric"] = "mean" + std_df["metric"] = "std" + + return mean_df, std_df + + +regions = ["Frontal", "Insula", "Limbic", "Occipital", "Parietal", "Temporal", "Unknown"] +region_to_index = {region: index for index, region in enumerate(regions)} + +# Destrieux atlas label-to-lobe mapping (labels 1–148, indexed by label-1). +# Left hemisphere (labels 1–74) followed by right hemisphere (labels 75–148). +# Source: https://github.com/norrisjamie23/Localising_SOZ_from_SPES/destrieux.rda (MIT licence) +_DESTRIEUX_LOBES = ( + "Limbic", "Limbic", "Limbic", "Frontal", "Occipital", "Parietal", "Parietal", "Frontal", + "Limbic", "Limbic", "Occipital", "Frontal", "Frontal", "Frontal", "Frontal", "Frontal", + "Insula", "Insula", "Occipital", "Occipital", "Occipital", "Occipital", "Temporal", "Frontal", + "Parietal", "Parietal", "Parietal", "Parietal", "Frontal", "Parietal", "Frontal", "Limbic", + "Temporal", "Temporal", "Temporal", "Temporal", "Temporal", "Temporal", "Frontal", "Frontal", + "Insula", "Occipital", "Temporal", "Occipital", "Frontal", "Limbic", "Insula", "Insula", + "Insula", "Temporal", "Occipital", "Frontal", "Frontal", "Frontal", "Parietal", "Parietal", + "Occipital", "Occipital", "Occipital", "Occipital", "Occipital", "Frontal", "Frontal", "Frontal", + "Parietal", "Limbic", "Parietal", "Frontal", "Frontal", "Frontal", "Parietal", "Temporal", + "Temporal", "Temporal", + "Limbic", "Limbic", "Limbic", "Frontal", "Occipital", "Parietal", "Parietal", "Frontal", + "Limbic", "Limbic", "Occipital", "Frontal", "Frontal", "Frontal", "Frontal", "Frontal", + "Insula", "Insula", "Occipital", "Occipital", "Occipital", "Occipital", "Temporal", "Frontal", + "Parietal", "Parietal", "Parietal", "Parietal", "Frontal", "Parietal", "Frontal", "Limbic", + "Temporal", "Temporal", "Temporal", "Temporal", "Temporal", "Temporal", "Frontal", "Frontal", + "Insula", "Occipital", "Temporal", "Occipital", "Frontal", "Limbic", "Insula", "Insula", + "Insula", "Temporal", "Occipital", "Frontal", "Frontal", "Frontal", "Parietal", "Parietal", + "Occipital", "Occipital", "Occipital", "Occipital", "Occipital", "Frontal", "Frontal", "Frontal", + "Parietal", "Limbic", "Parietal", "Frontal", "Frontal", "Frontal", "Parietal", "Temporal", + "Temporal", "Temporal", +) + + +def get_destrieux_lobe(label): + """Return the Destrieux lobe name for a numeric atlas label.""" + try: + is_missing = bool(np.isnan(label)) + except (TypeError, ValueError): + is_missing = label is None + + if is_missing or label == 0: + return "Unknown" + + try: + return _DESTRIEUX_LOBES[int(label) - 1] + except IndexError: + return "Unknown" + + +class StimulationDataProcessor: + """Extracts per-stimulation-site EEG epochs and computes mean/std responses. + + Args: + tmin: Start of the epoch window in seconds relative to the stimulus. + tmax: End of the epoch window in seconds relative to the stimulus. + """ + + def __init__(self, tmin, tmax): + """ + Args: + tmin: Start of the epoch window in seconds relative to the stimulus. + tmax: End of the epoch window in seconds relative to the stimulus. + """ + self.tmin = tmin + self.tmax = tmax + + def process_run_data(self, eeg, events_df, channels_df, subject): + """Process and extract stimulation response data for a single EEG run. + + Filters for valid electrical-stimulation events (excluding artifacts and + seizures), epochs the data around each stimulation site, and returns a + DataFrame with mean and std responses across trials. + + Args: + eeg: MNE Raw object for the run. + events_df: DataFrame of BIDS events (``*_events.tsv``). + channels_df: DataFrame of channel metadata (``*_channels.tsv``). + subject: Subject identifier string used in warning messages. + + Returns: + DataFrame with columns ``[subject, recording, stim_1, stim_2, + metric, ]``, or ``None`` if no valid stimulation + events were found. + """ + # Filter channels + chans_to_use = channels_df[channels_df.status_description == "included"].index.tolist() + eeg.pick(chans_to_use) + + # Filter EEG data + eeg.filter(1, 150, n_jobs=1, method='fir', fir_design='firwin') + + # Filter for electrical stimulation events + stim_events = events_df[events_df.trial_type == "electrical_stimulation"].copy() + + # Filter for events that occur within the EEG data + before = stim_events.shape[0] + stim_events = stim_events[stim_events['sample_start'] < eeg.n_times] + after = stim_events.shape[0] + + if before != after: + logger.warning("Subject %s: dropped %d out-of-bounds stim events (%d -> %d)", subject, before - after, before, after) + + # Filter for artifact events + artifacts = events_df[np.logical_and(events_df.trial_type == "artifact", + events_df.electrodes_involved_onset == "all")].copy() + + # Filter for seizure events + seizures = events_df[events_df.trial_type == "seizure"].copy() + + # Creating artifact mask + artifact_mask = [] + for _, stim_event in stim_events.iterrows(): + overlap = any(is_overlap(stim_event, artifact) for _, artifact in artifacts.iterrows()) + artifact_mask.append(overlap) + + # Creating the artmask + seizure_mask = [] + for _, stim_event in stim_events.iterrows(): + overlap = any(is_overlap(stim_event, seizure) for _, seizure in seizures.iterrows()) + seizure_mask.append(overlap) + + # Create mask of valid events + valid_mask = np.logical_not(np.logical_or(artifact_mask, seizure_mask)) + + # Filter for valid events + stim_events = stim_events[valid_mask] + + # Filter for artifact events occurring on only some channels + focal_artifacts = events_df[np.logical_and(events_df.trial_type == "artifact", + events_df.electrodes_involved_onset != "all")].copy() + + # Sort the 'electrical_stimulation_site' column - this ensures that E2-E1 and E1-E2 are treated the same + stim_events['electrical_stimulation_site'] = stim_events['electrical_stimulation_site'].apply(process_stimulation_sites) + + # Step 1: Convert 'electrical_stimulation_site' to a categorical datatype + stim_events['electrical_stimulation_site'] = stim_events['electrical_stimulation_site'].astype('category') + + # Get mapping of categories to integer codes for 'electrical_stimulation_site' + category_mapping = dict(enumerate(stim_events['electrical_stimulation_site'].cat.categories)) + + # Create a new column 'electrical_stimulation_site_cat' with integer codes for 'electrical_stimulation_site' + stim_events['electrical_stimulation_site_cat'] = stim_events['electrical_stimulation_site'].cat.codes + + # Step 2: Add a column 'zero_column' to the DataFrame with all values set to 0 + stim_events['zero_column'] = 0 + + # Step 3: Extract the columns 'sample_start', 'zero_column', and 'electrical_stimulation_site_cat' and convert to a NumPy array + result_array = stim_events[['sample_start', 'zero_column', 'electrical_stimulation_site_cat']].values + + # Creating the artmask + overlap_list = [] + + for _, stim_event in stim_events.iterrows(): + overlap_found = False + for _, artifact in focal_artifacts.iterrows(): + if is_overlap(stim_event, artifact): + overlap_list.append(artifact.electrodes_involved_onset) + overlap_found = True + break # Stop checking after the first overlap is found + if not overlap_found: + overlap_list.append(None) + overlap_list = np.array(overlap_list) + + response_dfs = [] + + for event_id in np.unique(result_array[:, 2]): + + # Chans to remove due to artifact + remove_chans = overlap_list[result_array[:, 2] == event_id] + remove_chans = np.unique([chan for chan_list in remove_chans if chan_list is not None for chan in chan_list.split(',')]) + + response_df = self._extract_epochs(eeg, + result_array, + event_id, + category_mapping, + subject, + remove_chans) + if response_df is not None: + response_dfs.append(response_df) + + try: + response_dfs = pd.concat(response_dfs) + except ValueError: + logger.warning("No stimulation events found for subject %s", subject) + return None + + return response_dfs + + def _extract_epochs(self, eeg, result_array, event_id, category_mapping, subject, remove_chans): + """Extract epochs for one stimulation site and return mean/std DataFrame. + + Args: + eeg: MNE Raw object for the run. + result_array: Array of shape ``(n_events, 3)`` with columns + ``[sample_start, zero, event_id]``. + event_id: Integer code identifying the stimulation site to epoch. + category_mapping: Dict mapping event_id codes to site strings. + subject: Subject identifier string inserted into the output DataFrame. + remove_chans: Array of channel names to exclude (focal artifacts). + + Returns: + DataFrame with mean and std rows for this stimulation site, or + ``None`` if fewer than 5 artifact-free trials remain. + """ + stimulated_electrodes = category_mapping[event_id].split('-') + + # When polarity is reversed, ensure no duplication + stimulated_electrodes.sort() + + # Get list of electrodes excluding stimulated electrodes + recording_channels = [chan for chan in eeg.info['ch_names'] if chan not in stimulated_electrodes] + + # Remove channels that are artifacted + recording_channels = [chan for chan in recording_channels if chan not in remove_chans] + + try: + epochs = mne.Epochs(eeg, result_array, event_id=event_id, tmin=self.tmin - 1 + , tmax=self.tmax, picks=recording_channels, preload=True, baseline=(None, -0.1)) + if len(epochs) < 5: + return None + epochs.crop(tmin=self.tmin) + epochs.resample(512) + except RuntimeError as e: + if "empty" in str(e).lower() or "epochs were dropped" in str(e).lower(): + return None + raise + + # If less than 5 trials, return None + if (epochs._data.shape[0]) < 5: + return None + + # Calculate mean and standard deviation + mean_response = epochs.average()._data # Shape: (channels, time steps) + std_response = epochs._data.std(axis=0) # Shape: (channels, time steps) + + # Create DataFrame for mean response + df_mean_response = pd.DataFrame(mean_response).astype('float32') + df_mean_response.insert(0, 'subject', [subject] * mean_response.shape[0]) + df_mean_response.insert(1, 'recording', epochs.info['ch_names']) + df_mean_response.insert(2, 'stim_1', stimulated_electrodes[0]) + df_mean_response.insert(3, 'stim_2', stimulated_electrodes[1]) + df_mean_response.insert(4, 'metric', 'mean') + + # Create DataFrame for std response + df_std_response = pd.DataFrame(std_response).astype('float32') + df_std_response.insert(0, 'subject', [subject] * std_response.shape[0]) + df_std_response.insert(1, 'recording', epochs.info['ch_names']) + df_std_response.insert(2, 'stim_1', stimulated_electrodes[0]) + df_std_response.insert(3, 'stim_2', stimulated_electrodes[1]) + df_std_response.insert(4, 'metric', 'std') + + # Concatenate the two DataFrames + response_df = pd.concat([df_mean_response, df_std_response], ignore_index=True) + + return response_df + + +class DatasetCreator: + """Converts per-run response DataFrames into analysis-ready arrays. + + Args: + response_df: DataFrame produced by :class:`StimulationDataProcessor` + containing mean and std CCEP responses for one or more subjects. + """ + + def __init__(self, response_df): + self.response_df = response_df + + def process_for_analysis(self, subject, electrodes_df): + """Build paired mean/std tensors and labels for all electrodes. + + Calls :meth:`process_metric_for_analysis` for both ``"mean"`` and + ``"std"`` metrics and aligns the results to channels present in both. + + Args: + subject: Subject identifier string. + electrodes_df: DataFrame indexed by electrode name with columns + ``x``, ``y``, ``z``, ``soz``, and ``Destrieux_label``. + + Returns: + A 6-tuple ``(channels, electrode_lobes, y, electrode_coords, + X_stim, X_recording)`` where ``X_stim`` and ``X_recording`` have + shape ``(n_electrodes, 2, n_trials, n_timesteps)``, or ``None`` + if no paired channels were found. + """ + mean_output = self.process_metric_for_analysis( + subject, electrodes_df, "mean", labels=True + ) + std_output = self.process_metric_for_analysis( + subject, electrodes_df, "std", labels=True + ) + if mean_output is None or std_output is None: + return None + + ( + mean_channels, + mean_lobes, + mean_y, + mean_coords, + mean_X_stim, + mean_X_recording, + ) = mean_output + ( + std_channels, + _std_lobes, + _std_y, + _std_coords, + std_X_stim, + std_X_recording, + ) = std_output + + std_channel_to_index = { + channel: index for index, channel in enumerate(std_channels) + } + common_channels = [ + channel for channel in mean_channels if channel in std_channel_to_index + ] + if len(common_channels) == 0: + logger.warning("No paired mean/std stimulation events found for subject %s", subject) + return None + + X_stim = [] + X_recording = [] + electrode_lobes = [] + electrode_coords = [] + y = [] + for mean_index, channel in enumerate(mean_channels): + if channel not in std_channel_to_index: + continue + std_index = std_channel_to_index[channel] + X_stim.append(np.stack([mean_X_stim[mean_index], std_X_stim[std_index]])) + X_recording.append( + np.stack([mean_X_recording[mean_index], std_X_recording[std_index]]) + ) + electrode_lobes.append(mean_lobes[mean_index]) + electrode_coords.append(mean_coords[mean_index]) + y.append(mean_y[mean_index]) + + return ( + common_channels, + electrode_lobes, + np.array(y, dtype=np.int32), + electrode_coords, + np.stack(X_stim).astype(np.float32), + np.stack(X_recording).astype(np.float32), + ) + + def process_metric_for_analysis(self, subject, electrodes_df, metric, labels=False): + """Build stimulus/recording tensors for one response metric. + + Computes Euclidean distances between stimulation and recording + electrode pairs, filters pairs closer than 13 mm, sorts by distance, + and stacks per-electrode arrays into padded tensors. + + Args: + subject: Subject identifier string. + electrodes_df: DataFrame indexed by electrode name with columns + ``x``, ``y``, ``z``, ``soz``, and ``Destrieux_label``. + metric: Either ``"mean"`` or ``"std"``. + labels: If ``True``, extract SOZ labels from ``electrodes_df``. + Default is ``False``. + + Returns: + A 6-tuple ``(channels, electrode_lobes, y, electrode_coords, + X_stim, X_recording)`` where ``X_stim`` and ``X_recording`` are + float32 arrays of shape ``(n_electrodes, n_trials, n_timesteps)``, + or ``None`` if no valid stimulation events were found or all SOZ + labels are negative. + """ + response_df = self.response_df[self.response_df.subject == subject] + response_df = response_df[response_df.metric == metric] + + try: + # Calculate stimulation and recording coordinates + stim_1_coords = np.array([ + [electrodes_df[electrodes_df.index == stimulated_electrode].x, + electrodes_df[electrodes_df.index == stimulated_electrode].y, + electrodes_df[electrodes_df.index == stimulated_electrode].z + ] for stimulated_electrode in response_df.stim_1]) + stim_2_coords = np.array([ + [electrodes_df[electrodes_df.index == stimulated_electrode].x, + electrodes_df[electrodes_df.index == stimulated_electrode].y, + electrodes_df[electrodes_df.index == stimulated_electrode].z + ] for stimulated_electrode in response_df.stim_2]) + stim_coords = (stim_1_coords + stim_2_coords) / 2 + + recording_coords = np.array([ + [electrodes_df[electrodes_df.index == stimulated_electrode].x, + electrodes_df[electrodes_df.index == stimulated_electrode].y, + electrodes_df[electrodes_df.index == stimulated_electrode].z + ] for stimulated_electrode in response_df.recording]) + + # Calculate Euclidean distance + distances = np.sqrt(np.sum((stim_coords - recording_coords) ** 2, axis=1))[:, 0] + + # Keep only distances greater than 13mm + response_df = response_df[distances > 13] + distances = distances[distances > 13] + + # Add distances to DataFrame + response_df['distances'] = distances + + # Sort by distances - used for CNN method + response_df = response_df.sort_values(by='distances', ascending=True) + + except Exception as e: + logger.warning("Error processing subject %s: %s", subject, e) + return None + + # Get channels used for both stimulation and recording + recording_stim_channels = sorted( + set(response_df.recording.unique()).intersection( + set(response_df.stim_2.unique()).union(set(response_df.stim_1.unique())) + ) + ) + + channels_recording_trials, channels_stim_trials = [], [] + + if labels: + channel_soz = [] + + if len(recording_stim_channels) == 0: + logger.warning("No stimulation events found for subject %s", subject) + return None + + electrode_coords = [] + electrode_lobes = [] + + for channel in recording_stim_channels: + + # Current channel responses when other channels were stimulated + channel_recording_trials = _with_distance_first( + response_df[response_df.recording == channel].select_dtypes(include='number').copy() + ) + + # Other channel responses when current channel was stimulated + channel_stim_trials = _with_distance_first( + response_df[np.logical_or(response_df.stim_1 == channel, + response_df.stim_2 == channel)].select_dtypes(include='number').copy() + ) + + # Add to corresponding lists, except for recording/stim channel names (i.e., only time series) + channels_recording_trials.append(np.array(channel_recording_trials)) + channels_stim_trials.append(np.array(channel_stim_trials)) + + if labels: + # Add label for current channel + channel_soz.append(electrodes_df[electrodes_df.index == channel].soz.iloc[0] == "yes") + + electrode_coords.append([electrodes_df[electrodes_df.index == channel].x.iloc[0], + electrodes_df[electrodes_df.index == channel].y.iloc[0], + electrodes_df[electrodes_df.index == channel].z.iloc[0]]) + + electrode_lobe = electrodes_df[electrodes_df.index == channel].Destrieux_label.values[0] + electrode_lobe = get_destrieux_lobe(electrode_lobe) + electrode_lobe = region_to_index[electrode_lobe] + electrode_lobes.append(electrode_lobe) + + # For recording channels + max_recording_rows = max(array.shape[0] for array in channels_recording_trials) + X_recording = pad_and_stack(channels_recording_trials, max_recording_rows).astype(np.float32) + + # For stim channels + max_stim_rows = max(array.shape[0] for array in channels_stim_trials) + X_stim = pad_and_stack(channels_stim_trials, max_stim_rows).astype(np.float32) + + if labels: + y = np.array(channel_soz, dtype=np.int32) + if y.sum() == 0: + return None + + return recording_stim_channels, electrode_lobes, y, electrode_coords, X_stim, X_recording + + +class LocalizeSOZ(BaseTask): + """Electrode-level seizure onset zone (SOZ) localization from CCEP ECoG. + + Orchestrates two preprocessing stages and produces one ML sample per + candidate electrode per recording run: + + **Stage 1 — :class:`StimulationDataProcessor`** + Reads a raw BrainVision EEG file and its associated BIDS metadata. + The signal is bandpass-filtered (1–150 Hz), then segmented into + epochs locked to each electrical stimulation event. Events + contaminated by whole-recording artifacts or seizures are excluded. + For each unique stimulation site (bipolar pair), epochs are averaged + and their standard deviation is computed across trials, yielding a + DataFrame of mean and std CCEP responses with one row per + recording-channel per stimulation site. + + **Stage 2 — :class:`DatasetCreator`** + Converts the response DataFrame into analysis-ready arrays. For each + electrode that appears both as a stimulation site and as a recording + channel, two tensors are built: + + - ``X_stim``: responses of *other* channels when *this* electrode was + stimulated (divergent / outgoing connectivity). + - ``X_recording``: responses of *this* channel when *other* electrodes + were stimulated (convergent / incoming connectivity). + + Stimulation–recording pairs closer than 13 mm are discarded to + reduce stimulation artifact contamination. Remaining trials are sorted + by Euclidean distance and zero-padded so all electrodes within a + batch share the same tensor shape. Lobe labels are looked up from the + Destrieux atlas via :func:`get_destrieux_lobe`. + + Each returned sample corresponds to one candidate electrode. Multiple + samples from the same patient share the same ``patient_id``; downstream + train/test splits must group by ``patient_id`` to avoid leakage across + electrodes from the same patient. + + Attributes: + task_name: ``"LocalizeSOZ"`` + input_schema: Tensor inputs — ``X_stim`` and ``X_recording`` + (mean/std CCEP responses), ``electrode_lobes`` (Destrieux lobe + index), and ``electrode_coords`` (MNI xyz coordinates). + output_schema: Binary SOZ label (``1`` = in SOZ, ``0`` = not in SOZ). + + Examples: + >>> from pyhealth.datasets import CCEPECoGDataset + >>> from pyhealth.tasks import LocalizeSOZ + >>> dataset = CCEPECoGDataset(root="/path/to/ds004080") + >>> task = LocalizeSOZ() + >>> samples = dataset.set_task(task) + """ + + task_name: str = "LocalizeSOZ" + input_schema: Dict[str, str] = { + "X_stim": "tensor", + "X_recording": "tensor", + "electrode_lobes": "tensor", + "electrode_coords": "tensor", + } + output_schema: Dict[str, str] = {"soz": "binary"} + + def __call__(self, patient) -> List[Dict[str, Any]]: + samples: List[Dict[str, Any]] = [] + + for split in ("ecog",): + try: + events = patient.get_events(split) + except (AttributeError, KeyError): + continue + + for event in events: + pid = patient.patient_id + try: + header_file = event.header_file + events_file = event.events_file + channels_file = event.channels_file + electrodes_file = event.electrodes_file + except AttributeError: + continue + + if not all( + [pid, header_file, events_file, channels_file, electrodes_file] + ): + continue + + try: + eeg = mne.io.read_raw_brainvision( + header_file, + verbose=False, + preload=True, + ) + events_df = pd.read_csv(events_file, sep="\t", index_col=0) + channels_df = pd.read_csv(channels_file, sep="\t", index_col=0) + electrodes_df = pd.read_csv(electrodes_file, sep="\t", index_col=0) + + stim_processor = StimulationDataProcessor(tmin=0.009, tmax=1) + response_df = stim_processor.process_run_data( + eeg, + events_df, + channels_df, + pid, + ) + if response_df is None: + continue + + dataset_creator = DatasetCreator(response_df) + processed = dataset_creator.process_for_analysis(pid, electrodes_df) + if processed is None: + continue + + ( + electrode_channels, + electrode_lobes, + y, + electrode_coords, + X_stim, + X_recording, + ) = processed + except (ValueError, KeyError, IndexError, FileNotFoundError, OSError): + continue + + for electrode_idx, channel in enumerate(electrode_channels): + electrode_id = f"{pid}-{event.session_id}-{event.run_id}-{channel}" + samples.append( + { + "patient_id": pid, + "visit_id": electrode_id, + "record_id": electrode_id, + "session_id": event.session_id, + "task_id": event.task_id, + "run_id": event.run_id, + "channel": channel, + "electrode_index": electrode_idx, + "header_file": header_file, + "events_file": events_file, + "channels_file": channels_file, + "electrodes_file": electrodes_file, + "soz": int(y[electrode_idx]), + "X_stim": X_stim[electrode_idx], + "X_recording": X_recording[electrode_idx], + "electrode_lobes": np.array( + [electrode_lobes[electrode_idx]], dtype=np.int64 + ), + "electrode_coords": np.array( + electrode_coords[electrode_idx], dtype=np.float32 + ), + } + ) + + return samples diff --git a/pyproject.toml b/pyproject.toml index 98f88d47b..31b55998e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "scikit-learn~=1.7.0", "networkx", "mne~=1.10.0", + "mne-bids>=0.18.0", "urllib3~=2.5.0", "numpy~=2.2.0", "tqdm", diff --git a/tests/core/test_ccep_ecog.py b/tests/core/test_ccep_ecog.py new file mode 100644 index 000000000..60b873f91 --- /dev/null +++ b/tests/core/test_ccep_ecog.py @@ -0,0 +1,344 @@ +from pathlib import Path +import shutil +import tempfile +from typing import List +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np +import pandas as pd + +from pyhealth.datasets import CCEPECoGDataset +from pyhealth.tasks.localize_soz import LocalizeSOZ + + +class _DummyEvent: + """Minimal stand-in for a BIDS ECoG run event.""" + + def __init__( + self, + header_file="dummy.vhdr", + events_file="dummy_events.tsv", + channels_file="dummy_channels.tsv", + electrodes_file="dummy_electrodes.tsv", + session_id="1", + task_id="ecog", + run_id="1", + ): + self.header_file = header_file + self.events_file = events_file + self.channels_file = channels_file + self.electrodes_file = electrodes_file + self.session_id = session_id + self.task_id = task_id + self.run_id = run_id + + +class _DummyPatient: + """Minimal stand-in for a PyHealth Patient.""" + + def __init__(self, patient_id: str, events: List[_DummyEvent]): + self.patient_id = patient_id + self._events = events + + def get_events(self, event_type=None) -> List[_DummyEvent]: + # Return events only for the "ecog" split; treat all others as empty + # so each event is processed exactly once. + if event_type in ("train", "eval"): + return [] + return self._events + + +class TestCCEPECoGDataset(unittest.TestCase): + """Tests for CCEPECoGDataset indexing and validation.""" + + def setUp(self): + """Generate a minimal synthetic BIDS directory.""" + self.temp_dir = tempfile.mkdtemp() + root = Path(self.temp_dir) + + for i in range(1, 3): + sub = f"{i:02d}" + session = "1" + task = "ecog" + run = "1" + + patient_dir = root / f"sub-{sub}" / f"ses-{session}" / "ieeg" + patient_dir.mkdir(parents=True, exist_ok=True) + prefix = f"sub-{sub}_ses-{session}_task-{task}_run-{run}" + + (patient_dir / f"{prefix}_ieeg.vhdr").write_text("Dummy VHDR file content") + + soz_vals = ["yes", "no"] if i == 1 else ["no", "no"] + pd.DataFrame({ + "name": ["PT01", "PT02"], + "x": [-45.2, -47.7], + "y": [-81.2, -80.2], + "z": [-1.6, 2.8], + "size": [4.2, 4.2], + "material": ["Platinum", "Platinum"], + "manufacturer": ["AdTech", "AdTech"], + "group": ["grid", "grid"], + "hemisphere": ["L", "L"], + "silicon": ["no", "no"], + "soz": soz_vals, + "resected": ["no", "no"], + "edge": ["no", "no"], + }).to_csv(patient_dir / f"sub-{sub}_ses-{session}_electrodes.tsv", sep="\t", index=False) + + pd.DataFrame({ + "name": ["PT01", "PT02"], + "type": ["ECOG", "ECOG"], + "units": ["µV", "µV"], + "low_cutoff": [232, 232], + "high_cutoff": [0.15, 0.15], + "reference": ["G2", "G2"], + "group": ["grid", "grid"], + "sampling_frequency": [512, 512], + "notch": ["n/a", "n/a"], + "status": ["good", "good"], + "status_description": ["included", "included"], + }).to_csv(patient_dir / f"{prefix}_channels.tsv", sep="\t", index=False) + + pd.DataFrame({ + "onset": [507.00585], + "duration": [0.23242], + "trial_type": ["artefact"], + "sub_type": ["n/a"], + "electrodes_involved_onset": ["all"], + "electrodes_involved_offset": ["all"], + "offset": [507.23828], + "sample_start": [259587], + "sample_end": [259706], + "electrical_stimulation_type": ["n/a"], + "electrical_stimulation_site": ["n/a"], + "electrical_stimulation_current": ["n/a"], + "electrical_stimulation_frequency": ["n/a"], + "electrical_stimulation_pulsewidth": ["n/a"], + "notes": ["n/a"], + }).to_csv(patient_dir / f"{prefix}_events.tsv", sep="\t", index=False) + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def test_index_data_positive_soz(self): + """CCEPECoGDataset correctly identifies positive SOZ cases.""" + dataset = CCEPECoGDataset(root=self.temp_dir, dev=True) + p1_events = dataset.get_patient("01").get_events() + self.assertEqual(len(p1_events), 1) + e1 = p1_events[0] + self.assertEqual(e1["session_id"], "1") + self.assertEqual(e1["task_id"], "ecog") + self.assertEqual(e1["run_id"], "1") + self.assertEqual(str(e1["has_soz"]), "True") + self.assertIn("sub-01_ses-1_task-ecog_run-1_ieeg.vhdr", e1["header_file"]) + self.assertIn("sub-01_ses-1_electrodes.tsv", e1["electrodes_file"]) + self.assertIn("sub-01_ses-1_task-ecog_run-1_channels.tsv", e1["channels_file"]) + self.assertIn("sub-01_ses-1_task-ecog_run-1_events.tsv", e1["events_file"]) + + def test_index_data_negative_soz(self): + """CCEPECoGDataset correctly identifies negative SOZ cases.""" + dataset = CCEPECoGDataset(root=self.temp_dir, dev=True) + p2_events = dataset.get_patient("02").get_events() + self.assertEqual(len(p2_events), 1) + e2 = p2_events[0] + self.assertEqual(e2["session_id"], "1") + self.assertEqual(e2["task_id"], "ecog") + self.assertEqual(e2["run_id"], "1") + self.assertEqual(str(e2["has_soz"]), "False") + self.assertIn("sub-02_ses-1_task-ecog_run-1_ieeg.vhdr", e2["header_file"]) + self.assertIn("sub-02_ses-1_electrodes.tsv", e2["electrodes_file"]) + self.assertIn("sub-02_ses-1_task-ecog_run-1_channels.tsv", e2["channels_file"]) + self.assertIn("sub-02_ses-1_task-ecog_run-1_events.tsv", e2["events_file"]) + + def test_default_task(self): + """CCEPECoGDataset.default_task returns a LocalizeSOZ instance.""" + dataset = CCEPECoGDataset(root=self.temp_dir, dev=True) + self.assertIsInstance(dataset.default_task, LocalizeSOZ) + + def test_verify_data_no_root(self): + """_verify_data raises FileNotFoundError for a non-existent root.""" + with self.assertRaises(FileNotFoundError): + CCEPECoGDataset(root="/tmp/non_existent_bids_root") + + def test_verify_data_no_subjects(self): + """_verify_data raises ValueError for an empty BIDS root.""" + with tempfile.TemporaryDirectory() as bad_dir: + with self.assertRaisesRegex(ValueError, "contains no 'sub-\\*' subject folders"): + CCEPECoGDataset(root=bad_dir) + + def test_verify_data_no_vhdr(self): + """_verify_data raises ValueError when .vhdr recordings are missing.""" + with tempfile.TemporaryDirectory() as bad_dir: + (Path(bad_dir) / "sub-01" / "ses-1" / "ieeg").mkdir(parents=True) + with self.assertRaisesRegex(ValueError, "contains no '.vhdr' files"): + CCEPECoGDataset(root=bad_dir) + + def test_verify_data_no_electrodes(self): + """_verify_data raises ValueError when electrodes.tsv is missing.""" + with tempfile.TemporaryDirectory() as temp_dir: + patient_dir = Path(temp_dir) / "sub-01" / "ses-1" / "ieeg" + patient_dir.mkdir(parents=True, exist_ok=True) + (patient_dir / "sub-01_ses-1_task-ecog_run-1_ieeg.vhdr").write_text("dummy") + with self.assertRaisesRegex(ValueError, "contains no 'electrodes.tsv' file"): + CCEPECoGDataset(root=temp_dir) + + def test_verify_data_no_channels(self): + """_verify_data raises ValueError when channels.tsv is missing.""" + with tempfile.TemporaryDirectory() as temp_dir: + patient_dir = Path(temp_dir) / "sub-01" / "ses-1" / "ieeg" + patient_dir.mkdir(parents=True, exist_ok=True) + (patient_dir / "sub-01_ses-1_task-ecog_run-1_ieeg.vhdr").write_text("dummy") + (patient_dir / "sub-01_ses-1_electrodes.tsv").write_text("dummy") + with self.assertRaisesRegex(ValueError, "contains no 'channels.tsv' files"): + CCEPECoGDataset(root=temp_dir) + + def test_verify_data_no_events(self): + """_verify_data raises ValueError when events.tsv is missing.""" + with tempfile.TemporaryDirectory() as temp_dir: + patient_dir = Path(temp_dir) / "sub-01" / "ses-1" / "ieeg" + patient_dir.mkdir(parents=True, exist_ok=True) + (patient_dir / "sub-01_ses-1_task-ecog_run-1_ieeg.vhdr").write_text("dummy") + (patient_dir / "sub-01_ses-1_electrodes.tsv").write_text("dummy") + (patient_dir / "sub-01_ses-1_task-ecog_run-1_channels.tsv").write_text("dummy") + with self.assertRaisesRegex(ValueError, "contains no 'events.tsv' files"): + CCEPECoGDataset(root=temp_dir) + + +class TestLocalizeSOZ(unittest.TestCase): + """Tests for the LocalizeSOZ task class.""" + + _EXPECTED_SAMPLE_KEYS = { + "patient_id", "visit_id", "record_id", "session_id", "task_id", + "run_id", "channel", "electrode_index", "header_file", "events_file", + "channels_file", "electrodes_file", "soz", "X_stim", "X_recording", + "electrode_lobes", "electrode_coords", + } + + @staticmethod + def _make_processed(n: int = 2): + """Return a fake process_for_analysis result with *n* electrodes.""" + channels = [f"CH{i:02d}" for i in range(n)] + lobes = list(range(n)) + y = np.zeros(n, dtype=np.int32) + if n > 0: + y[0] = 1 + coords = [[-45.2, -81.2, -1.6]] * n + X_stim = np.zeros((n, 2, 5, 10), dtype=np.float32) + X_recording = np.zeros((n, 2, 8, 10), dtype=np.float32) + return channels, lobes, y, coords, X_stim, X_recording + + def test_task_name(self): + self.assertEqual(LocalizeSOZ().task_name, "LocalizeSOZ") + + def test_input_schema_keys_and_types(self): + schema = LocalizeSOZ().input_schema + for key in ("X_stim", "X_recording", "electrode_lobes", "electrode_coords"): + with self.subTest(key=key): + self.assertIn(key, schema) + self.assertEqual(schema[key], "tensor") + + def test_output_schema(self): + schema = LocalizeSOZ().output_schema + self.assertIn("soz", schema) + self.assertEqual(schema["soz"], "binary") + + def test_call_empty_patient_returns_empty_list(self): + """Patient with no events returns an empty sample list.""" + samples = LocalizeSOZ()(_DummyPatient("sub-01", [])) + self.assertEqual(samples, []) + + @patch("pyhealth.tasks.localize_soz.DatasetCreator.process_for_analysis") + @patch("pyhealth.tasks.localize_soz.StimulationDataProcessor.process_run_data") + @patch("pyhealth.tasks.localize_soz.pd.read_csv") + @patch("pyhealth.tasks.localize_soz.mne.io.read_raw_brainvision") + def test_call_sample_count_matches_electrode_count(self, mock_eeg, mock_csv, mock_proc, mock_analysis): + """One sample is emitted per electrode.""" + mock_eeg.return_value = MagicMock() + mock_csv.return_value = pd.DataFrame() + mock_proc.return_value = pd.DataFrame() + mock_analysis.return_value = self._make_processed(n=3) + + samples = LocalizeSOZ()(_DummyPatient("sub-01", [_DummyEvent()])) + self.assertEqual(len(samples), 3) + + @patch("pyhealth.tasks.localize_soz.DatasetCreator.process_for_analysis") + @patch("pyhealth.tasks.localize_soz.StimulationDataProcessor.process_run_data") + @patch("pyhealth.tasks.localize_soz.pd.read_csv") + @patch("pyhealth.tasks.localize_soz.mne.io.read_raw_brainvision") + def test_call_sample_has_all_expected_keys(self, mock_eeg, mock_csv, mock_proc, mock_analysis): + """Every sample dict contains exactly the expected keys.""" + mock_eeg.return_value = MagicMock() + mock_csv.return_value = pd.DataFrame() + mock_proc.return_value = pd.DataFrame() + mock_analysis.return_value = self._make_processed(n=2) + + samples = LocalizeSOZ()(_DummyPatient("sub-01", [_DummyEvent()])) + self.assertEqual(set(samples[0].keys()), self._EXPECTED_SAMPLE_KEYS) + + @patch("pyhealth.tasks.localize_soz.DatasetCreator.process_for_analysis") + @patch("pyhealth.tasks.localize_soz.StimulationDataProcessor.process_run_data") + @patch("pyhealth.tasks.localize_soz.pd.read_csv") + @patch("pyhealth.tasks.localize_soz.mne.io.read_raw_brainvision") + def test_call_soz_labels_match_y_array(self, mock_eeg, mock_csv, mock_proc, mock_analysis): + """soz labels in samples exactly match the y array from processing.""" + mock_eeg.return_value = MagicMock() + mock_csv.return_value = pd.DataFrame() + mock_proc.return_value = pd.DataFrame() + channels, lobes, _, coords, X_stim, X_recording = self._make_processed(n=2) + y = np.array([1, 0], dtype=np.int32) + mock_analysis.return_value = (channels, lobes, y, coords, X_stim, X_recording) + + samples = LocalizeSOZ()(_DummyPatient("sub-01", [_DummyEvent()])) + self.assertEqual(samples[0]["soz"], 1) + self.assertEqual(samples[1]["soz"], 0) + + @patch("pyhealth.tasks.localize_soz.DatasetCreator.process_for_analysis") + @patch("pyhealth.tasks.localize_soz.StimulationDataProcessor.process_run_data") + @patch("pyhealth.tasks.localize_soz.pd.read_csv") + @patch("pyhealth.tasks.localize_soz.mne.io.read_raw_brainvision") + def test_call_visit_id_format(self, mock_eeg, mock_csv, mock_proc, mock_analysis): + """visit_id follows the {pid}-{session}-{run}-{channel} format.""" + mock_eeg.return_value = MagicMock() + mock_csv.return_value = pd.DataFrame() + mock_proc.return_value = pd.DataFrame() + channels, lobes, y, coords, X_stim, X_recording = self._make_processed(n=1) + channels = ["PT01"] + mock_analysis.return_value = (channels, lobes, y, coords, X_stim, X_recording) + + samples = LocalizeSOZ()(_DummyPatient("sub-01", [_DummyEvent(session_id="2", run_id="3")])) + self.assertEqual(samples[0]["visit_id"], "sub-01-2-3-PT01") + + @patch("pyhealth.tasks.localize_soz.DatasetCreator.process_for_analysis") + @patch("pyhealth.tasks.localize_soz.StimulationDataProcessor.process_run_data") + @patch("pyhealth.tasks.localize_soz.pd.read_csv") + @patch("pyhealth.tasks.localize_soz.mne.io.read_raw_brainvision") + def test_call_process_run_data_none_skips_event(self, mock_eeg, mock_csv, mock_proc, mock_analysis): + """When process_run_data returns None the event is skipped entirely.""" + mock_eeg.return_value = MagicMock() + mock_csv.return_value = pd.DataFrame() + mock_proc.return_value = None + + samples = LocalizeSOZ()(_DummyPatient("sub-01", [_DummyEvent()])) + self.assertEqual(samples, []) + mock_analysis.assert_not_called() + + @patch("pyhealth.tasks.localize_soz.DatasetCreator.process_for_analysis") + @patch("pyhealth.tasks.localize_soz.StimulationDataProcessor.process_run_data") + @patch("pyhealth.tasks.localize_soz.pd.read_csv") + @patch("pyhealth.tasks.localize_soz.mne.io.read_raw_brainvision") + def test_call_multiple_events_aggregate(self, mock_eeg, mock_csv, mock_proc, mock_analysis): + """Samples accumulate across multiple events for one patient.""" + mock_eeg.return_value = MagicMock() + mock_csv.return_value = pd.DataFrame() + mock_proc.return_value = pd.DataFrame() + mock_analysis.return_value = self._make_processed(n=2) + + events = [_DummyEvent(run_id="1"), _DummyEvent(run_id="2")] + samples = LocalizeSOZ()(_DummyPatient("sub-01", events)) + # 2 electrodes × 2 events = 4 samples + self.assertEqual(len(samples), 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_spes.py b/tests/core/test_spes.py new file mode 100644 index 000000000..56f99deeb --- /dev/null +++ b/tests/core/test_spes.py @@ -0,0 +1,300 @@ +import unittest + +import numpy as np +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import SPESResNet, SPESTransformer +from pyhealth.models.spes import MultiScaleResNet1D, SPESResponseEncoder + + +_N_STIM_CH = 8 +_N_REC_CH = 10 +_N_TIMESTEPS = 33 + + +def _make_dataset(n_samples: int = 4, seed: int = 7): + """Return a minimal synthetic SPES SampleDataset.""" + rng = np.random.default_rng(seed) + samples = [] + for i in range(n_samples): + x_stim = rng.normal(size=(2, _N_STIM_CH, _N_TIMESTEPS)).astype(np.float32) + x_rec = rng.normal(size=(2, _N_REC_CH, _N_TIMESTEPS)).astype(np.float32) + # Column 0 of mode 0 stores Euclidean distance; 0 means padded/invalid. + x_stim[:, :, 0] = [10, 20, 0, 30, 40, 0, 50, 60] + x_rec[:, :, 0] = [11, 0, 21, 31, 0, 41, 51, 61, 0, 71] + samples.append({ + "patient_id": f"p{i}", + "visit_id": f"v{i}", + "X_stim": x_stim, + "X_recording": x_rec, + "electrode_lobes": np.array([i % 7], dtype=np.int64), + "electrode_coords": rng.normal(size=3).astype(np.float32), + "soz": i % 2, + }) + return create_sample_dataset( + samples=samples, + input_schema={ + "X_stim": "tensor", + "X_recording": "tensor", + "electrode_lobes": "tensor", + "electrode_coords": "tensor", + }, + output_schema={"soz": "binary"}, + dataset_name="test_spes", + ) + + +class TestMultiScaleResNet1D(unittest.TestCase): + """Unit tests for the MultiScaleResNet1D backbone.""" + + def test_output_dim_constant(self): + """output_dim class attribute equals 256 * 3 = 768.""" + self.assertEqual(MultiScaleResNet1D.output_dim, 768) + + def test_output_shape_single_channel(self): + """Single-channel input produces (batch, output_dim) embedding.""" + model = MultiScaleResNet1D(input_channel=1) + with torch.no_grad(): + out = model(torch.randn(2, 1, 128)) + self.assertEqual(out.shape, (2, MultiScaleResNet1D.output_dim)) + + def test_output_shape_two_channels(self): + """Two-channel input produces (batch, output_dim) embedding.""" + model = MultiScaleResNet1D(input_channel=2) + with torch.no_grad(): + out = model(torch.randn(3, 2, 64)) + self.assertEqual(out.shape, (3, MultiScaleResNet1D.output_dim)) + + def test_short_signal_does_not_crash(self): + """Very short signals are handled without error in eval mode.""" + model = MultiScaleResNet1D(input_channel=1) + model.eval() + with torch.no_grad(): + out = model(torch.randn(1, 1, 16)) + self.assertEqual(out.shape[1], MultiScaleResNet1D.output_dim) + + +class TestSPESResponseEncoder(unittest.TestCase): + """Unit tests for the SPESResponseEncoder.""" + + def _enc(self, **kwargs): + defaults = dict( + mean=True, std=True, embedding_dim=16, num_layers=1, + dropout_rate=0.0, max_mlp_timesteps=16, expected_timesteps=32, + ) + defaults.update(kwargs) + return SPESResponseEncoder(**defaults) + + def _input(self, batch: int = 2): + x = torch.randn(batch, 2, _N_STIM_CH, _N_TIMESTEPS) + x[:, 0, :, 0] = torch.tensor([10, 20, 0, 30, 40, 0, 50, 60]) + return x + + def test_raises_if_neither_mean_nor_std(self): + """At least one of mean/std must be enabled.""" + with self.assertRaises(ValueError): + SPESResponseEncoder(mean=False, std=False) + + def test_output_shape_mean_and_std(self): + enc = self._enc(mean=True, std=True) + enc.eval() + with torch.no_grad(): + out = enc(self._input()) + self.assertEqual(out.shape, (2, 16)) + + def test_output_shape_mean_only(self): + enc = self._enc(mean=True, std=False) + enc.eval() + with torch.no_grad(): + out = enc(self._input()) + self.assertEqual(out.shape, (2, 16)) + + def test_conv_embedding_false(self): + """MLP-only path (no ResNet) produces correct output shape.""" + enc = self._enc(conv_embedding=False, expected_timesteps=32) + enc.eval() + with torch.no_grad(): + out = enc(self._input()) + self.assertEqual(out.shape, (2, 16)) + + def test_mlp_embedding_false(self): + """ResNet-only path (no MLP prefix) produces correct output shape.""" + enc = self._enc(mlp_embedding=False) + enc.eval() + with torch.no_grad(): + out = enc(self._input()) + self.assertEqual(out.shape, (2, 16)) + + def test_random_channels(self): + """random_channels sub-sampling produces correct output shape.""" + enc = self._enc(random_channels=4) + enc.eval() + with torch.no_grad(): + out = enc(self._input()) + self.assertEqual(out.shape, (2, 16)) + + def test_no_valid_channels_raises(self): + """All-zero distance column (all padding) raises ValueError.""" + enc = self._enc(random_channels=4) + x = torch.randn(1, 2, _N_STIM_CH, _N_TIMESTEPS) + x[:, 0, :, 0] = 0 + with self.assertRaises(ValueError): + enc(x) + + +class TestSPESResNet(unittest.TestCase): + """Tests for the SPESResNet model.""" + + @classmethod + def setUpClass(cls): + cls.dataset = _make_dataset() + cls.batch = next(iter(get_dataloader(cls.dataset, batch_size=4, shuffle=False))) + + def _model(self, **kwargs): + defaults = dict(dataset=self.dataset, input_channels=6, noise_std=0.0) + defaults.update(kwargs) + return SPESResNet(**defaults) + + def test_invalid_input_type_raises(self): + with self.assertRaises(ValueError): + SPESResNet(dataset=self.dataset, input_type="invalid") + + def test_divergent_output_keys(self): + """Forward pass returns loss, y_prob, y_true, and logit.""" + model = self._model(input_type="divergent") + model.eval() + with torch.no_grad(): + ret = model(**self.batch) + for key in ("loss", "y_prob", "y_true", "logit"): + with self.subTest(key=key): + self.assertIn(key, ret) + + def test_divergent_output_shapes(self): + model = self._model(input_type="divergent") + model.eval() + with torch.no_grad(): + ret = model(**self.batch) + self.assertEqual(ret["logit"].shape, (4, 1)) + self.assertEqual(ret["y_prob"].shape, (4, 1)) + self.assertEqual(ret["loss"].dim(), 0) + + def test_backward_gradients_flow(self): + model = self._model() + ret = model(**self.batch) + ret["loss"].backward() + self.assertTrue(any( + p.requires_grad and p.grad is not None for p in model.parameters() + )) + + def test_embed_returned_when_requested(self): + """embed=True adds an 'embed' key with shape (batch, output_dim).""" + model = self._model() + model.eval() + with torch.no_grad(): + ret = model(**self.batch, embed=True) + self.assertIn("embed", ret) + self.assertEqual(ret["embed"].shape, (4, MultiScaleResNet1D.output_dim)) + + def test_training_mode_runs(self): + """Training mode (noise injection + channel dropout) does not crash.""" + model = self._model(noise_std=0.1) + model.train() + ret = model(**self.batch) + self.assertEqual(ret["loss"].dim(), 0) + + def test_all_padding_channels_raises(self): + """Batch where all distance values are zero raises ValueError.""" + model = self._model() + bad_batch = {k: v.clone() if isinstance(v, torch.Tensor) else v + for k, v in self.batch.items()} + bad_batch["X_stim"][:, 0, :, 0] = 0 + with self.assertRaises(ValueError): + model(**bad_batch) + + +class TestSPESTransformer(unittest.TestCase): + """Tests for the SPESTransformer model.""" + + @classmethod + def setUpClass(cls): + cls.dataset = _make_dataset() + cls.batch = next(iter(get_dataloader(cls.dataset, batch_size=4, shuffle=False))) + + def _model(self, net_configs=None, **kwargs): + defaults = dict( + dataset=self.dataset, + net_configs=net_configs or [{"type": "divergent", "mean": True, "std": True}], + embedding_dim=16, + num_layers=1, + dropout_rate=0.0, + noise_std=0.0, + max_mlp_timesteps=16, + expected_timesteps=32, + ) + defaults.update(kwargs) + return SPESTransformer(**defaults) + + def test_output_keys(self): + """Forward pass returns loss, y_prob, y_true, and logit.""" + model = self._model() + model.eval() + with torch.no_grad(): + ret = model(**self.batch) + for key in ("loss", "y_prob", "y_true", "logit"): + with self.subTest(key=key): + self.assertIn(key, ret) + + def test_output_shapes(self): + model = self._model() + model.eval() + with torch.no_grad(): + ret = model(**self.batch) + self.assertEqual(ret["logit"].shape, (4, 1)) + self.assertEqual(ret["y_prob"].shape, (4, 1)) + self.assertEqual(ret["loss"].dim(), 0) + + def test_backward_gradients_flow(self): + model = self._model() + ret = model(**self.batch) + ret["loss"].backward() + self.assertTrue(any( + p.requires_grad and p.grad is not None for p in model.parameters() + )) + + def test_multiple_net_configs(self): + """Two encoders (convergent + divergent) concatenate without error.""" + model = self._model(net_configs=[ + {"type": "convergent", "mean": True, "std": True}, + {"type": "divergent", "mean": True, "std": False}, + ]) + model.eval() + with torch.no_grad(): + ret = model(**self.batch) + self.assertEqual(ret["logit"].shape, (4, 1)) + + def test_embed_returned_when_requested(self): + """embed=True adds an 'embed' key with shape (batch, total_embedding_dim).""" + model = self._model() + model.eval() + with torch.no_grad(): + ret = model(**self.batch, embed=True) + self.assertIn("embed", ret) + self.assertEqual(ret["embed"].shape, (4, 16)) + + def test_invalid_net_config_type_raises(self): + """Unrecognised input type in net_configs raises ValueError.""" + model = self._model(net_configs=[{"type": "invalid", "mean": True, "std": True}]) + with self.assertRaises(ValueError): + model(**self.batch) + + def test_training_mode_runs(self): + """Training mode (noise injection + channel dropout) does not crash.""" + model = self._model(noise_std=0.1) + model.train() + ret = model(**self.batch) + self.assertEqual(ret["loss"].dim(), 0) + + +if __name__ == "__main__": + unittest.main()