[https://nvbugs/6316980][fix] Added a runtime guard in FlashInferTrtllmGenAttention.is_supported using the…#15496
Conversation
📝 WalkthroughWalkthroughAdds a ChangesTMA overflow guard for FlashInfer trtllm-gen
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tensorrt_llm/_torch/attention_backend/trtllm_gen.py (2)
638-646: ⚡ Quick winClarify the error message to reference the actual threshold value.
The error message mentions "flashinfer 2^32 TMA shape limit" but the actual guard threshold is
MAX_TMA_GUARD_THRESHOLD(256 × 16384 = 4,194,304), which is significantly smaller than 2³². While the underlying root cause is flashinfer's 2³² constraint, the immediate reason for rejection is exceeding the conservative threshold. Including the actual threshold in the message would help users and maintainers understand the specific limit being enforced.📝 Proposed improvement to error message
if meta.max_num_requests * meta.max_seq_len > self.MAX_TMA_GUARD_THRESHOLD: return ( False, - f"engine maxima (max_num_requests={meta.max_num_requests}, " - f"max_seq_len={meta.max_seq_len}) would overflow flashinfer 2^32 TMA shape limit.", + f"engine maxima product (max_num_requests={meta.max_num_requests} × " + f"max_seq_len={meta.max_seq_len} = {meta.max_num_requests * meta.max_seq_len}) " + f"exceeds threshold ({self.MAX_TMA_GUARD_THRESHOLD}) to prevent flashinfer TMA overflow.", )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/attention_backend/trtllm_gen.py` around lines 638 - 646, The error message in the return statement of the condition checking against MAX_TMA_GUARD_THRESHOLD currently only references flashinfer's 2^32 TMA shape limit, but it should clarify the actual threshold value being enforced. Update the error message string (in the f-string starting with "engine maxima") to include the actual MAX_TMA_GUARD_THRESHOLD constant value so users understand the specific conservative limit that triggered the rejection, not just the underlying flashinfer constraint.
431-436: 💤 Low valueConsider clarifying why the threshold is 256×16384 rather than closer to 2³².
The comment mentions that flashinfer's
buildNdTmaDescriptorenforces a 2³² (4,294,967,296) limit, but the actual threshold is 256 × 16384 = 4,194,304. While the PR context indicates this value matches the existing JIT-warmup threshold, future maintainers would benefit from a brief explanation of why this specific conservative value was chosen (e.g., empirical testing, safety margin, or other flashinfer internals).📝 Suggested comment enhancement
- # flashinfer's buildNdTmaDescriptor (kernelParams.h:598) enforces shapes[ii] <= 2^32. - # When max_batch_size * max_seq_len exceeds this, the K/V cache pool's TMA shape - # overflows on one rank and the rest hang on the next NCCL collective + # flashinfer's buildNdTmaDescriptor (kernelParams.h:598) enforces shapes[ii] <= 2^32. + # This conservative threshold (256 × 16384 = 4,194,304) prevents TMA shape overflow + # in the K/V cache pool when max_batch_size * max_seq_len exceeds it, avoiding + # one-rank abort + NCCL hang on subsequent collectives🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/attention_backend/trtllm_gen.py` around lines 431 - 436, The constant MAX_TMA_GUARD_THRESHOLD is set to 256 * 16384 which is significantly lower than the 2^32 limit mentioned in the comment. Enhance the comment above MAX_TMA_GUARD_THRESHOLD to clarify why this specific conservative threshold value was chosen instead of being closer to the actual 2^32 limit enforced by flashinfer's buildNdTmaDescriptor, such as whether it comes from empirical testing, a safety margin, alignment with the JIT-warmup threshold, or other flashinfer internals. This will help future maintainers understand the design decision.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tensorrt_llm/_torch/attention_backend/trtllm_gen.py`:
- Around line 638-646: The error message in the return statement of the
condition checking against MAX_TMA_GUARD_THRESHOLD currently only references
flashinfer's 2^32 TMA shape limit, but it should clarify the actual threshold
value being enforced. Update the error message string (in the f-string starting
with "engine maxima") to include the actual MAX_TMA_GUARD_THRESHOLD constant
value so users understand the specific conservative limit that triggered the
rejection, not just the underlying flashinfer constraint.
- Around line 431-436: The constant MAX_TMA_GUARD_THRESHOLD is set to 256 *
16384 which is significantly lower than the 2^32 limit mentioned in the comment.
Enhance the comment above MAX_TMA_GUARD_THRESHOLD to clarify why this specific
conservative threshold value was chosen instead of being closer to the actual
2^32 limit enforced by flashinfer's buildNdTmaDescriptor, such as whether it
comes from empirical testing, a safety margin, alignment with the JIT-warmup
threshold, or other flashinfer internals. This will help future maintainers
understand the design decision.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 7c2bce37-fcf2-42d4-8cbc-3c6d0fade52c
⛔ Files ignored due to path filters (1)
tests/integration/defs/examples/visual_gen/golden/visual_gen_lpips/visual_gen_lpips_golden_media.zipis excluded by!**/*.zip
📒 Files selected for processing (2)
tensorrt_llm/_torch/attention_backend/trtllm_gen.pytensorrt_llm/_torch/pyexecutor/model_engine.py
…ngine configs Signed-off-by: tensorrt-cicd <90828364+tensorrt-cicd@users.noreply.github.com>
…ne configs flashinfer's buildNdTmaDescriptor at kernelParams.h:598 enforces shapes[ii] <= 2^32 per TMA dim. The previous fix (commit 9a05ee8 / nvbugs/6275959) skipped only the C++ TRTLLM-Gen FMHA JIT warmup grid; the runtime trtllm-gen forward path that runs immediately after (inside the general warmup) is still unguarded. With max_batch_size=720 and max_seq_len=131072 on GPT-OSS-120B, the runtime decode call into flashinfer aborts on one rank and the rest hang on the next NCCL collective, producing the silent multi-hour hang we saw post-fix. Add a runtime guard in FlashInferTrtllmGenAttention.is_supported that falls back to the legacy thop.attention path (C++ TRT-LLM runtime, no flashinfer TMA descriptors) when meta.max_num_requests * meta.max_seq_len exceeds the same 256*16384 threshold used by the JIT-warmup skip. The two paths now agree on what 'oversized' means and the legacy runtime takes over end-to-end for those configs. Signed-off-by: tensorrt-cicd <90828364+tensorrt-cicd@users.noreply.github.com>
5bb0569 to
2f3e040
Compare
Summary
Test plan
Links
Summary by CodeRabbit
Bug Fixes