未验证 提交 2f16e0c6 编写于 作者: C cc 提交者: GitHub

skip quantizing ops in cpu inference (#30342) (#30405)

上级 c07027e0
......@@ -42,6 +42,10 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
return;
}
if (op->Op()->GetAttrIfExists<int>("skip_quant") == 1) {
return;
}
if (op->Op()->HasAttr("mkldnn_data_type") ||
op->Op()->HasProtoAttr("mkldnn_data_type")) {
// use_quantizer is no longer used
......
......@@ -56,7 +56,8 @@ class Quant2Int8MkldnnPass(object):
]
self._fake_quantize_dequantize_types = [
'fake_quantize_dequantize_abs_max',
'fake_quantize_dequantize_moving_average_abs_max'
'fake_quantize_dequantize_moving_average_abs_max',
'fake_channel_wise_quantize_dequantize_abs_max'
]
self._ops_to_quantize = _ops_to_quantize
self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set(
......@@ -71,7 +72,7 @@ class Quant2Int8MkldnnPass(object):
self._relu_ops = ['relu', 'relu6']
self._matmul_ops = ['matmul']
self._gru_ops = ['fusion_gru', 'multi_gru']
self._weight_scales = {}
self._weight_thresholds = {}
# Collect the Input and Output sclaes from Fake quant models
self._var_quant_scales = {}
self._max_range = {}
......@@ -84,7 +85,8 @@ class Quant2Int8MkldnnPass(object):
IrGraph), 'graph must be the instance of IrGraph.'
self._reset_pass_idx_and_group('int8')
graph = self._gather_weight_scales_from_fake(graph)
graph = self._label_skip_quantized_op(graph)
graph = self._gather_weight_thresholds_from_fake(graph)
graph = self._gather_output_scales_from_attr(graph)
graph = self._gather_input_scales_from_fake(graph)
graph = self._remove_fake_ops(graph)
......@@ -135,6 +137,30 @@ class Quant2Int8MkldnnPass(object):
def _is_fc_quantized(self, graph):
return self._is_any_of_op_types_quantized(self._fc_ops, graph)
def _label_skip_quantized_op(self, graph):
"""
For some ops(conv2d, depthwise_conv2d, mul, matml), find and label
the skip quantized ops. cpu_quantize_placement_pass will use the
label to identify it.
For static models, the skip quantized ops have `skip_quant` attr.
Therefore, it only needs to find and label the skip quantized ops for
dygraph models, in which the quantized ops don't have `quantization_type`
attr.
"""
target_ops = self._conv_ops + self._mul_ops + self._matmul_ops
for op_node in graph.all_op_nodes():
if op_node.name() in target_ops and \
not op_node.op().has_attr("quantization_type"):
is_quantized_op = True
for var_node in op_node.inputs:
for front_op_node in var_node.inputs:
if "fake_quantize_dequantize_" not in front_op_node.name(
):
is_quantized_op = False
if not is_quantized_op:
op_node.op()._set_attr("skip_quant", True)
return graph
def _gather_input_scales_from_fake(self, graph):
def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor):
scales = self._var_quant_scales
......@@ -165,19 +191,19 @@ class Quant2Int8MkldnnPass(object):
return graph
def _gather_weight_scales_from_fake(self, graph):
def _gather_weight_thresholds_from_fake(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._fake_dequantize_types:
input_name = op.input("X")[0]
if op.op().has_attr("max_range"):
_max_range = np.array(op.op().attr("max_range")).astype(
np.float64)
self._weight_scales[input_name] = np.array(
self._weight_thresholds[input_name] = np.array(
self._s8_max * self._s8_max /
_max_range).astype(np.float64)
else:
scale_name = op.input("Scales")[0]
self._weight_scales[input_name] = np.array(
self._weight_thresholds[input_name] = np.array(
self._load_param(self._scope, scale_name)).astype(
np.float64)
......@@ -314,7 +340,7 @@ class Quant2Int8MkldnnPass(object):
weight_var_name = op_node.input(weight_name)[0]
output_var_name = op_node.output(output_name)[0]
# Convert int8 range weights to fp32 range weights
scales = self._weight_scales[output_var_name]
scales = self._weight_thresholds[output_var_name]
weight = self._load_param(self._scope, weight_var_name)
if scales.size == 1 or scales.size == weight.shape[0]:
w_fp32 = np.multiply(np.divide(weight, self._s8_max).T, scales.T).T
......
......@@ -180,7 +180,7 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase):
_place=self.place,
_core=core,
_debug=False)
qpass._weight_scales["mul_output"] = self.mul_output_scale
qpass._weight_thresholds["mul_output"] = self.mul_output_scale
param = self.scope.var("mul_weights").get_tensor()
param.set(self.variables_mul["mul_weights"], self.place)
qpass._dequantize_op_weights(graph, op_node, "Y", "Out")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册