Skip to content

[1/3][Refactor]: File reorg; deprecate ParallelDraft#1296

Open
h-guo18 wants to merge 3 commits intomainfrom
haoguo/spec-file-reorg
Open

[1/3][Refactor]: File reorg; deprecate ParallelDraft#1296
h-guo18 wants to merge 3 commits intomainfrom
haoguo/spec-file-reorg

Conversation

@h-guo18
Copy link
Copy Markdown
Contributor

@h-guo18 h-guo18 commented Apr 19, 2026

What does this PR do?

Type of change: refactoring

Part 1 of a 3-PR series splitting #1271:

Changes:

  • File reorg: transformers.pyhf_eagle.py; extract HFMedusaModelhf_medusa.py; extract EagleModule / EagleBaseModelOutputmodeling_eagle.py; extract DFlashModule / DFlashAttention / DFlashDecoderLayer / build_target_layer_ids / apply_rotary_pos_embmodeling_dflash.py.
  • Deprecate ParallelDraft: remove parallel_draft_step, parallel_draft_heads_num_layers, and the ParallelDraft module from HF Eagle; remove the EagleMedusaExporter branch from HFEagleModel.get_exporter() (the EagleMedusaExporter class itself still lives in hf_spec_export.py for Megatron parity).
  • Rename: _draft_model_configeagle_config in export plugin.
  • Update imports in examples/speculative_decoding/ and modelopt/torch/speculative/utils.py to follow the module rename.

Testing

Validated with existing Eagle and DFlash training scripts (re-run after 9ae5302729 revert behavior change).

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ❌ — renames modelopt.torch.speculative.plugins.transformers.hf_eagle; removes parallel_draft_step / parallel_draft_heads_num_layers from Eagle config; renames _draft_model_configeagle_config in export plugin.
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: N/A — pure refactor; existing tests updated for the rename. test_hf_spec_rope_export.py assertions were also corrected to reflect the actual production path (the old assertions were masked by MagicMock not invoking the _draft_model_config @property).
  • Did you update Changelog?: ❌

Additional Information

Breaking changes:

  • modelopt.torch.speculative.plugins.transformers.hf_eagle
  • parallel_draft_step / parallel_draft_heads_num_layers removed from Eagle config
  • _draft_model_configeagle_config in export plugin

Summary by CodeRabbit

  • Refactoring

    • Restructured speculative decoding plugins and simplified EAGLE draft-token generation and exporter behavior.
    • Delegated DFlash implementation into a shared modeling component.
  • New Features

    • Added dedicated Medusa speculative-decoding plugin with configurable medusa heads and combined-loss support.
    • Introduced new modeling modules for DFlash and EAGLE decoder stacks.
  • Chores

    • Removed unused default configuration keys and updated pre-commit header exclusions.
  • Tests

    • Adjusted export tests to validate rope-scaling fallback behavior.

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 19, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 19, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 64f91af2-20c1-4ddd-b771-aef3f04d22cf

📥 Commits

Reviewing files that changed from the base of the PR and between 9ae5302 and 4eaafb2.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/modeling_eagle.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/speculative/plugins/modeling_eagle.py

📝 Walkthrough

Walkthrough

This PR splits the prior unified Transformers speculative plugin into separate hf_eagle and hf_medusa plugins, relocates draft-model implementations into dedicated modeling modules (DFlash/Eagle), removes EAGLE parallel-draft-step behavior, and updates related exports, flags, examples, and tests.

Changes

