Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,71 @@ def linear_q8ta_q8csw(
lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd")
qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name)

##################
## q8ta_linear ##
##################


def q8ta_linear(
x: torch.Tensor,
input_scale: float,
input_zero_point: int,
weights: torch.Tensor,
weight_sums: torch.Tensor,
weight_scales: torch.Tensor,
output_scale: float,
output_zero_point: int,
bias: Optional[torch.Tensor] = None,
activation: str = "none",
):
weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
weights,
weight_scales,
weight_zeros,
0,
-127,
127,
torch.int8,
)

x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, input_scale, input_zero_point, -128, 127, x.dtype
)

out = torch.nn.functional.linear(x, weights)
if bias is not None:
out = out + bias

if activation == "relu":
out = torch.nn.functional.relu(out)

out = torch.ops.quantized_decomposed.quantize_per_tensor(
out, output_scale, output_zero_point, -128, 127, torch.int8
)

return out


name = "q8ta_linear"
lib.define(
f"""
{name}(
Tensor x,
float input_scale,
int input_zero_point,
Tensor weights,
Tensor weight_sums,
Tensor weight_scales,
float output_scale,
int output_zero_point,
Tensor? bias = None,
str activation = "none") -> Tensor
"""
)
lib.impl(name, q8ta_linear, "CompositeExplicitAutograd")
q8ta_linear_op = getattr(getattr(torch.ops, namespace), name)

###################
## q8ta_conv2d_* ##
###################
Expand Down
28 changes: 28 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,34 @@ def register_q8ta_conv2d_ops():
)


# =============================================================================
# Q8taLinear.cpp
# =============================================================================


@update_features(exir_ops.edge.et_vk.q8ta_linear.default)
def register_q8ta_linear():
return OpFeatures(
inputs_storage=[
utils.PACKED_INT8_4H4W_BUFFER, # input
utils.NO_STORAGE, # input_scale (non tensor)
utils.NO_STORAGE, # input_zero_point (non tensor)
utils.NO_STORAGE, # weight (prepacked)
utils.NO_STORAGE, # weight_sums (prepacked)
utils.NO_STORAGE, # weight_scales (prepacked)
utils.NO_STORAGE, # output_scale (non tensor)
utils.NO_STORAGE, # output_zero_point (non tensor)
utils.NO_STORAGE, # bias (prepacked)
utils.NO_STORAGE, # activation (non tensor)
],
outputs_storage=[
utils.PACKED_INT8_4H4W_BUFFER,
],
supports_resize=False,
supports_prepacking=True,
)


# =============================================================================
# SDPA.cpp
# =============================================================================
Expand Down
92 changes: 88 additions & 4 deletions backends/vulkan/patterns/quantized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


class QuantizedLinearMatch(PatternMatch):
def __init__(self, mm_node: torch.fx.Node) -> None:
def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901
self.anchor_node = mm_node
self.match_found = False
self.all_nodes = [self.anchor_node]
Expand Down Expand Up @@ -111,10 +111,17 @@ def __init__(self, mm_node: torch.fx.Node) -> None:
self.bias_node = None
if self.anchor_node.target == exir_ops.edge.aten.addmm.default:
self.bias_node, arg_chain = utils.trace_args_until_placeholder(
self.anchor_node.args[2]
self.anchor_node.args[0]
)
assert self.bias_node is not None
self.all_nodes.extend(arg_chain)
elif self.anchor_node.target == exir_ops.edge.aten.linear.default:
if len(self.anchor_node.args) > 2 and self.anchor_node.args[2] is not None:
self.bias_node, arg_chain = utils.trace_args_until_placeholder(
self.anchor_node.args[2]
)
if self.bias_node is not None:
self.all_nodes.extend(arg_chain)

# If input is not quantized, then we are done
if self.quantize_input_node is None:
Expand Down Expand Up @@ -143,11 +150,36 @@ def __init__(self, mm_node: torch.fx.Node) -> None:
]
)

# Check if the output is also quantized (q → dq → linear → q pattern)
# Also handle fused linear+relu (q → dq → linear → relu → q pattern)
self.quantize_output_node = None
self.output_scales_node = None
self.output_zeros_node = None
self.relu_node = None
if len(self.output_node.users) == 1:
cur_node = list(self.output_node.users)[0]
if cur_node.target == exir_ops.edge.aten.relu.default:
self.relu_node = cur_node
if len(cur_node.users) == 1:
cur_node = list(cur_node.users)[0]
else:
cur_node = None
if cur_node is not None and utils.is_quant_node(cur_node):
self.quantize_output_node = cur_node
self.output_scales_node = self.quantize_output_node.args[1]
self.output_zeros_node = self.quantize_output_node.args[2]

self.match_found = True

def is_weight_only_quantized(self) -> bool:
return self.quantize_input_node is None

def has_output_quantization(self) -> bool:
return (
hasattr(self, "quantize_output_node")
and self.quantize_output_node is not None
)

def is_weight_pergroup_quantized(self) -> bool:
weight_shape = self.weight_node.meta["val"].shape
scales_shape = self.weight_scales_node.meta["val"].shape
Expand Down Expand Up @@ -454,6 +486,49 @@ def make_linear_q8ta_q8csw_custom_op(
match.output_node.replace_all_uses_with(qlinear_node)


def make_q8ta_linear_custom_op(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
match: QuantizedLinearMatch,
weight_tensor: torch.Tensor,
):
first_graph_node = list(graph_module.graph.nodes)[0]
with graph_module.graph.inserting_before(first_graph_node):
weight_tensor_name = utils.get_tensor_name(ep, match.weight_node)
sum_per_output_channel = weight_tensor.sum(dim=1).to(torch.int32).contiguous()
sums_name = weight_tensor_name + "_sums"
sums_name = sums_name.replace(".", "_")

weight_sums_node = create_constant_placeholder(
exp_program=ep,
graph=graph_module.graph,
kind=InputKind.PARAMETER,
name=sums_name,
data=sum_per_output_channel,
)

with graph_module.graph.inserting_before(match.output_node):
qlinear_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.et_vk.q8ta_linear.default,
args=(
match.quantize_input_node,
match.input_scales_node,
match.input_zeros_node,
match.weight_node,
weight_sums_node,
match.weight_scales_node,
match.output_scales_node,
match.output_zeros_node,
match.bias_node,
"relu" if match.relu_node is not None else "none",
),
)

qlinear_node.meta["val"] = match.quantize_output_node.meta["val"]
match.quantize_output_node.replace_all_uses_with(qlinear_node)


@register_pattern_replacement("quantized_linear")
def replace_quantized_linear_patterns(
ep: ExportedProgram,
Expand All @@ -472,11 +547,20 @@ def replace_quantized_linear_patterns(
weight_zeros_tensor = get_param_tensor(ep, match.weight_zeros_node)
assert weight_zeros_tensor is not None

# Biases not supported at the moment
# Route to appropriate custom op.
# q8ta_linear supports bias, so check it first before the bias guard.
if (
match.is_input_static_per_tensor_quantized()
and match.is_weight_perchannel_quantized()
and match.has_output_quantization()
):
make_q8ta_linear_custom_op(ep, graph_module, match, weight_tensor)
return

# Remaining ops do not support bias
if match.bias_node is not None:
return

# Route to appropriate custom op
if (
match.is_weight_only_quantized()
and match.is_weight_pergroup_quantized()
Expand Down
Loading
Loading