提交 540935a8 编写于 作者: M Michał Gallus 提交者: Zhen Wang

[Bug-fix][1.6] Improve QAT accuracy (#20174)

* Leave fake quantization around mul

* Replace Fake with Real Quantized Mul

* Gather all scales from fake_quantize_ops

* Enable uint8 in conv_relu tensors

* Disable int8 mul and restore fake mul

* Fix buf for running QAT on VGG16 and 19
上级 a73e1f68
...@@ -321,10 +321,11 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -321,10 +321,11 @@ class FakeQAT2MkldnnINT8PerfPass(object):
graph = self._gather_scales(graph) graph = self._gather_scales(graph)
graph = self._remove_fake_ops(graph) graph = self._remove_fake_ops(graph)
graph = self._update_pooling_scales(graph)
graph = self._dequantize_weights(graph) graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph) graph = self._optimize_fp32_graph(graph)
graph = self._compute_weight_scales(graph) graph = self._compute_weight_scales(graph)
graph = self._update_conv_relu_scales(graph)
graph = self._update_pooling_scales(graph)
graph = self._quantize_fp32_graph(graph) graph = self._quantize_fp32_graph(graph)
graph = self._remove_unused_var_nodes(graph) graph = self._remove_unused_var_nodes(graph)
return graph return graph
...@@ -350,6 +351,8 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -350,6 +351,8 @@ class FakeQAT2MkldnnINT8PerfPass(object):
use_unsigned_int = False use_unsigned_int = False
self._var_quant_scales[input_name] = (use_unsigned_int, self._var_quant_scales[input_name] = (use_unsigned_int,
lod_tensor) lod_tensor)
self._var_quant_scales[scale_name.replace(".scale", "")] = (
use_unsigned_int, lod_tensor)
if op.name() in self._fake_dequantize_types: if op.name() in self._fake_dequantize_types:
input_name = op.input("X")[0] input_name = op.input("X")[0]
...@@ -378,12 +381,12 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -378,12 +381,12 @@ class FakeQAT2MkldnnINT8PerfPass(object):
next_op = op_out.outputs[0] next_op = op_out.outputs[0]
if next_op.name() not in self._mul_ops: if next_op.name() not in self._mul_ops:
self._remove_fake_quantize(graph, op) self._remove_fake_quantize(graph, op)
else:
quant_op = self._transform_to_quantize_mkldnn(graph, op)
self._transform_to_mul_mkldnn(graph, next_op, quant_op)
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._fake_dequantize_types: if op.name() in self._fake_dequantize_types:
op_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
prev_op = op_in.inputs[0]
if prev_op.name() not in self._mul_ops:
self._remove_fake_dequantize(graph, op) self._remove_fake_dequantize(graph, op)
return graph return graph
...@@ -530,7 +533,7 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -530,7 +533,7 @@ class FakeQAT2MkldnnINT8PerfPass(object):
if op.name() in self._pool_ops: if op.name() in self._pool_ops:
if op.op().attr("pooling_type") == "avg": if op.op().attr("pooling_type") == "avg":
ids.append(op.id()) ids.append(op.id())
return set(ids) return set(ids) if len(ids) else set([-1])
def _transform_to_quantize_mkldnn(self, graph, op_node): def _transform_to_quantize_mkldnn(self, graph, op_node):
""" """
...@@ -557,13 +560,16 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -557,13 +560,16 @@ class FakeQAT2MkldnnINT8PerfPass(object):
graph.safe_remove_nodes(op_node) graph.safe_remove_nodes(op_node)
return quant_op_node return quant_op_node
def _transform_to_mul_mkldnn(self, graph, op_node, quantize_node): def _update_conv_relu_scales(self, graph):
input_name = op_node.input("X")[0] for op in graph.all_op_nodes():
scale_in = quantize_node.op().attr("Scale") if op.name() in self._conv_ops:
op_node.set_attr("scale_y", [1.0]) out_name = op.output("Output")[0]
op_node.set_attr("scale_x", scale_in) if out_name in self._var_quant_scales and \
op_node.set_attr("scale_out", 1.0) op.op().attr("fuse_activation") == 'relu' and \
op_node.set_attr("force_fp32_output", True) op.op().attr("fuse_residual_connection") == False:
_, tensor = self._var_quant_scales[out_name]
self._var_quant_scales[out_name] = (True, tensor)
return graph
def _quantize_fp32_graph(self, graph): def _quantize_fp32_graph(self, graph):
ir_pass = self._core.get_pass('cpu_quantize_placement_pass') ir_pass = self._core.get_pass('cpu_quantize_placement_pass')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册