[https://nvbugs/6322045][fix] In triton_context, when max_q_len == max_kv_len (so cache_lens=0) and…#15479
[https://nvbugs/6322045][fix] In triton_context, when max_q_len == max_kv_len (so cache_lens=0) and…#15479tensorrt-cicd wants to merge 2 commits into
Conversation
📝 WalkthroughWalkthroughIn ChangesSDPA Causal Masking and Test Waiver
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py`:
- Around line 1181-1182: The sdpa_scratch_bytes calculation on line 1181
incorrectly uses q.element_size() which reflects the query tensor's dtype (e.g.,
2 bytes for fp16/bf16), but the actual SDPA scratch memory is allocated for fp32
score tensors which use 4 bytes per element. To fix this, replace the
q.element_size() call with a hardcoded 4 bytes (or equivalent like the size of a
float32 element) in the sdpa_scratch_bytes calculation to accurately estimate
the memory requirement and prevent premature selection of the explicit-mask
backend that could cause OOM on memory-constrained GPUs.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 07606216-ae91-4e99-93cf-6accb9a55b37
📒 Files selected for processing (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.pytests/integration/test_lists/waives.txt
💤 Files with no reviewable changes (1)
- tests/integration/test_lists/waives.txt
| sdpa_scratch_bytes = num_seq * n_heads * max_q_len * max_kv_len * q.element_size() | ||
| use_is_causal = max_q_len == max_kv_len and sdpa_scratch_bytes > 2 * 1024**3 |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Find PyTorch version pinning
find . -type f \( -name "requirements*.txt" -o -name "pyproject.toml" -o -name "setup.py" \) | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 952
🏁 Script executed:
# Read the specific file segment
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py | sed -n '1174,1200p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1788
🏁 Script executed:
# Also check surrounding comments for context on SDPA behavior
cat -n tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py | sed -n '1160,1185p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1464
🏁 Script executed:
# Find main project configuration
cat pyproject.toml | grep -A 10 "torch\|pytorch" | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 1219
🏁 Script executed:
# Also check if there's documentation about SDPA backend behavior
rg -i "sdpa|scaled_dot_product" --type py -B 2 -A 2 | grep -i "dtype\|fp32\|float32" | head -30Repository: NVIDIA/TensorRT-LLM
Length of output: 3723
🏁 Script executed:
# Check PyTorch documentation link or comments that might clarify the backend
rg "math backend\|flash\|fused kernel" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/ -B 2 -A 2Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Check the main pyproject.toml for torch version
cat pyproject.toml | grep -i "torch" | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 1219
🏁 Script executed:
# Look for PyTorch version in the base requirements
find . -name "requirements.txt" -exec grep -l "torch" {} \; | head -5Repository: NVIDIA/TensorRT-LLM
Length of output: 259
🏁 Script executed:
# Search for any tests or documentation about SDPA math backend behavior
find . -path "./tests/*" -name "*.py" -exec rg "scaled_dot_product_attention" {} + | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 3178
🏁 Script executed:
# Check if there are related PRs or comments about fp32 score tensors
rg "fp32.*score|score.*fp32" --type py tensorrt_llm/_torch/Repository: NVIDIA/TensorRT-LLM
Length of output: 497
🏁 Script executed:
# Check main requirements.txt for torch version
cat requirements.txt | grep -i torchRepository: NVIDIA/TensorRT-LLM
Length of output: 624
🏁 Script executed:
# Look for test that might reveal SDPA backend behavior with fp16/bf16
cat tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_attention.py | grep -A 20 "def test\|torch.float16\|torch.bfloat16" | head -80Repository: NVIDIA/TensorRT-LLM
Length of output: 3109
🏁 Script executed:
# Check if there's documentation about SDPA math backend in the repo
rg "torch.nn.functional.scaled_dot_product_attention" tensorrt_llm/_torch/ -B 5 -A 5 | grep -E "dtype|float32|fp32|math.*backend" | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Look at the actual test to see if fp16/bf16 queries are tested
cat tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_attention.py | sed -n '1,150p'Repository: NVIDIA/TensorRT-LLM
Length of output: 5217
🏁 Script executed:
# Search for PyTorch SDPA documentation or implementation details in the codebase
rg "math.*backend|cuda_flash\|efficient_attention" tensorrt_llm/_torch/ -B 2 -A 5 | head -40Repository: NVIDIA/TensorRT-LLM
Length of output: 3371
🏁 Script executed:
# Look for any comments about SDPA dtype behavior in tests or implementations
rg -i "sdpa.*dtype|score.*dtype|attention.*math" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/ -B 3 -A 3Repository: NVIDIA/TensorRT-LLM
Length of output: 6500
🏁 Script executed:
# Check if there are any references to PyTorch SDPA backend selection
cat tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py | sed -n '1100,1200p' | grep -i "backend\|dtype\|flash"Repository: NVIDIA/TensorRT-LLM
Length of output: 611
🏁 Script executed:
# Verify the proposed constants follow naming guidelines
# Check if similar constants are defined in the file
rg "THRESHOLD|BYTES|SIZE" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1681
🏁 Script executed:
# Verify the fix formatting matches the codebase style
# Check how constants are defined elsewhere in the file
cat tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py | grep -E "^[A-Z_]+\s*=" | head -10Repository: NVIDIA/TensorRT-LLM
Length of output: 82
🏁 Script executed:
# Verify there are no existing fp32/fp64 element size constants already defined
rg "element_size\|ELEMENT_SIZE" tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Check the context of the _MIN_TL_DOT_K constant to see naming style
cat tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py | grep -B 5 -A 5 "_MIN_TL_DOT_K"Repository: NVIDIA/TensorRT-LLM
Length of output: 730
🏁 Script executed:
# Check the exact lines and verify the fix makes sense with surrounding code
cat tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py | sed -n '1175,1200p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1526
Use fp32 bytes for the SDPA scratch threshold.
Line 1181 uses q.element_size(), but the preceding comment describes an fp32 score tensor. For fp16/bf16 queries, this halves the estimate, so the explicit-mask math backend can still be selected when the real fp32 score exceeds the 2 GiB guard—defeating the OOM prevention on memory-tight GPUs like L40S.
Proposed fix
- sdpa_scratch_bytes = num_seq * n_heads * max_q_len * max_kv_len * q.element_size()
- use_is_causal = max_q_len == max_kv_len and sdpa_scratch_bytes > 2 * 1024**3
+ FP32_ELEMENT_SIZE_BYTES = 4
+ SDPA_SCRATCH_THRESHOLD_BYTES = 2 * 1024**3
+ sdpa_scratch_bytes = (
+ num_seq * n_heads * max_q_len * max_kv_len * FP32_ELEMENT_SIZE_BYTES
+ )
+ use_is_causal = (
+ max_q_len == max_kv_len and sdpa_scratch_bytes > SDPA_SCRATCH_THRESHOLD_BYTES
+ )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py`
around lines 1181 - 1182, The sdpa_scratch_bytes calculation on line 1181
incorrectly uses q.element_size() which reflects the query tensor's dtype (e.g.,
2 bytes for fp16/bf16), but the actual SDPA scratch memory is allocated for fp32
score tensors which use 4 bytes per element. To fix this, replace the
q.element_size() call with a hardcoded 4 bytes (or equivalent like the size of a
float32 element) in the sdpa_scratch_bytes calculation to accurately estimate
the memory requirement and prevent premature selection of the explicit-mask
backend that could cause OOM on memory-constrained GPUs.
…attn The triton context-attention SDPA fast path passes an explicit attn_mask to torch.nn.functional.scaled_dot_product_attention, which forces the math backend and materializes [num_seq, n_heads, max_q_len, max_kv_len] fp32 score and softmax tensors. During piecewise cudagraph warmup at 8192 tokens for Llama-3.1-8B (1*32*8192*8192*4 = 8 GiB per tensor), this OOMs on memory-tight GPUs such as L40S (44 GiB). When the explicit mask reduces to a plain lower-triangular mask (every sequence has the same q_len under the existing all_same_q_len gate, and max_q_len == max_kv_len implies all cache_lens are 0), and the projected scratch would exceed 2 GiB, drop the mask and dispatch via is_causal=True so SDPA can use the fused flash kernel instead. The original explicit-mask path is preserved for shorter sequences and any case with cache reuse, so numerics are unchanged outside the OOM-prone regime. Signed-off-by: tensorrt-cicd <90828364+tensorrt-cicd@users.noreply.github.com>
6e60bc4 to
6e13e27
Compare
Summary
Test plan
Links
Summary by CodeRabbit
Bug Fixes
Performance Improvements