diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index e371338e904..fb64b27b49e 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -356,6 +356,71 @@ def linear_q8ta_q8csw( lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd") qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name) +################## +## q8ta_linear ## +################## + + +def q8ta_linear( + x: torch.Tensor, + input_scale: float, + input_zero_point: int, + weights: torch.Tensor, + weight_sums: torch.Tensor, + weight_scales: torch.Tensor, + output_scale: float, + output_zero_point: int, + bias: Optional[torch.Tensor] = None, + activation: str = "none", +): + weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32) + weights = torch.ops.quantized_decomposed.dequantize_per_channel( + weights, + weight_scales, + weight_zeros, + 0, + -127, + 127, + torch.int8, + ) + + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, input_scale, input_zero_point, -128, 127, x.dtype + ) + + out = torch.nn.functional.linear(x, weights) + if bias is not None: + out = out + bias + + if activation == "relu": + out = torch.nn.functional.relu(out) + + out = torch.ops.quantized_decomposed.quantize_per_tensor( + out, output_scale, output_zero_point, -128, 127, torch.int8 + ) + + return out + + +name = "q8ta_linear" +lib.define( + f""" + {name}( + Tensor x, + float input_scale, + int input_zero_point, + Tensor weights, + Tensor weight_sums, + Tensor weight_scales, + float output_scale, + int output_zero_point, + Tensor? bias = None, + str activation = "none") -> Tensor + """ +) +lib.impl(name, q8ta_linear, "CompositeExplicitAutograd") +q8ta_linear_op = getattr(getattr(torch.ops, namespace), name) + ################### ## q8ta_conv2d_* ## ################### diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 853ba5d3777..48fac18bc56 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -830,6 +830,34 @@ def register_q8ta_conv2d_ops(): ) +# ============================================================================= +# Q8taLinear.cpp +# ============================================================================= + + +@update_features(exir_ops.edge.et_vk.q8ta_linear.default) +def register_q8ta_linear(): + return OpFeatures( + inputs_storage=[ + utils.PACKED_INT8_4H4W_BUFFER, # input + utils.NO_STORAGE, # input_scale (non tensor) + utils.NO_STORAGE, # input_zero_point (non tensor) + utils.NO_STORAGE, # weight (prepacked) + utils.NO_STORAGE, # weight_sums (prepacked) + utils.NO_STORAGE, # weight_scales (prepacked) + utils.NO_STORAGE, # output_scale (non tensor) + utils.NO_STORAGE, # output_zero_point (non tensor) + utils.NO_STORAGE, # bias (prepacked) + utils.NO_STORAGE, # activation (non tensor) + ], + outputs_storage=[ + utils.PACKED_INT8_4H4W_BUFFER, + ], + supports_resize=False, + supports_prepacking=True, + ) + + # ============================================================================= # SDPA.cpp # ============================================================================= diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index 374e29c634d..fefad0eaf8a 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -31,7 +31,7 @@ class QuantizedLinearMatch(PatternMatch): - def __init__(self, mm_node: torch.fx.Node) -> None: + def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901 self.anchor_node = mm_node self.match_found = False self.all_nodes = [self.anchor_node] @@ -111,10 +111,17 @@ def __init__(self, mm_node: torch.fx.Node) -> None: self.bias_node = None if self.anchor_node.target == exir_ops.edge.aten.addmm.default: self.bias_node, arg_chain = utils.trace_args_until_placeholder( - self.anchor_node.args[2] + self.anchor_node.args[0] ) assert self.bias_node is not None self.all_nodes.extend(arg_chain) + elif self.anchor_node.target == exir_ops.edge.aten.linear.default: + if len(self.anchor_node.args) > 2 and self.anchor_node.args[2] is not None: + self.bias_node, arg_chain = utils.trace_args_until_placeholder( + self.anchor_node.args[2] + ) + if self.bias_node is not None: + self.all_nodes.extend(arg_chain) # If input is not quantized, then we are done if self.quantize_input_node is None: @@ -143,11 +150,36 @@ def __init__(self, mm_node: torch.fx.Node) -> None: ] ) + # Check if the output is also quantized (q → dq → linear → q pattern) + # Also handle fused linear+relu (q → dq → linear → relu → q pattern) + self.quantize_output_node = None + self.output_scales_node = None + self.output_zeros_node = None + self.relu_node = None + if len(self.output_node.users) == 1: + cur_node = list(self.output_node.users)[0] + if cur_node.target == exir_ops.edge.aten.relu.default: + self.relu_node = cur_node + if len(cur_node.users) == 1: + cur_node = list(cur_node.users)[0] + else: + cur_node = None + if cur_node is not None and utils.is_quant_node(cur_node): + self.quantize_output_node = cur_node + self.output_scales_node = self.quantize_output_node.args[1] + self.output_zeros_node = self.quantize_output_node.args[2] + self.match_found = True def is_weight_only_quantized(self) -> bool: return self.quantize_input_node is None + def has_output_quantization(self) -> bool: + return ( + hasattr(self, "quantize_output_node") + and self.quantize_output_node is not None + ) + def is_weight_pergroup_quantized(self) -> bool: weight_shape = self.weight_node.meta["val"].shape scales_shape = self.weight_scales_node.meta["val"].shape @@ -454,6 +486,49 @@ def make_linear_q8ta_q8csw_custom_op( match.output_node.replace_all_uses_with(qlinear_node) +def make_q8ta_linear_custom_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedLinearMatch, + weight_tensor: torch.Tensor, +): + first_graph_node = list(graph_module.graph.nodes)[0] + with graph_module.graph.inserting_before(first_graph_node): + weight_tensor_name = utils.get_tensor_name(ep, match.weight_node) + sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous() + sums_name = weight_tensor_name + "_sums" + sums_name = sums_name.replace(".", "_") + + weight_sums_node = create_constant_placeholder( + exp_program=ep, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=sums_name, + data=sum_per_output_channel, + ) + + with graph_module.graph.inserting_before(match.output_node): + qlinear_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.q8ta_linear.default, + args=( + match.quantize_input_node, + match.input_scales_node, + match.input_zeros_node, + match.weight_node, + weight_sums_node, + match.weight_scales_node, + match.output_scales_node, + match.output_zeros_node, + match.bias_node, + "relu" if match.relu_node is not None else "none", + ), + ) + + qlinear_node.meta["val"] = match.quantize_output_node.meta["val"] + match.quantize_output_node.replace_all_uses_with(qlinear_node) + + @register_pattern_replacement("quantized_linear") def replace_quantized_linear_patterns( ep: ExportedProgram, @@ -472,11 +547,20 @@ def replace_quantized_linear_patterns( weight_zeros_tensor = get_param_tensor(ep, match.weight_zeros_node) assert weight_zeros_tensor is not None - # Biases not supported at the moment + # Route to appropriate custom op. + # q8ta_linear supports bias, so check it first before the bias guard. + if ( + match.is_input_static_per_tensor_quantized() + and match.is_weight_perchannel_quantized() + and match.has_output_quantization() + ): + make_q8ta_linear_custom_op(ep, graph_module, match, weight_tensor) + return + + # Remaining ops do not support bias if match.bias_node is not None: return - # Route to appropriate custom op if ( match.is_weight_only_quantized() and match.is_weight_pergroup_quantized() diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.glsl new file mode 100644 index 00000000000..87a3d539297 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.glsl @@ -0,0 +1,160 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("buffer", DTYPE)} + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_integer_dot_product : require + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T int + +#define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_spec_const(C, "int", "apply_bias", "0")} +${layout_declare_spec_const(C, "int", "activation_type", "0")} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "common.glslh" +#include "linear_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_bias_load.glslh" + +void main() { + const int out_tile_x = int(gl_GlobalInvocationID.x); + const int out_tile_y = int(gl_GlobalInvocationID.y); + + const int n = out_tile_x * TILE_N; + const int m = out_tile_y * TILE_M; + + const int n4 = div_4(n); + const int m4 = div_4(m); + + if (n >= output_sizes.x || m >= output_sizes.y) { + return; + } + + const int M = output_sizes.y; + const int K4 = div_up_4(input_sizes.x); + const int N4 = div_up_4(output_sizes.x); + + Int32Accum out_accum; + initialize(out_accum); + + Int8InputTile int8_in_tile; + Int8WeightTile int8_weight_tile; + + for (int k4 = 0; k4 < K4; k4 += TILE_K4) { + load_int8_input_tile(int8_in_tile, k4, m4, K4); + load_int8_weight_tile(int8_weight_tile, n4, k4, N4); + + int_accumulate_with_int8_weight(out_accum, int8_in_tile, int8_weight_tile); + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, n4); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, n4); + + FPOutTile out_tile; + initialize(out_tile); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } + else { + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile); + } + + // Apply ReLU if enabled + if (activation_type > 0) { + [[unroll]] for (int tile_m = 0; tile_m < TILE_M; ++tile_m) { + [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { + out_tile.data[tile_m][tile_n4] = max(out_tile.data[tile_m][tile_n4], vec4(0.0)); + } + } + } + + // Quantize float output tile to int8 and write in PACKED_INT8_4H4W format + const int M4 = div_up_4(M); + + [[unroll]] for (int tile_m4 = 0; tile_m4 < TILE_M4; ++tile_m4) { + if (m4 + tile_m4 >= M4) { + break; + } + [[unroll]] for (int tile_n4 = 0; tile_n4 < TILE_N4; ++tile_n4) { + if (n4 + tile_n4 >= N4) { + break; + } + ivec4 packed_block; + [[unroll]] for (int i = 0; i < 4; ++i) { + const int tile_m = tile_m4 * 4 + i; + if (m + tile_m < M) { + packed_block[i] = quantize_and_pack( + out_tile.data[tile_m][tile_n4], output_inv_scale, output_zp); + } else { + packed_block[i] = 0; + } + } + t_packed_int8_output[(m4 + tile_m4) * N4 + n4 + tile_n4] = packed_block; + } + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.yaml new file mode 100644 index 00000000000..c7836c60477 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_linear.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_linear: + parameter_names_with_default_values: + DTYPE: float + WEIGHT_STORAGE: texture2d + TILE_M4: 1 + TILE_N4: 2 + TILE_K4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: q8ta_linear diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taLinear.cpp new file mode 100644 index 00000000000..45366fbf044 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taLinear.cpp @@ -0,0 +1,207 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +namespace vkcompute { + +bool q8ta_linear_check_packed_dim_info(const api::PackedDimInfo& info) { + return info.packed_dim == WHCN::kWidthDim && + info.packed_dim_block_size == 4 && + info.outer_packed_dim == WHCN::kHeightDim && + info.outer_packed_dim_block_size == 4; +} + +// +// Workgroup size selection +// + +utils::uvec3 q8ta_linear_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef out = args.at(0).refs.at(0); + + std::vector out_sizes = graph->sizes_of(out); + const uint32_t N = utils::val_at(-1, out_sizes); + const uint32_t M = utils::val_at(-2, out_sizes); + + // Each output tile contains 8 columns (TILE_N4=2 -> 8 output channels) + const uint32_t N_per_tile = 8; + const uint32_t M_per_tile = 4; + + const uint32_t num_N_tiles = utils::div_up(N, N_per_tile); + const uint32_t num_M_tiles = utils::div_up(M, M_per_tile); + + return {num_N_tiles, num_M_tiles, 1}; +} + +utils::uvec3 q8ta_linear_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + return pick_hw_square_wg_size( + graph, shader, global_workgroup_size, args, resize_args); +} + +// +// Dispatch node +// + +void add_q8ta_linear_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const uint32_t activation_type, + const ValueRef packed_int8_output) { + // Validate packed dim info matches 4H4W layout + VK_CHECK_COND(q8ta_linear_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_input))); + VK_CHECK_COND(q8ta_linear_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_output))); + + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + std::string kernel_name = "q8ta_linear"; + add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(packed_int8_output), graph.sizes_ubo(packed_int8_input)}; + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + q8ta_linear_global_wg_size, + q8ta_linear_local_wg_size, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, + {{packed_int8_input, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias, activation_type}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +// +// High level operator impl +// + +void q8ta_linear(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef activation = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + const int64_t K = graph.size_at(-1, packed_int8_input); + VK_CHECK_COND(K % 4 == 0); + + QuantizationConfig weight_quant_config(8, kPerChannel, {K}); + + // Prepack weight data + const ValueRef packed_weight = + prepack_quantized_linear_weight(graph, weight_quant_config, weight_data); + const ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + const ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + // Prepack bias data + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(packed_weight_scales), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + + add_q8ta_linear_node( + graph, + packed_int8_input, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + activation_type_val, + packed_int8_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.q8ta_linear.default, q8ta_linear); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taLinear.h b/backends/vulkan/runtime/graph/ops/impl/Q8taLinear.h new file mode 100644 index 00000000000..9f975525324 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taLinear.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace vkcompute { + +void add_q8ta_linear_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const uint32_t activation_type, + const ValueRef packed_int8_output); + +} // namespace vkcompute diff --git a/backends/vulkan/test/TARGETS b/backends/vulkan/test/TARGETS index ee9021768b6..7517f7d66f3 100644 --- a/backends/vulkan/test/TARGETS +++ b/backends/vulkan/test/TARGETS @@ -35,6 +35,7 @@ python_unittest( "//caffe2:torch", "//executorch/backends/vulkan/_passes:vulkan_passes", "//executorch/backends/vulkan:vulkan_preprocess", + "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", "//pytorch/ao:torchao", # @manual ] ) diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taLinear.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taLinear.cpp new file mode 100644 index 00000000000..d0803fe746b --- /dev/null +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taLinear.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +void test_q8ta_linear(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef activation = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + // Create temporary packed int8 tensors for input and output + // Input uses 4H4W layout to match the linear shader's ivec4 reading pattern + // where each ivec4 contains data from 4 rows + TmpTensor packed_int8_input( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4H4W); + + // Output uses 4H4W layout to match the linear shader's ivec4 writing pattern + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_output), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4H4W); + + // Quantize floating point input to packed int8 + add_q8ta_quantize_node( + graph, fp_input, input_scale, input_zp, packed_int8_input); + + // Call the q8ta_linear operator + std::vector linear_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + activation, + packed_int8_output}; + VK_GET_OP_FN("et_vk.q8ta_linear.default")(graph, linear_args); + + // Dequantize packed int8 output to floating point + add_q8ta_dequantize_node( + graph, packed_int8_output, output_scale, output_zp, fp_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(test_etvk.test_q8ta_linear.default, test_q8ta_linear); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 73b1e343bbe..badba5666fa 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -97,3 +97,4 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("test_q8ta_conv2d") define_custom_op_test_binary("test_q8ta_conv2d_pw") define_custom_op_test_binary("test_q8ta_conv2d_dw") + define_custom_op_test_binary("test_q8ta_linear") diff --git a/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp b/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp new file mode 100644 index 00000000000..faec638059c --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_q8ta_linear.cpp @@ -0,0 +1,335 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include +#include + +#include + +#include "utils.h" + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 300; + +struct LinearConfig { + int64_t M; + int64_t K; + int64_t N; + bool has_bias = true; + std::string test_case_name = "placeholder"; +}; + +static TestCase create_test_case_from_config( + const LinearConfig& config, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = config.test_case_name + "_Buffer_" + dtype_str; + test_case.set_name(test_name); + + test_case.set_operator_name("test_etvk.test_q8ta_linear.default"); + + std::vector input_size = {config.M, config.K}; + std::vector weight_size = {config.N, config.K}; + + // Input tensor (float) - [M, K] + ValueSpec input_tensor( + input_size, + input_dtype, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::RANDOM); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + float input_scale_val = 0.008f; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = -2; + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) - [N, K] + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + // Weight quantization scales (float, per-channel) + ValueSpec weight_scales( + {config.N}, + input_dtype, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {config.N}, + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + // Compute weight_sums data based on quantized weights + compute_weight_sums(weight_sums, quantized_weight, config.N, config.K); + + // Output quantization parameters + float output_scale_val = 0.05314f; + ValueSpec output_scale(output_scale_val); + + int32_t output_zero_point_val = -1; + ValueSpec output_zero_point(output_zero_point_val); + + // Bias (optional, float) - [N] + ValueSpec bias( + {config.N}, + input_dtype, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::RANDOM); + bias.set_constant(true); + if (!config.has_bias) { + bias.set_none(true); + } + + // Output tensor (float) - [M, N] + ValueSpec output( + {config.M, config.N}, + input_dtype, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::ZEROS); + + // Add all specs to test case + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zero_point); + test_case.add_input_spec(bias); + + // Activation (none = no activation) + ValueSpec activation = ValueSpec::make_string("none"); + test_case.add_input_spec(activation); + + test_case.add_output_spec(output); + + test_case.set_abs_tolerance(output_scale_val + 1e-4f); + + // Filter out quantize/dequantize shaders from timing measurements + test_case.set_shader_filter({ + "nchw_to", + "to_nchw", + "q8ta_quantize", + "q8ta_dequantize", + }); + + return test_case; +} + +// Generate test cases for q8ta_linear operation +static std::vector generate_q8ta_linear_test_cases() { + std::vector test_cases; + if (!vkcompute::api::context()->adapter_ptr()->supports_int8_dot_product()) { + return test_cases; + } + + std::vector configs = { + {4, 64, 32}, + {4, 128, 64}, + {4, 256, 128}, + {32, 64, 32}, + {32, 128, 64}, + {32, 256, 128}, + // No bias tests + {32, 128, 64, false}, + {32, 256, 128, false}, + // Performance cases + {256, 2048, 2048}, + {512, 2048, 2048}, + {1024, 2048, 2048}, + }; + + for (auto config : configs) { + bool is_performance = config.M >= kRefDimSizeLimit || + config.K >= kRefDimSizeLimit || config.N >= kRefDimSizeLimit; + + std::string prefix = is_performance ? "performance_" : "correctness_"; + std::string generated_test_case_name = prefix + std::to_string(config.M) + + "_" + std::to_string(config.K) + "_" + std::to_string(config.N); + if (!config.has_bias) { + generated_test_case_name += "_no_bias"; + } + + config.test_case_name = generated_test_case_name; + + test_cases.push_back(create_test_case_from_config(config, vkapi::kFloat)); + } + + return test_cases; +} + +// Reference implementation for q8ta_linear (activation+weight+output quantized) +static void q8ta_linear_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + + ValueSpec& output_spec = test_case.outputs()[0]; + + auto input_sizes = input_spec.get_tensor_sizes(); + auto weight_sizes = weight_spec.get_tensor_sizes(); + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = weight_sizes[0]; + + if (batch_size > kRefDimSizeLimit || in_features > kRefDimSizeLimit || + out_features > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + const float output_scale = output_scale_spec.get_float_value(); + const int32_t output_zero_point = output_zeros_spec.get_int_value(); + + int64_t num_output_elements = batch_size * out_features; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t out_f = 0; out_f < out_features; ++out_f) { + int32_t int_sum = 0; + int32_t weight_sum = 0; + + for (int64_t in_f = 0; in_f < in_features; ++in_f) { + int64_t input_idx = b * in_features + in_f; + + float quant_input_f = + std::round(input_data[input_idx] / input_scale) + input_zero_point; + quant_input_f = std::min(std::max(quant_input_f, -128.0f), 127.0f); + int8_t quantized_input = static_cast(quant_input_f); + + int64_t weight_idx = out_f * in_features + in_f; + int8_t quantized_weight = weight_data[weight_idx]; + + int_sum += static_cast(quantized_input) * + static_cast(quantized_weight); + + weight_sum += static_cast(quantized_weight); + } + + int32_t zero_point_correction = input_zero_point * weight_sum; + int32_t accum_adjusted = int_sum - zero_point_correction; + + float float_result = + accum_adjusted * input_scale * weight_scales_data[out_f]; + + if (!bias_spec.is_none()) { + float_result += bias_data[out_f]; + } + + // Quantize the output to int8 + float quant_output_f = + std::round(float_result / output_scale) + output_zero_point; + quant_output_f = std::min(std::max(quant_output_f, -128.0f), 127.0f); + int8_t quantized_output = static_cast(quant_output_f); + + // Dequantize back to float (this is what the test wrapper does) + float dequant_output = + (static_cast(quantized_output) - output_zero_point) * + output_scale; + + int64_t output_idx = b * out_features + out_f; + ref_data[output_idx] = dequant_output; + } + } +} + +static void reference_impl(TestCase& test_case) { + q8ta_linear_reference_impl(test_case); +} + +static int64_t q8ta_linear_flop_calculator(const TestCase& test_case) { + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& weight_sizes = test_case.inputs()[3].get_tensor_sizes(); + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = weight_sizes[0]; + + int64_t output_elements = batch_size * out_features; + int64_t ops_per_output = in_features; + + int64_t flop = output_elements * ops_per_output; + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Q8ta Linear Operation Prototyping Framework" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + auto results = execute_test_cases( + generate_q8ta_linear_test_cases, + q8ta_linear_flop_calculator, + "Q8taLinear", + 3, + 10, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 2c0bc12b7cc..7c9f31b720c 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -2364,6 +2364,7 @@ def apply_quantization(self): quantized_linear_module_gemm, sample_inputs_gemm, atol=1e-2, rtol=1e-2 ) + @unittest.skip("Cannot run on swiftshader due to no integer dot product support") def test_vulkan_backend_xnnpack_pt2e_quantized_linear_sequence(self): """ Test a sequence of linear layers quantized with XNNPACK quantization config. diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index 438126a179f..bbab1535954 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -191,3 +191,49 @@ def _reshape_for_broadcast(self, freqs_cis: torch.Tensor, x: torch.Tensor): # We expect at least one custom op to be created self.assertGreater(custom_op_count, 0) + + def test_fuse_q8ta_linear(self): + """Test that sequential quantized linears fuse into q8ta_linear when output quantization is present.""" + from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, + ) + + class TwoLinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(128, 64, bias=False) + self.linear2 = torch.nn.Linear(64, 32, bias=False) + + def forward(self, x): + return self.linear2(self.linear1(x)) + + model = TwoLinearModule() + sample_inputs = (torch.randn(4, 128),) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + quantizer.set_global(operator_config) + + edge_program = quantize_and_lower_module(model, sample_inputs, quantizer) + + ep = edge_program._edge_programs["forward"] + fuse_pass = FusePatternsPass() + fuse_pass._exported_program = ep + result = fuse_pass.call(ep.graph_module) + + self.assertTrue(result.modified) + + gm = ep.graph_module + + # The first linear should fuse to q8ta_linear (has output quantization + # from the second linear's input quantize node) + q8ta_linear_count = op_node_count(gm, "q8ta_linear.default") + self.assertGreaterEqual( + q8ta_linear_count, + 1, + "Expected at least one q8ta_linear op from output-quantized linear fusion", + ) diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index b276ffd16f5..db1211883c7 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -164,6 +164,7 @@ def preprocess( # noqa: C901 [ AddmmToLinearTransform(), FuseBatchNormPass(program), + AddmmToLinearTransform(), FusePatternsPass(), FuseClampPass(), RemoveRedundantOpsTransform(),