Cohort / File(s) Summary
Pre-commit & plugin wiring
.pre-commit-config.yaml, modelopt/torch/speculative/plugins/__init__.py
Replaced transformers exclusion with hf_medusa in pre-commit; plugin import wiring now conditionally imports hf_eagle and hf_medusa instead of transformers.
EAGLE plugin & exporter changes
modelopt/torch/speculative/plugins/hf_eagle.py, modelopt/torch/speculative/eagle/default_config.py, modelopt/torch/export/plugins/hf_spec_export.py
Removed Medusa-related code and parallel draft-step behavior from hf_eagle.py (single-logit path only); removed parallel_draft_* keys from default configs; exporter now reads draft config from model.eagle_config.
New Medusa plugin
modelopt/torch/speculative/plugins/hf_medusa.py
Added new HFMedusaModel implementing Medusa-specific modify/forward logic, medusa heads construction, loss aggregation, and registration in MedusaDMRegistry.
DFlash relocation & modeling
modelopt/torch/speculative/plugins/hf_dflash.py, modelopt/torch/speculative/plugins/modeling_dflash.py
Moved in-file DFlash attention/decoder/draft implementations out of hf_dflash.py into new modeling_dflash.py; hf_dflash.py now imports implementations from the shared modeling module.
Eagle modeling module
modelopt/torch/speculative/plugins/modeling_eagle.py
Added EagleModule and EagleBaseModelOutput dataclass building decoder stack, optional lm_head/draft routing, auxiliary hidden-state routing, rotary init, and pre-hook handling for inputs_embeds.
Utility & CP TTT patch flag usage
modelopt/torch/speculative/utils.py, examples/speculative_decoding/eagle_utils.py
Switched CP TTT patch flag imports/usage from transformers to hf_eagle, updating the context manager and conditional checks that enable the patch.
Examples & scripts
examples/speculative_decoding/eagle_utils.py, examples/speculative_decoding/scripts/ar_validate.py
Updated feature-flag and validator imports to reference hf_eagle instead of transformers; adjusted conditional attention/kwargs behavior to use new flag location.
Tests
tests/unit/torch/export/test_hf_spec_rope_export.py
Adjusted exporter test fixture to set model.eagle_config.rope_theta = None and renamed/updated tests to assert rope-scaling fallback semantics from training config.
New modeling exports & utilities
modelopt/torch/speculative/plugins/modeling_dflash.py, modelopt/torch/speculative/plugins/modeling_eagle.py
Added and exported DFlash and Eagle modeling utilities: build_target_layer_ids, DFlashModule, DFlashAttention, EagleModule, and related helpers/dataclasses.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 78.72% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[1/3][Refactor]: File reorg; deprecate ParallelDraft' accurately summarizes the main changes: file reorganization, module extraction, and ParallelDraft deprecation. It is specific, clear, and directly reflects the primary refactoring work described in the PR objectives.
Security Anti-Patterns ✅ Passed Comprehensive security scan found no torch.load with weights_only=False, numpy.load with allow_pickle=True, hardcoded trust_remote_code=True, eval/exec calls, or nosec comments bypassing security checks.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch haoguo/spec-file-reorg

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 19, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1296/

Built to branch gh-pages at 2026-04-19 23:33 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 19, 2026

Codecov Report

❌ Patch coverage is 74.75728% with 78 lines in your changes missing coverage. Please review.
✅ Project coverage is 71.51%. Comparing base (26ae8da) to head (4eaafb2).

Files with missing lines Patch % Lines
...delopt/torch/speculative/plugins/modeling_eagle.py 50.60% 41 Missing ⚠️
modelopt/torch/speculative/plugins/hf_medusa.py 35.41% 31 Missing ⚠️
modelopt/torch/speculative/plugins/hf_eagle.py 90.38% 5 Missing ⚠️
...elopt/torch/speculative/plugins/modeling_dflash.py 99.15% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1296      +/-   ##
==========================================
- Coverage   75.87%   71.51%   -4.37%     
==========================================
  Files         462      465       +3     
  Lines       49747    49731      -16     
