未验证 提交 1753860d 编写于 作者: W Wojciech Uss 提交者: GitHub

Enable matmul and cleanup in QAT2 (#23657)

上级 4d0efee4
...@@ -15,10 +15,11 @@ ...@@ -15,10 +15,11 @@
import numpy as np import numpy as np
from .... import core from .... import core
from ....framework import IrGraph from ....framework import IrGraph
from ....framework import IrNode
__all__ = ['Qat2Int8MkldnnPass'] __all__ = ['Qat2Int8MkldnnPass']
OpRole = core.op_proto_and_checker_maker.OpRole
class Qat2Int8MkldnnPass(object): class Qat2Int8MkldnnPass(object):
""" """
...@@ -62,6 +63,7 @@ class Qat2Int8MkldnnPass(object): ...@@ -62,6 +63,7 @@ class Qat2Int8MkldnnPass(object):
self._pool_ops = ['pool2d'] self._pool_ops = ['pool2d']
self._mul_ops = ['mul'] self._mul_ops = ['mul']
self._fc_ops = ['fc'] self._fc_ops = ['fc']
self._matmul_ops = ['matmul']
self._weight_scales = {} self._weight_scales = {}
# Collect the Input and Output sclaes from Fake QAT models # Collect the Input and Output sclaes from Fake QAT models
self._var_quant_scales = {} self._var_quant_scales = {}
...@@ -79,9 +81,9 @@ class Qat2Int8MkldnnPass(object): ...@@ -79,9 +81,9 @@ class Qat2Int8MkldnnPass(object):
graph = self._compute_weight_scales(graph) graph = self._compute_weight_scales(graph)
graph = self._update_relu_output_scales(graph) graph = self._update_relu_output_scales(graph)
graph = self._propagate_scales(graph) graph = self._propagate_scales(graph)
graph = self._set_dummy_fc_out_scales(graph) graph = self._set_dummy_out_scales(graph)
graph = self._quantize_fp32_graph(graph) graph = self._quantize_fp32_graph(graph)
graph = self._remove_unused_var_nodes(graph) graph = self._cleanup(graph)
return graph return graph
def apply_fp32(self, graph): def apply_fp32(self, graph):
...@@ -92,7 +94,7 @@ class Qat2Int8MkldnnPass(object): ...@@ -92,7 +94,7 @@ class Qat2Int8MkldnnPass(object):
graph = self._remove_fake_ops(graph) graph = self._remove_fake_ops(graph)
graph = self._dequantize_weights(graph) graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph) graph = self._optimize_fp32_graph(graph)
graph = self._remove_unused_var_nodes(graph) graph = self._cleanup(graph)
return graph return graph
def _convert_scale2tensor(self, scale): def _convert_scale2tensor(self, scale):
...@@ -176,23 +178,29 @@ class Qat2Int8MkldnnPass(object): ...@@ -176,23 +178,29 @@ class Qat2Int8MkldnnPass(object):
return graph return graph
def _set_dummy_fc_out_scales(self, graph): def _set_dummy_out_scales(self, graph):
''' '''
For the output tensors of FC that do not have an assigned scale, For the output tensors of fc, conv2d and matmul ops that do not have an assigned scale,
assign a dummy scale (same scale as input), so that the quantize pass assign a dummy scale (same scale as input), so that the quantize pass
won't fail. In the end these scales aren't used, since FCs that won't fail. In the end these scales aren't used, since the ops that
have an unassigend output scale will have a force_fp32_output attr have an unassigend output scale will have a force_fp32_output attr
set to True. set to True.
''' '''
def _set_scale(op, op_types, input_names, output_name):
scales = self._var_quant_scales
should_set = op.name() in op_types \
and op.output(output_name)[0] not in scales \
and all(op.input(input_name)[0] in scales for input_name in input_names)
if should_set:
output_var_name = op.output(output_name)[0]
input_var_name = op.input(input_names[0])[0]
scales[output_var_name] = scales[input_var_name]
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._fc_ops: _set_scale(op, self._conv_ops, ["Input"], "Output")
input_name = op.input("Input")[0] _set_scale(op, self._fc_ops, ["Input"], "Out")
output_name = op.output("Out")[0] _set_scale(op, self._matmul_ops, ["X", "Y"], "Out")
if input_name in self._var_quant_scales and \
output_name not in self._var_quant_scales:
# use input scale as a "dummy" scale
self._var_quant_scales[
output_name] = self._var_quant_scales[input_name]
return graph return graph
...@@ -358,6 +366,15 @@ class Qat2Int8MkldnnPass(object): ...@@ -358,6 +366,15 @@ class Qat2Int8MkldnnPass(object):
self._remove_unused_var_nodes(graph) self._remove_unused_var_nodes(graph)
return graph return graph
def _cleanup(self, graph):
# remove dropout ops
graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass')
# make some MKL-DNN ops working inplace
graph = self._apply_pass(graph, 'mkldnn_inplace_pass')
graph = self._remove_unused_var_nodes(graph)
graph = self._set_op_role_forward(graph)
return graph
def _remove_unused_var_nodes(self, graph): def _remove_unused_var_nodes(self, graph):
all_used_vars = set() all_used_vars = set()
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
...@@ -376,8 +393,14 @@ class Qat2Int8MkldnnPass(object): ...@@ -376,8 +393,14 @@ class Qat2Int8MkldnnPass(object):
graph.safe_remove_nodes(all_unused_vars) graph.safe_remove_nodes(all_unused_vars)
return graph return graph
def _set_op_role_forward(self, graph):
ops = graph.all_op_nodes()
for op in ops:
op.set_attr("op_role", OpRole.Forward)
return graph
def _compute_weight_scales(self, graph): def _compute_weight_scales(self, graph):
def _compute_var_scales(ops, out_name, w_name, axis): def _compute_var_scales(ops, w_name, axis):
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.op().type() in ops: if op.op().type() in ops:
weight_var_name = op.input(w_name)[0] weight_var_name = op.input(w_name)[0]
...@@ -394,8 +417,8 @@ class Qat2Int8MkldnnPass(object): ...@@ -394,8 +417,8 @@ class Qat2Int8MkldnnPass(object):
self._var_quant_scales[weight_var_name] = (use_unsigned_int, self._var_quant_scales[weight_var_name] = (use_unsigned_int,
lod_tensor) lod_tensor)
_compute_var_scales(self._conv_ops, "Output", "Filter", axis=1) _compute_var_scales(self._conv_ops, "Filter", axis=1)
_compute_var_scales(self._fc_ops, "Out", "W", axis=0) _compute_var_scales(self._fc_ops, "W", axis=0)
return graph return graph
def _find_avg_pooling_ids(self, graph): def _find_avg_pooling_ids(self, graph):
......
...@@ -231,7 +231,7 @@ if(LINUX AND WITH_MKLDNN) ...@@ -231,7 +231,7 @@ if(LINUX AND WITH_MKLDNN)
### QATv2 for NLP ### QATv2 for NLP
set(QAT2_NLP_QUANTIZED_OPS "fc,reshape2,transpose2") set(QAT2_NLP_QUANTIZED_OPS "fc,reshape2,transpose2,matmul")
set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz") set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz")
set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset") set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册