未验证 提交 814b38e3 编写于 作者: W Wojciech Uss 提交者: GitHub

update scale collection and propagation algorithm (#31783)

上级 513641e1
......@@ -62,9 +62,8 @@ class Quant2Int8MkldnnPass(object):
self._ops_to_quantize = _ops_to_quantize
self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set(
[-1])
self._scale_immutable_ops = [
'transpose2', 'reshape2', 'pool2d', 'scale'
]
self._scale_immutable_ops = ['transpose2', 'reshape2', 'pool2d']
self._scale_ops = ['scale']
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._pool_ops = ['pool2d']
self._mul_ops = ['mul']
......@@ -87,8 +86,8 @@ class Quant2Int8MkldnnPass(object):
self._reset_pass_idx_and_group('int8')
graph = self._label_skip_quantized_op(graph)
graph = self._gather_weight_thresholds_from_fake(graph)
graph = self._gather_output_scales_from_attr(graph)
graph = self._gather_input_scales_from_fake(graph)
graph = self._gather_output_scales_from_attr(graph)
graph = self._remove_fake_ops(graph)
graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph)
......@@ -160,12 +159,16 @@ class Quant2Int8MkldnnPass(object):
op_node.op()._set_attr("skip_quant", True)
return graph
def _gather_input_scales_from_fake(self, graph):
def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor):
scales = self._var_quant_scales
for var_name in var_names:
def _add_scale_for_vars(self, var_names, use_unsigned_int, lod_tensor):
"""
Save quantization scales for variables. Do not overwrite.
"""
scales = self._var_quant_scales
for var_name in var_names:
if var_name not in scales:
scales[var_name] = (use_unsigned_int, lod_tensor)
def _gather_input_scales_from_fake(self, graph):
# fake_quantize_dequantize_abs_max doesn't have scale value
fake_ops = ['fake_quantize_dequantize_moving_average_abs_max']
fake_ops.extend(self._fake_quantize_types)
......@@ -185,8 +188,8 @@ class Quant2Int8MkldnnPass(object):
scale[scale == np.Inf] = 0.0
lod_tensor = self._convert_scale2tensor(scale)
use_unsigned_int = False
_add_scale_for_vars([input_name, output_name], use_unsigned_int,
lod_tensor)
self._add_scale_for_vars([input_name, output_name],
use_unsigned_int, lod_tensor)
return graph
......@@ -219,8 +222,8 @@ class Quant2Int8MkldnnPass(object):
use_unsigned_int = False
for output_name in op.op().outputs():
for out_var_name in op.op().output(output_name):
self._var_quant_scales[out_var_name] = (
use_unsigned_int, scale_lod_tensor)
self._add_scale_for_vars(
[out_var_name], use_unsigned_int, scale_lod_tensor)
return graph
......@@ -239,24 +242,21 @@ class Quant2Int8MkldnnPass(object):
output_name = op.output("Out")[0]
tensor_names = [input_name, output_name]
# Scale is not quantized, so if it doesn't have any scales
# to propagate, its tensors won't be added to the waiting list.
if all(name not in self._var_quant_scales for name in tensor_names) \
and op.name() != 'scale':
if all(name not in self._var_quant_scales
for name in tensor_names):
waiting_for_scale.update(tensor_names)
continue
if input_name in self._var_quant_scales:
elif input_name in self._var_quant_scales:
self._var_quant_scales[
output_name] = self._var_quant_scales[input_name]
elif output_name in self._var_quant_scales:
if op.name() == 'scale':
_update_scale_op_in_scale(op, input_name,
output_name)
else:
self._var_quant_scales[
input_name] = self._var_quant_scales[
output_name]
self._var_quant_scales[
input_name] = self._var_quant_scales[output_name]
elif op.name() in self._scale_ops:
input_name = op.input("X")[0]
output_name = op.output("Out")[0]
if output_name in self._var_quant_scales:
_update_scale_op_in_scale(op, input_name, output_name)
return waiting_for_scale
waiting_for_scale = _update_scales(graph)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册