Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions tensorrt_llm/_torch/modules/mamba/gdn_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,14 +890,19 @@ def forward(

state_indices_p, state_indices_d = torch.split(state_indices, batch_split_size)
if num_prefills > 0:
# PyExecutor guarantees prefill requests are placed before decode requests
has_initial_states_p = has_initial_states[:num_prefills]
ssm_states[state_indices_p[~has_initial_states_p]] = torch.zeros(
(), dtype=ssm_states.dtype, device=ssm_states.device
)
conv_states[state_indices_p[~has_initial_states_p]] = torch.zeros(
(), dtype=conv_states.dtype, device=conv_states.device
)
if not mamba_metadata.use_initial_states:
# All prefills are fresh — zero every slot unconditionally.
for state in (ssm_states, conv_states):
state[state_indices_p] = 0
else:
# Use torch.where so the output shape is data-independent;
# boolean-mask indexing on a CUDA tensor would force a CPU sync.
# PyExecutor guarantees prefill requests are placed before decode requests.
has_initial_states_p = has_initial_states[:num_prefills]
for state in (ssm_states, conv_states):
kept = state[state_indices_p]
mask = has_initial_states_p.view(-1, *([1] * (kept.ndim - 1)))
state.index_copy_(0, state_indices_p, torch.where(mask, kept, 0))

is_target_verify = (
num_decodes > 0
Expand Down
10 changes: 10 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1676,6 +1676,16 @@ def add_dummy_requests(
num_extra_decoding_steps: int = 0,
draft_kv_cache_manager: Optional[KVCacheManager] = None,
) -> List[LlmRequest]:
# The caller's get_num_free_blocks check sees only the full-attention
# window, but the unified C++ pool also holds the recurrent-states
# window used by mamba layers. If a CUDA-graph warmup batch exceeds
# that smaller window, add_sequence_batch raises "No free block found"
# from C++ and deadlocks ranks already past the collective. Return
# None so the caller skips this batch (the documented contract).
per_window_free = self.impl.get_kv_cache_stats(
).num_free_blocks_per_window_size
if per_window_free and len(request_ids) > min(per_window_free.values()):
return None
requests = super().add_dummy_requests(
request_ids=request_ids,
token_nums=token_nums,
Expand Down
16 changes: 16 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,22 @@ def _run_attention_warmup(self,
if not issubclass(self.attn_backend.Metadata, TrtllmAttentionMetadata):
return

# The C++ TRTLLM-Gen FMHA JIT warmup enumerates a (batchSize x seqLenKv)
# cartesian grid sized by engine maxima. For long-context configs such
# as Qwen3-Next-80B-A3B-Thinking tp4ep4 (max_batch_size=2048,
# max_seq_len=262144), the densified grid pushes warmup TMA descriptor
# shapes past the flashinfer 2^32 limit and hangs engine startup. Skip
# the warmup whenever the maxima product is too large; any kernel not
# pre-warmed JIT-compiles lazily on first request, which is correct
# (just slower for that one request). The threshold matches the
# pre-PR #15305 effective grid size.
if self.batch_size * self.max_seq_len > 256 * 16384:
logger.info(
f"Skipping TRTLLM-Gen FMHA JIT warmup: engine config "
f"(max_batch_size={self.batch_size}, max_seq_len={self.max_seq_len}) "
f"would produce too many warmup grid points")
return

@contextlib.contextmanager
def trtllm_gen_fmha_jit_warmup():
previous = self._trtllm_gen_jit_warmup
Expand Down
Loading