[1/3][Refactor]: File reorg; deprecate ParallelDraft#1296
[1/3][Refactor]: File reorg; deprecate ParallelDraft#1296
Conversation
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughThis PR splits the prior unified Transformers speculative plugin into separate Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1296 +/- ##
==========================================
- Coverage 75.87% 71.51% -4.37%
==========================================
Files 462 465 +3
Lines 49747 49731 -16
==========================================
- Hits 37745 35564 -2181
- Misses 12002 14167 +2165
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
|
[AI Review] Generated by Claude — reference only, not a substitute for human review. Verdict: pure file-reorg + rename + the documented Verified:
Please double-check:
|
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/hf_medusa.py (1)
147-156: Consider avoiding in-place mutation oflabelsparameter.The loop mutates
labelson each iteration (labels = labels[..., 1:].contiguous()), which shadows the input parameter and could be confusing. Using a separate variable would improve clarity.♻️ Suggested improvement
if labels is not None: loss = 0 loss_fct = CrossEntropyLoss() # Base model loss if not freeze_base_model: loss_logits = logits.view(-1, logits.shape[-1]) loss_labels = labels.view(-1) base_model_loss = loss_fct(loss_logits, loss_labels) loss += base_model_loss # Medusa loss + shifted_labels = labels for i in range(self.medusa_num_heads): - labels = labels[..., 1:].contiguous() + shifted_labels = shifted_labels[..., 1:].contiguous() loss_logits = medusa_logits[i][:, : -(1 + i)].contiguous() loss_logits = loss_logits.view(-1, loss_logits.shape[-1]) - loss_labels = labels.view(-1) + loss_labels = shifted_labels.view(-1) loss += ( loss_fct(loss_logits, loss_labels) * medusa_decay_coefficient**i * medusa_heads_coefficient )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_medusa.py` around lines 147 - 156, The loop in function(s) using medusa_num_heads currently mutates the input parameter labels with labels = labels[..., 1:].contiguous() each iteration; replace this in-place mutation by computing a fresh shifted slice per head (e.g. shifted_labels = labels[..., 1 + i:].contiguous() or compute start = i+1 and use labels[..., start:]) and then use shifted_labels to form loss_labels; keep the rest of the computation (loss_logits from medusa_logits[i], view/reshape, loss_fct, medusa_decay_coefficient**i, medusa_heads_coefficient) unchanged so the original labels argument is not shadowed or modified.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/hf_medusa.py`:
- Line 122: Replace the incorrect keyword argument name rcache_position with the
correct cache_position where the model forward call is constructed so the
HuggingFace PreTrainedModel.forward() receives the cache position correctly;
locate the call using the rcache_position symbol in hf_medusa plugin code and
rename that argument to cache_position (and ensure any downstream uses expect
cache_position, not rcache_position).
In `@modelopt/torch/speculative/plugins/modeling_dflash.py`:
- Around line 25-31: Top-level hard imports of transformers/Qwen3 symbols
(ALL_ATTENTION_FUNCTIONS, Qwen3MLP/_MLP_CLS, Qwen3RMSNorm/_NORM_CLS,
Qwen3RotaryEmbedding/_ROTARY_CLS, _rotate_half) must be removed from
modeling_dflash.py and instead acquired via the plugin lazy loader; replace
those module-level imports with calls to the plugin system (import_plugin()) and
perform the transformers/Qwen3 imports inside the plugin initialization or
inside the functions that need them so they only run when the hf_dflash plugin
is loaded, ensuring modeling_dflash.py can be imported in environments without
the optional transformers integration.
- Around line 36-47: The function build_target_layer_ids is producing duplicate
layer indices for shallow target models (e.g., build_target_layer_ids(4,2) ->
[1,1]); update it to detect when the computed interior window is too small to
yield unique indices and either (a) fall back to using the full layer span (0
through num_target_layers-1) to select uniformly spaced indices, or (b) raise a
ValueError for unsupported configs; implement the check inside
build_target_layer_ids (use the computed start/end/span and compare span to
num_draft_layers - 1) and then either compute indices across the full range or
raise, so hf_dflash.py will not receive duplicate hidden-state indices.
In `@modelopt/torch/speculative/plugins/modeling_eagle.py`:
- Line 49: Fix the typo in the comment that reads "Their values depend on
specific tokenzier and calibrate dataset, and should be set in training script."
by changing "tokenzier" to "tokenizer" so the comment reads "Their values depend
on specific tokenizer and calibrate dataset, and should be set in training
script."; update the text in modeling_eagle.py wherever that exact comment
appears.
---
Nitpick comments:
In `@modelopt/torch/speculative/plugins/hf_medusa.py`:
- Around line 147-156: The loop in function(s) using medusa_num_heads currently
mutates the input parameter labels with labels = labels[..., 1:].contiguous()
each iteration; replace this in-place mutation by computing a fresh shifted
slice per head (e.g. shifted_labels = labels[..., 1 + i:].contiguous() or
compute start = i+1 and use labels[..., start:]) and then use shifted_labels to
form loss_labels; keep the rest of the computation (loss_logits from
medusa_logits[i], view/reshape, loss_fct, medusa_decay_coefficient**i,
medusa_heads_coefficient) unchanged so the original labels argument is not
shadowed or modified.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 67d25ab7-428e-47d9-b671-09260bc72cf7
📒 Files selected for processing (13)
.pre-commit-config.yamlexamples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/scripts/ar_validate.pymodelopt/torch/export/plugins/hf_spec_export.pymodelopt/torch/speculative/eagle/default_config.pymodelopt/torch/speculative/plugins/__init__.pymodelopt/torch/speculative/plugins/hf_dflash.pymodelopt/torch/speculative/plugins/hf_eagle.pymodelopt/torch/speculative/plugins/hf_medusa.pymodelopt/torch/speculative/plugins/modeling_dflash.pymodelopt/torch/speculative/plugins/modeling_eagle.pymodelopt/torch/speculative/utils.pytests/unit/torch/export/test_hf_spec_rope_export.py
💤 Files with no reviewable changes (1)
- modelopt/torch/speculative/eagle/default_config.py
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
What does this PR do?
Type of change: refactoring
Part 1 of a 3-PR series splitting #1271:
ParallelDraftHFSpecDecMixinChanges:
transformers.py→hf_eagle.py; extractHFMedusaModel→hf_medusa.py; extractEagleModule/EagleBaseModelOutput→modeling_eagle.py; extractDFlashModule/DFlashAttention/DFlashDecoderLayer/build_target_layer_ids/apply_rotary_pos_emb→modeling_dflash.py.ParallelDraft: removeparallel_draft_step,parallel_draft_heads_num_layers, and theParallelDraftmodule from HF Eagle; remove theEagleMedusaExporterbranch fromHFEagleModel.get_exporter()(theEagleMedusaExporterclass itself still lives inhf_spec_export.pyfor Megatron parity)._draft_model_config→eagle_configin export plugin.examples/speculative_decoding/andmodelopt/torch/speculative/utils.pyto follow the module rename.Testing
Validated with existing Eagle and DFlash training scripts (re-run after
9ae5302729 revert behavior change).Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).modelopt.torch.speculative.plugins.transformers→.hf_eagle; removesparallel_draft_step/parallel_draft_heads_num_layersfrom Eagle config; renames_draft_model_config→eagle_configin export plugin.CONTRIBUTING.md: N/Atest_hf_spec_rope_export.pyassertions were also corrected to reflect the actual production path (the old assertions were masked byMagicMocknot invoking the_draft_model_config@property).Additional Information
Breaking changes:
modelopt.torch.speculative.plugins.transformers→.hf_eagleparallel_draft_step/parallel_draft_heads_num_layersremoved from Eagle config_draft_model_config→eagle_configin export pluginSummary by CodeRabbit
Refactoring
New Features
Chores
Tests