Skip to content

[https://nvbugs/6322045][fix] In triton_context, when max_q_len == max_kv_len (so cache_lens=0) and…#15479

Open
tensorrt-cicd wants to merge 2 commits into
NVIDIA:mainfrom
tensorrt-cicd:repair-bot-bug6322045
Open

[https://nvbugs/6322045][fix] In triton_context, when max_q_len == max_kv_len (so cache_lens=0) and…#15479
tensorrt-cicd wants to merge 2 commits into
NVIDIA:mainfrom
tensorrt-cicd:repair-bot-bug6322045

Conversation

@tensorrt-cicd

@tensorrt-cicd tensorrt-cicd commented Jun 18, 2026

Copy link
Copy Markdown
Collaborator

Summary

  • Root cause: SDPA with explicit attn_mask + enable_gqa=True forces the math backend, which materializes an 8 GiB fp32 score tensor at the 8192-token Llama-3.1-8B prefill warmup and OOMs on the 44 GiB L40S.
  • Fix: In triton_context, when max_q_len == max_kv_len (so cache_lens=0) and projected scratch > 2 GiB, drop attn_mask and use is_causal=True; keep the explicit-mask path otherwise so numerics are unchanged for short prefills and any cache-reuse case.
  • Automated fix generated by repair-bot

Test plan

  • Verify fix on the same GPU type as the original failure
  • Check for regressions in related tests

Links

Summary by CodeRabbit

  • Bug Fixes

    • Resolved failing test case for Llama 3.1 8B auto-dtype configuration with Triton inference.
  • Performance Improvements

    • Optimized attention computation to dynamically use more efficient kernel execution paths based on memory constraints.

@coderabbitai

coderabbitai Bot commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

In triton_context's SDPA path, a new use_is_causal flag is computed from max_q_len == max_kv_len and an SDPA scratch-size threshold. When true, attn_mask is set to None and is_causal=True is passed to SDPA. Otherwise, the mask combines a KV-length padding gate with a causal <= constraint. A previously-waived L40S accuracy test for this path is removed from waives.txt.

Changes

SDPA Causal Masking and Test Waiver

Layer / File(s) Summary
Conditional is_causal SDPA masking in triton_context
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py, tests/integration/test_lists/waives.txt
Introduces use_is_causal flag based on max_q_len == max_kv_len and scratch-size threshold; passes attn_mask=None, is_causal=True when flag is set, otherwise builds a combined KV-padding and causal <= mask. Removes the L40S test_auto_dtype[triton-False-1] waiver entry that was covering this failure.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

  • NVIDIA/TensorRT-LLM#15389: Added the waiver for TestLlama3_1_8B::test_auto_dtype[triton-False-1] in waives.txt that this PR now removes.
  • NVIDIA/TensorRT-LLM#15395: Also modifies tests/integration/test_lists/waives.txt CI SKIP/waiver entries in the same file.

Suggested reviewers

  • bmarimuthu-nv
  • tcherckez-nvidia
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Title check ❓ Inconclusive The title is incomplete and truncated with an ellipsis, making it impossible to evaluate the full summary of changes or verify its relevance to the changeset. Complete the title by removing the ellipsis and providing the full, coherent title that summarizes the main change.
✅ Passed checks (4 passed)
Check name Status Explanation
Description check ✅ Passed The pull request description covers the root cause, fix, and testing, but lacks completion of the PR checklist and is missing detailed sections on test coverage specifics as outlined in the template.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py`:
- Around line 1181-1182: The sdpa_scratch_bytes calculation on line 1181
incorrectly uses q.element_size() which reflects the query tensor's dtype (e.g.,
2 bytes for fp16/bf16), but the actual SDPA scratch memory is allocated for fp32
score tensors which use 4 bytes per element. To fix this, replace the
q.element_size() call with a hardcoded 4 bytes (or equivalent like the size of a
float32 element) in the sdpa_scratch_bytes calculation to accurately estimate
the memory requirement and prevent premature selection of the explicit-mask
backend that could cause OOM on memory-constrained GPUs.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 07606216-ae91-4e99-93cf-6accb9a55b37

📥 Commits

Reviewing files that changed from the base of the PR and between c25fa74 and 6e60bc4.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py
  • tests/integration/test_lists/waives.txt
💤 Files with no reviewable changes (1)
  • tests/integration/test_lists/waives.txt

Comment on lines +1181 to +1182
sdpa_scratch_bytes = num_seq * n_heads * max_q_len * max_kv_len * q.element_size()
use_is_causal = max_q_len == max_kv_len and sdpa_scratch_bytes > 2 * 1024**3

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Find PyTorch version pinning
find . -type f \( -name "requirements*.txt" -o -name "pyproject.toml" -o -name "setup.py" \) | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 952


🏁 Script executed:

# Read the specific file segment
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py | sed -n '1174,1200p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 1788


🏁 Script executed:

# Also check surrounding comments for context on SDPA behavior
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py | sed -n '1160,1185p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 1464


🏁 Script executed:

# Find main project configuration
cat pyproject.toml | grep -A 10 "torch\|pytorch" | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 1219


🏁 Script executed:

# Also check if there's documentation about SDPA backend behavior
rg -i "sdpa|scaled_dot_product" --type py -B 2 -A 2 | grep -i "dtype\|fp32\|float32" | head -30

Repository: NVIDIA/TensorRT-LLM

Length of output: 3723


🏁 Script executed:

# Check PyTorch documentation link or comments that might clarify the backend
rg "math backend\|flash\|fused kernel" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/ -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check the main pyproject.toml for torch version
cat pyproject.toml | grep -i "torch" | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 1219


🏁 Script executed:

# Look for PyTorch version in the base requirements
find . -name "requirements.txt" -exec grep -l "torch" {} \; | head -5

Repository: NVIDIA/TensorRT-LLM

Length of output: 259


🏁 Script executed:

# Search for any tests or documentation about SDPA math backend behavior
find . -path "./tests/*" -name "*.py" -exec rg "scaled_dot_product_attention" {} + | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 3178


🏁 Script executed:

# Check if there are related PRs or comments about fp32 score tensors
rg "fp32.*score|score.*fp32" --type py tensorrt_llm/_torch/

Repository: NVIDIA/TensorRT-LLM

Length of output: 497


🏁 Script executed:

# Check main requirements.txt for torch version
cat requirements.txt | grep -i torch

Repository: NVIDIA/TensorRT-LLM

Length of output: 624


🏁 Script executed:

# Look for test that might reveal SDPA backend behavior with fp16/bf16
cat tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_attention.py | grep -A 20 "def test\|torch.float16\|torch.bfloat16" | head -80

Repository: NVIDIA/TensorRT-LLM

Length of output: 3109


🏁 Script executed:

# Check if there's documentation about SDPA math backend in the repo
rg "torch.nn.functional.scaled_dot_product_attention" tensorrt_llm/_torch/ -B 5 -A 5 | grep -E "dtype|float32|fp32|math.*backend" | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Look at the actual test to see if fp16/bf16 queries are tested
cat tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_attention.py | sed -n '1,150p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 5217


🏁 Script executed:

# Search for PyTorch SDPA documentation or implementation details in the codebase
rg "math.*backend|cuda_flash\|efficient_attention" tensorrt_llm/_torch/ -B 2 -A 5 | head -40

Repository: NVIDIA/TensorRT-LLM

Length of output: 3371


🏁 Script executed:

# Look for any comments about SDPA dtype behavior in tests or implementations
rg -i "sdpa.*dtype|score.*dtype|attention.*math" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/ -B 3 -A 3

Repository: NVIDIA/TensorRT-LLM

Length of output: 6500


🏁 Script executed:

# Check if there are any references to PyTorch SDPA backend selection
cat tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py | sed -n '1100,1200p' | grep -i "backend\|dtype\|flash"

Repository: NVIDIA/TensorRT-LLM

Length of output: 611


🏁 Script executed:

# Verify the proposed constants follow naming guidelines
# Check if similar constants are defined in the file
rg "THRESHOLD|BYTES|SIZE" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1681


🏁 Script executed:

# Verify the fix formatting matches the codebase style
# Check how constants are defined elsewhere in the file
cat tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py | grep -E "^[A-Z_]+\s*=" | head -10

Repository: NVIDIA/TensorRT-LLM

Length of output: 82


🏁 Script executed:

# Verify there are no existing fp32/fp64 element size constants already defined
rg "element_size\|ELEMENT_SIZE" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

# Check the context of the _MIN_TL_DOT_K constant to see naming style
cat tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py | grep -B 5 -A 5 "_MIN_TL_DOT_K"

Repository: NVIDIA/TensorRT-LLM

Length of output: 730


🏁 Script executed:

# Check the exact lines and verify the fix makes sense with surrounding code
cat tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py | sed -n '1175,1200p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 1526


Use fp32 bytes for the SDPA scratch threshold.

Line 1181 uses q.element_size(), but the preceding comment describes an fp32 score tensor. For fp16/bf16 queries, this halves the estimate, so the explicit-mask math backend can still be selected when the real fp32 score exceeds the 2 GiB guard—defeating the OOM prevention on memory-tight GPUs like L40S.

Proposed fix
-        sdpa_scratch_bytes = num_seq * n_heads * max_q_len * max_kv_len * q.element_size()
-        use_is_causal = max_q_len == max_kv_len and sdpa_scratch_bytes > 2 * 1024**3
+        FP32_ELEMENT_SIZE_BYTES = 4
+        SDPA_SCRATCH_THRESHOLD_BYTES = 2 * 1024**3
+        sdpa_scratch_bytes = (
+            num_seq * n_heads * max_q_len * max_kv_len * FP32_ELEMENT_SIZE_BYTES
+        )
+        use_is_causal = (
+            max_q_len == max_kv_len and sdpa_scratch_bytes > SDPA_SCRATCH_THRESHOLD_BYTES
+        )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py`
around lines 1181 - 1182, The sdpa_scratch_bytes calculation on line 1181
incorrectly uses q.element_size() which reflects the query tensor's dtype (e.g.,
2 bytes for fp16/bf16), but the actual SDPA scratch memory is allocated for fp32
score tensors which use 4 bytes per element. To fix this, replace the
q.element_size() call with a hardcoded 4 bytes (or equivalent like the size of a
float32 element) in the sdpa_scratch_bytes calculation to accurately estimate
the memory requirement and prevent premature selection of the explicit-mask
backend that could cause OOM on memory-constrained GPUs.

…attn

The triton context-attention SDPA fast path passes an explicit attn_mask to
torch.nn.functional.scaled_dot_product_attention, which forces the math
backend and materializes [num_seq, n_heads, max_q_len, max_kv_len] fp32
score and softmax tensors. During piecewise cudagraph warmup at 8192 tokens
for Llama-3.1-8B (1*32*8192*8192*4 = 8 GiB per tensor), this OOMs on
memory-tight GPUs such as L40S (44 GiB).

When the explicit mask reduces to a plain lower-triangular mask (every
sequence has the same q_len under the existing all_same_q_len gate, and
max_q_len == max_kv_len implies all cache_lens are 0), and the projected
scratch would exceed 2 GiB, drop the mask and dispatch via is_causal=True
so SDPA can use the fused flash kernel instead. The original explicit-mask
path is preserved for shorter sequences and any case with cache reuse, so
numerics are unchanged outside the OOM-prone regime.

Signed-off-by: tensorrt-cicd <90828364+tensorrt-cicd@users.noreply.github.com>
@tensorrt-cicd tensorrt-cicd force-pushed the repair-bot-bug6322045 branch from 6e60bc4 to 6e13e27 Compare June 18, 2026 17:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants