Add support and optimizations for Seko AR model#1091
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for autoregressive (AR) audio-driven video generation via the "seko_talk_ar" model and runner. Key technical additions include a ring-buffer based RollingKVCachePool with CPU offloading, specialized CUDA and Triton kernels for FP4 and int4/int8 KV cache quantization (OScaR), and an asynchronous VAE chunk decoder to improve throughput. Feedback focuses on optimizing the new offloading and synchronization logic, specifically recommending the use of record_stream for safer asynchronous memory handling and suggesting more granular synchronization to avoid pipeline stalls during step updates. The reviewer also noted that the stateful nature of the VAE decoder currently limits the potential gains from asynchronous execution.
| k_cpu[phys_s : phys_s + n].copy_(k[ks:ke], non_blocking=True) | ||
| v_cpu[phys_s : phys_s + n].copy_(v[ks:ke], non_blocking=True) |
There was a problem hiding this comment.
When performing a non-blocking D2H copy from a GPU tensor that might be reused or freed immediately after store_kv returns, it is safer to use record_stream on the source tensor to ensure the allocator doesn't reclaim the memory before the copy is complete. While PyTorch's stream semantics often handle this if the copy is on the current stream, explicit record_stream is defensive best practice for asynchronous offloading.
| k_cpu[phys_s : phys_s + n].copy_(k[ks:ke], non_blocking=True) | |
| v_cpu[phys_s : phys_s + n].copy_(v[ks:ke], non_blocking=True) | |
| k_cpu[phys_s : phys_s + n].copy_(k[ks:ke], non_blocking=True) | |
| v_cpu[phys_s : phys_s + n].copy_(v[ks:ke], non_blocking=True) | |
| k.record_stream(torch.cuda.current_stream()) | |
| v.record_stream(torch.cuda.current_stream()) |
| if self._kv_offload and hasattr(self, "_prefetch_stream"): | ||
| self.sync_all() | ||
| self._current_step = value | ||
| self._reset_offload_state() | ||
| return |
There was a problem hiding this comment.
Calling self.sync_all() every time current_step is updated might introduce significant pipeline stalls, especially in autoregressive decoding where this happens multiple times per chunk. Consider if a more granular synchronization (e.g., using events per step) could achieve the same safety with less performance impact.
| # chunk, which is where the overlap is gained. | ||
| if self._prev_done is not None: | ||
| t0 = time.perf_counter() | ||
| self._prev_done.synchronize() |
There was a problem hiding this comment.
The use of self._prev_done.synchronize() on the CPU side partially defeats the purpose of asynchronous execution if the DiT work for the next chunk is faster than the VAE decode. However, given the stateful nature of the Wan VAE decoder mentioned in the comments, this might be a necessary constraint. If the decoder can be made stateless or if multiple decoder instances can be used, this bottleneck could be removed.
0c683c3 to
9861558
Compare
Summary
This PR adds and optimizes KV cache support for Seko autoregressive inference, with a focus on making the Seko AR model practical on consumer GPUs.
Key changes include:
[sink | chunk1 | chunk2 | chunk3]into[sink | chunk2 | chunk3 | chunk4], the cache can now keep a ring layout such as[sink | chunk4 | chunk2 | chunk3], reducing memory movement.