Fix #16: keep_all_cls_pred with heterogeneous head sizes#17
Closed
jkobject wants to merge 1 commit into
Closed
Conversation
Three related bugs reported in #16: 1. scPrint._predict: with keep_all_cls_pred=True, the code stacked the per-class logits with torch.stack, but classification heads have different n_classes (e.g. 424 for cell_type vs 62 for disease), so stacking raises 'stack expects each tensor to be equal size'. Store per-class logits in a dict {clsname: tensor} instead, and concatenate per head across batches. 2. on_validation_epoch_end: handle the dict shape when all_gather'ing self.pred. 3. Embedder.__call__ (cell_emb.py): - move logits to CPU/numpy before wrapping in pd.DataFrame (CUDA tensors cannot be converted directly). - fix pd.concat(adata.obs, allclspred) -> pd.concat([adata.obs, allclspred], axis=1) (positional misuse, second arg was being interpreted as 'axis'). The keep_all_cls_pred=False (argmax) path is unchanged. Co-authored-by: Prachi-Priyam <noreply@github.com>
There was a problem hiding this comment.
Pull request overview
Fixes failures in the embedding/prediction path when keep_all_cls_pred=True and multiple classification heads have heterogeneous numbers of classes, ensuring per-head outputs can be collected and written into adata.obs.
Changes:
- Store per-head classification logits as a
dict[clsname -> tensor]instead of stacking whenkeep_all_cls_pred=True, and concatenate per head across batches. - Update validation epoch end gathering logic to support dict-shaped predictions.
- Fix pandas integration for per-class predictions by moving tensors to CPU/NumPy and correcting
pd.concatusage; bump version to1.1.4.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
scprint/tasks/cell_emb.py |
Builds per-head DataFrames from model.pred (dict), ensures CPU/NumPy conversion, and correctly concatenates into adata.obs. |
scprint/model/model.py |
Changes _predict to store logits per head in a dict when keep_all_cls_pred=True and adjusts on_validation_epoch_end to all_gather dict values. |
pyproject.toml |
Bumps package version to 1.1.4. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
237
to
+253
| if self.keep_all_cls_pred: | ||
| allclspred = model.pred | ||
| columns = [] | ||
| # model.pred is a dict[clsname -> tensor[n_cells, n_classes_cl]] | ||
| # (heads have different n_classes), so concatenate per head. | ||
| dfs = [] | ||
| for cl in model.classes: | ||
| n = model.label_counts[cl] | ||
| columns += [model.label_decoders[cl][i] for i in range(n)] | ||
| allclspred = pd.DataFrame( | ||
| allclspred, columns=columns, index=adata.obs.index | ||
| ) | ||
| adata.obs = pd.concat(adata.obs, allclspred) | ||
| columns = [model.label_decoders[cl][i] for i in range(n)] | ||
| tensor = model.pred[cl] | ||
| if hasattr(tensor, "detach"): | ||
| tensor = tensor.detach().cpu().numpy() | ||
| dfs.append( | ||
| pd.DataFrame( | ||
| tensor, columns=columns, index=adata.obs.index | ||
| ) | ||
| ) | ||
| allclspred = pd.concat(dfs, axis=1) | ||
| adata.obs = pd.concat([adata.obs, allclspred], axis=1) |
Owner
Author
|
Closing in favour of cantinilab#52 — |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #16. Three related bugs in the Embedder path when
keep_all_cls_pred=Truewith multiple labels inpred_embedding(e.g.cell_type_ontology_term_id+disease_ontology_term_id).Bug 1 —
torch.stackwith mismatched head sizesscprint/model/model.py:_predictstacked per-class logits along a new dim, but heads have differentn_classes(424 vs 62 in the report), so:Fix: when
keep_all_cls_pred=True, store per-head logits in adict[clsname -> tensor]and concatenate per head across batches (instead of stacking). The argmax branch (keep_all_cls_pred=False) is unchanged.Also patched
on_validation_epoch_endtoall_gatherthe dict shape.Bug 2 — CUDA tensor in
pd.DataFramescprint/tasks/cell_emb.pywrapped a CUDA tensor inpd.DataFrame→Fix: move to CPU + numpy before the
DataFramecall.Bug 3 —
pd.concatpositional argsadata.obs = pd.concat(adata.obs, allclspred)is a positional bug — the second arg is interpreted asaxis.Fix:
pd.concat([adata.obs, allclspred], axis=1).Test plan
pip install git+https://github.com/jkobject/scPRINT.git@fix/keep-all-cls-predOutput shape note
With
keep_all_cls_pred=True, all per-head probability columns are concatenated intoadata.obs(one column per class, named vialabel_decoders). For 424 + 62 classes this is ~486 columns — that's intentional but worth flagging.Bumped to
1.1.4.CC @Prachi-Priyam — thanks for the very detailed report (root cause + proposed fix locations were spot on).