未验证 提交 cb6c0e21 编写于 作者: X XGZhang 提交者: GitHub

Fix a bug of quantization (#36982)

* fix a quantization bug
上级 be4eaba0
...@@ -1292,10 +1292,11 @@ class QuantizationFreezePass(object): ...@@ -1292,10 +1292,11 @@ class QuantizationFreezePass(object):
var_type=output_var_node.type(), var_type=output_var_node.type(),
shape=output_var_node.shape(), shape=output_var_node.shape(),
var_dtype=output_var_node.dtype()) var_dtype=output_var_node.dtype())
x_num_col_dims = 1
if op_node.name() in ['matmul', 'matmul_v2', 'mul']:
x_num_col_dims = len(op_node.outputs[0].shape()) - 1
if op_node.op().has_attr("x_num_col_dims"): if op_node.op().has_attr("x_num_col_dims"):
x_num_col_dims = op_node.op().attr("x_num_col_dims") x_num_col_dims = op_node.op().attr("x_num_col_dims")
else:
x_num_col_dims = 1
dequant_op_node = graph.create_op_node( dequant_op_node = graph.create_op_node(
op_type='fake_channel_wise_dequantize_max_abs', op_type='fake_channel_wise_dequantize_max_abs',
attrs={ attrs={
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册