==========================================
- Hits        37745    35564    -2181     
- Misses      12002    14167    +2165     
Flag Coverage Δ
examples 40.34% <23.30%> (-1.70%) ⬇️
gpu 50.33% <53.39%> (-9.07%) ⬇️
regression 14.89% <53.39%> (+0.10%) ⬆️
unit 52.52% <60.51%> (+0.02%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@h-guo18 h-guo18 changed the title reorg files [1/3][Refactor]: File reorg; deprecate ParallelDraft Apr 19, 2026
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@h-guo18 h-guo18 marked this pull request as ready for review April 19, 2026 23:15
@h-guo18 h-guo18 requested review from a team as code owners April 19, 2026 23:15
@h-guo18 h-guo18 requested a review from ChenhanYu April 19, 2026 23:15
@h-guo18
Copy link
Copy Markdown
Contributor Author

h-guo18 commented Apr 19, 2026

[AI Review] Generated by Claude — reference only, not a substitute for human review.
Diffed against origin/main at PR tip 9ae5302729.

Verdict: pure file-reorg + rename + the documented ParallelDraft deprecation. No hidden behavior changes.

Verified:

  • DFlash forward() / pseudo_speculative_generate() and Eagle _find_base_model_parts() are byte-identical to main.
  • Extracted modules (hf_medusa.py, modeling_eagle.py, modeling_dflash.py) match main's originals byte-for-byte, modulo removed ParallelDraft and added docstrings.
  • Rename-only files (utils.py, __init__.py, .pre-commit-config.yaml, examples/**, hf_spec_export.py) are string substitutions; _draft_model_configeagle_config preserves production behavior because the former was a @property returning the latter on main.
  • Non-trivial equivalences: _maybe_init_rope moved inside EagleModule.forward (idempotent, same device); Eagle TTT loop rewrite num_ttt_steps = eagle_ttt_steps if training else 1 produces identical iteration count and slice as the old for ... if not training: break + train_accs[:, :ttt_step+1].
  • Test updates in test_hf_spec_rope_export.py exercise the real production path that was masked by MagicMock in the old fixture — not a behavior change.

Please double-check:

  1. Existing Eagle and DFlash training scripts were re-run after 9ae5302729 (the revert touches DFlash forward + generate).
  2. EagleMedusaExporter in hf_spec_export.py is now dead code for HF Eagle — keep for Megatron parity or delete?
  3. Leftover parallel_draft_step entries in hf_spec_configs.py — intentional back-compat, or stale?
  4. CHANGELOG update for three breaking changes: module rename, Eagle config key removal, _draft_model_config rename.

@h-guo18 h-guo18 self-assigned this Apr 19, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/hf_medusa.py (1)

147-156: Consider avoiding in-place mutation of labels parameter.

The loop mutates labels on each iteration (labels = labels[..., 1:].contiguous()), which shadows the input parameter and could be confusing. Using a separate variable would improve clarity.

♻️ Suggested improvement
         if labels is not None:
             loss = 0
             loss_fct = CrossEntropyLoss()
             # Base model loss
             if not freeze_base_model:
                 loss_logits = logits.view(-1, logits.shape[-1])
                 loss_labels = labels.view(-1)
                 base_model_loss = loss_fct(loss_logits, loss_labels)
                 loss += base_model_loss
             # Medusa loss
+            shifted_labels = labels
             for i in range(self.medusa_num_heads):
-                labels = labels[..., 1:].contiguous()
+                shifted_labels = shifted_labels[..., 1:].contiguous()
                 loss_logits = medusa_logits[i][:, : -(1 + i)].contiguous()
                 loss_logits = loss_logits.view(-1, loss_logits.shape[-1])
-                loss_labels = labels.view(-1)
+                loss_labels = shifted_labels.view(-1)
                 loss += (
                     loss_fct(loss_logits, loss_labels)
                     * medusa_decay_coefficient**i
                     * medusa_heads_coefficient
                 )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_medusa.py` around lines 147 - 156, The
loop in function(s) using medusa_num_heads currently mutates the input parameter
labels with labels = labels[..., 1:].contiguous() each iteration; replace this
in-place mutation by computing a fresh shifted slice per head (e.g.
shifted_labels = labels[..., 1 + i:].contiguous() or compute start = i+1 and use
labels[..., start:]) and then use shifted_labels to form loss_labels; keep the
rest of the computation (loss_logits from medusa_logits[i], view/reshape,
loss_fct, medusa_decay_coefficient**i, medusa_heads_coefficient) unchanged so
the original labels argument is not shadowed or modified.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/plugins/hf_medusa.py`:
- Line 122: Replace the incorrect keyword argument name rcache_position with the
correct cache_position where the model forward call is constructed so the
HuggingFace PreTrainedModel.forward() receives the cache position correctly;
locate the call using the rcache_position symbol in hf_medusa plugin code and
rename that argument to cache_position (and ensure any downstream uses expect
cache_position, not rcache_position).

In `@modelopt/torch/speculative/plugins/modeling_dflash.py`:
- Around line 25-31: Top-level hard imports of transformers/Qwen3 symbols
(ALL_ATTENTION_FUNCTIONS, Qwen3MLP/_MLP_CLS, Qwen3RMSNorm/_NORM_CLS,
Qwen3RotaryEmbedding/_ROTARY_CLS, _rotate_half) must be removed from
modeling_dflash.py and instead acquired via the plugin lazy loader; replace
those module-level imports with calls to the plugin system (import_plugin()) and
perform the transformers/Qwen3 imports inside the plugin initialization or
inside the functions that need them so they only run when the hf_dflash plugin
is loaded, ensuring modeling_dflash.py can be imported in environments without
the optional transformers integration.
- Around line 36-47: The function build_target_layer_ids is producing duplicate
layer indices for shallow target models (e.g., build_target_layer_ids(4,2) ->
[1,1]); update it to detect when the computed interior window is too small to
yield unique indices and either (a) fall back to using the full layer span (0
through num_target_layers-1) to select uniformly spaced indices, or (b) raise a
ValueError for unsupported configs; implement the check inside
build_target_layer_ids (use the computed start/end/span and compare span to
num_draft_layers - 1) and then either compute indices across the full range or
raise, so hf_dflash.py will not receive duplicate hidden-state indices.

In `@modelopt/torch/speculative/plugins/modeling_eagle.py`:
- Line 49: Fix the typo in the comment that reads "Their values depend on
specific tokenzier and calibrate dataset, and should be set in training script."
by changing "tokenzier" to "tokenizer" so the comment reads "Their values depend
on specific tokenizer and calibrate dataset, and should be set in training
script."; update the text in modeling_eagle.py wherever that exact comment
appears.

---

Nitpick comments:
In `@modelopt/torch/speculative/plugins/hf_medusa.py`:
- Around line 147-156: The loop in function(s) using medusa_num_heads currently
mutates the input parameter labels with labels = labels[..., 1:].contiguous()
each iteration; replace this in-place mutation by computing a fresh shifted
slice per head (e.g. shifted_labels = labels[..., 1 + i:].contiguous() or
compute start = i+1 and use labels[..., start:]) and then use shifted_labels to
form loss_labels; keep the rest of the computation (loss_logits from
medusa_logits[i], view/reshape, loss_fct, medusa_decay_coefficient**i,
medusa_heads_coefficient) unchanged so the original labels argument is not
shadowed or modified.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 67d25ab7-428e-47d9-b671-09260bc72cf7

📥 Commits

Reviewing files that changed from the base of the PR and between 26ae8da and 9ae5302.

📒 Files selected for processing (13)
  • .pre-commit-config.yaml
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/scripts/ar_validate.py
  • modelopt/torch/export/plugins/hf_spec_export.py
  • modelopt/torch/speculative/eagle/default_config.py
  • modelopt/torch/speculative/plugins/__init__.py
  • modelopt/torch/speculative/plugins/hf_dflash.py
  • modelopt/torch/speculative/plugins/hf_eagle.py
  • modelopt/torch/speculative/plugins/hf_medusa.py
  • modelopt/torch/speculative/plugins/modeling_dflash.py
  • modelopt/torch/speculative/plugins/modeling_eagle.py
  • modelopt/torch/speculative/utils.py
  • tests/unit/torch/export/test_hf_spec_rope_export.py
💤 Files with no reviewable changes (1)
  • modelopt/torch/speculative/eagle/default_config.py

Comment thread modelopt/torch/speculative/plugins/hf_medusa.py
Comment thread modelopt/torch/speculative/plugins/modeling_dflash.py
Comment thread modelopt/torch/speculative/plugins/modeling_dflash.py
Comment thread modelopt/torch/speculative/plugins/modeling_eagle.py Outdated
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@h-guo18 h-guo18 requested a review from yeyu-nvidia April 19, 2026 23:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant