Skip to content

Fix #16: keep_all_cls_pred with heterogeneous head sizes#17

Closed
jkobject wants to merge 1 commit into
mainfrom
fix/keep-all-cls-pred
Closed

Fix #16: keep_all_cls_pred with heterogeneous head sizes#17
jkobject wants to merge 1 commit into
mainfrom
fix/keep-all-cls-pred

Conversation

@jkobject
Copy link
Copy Markdown
Owner

Fixes #16. Three related bugs in the Embedder path when keep_all_cls_pred=True with multiple labels in pred_embedding (e.g. cell_type_ontology_term_id + disease_ontology_term_id).

Bug 1 — torch.stack with mismatched head sizes

scprint/model/model.py:_predict stacked per-class logits along a new dim, but heads have different n_classes (424 vs 62 in the report), so:

RuntimeError: stack expects each tensor to be equal size,
but got [32, 424] at entry 0 and [32, 62] at entry 1

Fix: when keep_all_cls_pred=True, store per-head logits in a dict[clsname -> tensor] and concatenate per head across batches (instead of stacking). The argmax branch (keep_all_cls_pred=False) is unchanged.

Also patched on_validation_epoch_end to all_gather the dict shape.

Bug 2 — CUDA tensor in pd.DataFrame

scprint/tasks/cell_emb.py wrapped a CUDA tensor in pd.DataFrame

TypeError: can't convert cuda:0 device type tensor to numpy.

Fix: move to CPU + numpy before the DataFrame call.

Bug 3 — pd.concat positional args

adata.obs = pd.concat(adata.obs, allclspred) is a positional bug — the second arg is interpreted as axis.
Fix: pd.concat([adata.obs, allclspred], axis=1).

Test plan

  • AST-checked.
  • Functional smoke test requires a checkpoint + GPU; would appreciate @Prachi-Priyam re-running their reproducer with this branch installed:
    pip install git+https://github.com/jkobject/scPRINT.git@fix/keep-all-cls-pred

Output shape note

With keep_all_cls_pred=True, all per-head probability columns are concatenated into adata.obs (one column per class, named via label_decoders). For 424 + 62 classes this is ~486 columns — that's intentional but worth flagging.

Bumped to 1.1.4.

CC @Prachi-Priyam — thanks for the very detailed report (root cause + proposed fix locations were spot on).

Three related bugs reported in #16:

1. scPrint._predict: with keep_all_cls_pred=True, the code stacked the
   per-class logits with torch.stack, but classification heads have
   different n_classes (e.g. 424 for cell_type vs 62 for disease), so
   stacking raises 'stack expects each tensor to be equal size'. Store
   per-class logits in a dict {clsname: tensor} instead, and concatenate
   per head across batches.

2. on_validation_epoch_end: handle the dict shape when all_gather'ing
   self.pred.

3. Embedder.__call__ (cell_emb.py):
   - move logits to CPU/numpy before wrapping in pd.DataFrame (CUDA
     tensors cannot be converted directly).
   - fix pd.concat(adata.obs, allclspred) -> pd.concat([adata.obs,
     allclspred], axis=1) (positional misuse, second arg was being
     interpreted as 'axis').

The keep_all_cls_pred=False (argmax) path is unchanged.

Co-authored-by: Prachi-Priyam <noreply@github.com>
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes failures in the embedding/prediction path when keep_all_cls_pred=True and multiple classification heads have heterogeneous numbers of classes, ensuring per-head outputs can be collected and written into adata.obs.

Changes:

  • Store per-head classification logits as a dict[clsname -> tensor] instead of stacking when keep_all_cls_pred=True, and concatenate per head across batches.
  • Update validation epoch end gathering logic to support dict-shaped predictions.
  • Fix pandas integration for per-class predictions by moving tensors to CPU/NumPy and correcting pd.concat usage; bump version to 1.1.4.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
scprint/tasks/cell_emb.py Builds per-head DataFrames from model.pred (dict), ensures CPU/NumPy conversion, and correctly concatenates into adata.obs.
scprint/model/model.py Changes _predict to store logits per head in a dict when keep_all_cls_pred=True and adjusts on_validation_epoch_end to all_gather dict values.
pyproject.toml Bumps package version to 1.1.4.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread scprint/tasks/cell_emb.py
Comment on lines 237 to +253
if self.keep_all_cls_pred:
allclspred = model.pred
columns = []
# model.pred is a dict[clsname -> tensor[n_cells, n_classes_cl]]
# (heads have different n_classes), so concatenate per head.
dfs = []
for cl in model.classes:
n = model.label_counts[cl]
columns += [model.label_decoders[cl][i] for i in range(n)]
allclspred = pd.DataFrame(
allclspred, columns=columns, index=adata.obs.index
)
adata.obs = pd.concat(adata.obs, allclspred)
columns = [model.label_decoders[cl][i] for i in range(n)]
tensor = model.pred[cl]
if hasattr(tensor, "detach"):
tensor = tensor.detach().cpu().numpy()
dfs.append(
pd.DataFrame(
tensor, columns=columns, index=adata.obs.index
)
)
allclspred = pd.concat(dfs, axis=1)
adata.obs = pd.concat([adata.obs, allclspred], axis=1)
@jkobject
Copy link
Copy Markdown
Owner Author

Closing in favour of cantinilab#52cantinilab/scPRINT is the canonical user-facing repo, and jkobject/scPRINT/main is stale relative to it. Same fix, rebased onto cantinilab/scPRINT:main.

@jkobject jkobject closed this May 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RuntimeError: stack expects each tensor to be equal size during multi-task prediction when keep_all_cls_pred=True

2 participants