提交 540935a8 编写于 作者: M Michał Gallus 提交者: Zhen Wang

[Bug-fix][1.6] Improve QAT accuracy (#20174)

* Leave fake quantization around mul

* Replace Fake with Real Quantized Mul

* Gather all scales from fake_quantize_ops

* Enable uint8 in conv_relu tensors

* Disable int8 mul and restore fake mul

* Fix buf for running QAT on VGG16 and 19
上级 a73e1f68
......@@ -321,10 +321,11 @@ class FakeQAT2MkldnnINT8PerfPass(object):
graph = self._gather_scales(graph)
graph = self._remove_fake_ops(graph)
graph = self._update_pooling_scales(graph)
graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph)
graph = self._compute_weight_scales(graph)
graph = self._update_conv_relu_scales(graph)
graph = self._update_pooling_scales(graph)
graph = self._quantize_fp32_graph(graph)
graph = self._remove_unused_var_nodes(graph)
return graph
......@@ -350,6 +351,8 @@ class FakeQAT2MkldnnINT8PerfPass(object):
use_unsigned_int = False
self._var_quant_scales[input_name] = (use_unsigned_int,
lod_tensor)
self._var_quant_scales[scale_name.replace(".scale", "")] = (
use_unsigned_int, lod_tensor)
if op.name() in self._fake_dequantize_types:
input_name = op.input("X")[0]
......@@ -378,13 +381,13 @@ class FakeQAT2MkldnnINT8PerfPass(object):
next_op = op_out.outputs[0]
if next_op.name() not in self._mul_ops:
self._remove_fake_quantize(graph, op)
else:
quant_op = self._transform_to_quantize_mkldnn(graph, op)
self._transform_to_mul_mkldnn(graph, next_op, quant_op)
for op in graph.all_op_nodes():
if op.name() in self._fake_dequantize_types:
self._remove_fake_dequantize(graph, op)
op_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
prev_op = op_in.inputs[0]
if prev_op.name() not in self._mul_ops:
self._remove_fake_dequantize(graph, op)
return graph
def _remove_fake_quantize(self, graph, op):
......@@ -530,7 +533,7 @@ class FakeQAT2MkldnnINT8PerfPass(object):
if op.name() in self._pool_ops:
if op.op().attr("pooling_type") == "avg":
ids.append(op.id())
return set(ids)
return set(ids) if len(ids) else set([-1])
def _transform_to_quantize_mkldnn(self, graph, op_node):
"""
......@@ -557,13 +560,16 @@ class FakeQAT2MkldnnINT8PerfPass(object):
graph.safe_remove_nodes(op_node)
return quant_op_node
def _transform_to_mul_mkldnn(self, graph, op_node, quantize_node):
input_name = op_node.input("X")[0]
scale_in = quantize_node.op().attr("Scale")
op_node.set_attr("scale_y", [1.0])
op_node.set_attr("scale_x", scale_in)
op_node.set_attr("scale_out", 1.0)
op_node.set_attr("force_fp32_output", True)
def _update_conv_relu_scales(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._conv_ops:
out_name = op.output("Output")[0]
if out_name in self._var_quant_scales and \
op.op().attr("fuse_activation") == 'relu' and \
op.op().attr("fuse_residual_connection") == False:
_, tensor = self._var_quant_scales[out_name]
self._var_quant_scales[out_name] = (True, tensor)
return graph
def _quantize_fp32_graph(self, graph):
ir_pass = self._core.get_pass('cpu_quantize_placement_pass')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册