提交 9de67725 编写于 作者: B bingyanghuang 提交者: Tao Luo

Follow comment of Merged QAT PR 18970 (#19979)

* Follow Wangzhen's comment in PR 18970, test=develop

* Review comments, test=develop

* Leave fake quantization around mul

test=develop

* Replace Fake with Real Quantized Mul

test=develop

* Fix bug in quantize placement pass

Nodes in the graph now have checked type instead of node name when they are to be marked for quantization test=develop
上级 c92348c3
...@@ -36,7 +36,7 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { ...@@ -36,7 +36,7 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
if (op_types_list.empty()) { if (op_types_list.empty()) {
op->SetAttr("use_quantizer", true); op->SetAttr("use_quantizer", true);
} else if (std::find(op_types_list.begin(), op_types_list.end(), } else if (std::find(op_types_list.begin(), op_types_list.end(),
n->Name()) != op_types_list.end()) { op->Type()) != op_types_list.end()) {
op->SetAttr("use_quantizer", true); op->SetAttr("use_quantizer", true);
} }
} }
......
...@@ -375,11 +375,15 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -375,11 +375,15 @@ class FakeQAT2MkldnnINT8PerfPass(object):
if op.name() in self._fake_quantize_types: if op.name() in self._fake_quantize_types:
op_out = graph._find_node_by_name(op.outputs, op_out = graph._find_node_by_name(op.outputs,
op.output("Out")[0]) op.output("Out")[0])
next_op = op_out.outputs[0]
if next_op.name() not in self._mul_ops:
self._remove_fake_quantize(graph, op) self._remove_fake_quantize(graph, op)
else:
quant_op = self._transform_to_quantize_mkldnn(graph, op)
self._transform_to_mul_mkldnn(graph, next_op, quant_op)
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._fake_dequantize_types: if op.name() in self._fake_dequantize_types:
op_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
self._remove_fake_dequantize(graph, op) self._remove_fake_dequantize(graph, op)
return graph return graph
...@@ -426,8 +430,6 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -426,8 +430,6 @@ class FakeQAT2MkldnnINT8PerfPass(object):
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._conv_ops: if op.name() in self._conv_ops:
self._dequantize_conv_weights(graph, op) self._dequantize_conv_weights(graph, op)
elif op.name() in self._mul_ops:
self._dequantize_mul_weights(graph, op)
return graph return graph
def _dequantize_conv_weights(self, graph, op_node): def _dequantize_conv_weights(self, graph, op_node):
...@@ -463,22 +465,20 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -463,22 +465,20 @@ class FakeQAT2MkldnnINT8PerfPass(object):
graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'fc_fuse_pass')
return graph return graph
def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None): def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None):
ir_pass = core.get_pass(pass_name) ir_pass = core.get_pass(pass_name)
inference_program = graph.to_program() cpp_graph = graph.graph
ir_graph = core.Graph(inference_program.desc) if not cpp_graph.has('__param_scope__'):
ir_graph.set_not_owned('__param_scope__', self._scope) cpp_graph.set_not_owned('__param_scope__', self._scope)
if attrs: if attrs:
assert attr_values and len(attrs) == len( assert attr_values and len(attrs) == len(
attr_values attr_values
), "Different number of pass attributes and their values." ), "Different number of pass attributes and their values."
for attr, value in zip(attrs, attr_values): for attr, value in zip(attrs, attr_values):
ir_pass.set(attr, value) ir_pass.set(attr, value)
ir_pass.apply(ir_graph) ir_pass.apply(cpp_graph)
graph = IrGraph(ir_graph, for_test=True)
if self._debug: if self._debug:
graph.draw('.', 'qat_fp32_{}'.format(pass_name), graph.draw('.', 'qat_fp32_{}'.format(pass_name),
graph.all_op_nodes()) graph.all_op_nodes())
...@@ -532,15 +532,46 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -532,15 +532,46 @@ class FakeQAT2MkldnnINT8PerfPass(object):
ids.append(op.id()) ids.append(op.id())
return set(ids) return set(ids)
def _transform_to_quantize_mkldnn(self, graph, op_node):
"""
Transform fake_quantize_xx op to quantize mkldnn op in the graph.
"""
input_var_node = graph._find_node_by_name(op_node.inputs,
op_node.input("X")[0])
output_var_node = graph._find_node_by_name(op_node.outputs,
op_node.output("Out")[0])
scale_in = self._s8_max / self._load_param(
self._scope, op_node.input("InScale")[0])[0]
quant_op_node = graph.create_op_node(
op_type='quantize',
attrs={
'data_format': 'MKLDNNLAYOUT',
'use_mkldnn': 1,
'Scale': scale_in,
'is_negative_input': 1
},
inputs={'Input': input_var_node},
outputs={'Output': output_var_node})
graph.link_to(input_var_node, quant_op_node)
graph.link_to(quant_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
return quant_op_node
def _transform_to_mul_mkldnn(self, graph, op_node, quantize_node):
input_name = op_node.input("X")[0]
scale_in = quantize_node.op().attr("Scale")
op_node.set_attr("scale_y", [1.0])
op_node.set_attr("scale_x", scale_in)
op_node.set_attr("scale_out", 1.0)
op_node.set_attr("force_fp32_output", True)
def _quantize_fp32_graph(self, graph): def _quantize_fp32_graph(self, graph):
ir_pass = self._core.get_pass('cpu_quantize_placement_pass') ir_pass = self._core.get_pass('cpu_quantize_placement_pass')
inference_program = graph.to_program() cpp_graph = graph.graph
ir_graph = self._core.Graph(inference_program.desc)
ir_pass.set('quantize_enabled_op_types', {'conv2d', 'pool2d'}) ir_pass.set('quantize_enabled_op_types', {'conv2d', 'pool2d'})
ir_pass.set('quantize_excluded_op_ids', ir_pass.set('quantize_excluded_op_ids',
self._find_avg_pooling_ids(graph)) self._find_avg_pooling_ids(graph))
ir_pass.apply(ir_graph) ir_pass.apply(cpp_graph)
graph = IrGraph(ir_graph, for_test=True)
if self._debug: if self._debug:
graph.draw('.', 'qat_int8_{}'.format(ir_pass.type()), graph.draw('.', 'qat_int8_{}'.format(ir_pass.type()),
graph.all_op_nodes()) graph.all_op_nodes())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册