fix linear att cpu cache offload load speed#1317
Conversation
There was a problem hiding this comment.
Code Review
This pull request optimizes the Triton kernel for linear attention CPU cache copying by implementing 128-bit alignment using torch.uint64 and refining the parallelization strategy for better performance. It also updates stream synchronization to use non-blocking wait calls and refactors the multi-turn benchmark tool from an asynchronous architecture to a threaded implementation using requests and ThreadPoolExecutor. The reviewer feedback identifies several improvement opportunities, including the removal of dead commented-out code, the correction of an inaccurate comment in the new speed test script, and the restoration of robust error handling in the benchmark tool to prevent the entire process from crashing due to individual request failures.
| # TODO fa3 现在必须使用同步模式, 未来需要移除 | ||
| g_infer_context.get_overlap_stream().synchronize() | ||
| torch.cuda.current_stream().wait_stream(g_infer_context.get_overlap_stream()) | ||
| # g_infer_context.get_overlap_stream().synchronize() |
| if response.status_code != 200: | ||
| err = response.text | ||
| raise RuntimeError(f"stream_one_turn failed: status={response.status_code}, body={err[:200]}") |
There was a problem hiding this comment.
Changing the error handling to raise a RuntimeError makes the benchmark tool fragile. A single request failure will now crash the entire benchmark process. It is better to log the error and return None to allow other concurrent sessions to continue, maintaining the previous robust behavior.
| if response.status_code != 200: | |
| err = response.text | |
| raise RuntimeError(f"stream_one_turn failed: status={response.status_code}, body={err[:200]}") | |
| if response.status_code != 200: | |
| err = response.text | |
| print(f"\n[turn failed] status={response.status_code} body={err[:200]}") | |
| return None |
| if first_token_time is None: | ||
| return None | ||
| raise RuntimeError("stream_one_turn failed: no token received from stream") |
There was a problem hiding this comment.
Raising a RuntimeError here will terminate the benchmark session abruptly. Consider returning None to allow the benchmark to continue with other sessions.
| if first_token_time is None: | |
| return None | |
| raise RuntimeError("stream_one_turn failed: no token received from stream") | |
| if first_token_time is None: | |
| print("\n[turn failed] no token received from stream") | |
| return None |
| # --------------------------------------------------------------------------- | ||
| # Step 2 – derive sizes from the config | ||
| # --------------------------------------------------------------------------- | ||
| big_page_token_num = _env_args["cpu_cache_token_page_size"] # 512 |
There was a problem hiding this comment.
| # conv_shape = linear_config.get_conv_state_shape() | ||
| # cpu_kv_conv_state = torch.empty( | ||
| # (buffer_count, linear_config.linear_layer_num, *conv_shape), | ||
| # dtype=linear_config.conv_state_dtype, device="cuda", | ||
| # ) | ||
|
|
||
| # ssm_shape = linear_config.get_ssm_state_shape() # (num_linear_v_heads, head_linear_k_dim, head_linear_v_dim) | ||
| # cpu_kv_ssm_state = torch.empty( | ||
| # (buffer_count, linear_config.linear_layer_num, *ssm_shape), | ||
| # dtype=linear_config.ssm_state_dtype, device="cuda", | ||
| # ) |
No description provided.