Skip to content

Add support and optimizations for Seko AR model#1091

Merged
gushiqiao merged 5 commits into
mainfrom
gsq/dev-seko
May 25, 2026
Merged

Add support and optimizations for Seko AR model#1091
gushiqiao merged 5 commits into
mainfrom
gsq/dev-seko

Conversation

@gushiqiao
Copy link
Copy Markdown
Contributor

@gushiqiao gushiqiao commented May 25, 2026

Summary

This PR adds and optimizes KV cache support for Seko autoregressive inference, with a focus on making the Seko AR model practical on consumer GPUs.

Key changes include:

  • Add support for Seko autoregressive model inference, mainly through the step-aware KV cache architecture.
  • Support sequence-parallel inference for the Seko AR path, including SP-aware cache indexing and RoPE position handling.
  • Refactor KV cache storage into a ring layout to avoid unnecessary copies during rolling-window updates. Instead of physically shifting [sink | chunk1 | chunk2 | chunk3] into [sink | chunk2 | chunk3 | chunk4], the cache can now keep a ring layout such as [sink | chunk4 | chunk2 | chunk3], reducing memory movement.
  • Optimize RoPE in the Seko AR path with a Triton kernel, including support for cache-range RoPE and sequence-parallel spatial position mapping.
  • Optimize KIVI KV cache dequantization by fusing unpack and dequant into a Triton kernel.
  • Add asynchronous autoregressive decode support to overlap and hide decode latency.
  • Add LongLive2 NVFP4 KV cache quantization support, including parallel dequantization across multiple chunks.
  • Refactor KV offload with a simpler single-buffer prefetch design, improving readability and inference speed.
  • Together, these optimizations make Seko AR inference practical on consumer-grade GPUs, while the original reference setup typically targets an 8-GPU environment with 80 GB of memory per GPU.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for autoregressive (AR) audio-driven video generation via the "seko_talk_ar" model and runner. Key technical additions include a ring-buffer based RollingKVCachePool with CPU offloading, specialized CUDA and Triton kernels for FP4 and int4/int8 KV cache quantization (OScaR), and an asynchronous VAE chunk decoder to improve throughput. Feedback focuses on optimizing the new offloading and synchronization logic, specifically recommending the use of record_stream for safer asynchronous memory handling and suggesting more granular synchronization to avoid pipeline stalls during step updates. The reviewer also noted that the stateful nature of the VAE decoder currently limits the potential gains from asynchronous execution.

Comment on lines +318 to +319
k_cpu[phys_s : phys_s + n].copy_(k[ks:ke], non_blocking=True)
v_cpu[phys_s : phys_s + n].copy_(v[ks:ke], non_blocking=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When performing a non-blocking D2H copy from a GPU tensor that might be reused or freed immediately after store_kv returns, it is safer to use record_stream on the source tensor to ensure the allocator doesn't reclaim the memory before the copy is complete. While PyTorch's stream semantics often handle this if the copy is on the current stream, explicit record_stream is defensive best practice for asynchronous offloading.

Suggested change
k_cpu[phys_s : phys_s + n].copy_(k[ks:ke], non_blocking=True)
v_cpu[phys_s : phys_s + n].copy_(v[ks:ke], non_blocking=True)
k_cpu[phys_s : phys_s + n].copy_(k[ks:ke], non_blocking=True)
v_cpu[phys_s : phys_s + n].copy_(v[ks:ke], non_blocking=True)
k.record_stream(torch.cuda.current_stream())
v.record_stream(torch.cuda.current_stream())

Comment on lines +431 to +435
if self._kv_offload and hasattr(self, "_prefetch_stream"):
self.sync_all()
self._current_step = value
self._reset_offload_state()
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling self.sync_all() every time current_step is updated might introduce significant pipeline stalls, especially in autoregressive decoding where this happens multiple times per chunk. Consider if a more granular synchronization (e.g., using events per step) could achieve the same safety with less performance impact.

# chunk, which is where the overlap is gained.
if self._prev_done is not None:
t0 = time.perf_counter()
self._prev_done.synchronize()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of self._prev_done.synchronize() on the CPU side partially defeats the purpose of asynchronous execution if the DiT work for the next chunk is faster than the VAE decode. However, given the stateful nature of the Wan VAE decoder mentioned in the comments, this might be a necessary constraint. If the decoder can be made stateless or if multiple decoder instances can be used, this bottleneck could be removed.

@helloyongyang helloyongyang force-pushed the main branch 2 times, most recently from 0c683c3 to 9861558 Compare May 25, 2026 06:36
@gushiqiao gushiqiao changed the title support seko ar model Add support and optimizations for Seko AR model May 25, 2026
@gushiqiao gushiqiao merged commit 4721f77 into main May 25, 2026
2 checks passed
@gushiqiao gushiqiao deleted the gsq/dev-seko branch May 25, 2026 08:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants