diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index f2e4482c9b9..e371338e904 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -376,6 +376,7 @@ def q8ta_conv2d( padding: list, dilation: list, groups: int, + activation: str, ): x = torch.ops.quantized_decomposed.dequantize_per_tensor( x, input_scale, input_zero_point, -128, 127, x.dtype @@ -418,6 +419,9 @@ def q8ta_conv2d( x, weights, bias, stride, padding, dilation, groups ) + if activation == "relu": + out = torch.nn.functional.relu(out) + out = torch.ops.quantized_decomposed.quantize_per_tensor( out, output_scale, output_zero_point, -128, 127, torch.int8 ) @@ -442,7 +446,8 @@ def q8ta_conv2d( SymInt[] stride, SymInt[] padding, SymInt[] dilation, - SymInt groups) -> Tensor + SymInt groups, + str activation) -> Tensor """ ) lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd") @@ -466,7 +471,8 @@ def q8ta_conv2d( SymInt[] stride, SymInt[] padding, SymInt[] dilation, - SymInt groups) -> Tensor + SymInt groups, + str activation) -> Tensor """ ) lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd") @@ -488,6 +494,7 @@ def q8ta_conv2d_dw( padding: list, dilation: list, groups: int, + activation: str, ): x = torch.ops.quantized_decomposed.dequantize_per_tensor( x, input_scale, input_zero_point, -128, 127, x.dtype @@ -514,6 +521,9 @@ def q8ta_conv2d_dw( x, weights, bias, stride, padding, dilation, groups ) + if activation == "relu": + out = torch.nn.functional.relu(out) + out = torch.ops.quantized_decomposed.quantize_per_tensor( out, output_scale, output_zero_point, -128, 127, torch.int8 ) @@ -538,7 +548,8 @@ def q8ta_conv2d_dw( SymInt[] stride, SymInt[] padding, SymInt[] dilation, - SymInt groups) -> Tensor + SymInt groups, + str activation) -> Tensor """ ) lib.impl(name, q8ta_conv2d_dw, "CompositeExplicitAutograd") @@ -605,6 +616,41 @@ def q8ta_add_impl( lib.impl(name, q8ta_add_impl, "CompositeExplicitAutograd") q8ta_add_op = getattr(getattr(torch.ops, namespace), name) +######################## +## q8ta_relu ## +######################## + + +def q8ta_relu_impl( + input: torch.Tensor, + input_scale: float, + input_zero_point: int, + output_scale: float, + output_zero_point: int, +): + # Dequantize input to float + dequant = torch.ops.quantized_decomposed.dequantize_per_tensor( + input, input_scale, input_zero_point, -128, 127, input.dtype + ) + + # Apply ReLU + result = torch.nn.functional.relu(dequant) + + # Quantize the result back to int8 + quantized_result = torch.ops.quantized_decomposed.quantize_per_tensor( + result, output_scale, output_zero_point, -128, 127, torch.int8 + ) + + return quantized_result + + +name = "q8ta_relu" +lib.define( + f"{name}(Tensor input, float input_scale, int input_zero_point, float output_scale, int output_zero_point) -> Tensor" +) +lib.impl(name, q8ta_relu_impl, "CompositeExplicitAutograd") +q8ta_relu_op = getattr(getattr(torch.ops, namespace), name) + ############################# ## select_as_symint ## ############################# diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 55a92335bc7..853ba5d3777 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -514,7 +514,19 @@ def register_q8ta_add(): # ============================================================================= -# Reduce.cpp +# Q8taUnary.cpp +# ============================================================================= + + +@update_features(exir_ops.edge.et_vk.q8ta_relu.default) +def register_q8ta_relu(): + return OpFeatures( + inputs_storage=utils.PACKED_INT8_BUFFER, + supports_resize=True, + ) + + +# ============================================================================= # ============================================================================= @@ -1221,25 +1233,11 @@ def register_embedding(): @update_features(exir_ops.edge.aten._native_batch_norm_legit_no_training.default) def register_native_batch_norm_legit_no_training(): - def check_batch_norm_node(node: torch.fx.Node) -> bool: - x = node.args[0] - if not isinstance(x, torch.fx.Node): - return False - x_val = x.meta.get("val", None) - if x_val is None: - return False - x_shape = x_val.size() - # Only support 4-D input tensors since this is a restriction enforced by the - # operator implementation. - # TODO(ssjia): Add shape agnostic support for batch norm - return len(x_shape) == 4 - return OpFeatures( inputs_storage=utils.CHANNELS_PACKED_TEXTURE, inputs_dtypes=utils.FP_T, supports_prepacking=True, supports_resize=True, - are_node_inputs_supported_fn=check_batch_norm_node, ) diff --git a/backends/vulkan/patterns/BUCK b/backends/vulkan/patterns/BUCK index a7153b30967..711000f74ca 100644 --- a/backends/vulkan/patterns/BUCK +++ b/backends/vulkan/patterns/BUCK @@ -13,6 +13,7 @@ fbcode_target(_kind = runtime.python_library, "quantized_linear.py", "quantized_convolution.py", "quantized_binary.py", + "quantized_unary.py", "sdpa.py", "select_as_symint.py", ], diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index 9b875def944..050680b024d 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -12,6 +12,8 @@ import executorch.backends.vulkan.patterns.quantized_linear # noqa +import executorch.backends.vulkan.patterns.quantized_unary # noqa + import executorch.backends.vulkan.patterns.rope # noqa import executorch.backends.vulkan.patterns.sdpa # noqa diff --git a/backends/vulkan/patterns/quantized_convolution.py b/backends/vulkan/patterns/quantized_convolution.py index 93140e15341..12ebbd1a382 100644 --- a/backends/vulkan/patterns/quantized_convolution.py +++ b/backends/vulkan/patterns/quantized_convolution.py @@ -226,6 +226,16 @@ def make_q8ta_conv2d_custom_op( sum_per_output_channel = ( weight_tensor.sum(dim=1).to(torch.int32).contiguous() ) + # Pad weight sums to align OC to multiple of 4, matching the alignment + # applied to weight, weight_scales, and bias above. Without this, the + # GPU shader would read out-of-bounds when OC is not a multiple of 4. + oc = sum_per_output_channel.shape[0] + if oc % 4 != 0: + num_padding = 4 - (oc % 4) + sum_per_output_channel = torch.nn.functional.pad( + sum_per_output_channel, (0, num_padding) + ).contiguous() + sums_name = qweight_tensor_name + "_sums" # Sanitize the name sums_name = sums_name.replace(".", "_") @@ -271,6 +281,7 @@ def make_q8ta_conv2d_custom_op( match.padding, match.dilation, match.groups, + "relu" if match.relu_node is not None else "none", ), ) diff --git a/backends/vulkan/patterns/quantized_unary.py b/backends/vulkan/patterns/quantized_unary.py new file mode 100644 index 00000000000..28dc84b7997 --- /dev/null +++ b/backends/vulkan/patterns/quantized_unary.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import executorch.backends.vulkan.utils as utils + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops + + +class QuantizedUnaryMatch(PatternMatch): + def __init__(self, unary_node: torch.fx.Node) -> None: + self.anchor_node = unary_node + self.match_found = False + self.all_nodes = [self.anchor_node] + + # The unary op takes a single input which must be a dequantize node + if len(unary_node.args) < 1: + return + + input_node = unary_node.args[0] + assert isinstance(input_node, torch.fx.Node) + + if not utils.is_dequant_node(input_node): + return + + self.dequantize_input_node = input_node + + # Extract quantization parameters for the input + self.quantize_input_node = self.dequantize_input_node.args[0] + self.input_scales_node = self.dequantize_input_node.args[1] + self.input_zeros_node = self.dequantize_input_node.args[2] + + self.all_nodes.append(self.dequantize_input_node) + + # The unary op output must have exactly one user: a quantize node + self.output_node = self.anchor_node + + if len(self.output_node.users) != 1: + return + + cur_node = list(self.output_node.users)[0] + + if not utils.is_quant_node(cur_node): + return + + self.quantize_output_node = cur_node + self.output_scales_node = self.quantize_output_node.args[1] + self.output_zeros_node = self.quantize_output_node.args[2] + + self.all_nodes.append(self.quantize_output_node) + + self.match_found = True + + +# Unary operation anchor nodes that we support +unary_anchor_nodes = { + exir_ops.edge.aten.relu.default, +} + + +@register_pattern_detector("quantized_unary") +def find_quantized_unary_patterns( + node: torch.fx.Node, +) -> Optional[QuantizedUnaryMatch]: + if node.target not in unary_anchor_nodes: + return None + + matched_pattern = QuantizedUnaryMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Pattern Replacement +## + + +@register_pattern_replacement("quantized_unary") +def make_q8ta_unary_custom_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedUnaryMatch, +): + op_target = None + if match.anchor_node.target == exir_ops.edge.aten.relu.default: + op_target = exir_ops.edge.et_vk.q8ta_relu.default + else: + raise NotImplementedError( + f"Unsupported unary operation: {match.anchor_node.target}" + ) + + with graph_module.graph.inserting_before(match.output_node): + qunary_node = graph_module.graph.create_node( + "call_function", + op_target, + args=( + match.quantize_input_node, + match.input_scales_node, + match.input_zeros_node, + match.output_scales_node, + match.output_zeros_node, + ), + ) + + qunary_node.meta["val"] = match.output_node.meta["val"] + match.quantize_output_node.replace_all_uses_with(qunary_node) diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl index 60f437fbdce..be93e800436 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl @@ -46,6 +46,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")} ${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "other_layout", "CONTIG_LAYOUT_INT")} ${layout_declare_spec_const(C, "int", "block_config", "0")} // Generate loading functions for input buffers @@ -71,7 +72,7 @@ void main() { ivec4 in_block_a = load_int8x4_block_from_t_in_a( in_a_meta, tidx, in_layout, block_outer_dim); ivec4 in_block_b = load_int8x4_block_from_t_in_b( - in_b_meta, tidx, in_layout, block_outer_dim); + in_b_meta, tidx, other_layout, block_outer_dim); ivec4 out_block; diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl index 623de3a5d9a..d693acbab3f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d.glsl @@ -47,6 +47,7 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "apply_bias", "1")} +${layout_declare_spec_const(C, "int", "activation_type", "0")} // Layout specialization constants ${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} @@ -220,6 +221,13 @@ void main() { } } + // Apply ReLU if enabled + if (activation_type > 0) { + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + facc[subtile_w] = max(facc[subtile_w], vec4(0.0)); + } + } + // Compute base output texel index (for subtile_w=0) const int base_outp_texel_idx = tensor4d_idx_to_texel_idx(outp, outp_tidx, outp_layout); const int out_w_stride = int(outp.strides[0][0]); diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl index e6be92e7ba1..7f4d03887df 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_dw.glsl @@ -44,6 +44,7 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "apply_bias", "1")} +${layout_declare_spec_const(C, "int", "activation_type", "0")} // Layout specialization constants ${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} @@ -197,6 +198,13 @@ void main() { } } + // Apply ReLU if enabled + if (activation_type > 0) { + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + facc[subtile_w] = max(facc[subtile_w], vec4(0.0)); + } + } + // Compute base output texel index (for subtile_w=0) const int base_outp_texel_idx = tensor4d_idx_to_texel_idx(outp, outp_tidx, outp_layout); const int out_w_stride = int(outp.strides[0][0]); diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl index e0963dfcf48..ec41d933114 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl @@ -57,6 +57,7 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "apply_bias", "1")} +${layout_declare_spec_const(C, "int", "activation_type", "0")} ${layout_declare_spec_const(C, "int", "conv2d_params_K4_per_group", "1")} // Layout specialization constants @@ -197,6 +198,10 @@ void main() { fma(vec4(accum_adjusted), vec4(weight_scales[n4]) * input_scale, vec4(bias[n4])); + // Apply ReLU if enabled + if (activation_type > 0) { + float_out_texel = max(float_out_texel, vec4(0.0)); + } // Requantize to int8 float_out_texel = round(float_out_texel * output_inv_scale) + output_zp; @@ -216,6 +221,10 @@ void main() { input_zp_vec * weight_sums[n4] + out_accum[m][n4]; vec4 float_out_texel = vec4(accum_adjusted) * vec4(weight_scales[n4] * input_scale); + // Apply ReLU if enabled + if (activation_type > 0) { + float_out_texel = max(float_out_texel, vec4(0.0)); + } // Requantize to int8 float_out_texel = round(float_out_texel * output_inv_scale) + output_zp; diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.glsl new file mode 100644 index 00000000000..e97d6d47877 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.glsl @@ -0,0 +1,82 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type("buffer")} + +#define op(X) ${OPERATOR} + +layout(std430) buffer; + +#include "indexing.glslh" +#include "common.glslh" +#include "block_indexing.glslh" +#include "block_int8x4_load.glslh" +#include "block_int8x4_store.glslh" + +// Output buffer: packed int8x4 values +${layout_declare_tensor(B, "w", "t_out", "int", "buffer")} +// Input buffer: packed int8x4 values +${layout_declare_tensor(B, "r", "t_in", "int", "buffer")} + +// Metadata for output and input tensors +${layout_declare_ubo(B, "BufferMetadata", "out_meta")} +${layout_declare_ubo(B, "BufferMetadata", "in_meta")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "block_config", "0")} + +// Generate loading functions for input buffer +define_load_int8x4_buffer_fns(t_in) + +// Generate storing functions for output buffer +define_store_int8x4_buffer_fns(t_out) + +void main() { + // Buffer storage: use linear dispatch + const uint contig_block_idx = gl_GlobalInvocationID.x; + TensorIndex4D tidx = contiguous_block_idx_to_tensor4d_idx_with_block_config( + out_meta, contig_block_idx, block_config); + + if (out_of_bounds(tidx, out_meta)) { + return; + } + + const int block_outer_dim = get_block_outer_dim(block_config); + + // Load int8x4 block from input + ivec4 in_block = load_int8x4_block_from_t_in( + in_meta, tidx, in_layout, block_outer_dim); + + ivec4 out_block; + + for (int row = 0; row < 4; row++) { + vec4 in_texel = unpack_and_dequantize( + in_block[row], input_scale, input_zp); + + vec4 out_texel = op(in_texel); + out_block[row] = quantize_and_pack(out_texel, output_inv_scale, output_zp); + } + + // Store to output buffer + store_int8x4_block_to_t_out( + out_meta, tidx, out_layout, block_outer_dim, out_block); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.yaml new file mode 100644 index 00000000000..257f6a44205 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_unary.yaml @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_unary: + parameter_names_with_default_values: + OPERATOR: X + shader_variants: + - NAME: q8ta_relu_buffer + OPERATOR: max(X, vec4(0.0)) diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp index af934b9b521..05bdd9431c8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp @@ -42,6 +42,7 @@ void add_q8ta_binary_node( VK_CHECK_COND(input_a_info.packed_dim == output_info.packed_dim); VK_CHECK_COND(input_b_info.packed_dim == output_info.packed_dim); + VK_CHECK_COND( input_a_info.packed_dim_block_size == output_info.packed_dim_block_size); VK_CHECK_COND( @@ -105,6 +106,7 @@ void add_q8ta_binary_node( // Specialization Constants {graph.hashed_layout_of(packed_int8_output), graph.hashed_layout_of(packed_int8_input_a), + graph.hashed_layout_of(packed_int8_input_b), block_config.as_packed_int()}, // Resize args {block_config_ref}, diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp index 4f047d414f8..33b7005a845 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp @@ -17,6 +17,15 @@ namespace vkcompute { +ActivationType activation_type_from_string(const std::string& activation) { + if (activation == "none") { + return ActivationType::kNone; + } else if (activation == "relu") { + return ActivationType::kRelu; + } + VK_THROW("Unknown activation type: ", activation); +} + bool q8ta_conv2d_check_packed_dim_info(const api::PackedDimInfo& info) { return info.packed_dim == WHCN::kChannelsDim && info.packed_dim_block_size == 4 && @@ -231,6 +240,7 @@ void add_q8ta_conv2d_node( const ValueRef padding, const ValueRef dilation, const ValueRef groups, + const uint32_t activation_type, const ValueRef packed_int8_output) { (void)packed_int8_input_im2col; // Not used in general shader @@ -288,9 +298,10 @@ void add_q8ta_conv2d_node( graph.buffer_meta_ubo(packed_int8_input), graph.create_params_buffer(conv_params)}; - // Build spec constants: apply_bias + layout constants + // Build spec constants: apply_bias, apply_relu + layout constants vkapi::SpecVarList spec_constants = { apply_bias, + activation_type, // Layout specialization constants graph.hashed_layout_of(packed_int8_input), graph.hashed_layout_of(packed_int8_output), @@ -341,8 +352,12 @@ void q8ta_conv2d_general( const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef packed_int8_output = args.at(idx++); + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + QuantizationConfig weight_quant_config(8, kPerChannel, {}); // Prepack weight using the conv2d weight packing for the general shader @@ -397,6 +412,7 @@ void q8ta_conv2d_general( padding, dilation, groups, + activation_type_val, packed_int8_output); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h index 9686c873c1b..2779a7445a8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h @@ -13,6 +13,13 @@ namespace vkcompute { +enum class ActivationType : uint32_t { + kNone = 0, + kRelu = 1, +}; + +ActivationType activation_type_from_string(const std::string& activation); + bool q8ta_conv2d_check_packed_dim_info(const api::PackedDimInfo& info); bool q8ta_conv2d_check_4w4c_packed_dim_info(const api::PackedDimInfo& info); @@ -58,6 +65,7 @@ void add_q8ta_conv2d_dw_node( const ValueRef padding, const ValueRef dilation, const ValueRef groups, + const uint32_t activation_type, const ValueRef packed_int8_output); void add_conv2d_dw_q8ta_q8csw_q8to_4w4c_node( @@ -97,6 +105,7 @@ void add_q8ta_conv2d_node( const ValueRef padding, const ValueRef dilation, const ValueRef groups, + const uint32_t activation_type, const ValueRef packed_int8_output); void add_q8ta_conv2d_pw_node( @@ -111,6 +120,7 @@ void add_q8ta_conv2d_pw_node( const ValueRef output_zp, const ValueRef bias_data, const ValueRef packed_bias, + const uint32_t activation_type, const ValueRef packed_int8_output); void q8ta_conv2d_im2col(ComputeGraph& graph, const std::vector& args); diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp index d12bbc0574a..e690ff435a8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dDW.cpp @@ -281,6 +281,7 @@ void add_q8ta_conv2d_dw_node( const ValueRef padding, const ValueRef dilation, const ValueRef groups, + const uint32_t activation_type, const ValueRef packed_int8_output) { Conv2DParams conv_params = create_conv2d_params( graph, @@ -334,9 +335,10 @@ void add_q8ta_conv2d_dw_node( graph.buffer_meta_ubo(packed_int8_input), graph.create_params_buffer(conv_params)}; - // Build spec constants: apply_bias + layout constants + // Build spec constants: apply_bias, activation_type + layout constants vkapi::SpecVarList spec_constants = { apply_bias, + activation_type, // Layout specialization constants graph.hashed_layout_of(packed_int8_input), graph.hashed_layout_of(packed_int8_output), @@ -385,8 +387,12 @@ void q8ta_conv2d_dw(ComputeGraph& graph, const std::vector& args) { const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef packed_int8_output = args.at(idx++); + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + QuantizationConfig weight_quant_config(8, kPerChannel, {}); // Prepack weight using depthwise-specific packing @@ -432,6 +438,7 @@ void q8ta_conv2d_dw(ComputeGraph& graph, const std::vector& args) { padding, dilation, groups, + activation_type_val, packed_int8_output); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp index e89ebc92aba..161b5e8fc24 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dIm2Col.cpp @@ -197,6 +197,7 @@ void q8ta_conv2d_im2col( const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef packed_int8_output = args.at(idx++); QuantizationConfig weight_quant_config(8, kPerChannel, {}); @@ -225,6 +226,9 @@ void q8ta_conv2d_im2col( prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); } + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + // Calculate im2col output sizes std::vector im2col_sizes = calculate_q8ta_im2col_sizes( &graph, packed_int8_input, packed_int8_output, kernel_size, groups); @@ -265,6 +269,7 @@ void q8ta_conv2d_im2col( output_zp, bias_data, packed_bias, + activation_type_val, packed_int8_output); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp index fc883eefeef..b72f5b78f53 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp @@ -199,6 +199,7 @@ void add_q8ta_conv2d_pw_node( const ValueRef output_zp, const ValueRef bias_data, const ValueRef packed_bias, + const uint32_t activation_type, const ValueRef packed_int8_output) { // Validate packed dim info for input and output tensors // To maximize performance, the input tensor must be in 4W4C layout @@ -242,9 +243,10 @@ void add_q8ta_conv2d_pw_node( graph.buffer_meta_ubo(packed_int8_output), graph.buffer_meta_ubo(packed_int8_input)}; - // Build spec constants: apply_bias + layout constants + // Build spec constants: apply_bias, activation_type + layout constants vkapi::SpecVarList spec_constants = { apply_bias, + activation_type, K4_per_group, // Layout specialization constants graph.hashed_layout_of(packed_int8_output), @@ -296,8 +298,12 @@ void q8ta_conv2d_pw(ComputeGraph& graph, const std::vector& args) { (void)args.at(idx++); // padding (void)args.at(idx++); // dilation (void)args.at(idx++); // groups + const ValueRef activation_ref = args.at(idx++); const ValueRef packed_int8_output = args.at(idx++); + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation_ref))); + QuantizationConfig weight_quant_config(8, kPerChannel, {}); // Prepack weight using pointwise-specific packing @@ -342,6 +348,7 @@ void q8ta_conv2d_pw(ComputeGraph& graph, const std::vector& args) { output_zp, bias_data, packed_bias, + activation_type_val, packed_int8_output); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.cpp new file mode 100644 index 00000000000..f8b606f3dfa --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.cpp @@ -0,0 +1,124 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +void resize_q8ta_unary_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); + graph->virtual_resize(out, graph->sizes_of(self)); +} + +// +// Dispatch nodes +// + +void add_q8ta_unary_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef packed_int8_output, + const std::string& op_name) { + const api::PackedDimInfo& output_info = + graph.packed_dim_info_of(packed_int8_output); + const api::PackedDimInfo& input_info = + graph.packed_dim_info_of(packed_int8_input); + + VK_CHECK_COND(input_info.packed_dim == output_info.packed_dim); + VK_CHECK_COND( + input_info.packed_dim_block_size == output_info.packed_dim_block_size); + + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + std::string kernel_name = "q8ta_" + op_name; + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_output)); + + vkapi::ParamsBindList param_buffers; + param_buffers.append(graph.buffer_meta_ubo(packed_int8_output)); + param_buffers.append(graph.buffer_meta_ubo(packed_int8_input)); + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + const BlockConfig block_config = + create_block_config_for_tensor(graph, packed_int8_output); + + const ValueRef block_config_ref = + static_cast(block_config.as_packed_int()); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_linear_global_wg_with_block_config, + pick_square_local_wg_with_block_config, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, {packed_int8_input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {graph.hashed_layout_of(packed_int8_output), + graph.hashed_layout_of(packed_int8_input), + block_config.as_packed_int()}, + // Resize args + {block_config_ref}, + // Resizing Logic + resize_q8ta_unary_node)); +} + +// +// High level operator impl +// + +void q8ta_relu(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + add_q8ta_unary_node( + graph, + packed_int8_input, + input_scale, + input_zp, + output_scale, + output_zp, + packed_int8_output, + "relu"); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.q8ta_relu.default, q8ta_relu); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.h b/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.h new file mode 100644 index 00000000000..2b68fa53c22 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taUnary.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace vkcompute { + +// +// Unary operations for int8x4 tensors +// + +void add_q8ta_unary_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef packed_int8_output, + const std::string& op_name); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp index 1bfff6f1342..ebc276ee347 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp @@ -894,6 +894,7 @@ void add_conv2d_q8ta_q8csw_q8to_node( padding, dilation, groups, + static_cast(ActivationType::kNone), packed_int8_output); } } diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp index 4fed7461ce6..679ac33d11b 100644 --- a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp @@ -32,6 +32,7 @@ void test_q8ta_conv2d_dw( const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef layout_int = args.at(idx++); const ValueRef impl_selector_str = args.at(idx++); const ValueRef fp_output = args.at(idx++); @@ -59,29 +60,43 @@ void test_q8ta_conv2d_dw( add_q8ta_quantize_node( graph, fp_input, input_scale, input_zp, packed_int8_input); - // Build args for conv operator - std::vector conv_args = { - packed_int8_input, - input_scale, - input_zp, - weight_data, - weight_sums_data, - weight_scales_data, - output_scale, - output_zp, - bias_data, - kernel_size, - stride, - padding, - dilation, - groups, - packed_int8_output}; - if (impl_selector == "legacy_4w4c") { - // Use the general quantized conv2d operator for legacy path + // Legacy path does not support activation + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output}; VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args); } else { - // Use the dedicated depthwise conv2d operator + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + activation, + packed_int8_output}; VK_GET_OP_FN("et_vk.q8ta_conv2d_dw.default")(graph, conv_args); } @@ -106,6 +121,7 @@ void test_q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef layout_int = args.at(idx++); const ValueRef impl_selector_str = args.at(idx++); const ValueRef fp_output = args.at(idx++); @@ -133,36 +149,50 @@ void test_q8ta_conv2d(ComputeGraph& graph, const std::vector& args) { add_q8ta_quantize_node( graph, fp_input, input_scale, input_zp, packed_int8_input); - // Build args for conv operator - std::vector conv_args = { - packed_int8_input, - input_scale, - input_zp, - weight_data, - weight_sums_data, - weight_scales_data, - output_scale, - output_zp, - bias_data, - kernel_size, - stride, - padding, - dilation, - groups, - packed_int8_output}; - if (impl_selector == "legacy_4w4c") { - // Use the general quantized conv2d operator for legacy path + // Legacy path does not support activation + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output}; VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args); - } else if (impl_selector == "im2col") { - // Use the im2col-based conv2d operator - VK_GET_OP_FN("et_vk.q8ta_conv2d_im2col.default")(graph, conv_args); - } else if (impl_selector == "general") { - // Use the general q8ta_conv2d operator (no im2col dispatch) - VK_GET_OP_FN("et_vk.q8ta_conv2d_general.default")(graph, conv_args); } else { - // Use the new general q8ta_conv2d operator - VK_GET_OP_FN("et_vk.q8ta_conv2d.default")(graph, conv_args); + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + activation, + packed_int8_output}; + if (impl_selector == "im2col") { + VK_GET_OP_FN("et_vk.q8ta_conv2d_im2col.default")(graph, conv_args); + } else if (impl_selector == "general") { + VK_GET_OP_FN("et_vk.q8ta_conv2d_general.default")(graph, conv_args); + } else { + VK_GET_OP_FN("et_vk.q8ta_conv2d.default")(graph, conv_args); + } } // Dequantize packed int8 output to floating point @@ -188,6 +218,7 @@ void test_q8ta_conv2d_pw( const ValueRef padding = args.at(idx++); const ValueRef dilation = args.at(idx++); const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); const ValueRef layout_int = args.at(idx++); const ValueRef impl_selector_str = args.at(idx++); const ValueRef fp_output = args.at(idx++); @@ -219,27 +250,43 @@ void test_q8ta_conv2d_pw( add_q8ta_quantize_node( graph, fp_input, input_scale, input_zp, packed_int8_input); - // Build args for conv operator - std::vector conv_args = { - packed_int8_input, - input_scale, - input_zp, - weight_data, - weight_sums_data, - weight_scales_data, - output_scale, - output_zp, - bias_data, - kernel_size, - stride, - padding, - dilation, - groups, - packed_int8_output}; - if (impl_selector == "legacy_4w4c") { + // Legacy path does not support activation + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + packed_int8_output}; VK_GET_OP_FN("et_vk.conv2d_q8ta_q8csw_q8to.default")(graph, conv_args); } else { + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + activation, + packed_int8_output}; VK_GET_OP_FN("et_vk.q8ta_conv2d_pw.default")(graph, conv_args); } diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taUnary.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taUnary.cpp new file mode 100644 index 00000000000..6212216686f --- /dev/null +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taUnary.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +void q8ta_unary_test(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef quant_layout_int = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + int32_t layout_value = graph.extract_scalar(quant_layout_int); + utils::GPUMemoryLayout quant_layout = + static_cast(layout_value); + + // Create temporary tensor for quantized input + TmpTensor packed_int8_input( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + quant_layout); + + // Create temporary tensor for quantized output + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + quant_layout); + + // Quantize: FP -> int8x4 + add_q8ta_quantize_node( + graph, fp_input, input_scale, input_zp, packed_int8_input); + + // Unary op: int8x4 -> int8x4 + add_q8ta_unary_node( + graph, + packed_int8_input, + input_scale, + input_zp, + output_scale, + output_zp, + packed_int8_output, + "relu"); + + // Dequantize: int8x4 -> FP + add_q8ta_dequantize_node( + graph, packed_int8_output, output_scale, output_zp, fp_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(test_etvk.q8ta_unary_test.default, q8ta_unary_test); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp index 17dd7a0fc53..bc95cc724f5 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp @@ -178,6 +178,10 @@ static TestCase create_test_case_from_config( test_case.add_input_spec(dilation); test_case.add_input_spec(groups); + // Activation (none = no activation) + ValueSpec activation = ValueSpec::make_string("none"); + test_case.add_input_spec(activation); + // Add memory layout parameter for the quantized tensors ValueSpec layout_int(static_cast(int8_memory_layout)); test_case.add_input_spec(layout_int); @@ -455,6 +459,8 @@ static void conv2d_q8ta_q8csw_q8to_reference_impl(TestCase& test_case) { const ValueSpec& padding_spec = test_case.inputs()[idx++]; const ValueSpec& dilation_spec = test_case.inputs()[idx++]; const ValueSpec& groups_spec = test_case.inputs()[idx++]; + const ValueSpec& activation_spec = test_case.inputs()[idx++]; + (void)activation_spec; // Not used in reference implementation const ValueSpec& layout_spec = test_case.inputs()[idx++]; (void)layout_spec; // Not used in reference implementation const ValueSpec& impl_selector_spec = test_case.inputs()[idx++]; diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp index 7ef73d49802..0734e444d57 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp @@ -187,6 +187,10 @@ TestCase create_test_case_from_config( test_case.add_input_spec(dilation); test_case.add_input_spec(groups); + // Activation (none = no activation) + ValueSpec activation = ValueSpec::make_string("none"); + test_case.add_input_spec(activation); + // Add memory layout parameter for the quantized tensors ValueSpec layout_int(static_cast(int8_memory_layout)); test_case.add_input_spec(layout_int); diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp index 51095c649b6..83b9f92fb3a 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp @@ -179,6 +179,10 @@ static TestCase create_test_case_from_config( test_case.add_input_spec(dilation); test_case.add_input_spec(groups); + // Activation (none = no activation) + ValueSpec activation = ValueSpec::make_string("none"); + test_case.add_input_spec(activation); + // Add memory layout parameter for the quantized tensors ValueSpec layout_int(static_cast(int8_memory_layout)); test_case.add_input_spec(layout_int); @@ -210,6 +214,28 @@ static std::vector generate_quantized_conv2d_pw_test_cases() { } std::vector configs = { + // OC < 4 cases to test edge cases with partial output channel blocks + {OutInChannels(1, 16), + InputSize2D(8, 8), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(2, 16), + InputSize2D(8, 8), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, + {OutInChannels(3, 16), + InputSize2D(8, 8), + KernelSize(1, 1), + Stride(1, 1), + Padding(0, 0), + Dilation(1, 1), + 1}, // Pointwise convolutions: kernel size 1x1 {OutInChannels(32, 3), InputSize2D(64, 64), @@ -344,6 +370,8 @@ static void conv2d_q8ta_q8csw_q8to_reference_impl(TestCase& test_case) { const ValueSpec& padding_spec = test_case.inputs()[idx++]; const ValueSpec& dilation_spec = test_case.inputs()[idx++]; const ValueSpec& groups_spec = test_case.inputs()[idx++]; + const ValueSpec& activation_spec = test_case.inputs()[idx++]; + (void)activation_spec; // Not used in reference implementation const ValueSpec& layout_spec = test_case.inputs()[idx++]; (void)layout_spec; // Not used in reference implementation const ValueSpec& impl_selector_spec = test_case.inputs()[idx++]; diff --git a/backends/vulkan/test/custom_ops/test_q8ta_unary.cpp b/backends/vulkan/test/custom_ops/test_q8ta_unary.cpp new file mode 100644 index 00000000000..bc184c6c182 --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_q8ta_unary.cpp @@ -0,0 +1,311 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include "utils.h" + +#include + +// #define DEBUG_MODE + +using namespace executorch::vulkan::prototyping; +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 512; + +struct Q8taUnaryConfig { + std::vector shape; + std::string test_case_name = "placeholder"; + std::string op_name = "q8ta_unary_test"; +}; + +TestCase create_test_case_from_config( + const Q8taUnaryConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype, + utils::GPUMemoryLayout fp_memory_layout, + utils::GPUMemoryLayout quant_layout) { + TestCase test_case; + + std::string shape_str = shape_string(config.shape); + std::string test_name = config.test_case_name + " I=" + shape_str + " " + + repr_str(storage_type, fp_memory_layout) + "->" + + repr_str(utils::kBuffer, quant_layout); + test_case.set_name(test_name); + + std::string operator_name = "test_etvk." + config.op_name + ".default"; + test_case.set_operator_name(operator_name); + + // Input tensor (float) + ValueSpec input_tensor( + config.shape, + input_dtype, + storage_type, + fp_memory_layout, + DataGenType::RANDOM); + + float scale_val = 0.007112; + ValueSpec input_scale(scale_val); + + int32_t zero_point_val = 0; + ValueSpec input_zero_point(zero_point_val); + + // For relu, output scale and zero point can differ from input + float output_scale_val = 0.007112; + ValueSpec output_scale(output_scale_val); + + int32_t output_zp_val = 0; + ValueSpec output_zero_point(output_zp_val); + + int32_t layout_int = static_cast(quant_layout); + ValueSpec layout_spec(layout_int); + + // Output tensor (float) - same shape as input + ValueSpec output_tensor( + config.shape, + input_dtype, + storage_type, + fp_memory_layout, + DataGenType::ZEROS); + + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zero_point); + test_case.add_input_spec(layout_spec); + test_case.add_output_spec(output_tensor); + + test_case.set_abs_tolerance(scale_val + 1e-4); + + test_case.set_shader_filter({ + "nchw_to", + "to_nchw", + "q8ta_quantize", + "q8ta_dequantize", + }); + + return test_case; +} + +std::vector generate_q8ta_unary_easy_cases() { + std::vector test_cases; + + Q8taUnaryConfig config = { + {1, 16, 16, 16}, + "ACCU", + }; + + std::vector fp_layouts = { + utils::kWidthPacked, + utils::kChannelsPacked, + }; + + std::vector quant_layouts = { + utils::kPackedInt8_4W, + utils::kPackedInt8_4C, + utils::kPackedInt8_4W4C, + utils::kPackedInt8_4H4W, + utils::kPackedInt8_4C1W, + }; + + std::vector storage_types = {utils::kBuffer}; + std::vector float_types = {vkapi::kFloat}; + + for (const auto& fp_layout : fp_layouts) { + for (const auto& quant_layout : quant_layouts) { + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back(create_test_case_from_config( + config, storage_type, input_dtype, fp_layout, quant_layout)); + } + } + } + } + + return test_cases; +} + +std::vector generate_q8ta_unary_test_cases() { + std::vector test_cases; + + std::vector> shapes = { + {1, 3, 16, 16}, + {1, 8, 32, 32}, + {1, 16, 24, 24}, + {1, 32, 12, 12}, + {1, 1, 64, 64}, + {1, 3, 64, 64}, + {1, 4, 16, 16}, + + {1, 8, 20, 20}, + {1, 16, 14, 14}, + {1, 8, 28, 28}, + + // Odd tensor sizes + {1, 3, 15, 15}, + {1, 13, 31, 31}, + {1, 17, 23, 23}, + + // Larger tensors + {1, 64, 128, 128}, + {1, 32, 64, 64}, + {1, 128, 56, 56}, + {1, 128, 128, 128}, + }; + + std::vector fp_layouts = { + utils::kWidthPacked, + utils::kChannelsPacked, + }; + + std::vector quant_layouts = { + utils::kPackedInt8_4W, + utils::kPackedInt8_4C, + utils::kPackedInt8_4W4C, + utils::kPackedInt8_4H4W, + utils::kPackedInt8_4C1W, + }; + + std::vector storage_types = {utils::kBuffer}; + + for (const auto& shape : shapes) { + std::string prefix = "ACCU"; + for (const auto& dim : shape) { + if (dim > kRefDimSizeLimit) { + prefix = "PERF"; + break; + } + } + + for (const auto& fp_layout : fp_layouts) { + for (const auto& quant_layout : quant_layouts) { + for (const auto& storage_type : storage_types) { + Q8taUnaryConfig config; + config.shape = shape; + config.test_case_name = prefix; + + test_cases.push_back(create_test_case_from_config( + config, storage_type, vkapi::kFloat, fp_layout, quant_layout)); + } + } + } + } + + return test_cases; +} + +// Reference implementation: quantize -> relu -> dequantize +void q8ta_unary_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zp_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zp_spec = test_case.inputs()[idx++]; + const ValueSpec& layout_spec = test_case.inputs()[idx++]; + (void)layout_spec; + + ValueSpec& output_spec = test_case.outputs()[0]; + + auto input_sizes = input_spec.get_tensor_sizes(); + + int64_t num_elements = 1; + for (const auto& dim : input_sizes) { + num_elements *= dim; + } + + for (const auto& dim : input_sizes) { + if (dim > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference " + "implementation."); + } + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + auto& input_data = input_spec.get_float_data(); + + float input_scale = input_scale_spec.get_float_value(); + int32_t input_zp = input_zp_spec.get_int_value(); + float output_scale = output_scale_spec.get_float_value(); + int32_t output_zp = output_zp_spec.get_int_value(); + int32_t quant_min = -128; + int32_t quant_max = 127; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_elements); + + for (int64_t i = 0; i < num_elements; ++i) { + float input_val = input_data[i]; + + // Quantize with input scale/zp + float quantized_float = std::round(input_val / input_scale) + input_zp; + quantized_float = std::max(quantized_float, static_cast(quant_min)); + quantized_float = std::min(quantized_float, static_cast(quant_max)); + int32_t quantized_int = static_cast(quantized_float); + + // Dequantize to float + float dequantized = (quantized_int - input_zp) * input_scale; + + // Apply ReLU + float activated = std::max(dequantized, 0.0f); + + // Requantize with output scale/zp + float requantized_float = std::round(activated / output_scale) + output_zp; + requantized_float = + std::max(requantized_float, static_cast(quant_min)); + requantized_float = + std::min(requantized_float, static_cast(quant_max)); + int32_t requantized_int = static_cast(requantized_float); + + // Dequantize back to float for comparison + ref_data[i] = (requantized_int - output_zp) * output_scale; + } +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); +#ifdef DEBUG_MODE + set_print_latencies(false); +#else + set_print_latencies(false); +#endif + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Q8TA Unary (ReLU) Operation Prototyping Framework" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = q8ta_unary_reference_impl; + + auto results = execute_test_cases( +#ifdef DEBUG_MODE + generate_q8ta_unary_easy_cases, +#else + generate_q8ta_unary_test_cases, +#endif + "Q8taUnary", +#ifdef DEBUG_MODE + 0, + 1, +#else + 3, + 10, +#endif + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/utils.cpp b/backends/vulkan/test/custom_ops/utils.cpp index b23c288a58f..2a50e7b5ec1 100644 --- a/backends/vulkan/test/custom_ops/utils.cpp +++ b/backends/vulkan/test/custom_ops/utils.cpp @@ -2064,7 +2064,11 @@ void compute_weight_sums( auto& weight_sums_data = weight_sums.get_int32_data(); auto& quantized_weight_data = quantized_weight.get_int8_data(); - weight_sums_data.resize(out_features); + // Don't resize down - the buffer may be pre-allocated with aligned size. + // Only resize up if needed. + if (weight_sums_data.size() < static_cast(out_features)) { + weight_sums_data.resize(out_features); + } // For each output feature, compute the sum of quantized weights for (int64_t out_f = 0; out_f < out_features; ++out_f) { diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 3ccbdc8ab85..b276ffd16f5 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -162,10 +162,10 @@ def preprocess( # noqa: C901 program = apply_passes( program, [ + AddmmToLinearTransform(), FuseBatchNormPass(program), FusePatternsPass(), FuseClampPass(), - AddmmToLinearTransform(), RemoveRedundantOpsTransform(), FuseQuantizedOpsTransform(), FoldQDQPass(),