未验证 提交 814b38e3 编写于 作者: W Wojciech Uss 提交者: GitHub

update scale collection and propagation algorithm (#31783)

上级 513641e1
...@@ -62,9 +62,8 @@ class Quant2Int8MkldnnPass(object): ...@@ -62,9 +62,8 @@ class Quant2Int8MkldnnPass(object):
self._ops_to_quantize = _ops_to_quantize self._ops_to_quantize = _ops_to_quantize
self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set( self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set(
[-1]) [-1])
self._scale_immutable_ops = [ self._scale_immutable_ops = ['transpose2', 'reshape2', 'pool2d']
'transpose2', 'reshape2', 'pool2d', 'scale' self._scale_ops = ['scale']
]
self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._pool_ops = ['pool2d'] self._pool_ops = ['pool2d']
self._mul_ops = ['mul'] self._mul_ops = ['mul']
...@@ -87,8 +86,8 @@ class Quant2Int8MkldnnPass(object): ...@@ -87,8 +86,8 @@ class Quant2Int8MkldnnPass(object):
self._reset_pass_idx_and_group('int8') self._reset_pass_idx_and_group('int8')
graph = self._label_skip_quantized_op(graph) graph = self._label_skip_quantized_op(graph)
graph = self._gather_weight_thresholds_from_fake(graph) graph = self._gather_weight_thresholds_from_fake(graph)
graph = self._gather_output_scales_from_attr(graph)
graph = self._gather_input_scales_from_fake(graph) graph = self._gather_input_scales_from_fake(graph)
graph = self._gather_output_scales_from_attr(graph)
graph = self._remove_fake_ops(graph) graph = self._remove_fake_ops(graph)
graph = self._dequantize_weights(graph) graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph) graph = self._optimize_fp32_graph(graph)
...@@ -160,12 +159,16 @@ class Quant2Int8MkldnnPass(object): ...@@ -160,12 +159,16 @@ class Quant2Int8MkldnnPass(object):
op_node.op()._set_attr("skip_quant", True) op_node.op()._set_attr("skip_quant", True)
return graph return graph
def _gather_input_scales_from_fake(self, graph): def _add_scale_for_vars(self, var_names, use_unsigned_int, lod_tensor):
def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor): """
scales = self._var_quant_scales Save quantization scales for variables. Do not overwrite.
for var_name in var_names: """
scales = self._var_quant_scales
for var_name in var_names:
if var_name not in scales:
scales[var_name] = (use_unsigned_int, lod_tensor) scales[var_name] = (use_unsigned_int, lod_tensor)
def _gather_input_scales_from_fake(self, graph):
# fake_quantize_dequantize_abs_max doesn't have scale value # fake_quantize_dequantize_abs_max doesn't have scale value
fake_ops = ['fake_quantize_dequantize_moving_average_abs_max'] fake_ops = ['fake_quantize_dequantize_moving_average_abs_max']
fake_ops.extend(self._fake_quantize_types) fake_ops.extend(self._fake_quantize_types)
...@@ -185,8 +188,8 @@ class Quant2Int8MkldnnPass(object): ...@@ -185,8 +188,8 @@ class Quant2Int8MkldnnPass(object):
scale[scale == np.Inf] = 0.0 scale[scale == np.Inf] = 0.0
lod_tensor = self._convert_scale2tensor(scale) lod_tensor = self._convert_scale2tensor(scale)
use_unsigned_int = False use_unsigned_int = False
_add_scale_for_vars([input_name, output_name], use_unsigned_int, self._add_scale_for_vars([input_name, output_name],
lod_tensor) use_unsigned_int, lod_tensor)
return graph return graph
...@@ -219,8 +222,8 @@ class Quant2Int8MkldnnPass(object): ...@@ -219,8 +222,8 @@ class Quant2Int8MkldnnPass(object):
use_unsigned_int = False use_unsigned_int = False
for output_name in op.op().outputs(): for output_name in op.op().outputs():
for out_var_name in op.op().output(output_name): for out_var_name in op.op().output(output_name):
self._var_quant_scales[out_var_name] = ( self._add_scale_for_vars(
use_unsigned_int, scale_lod_tensor) [out_var_name], use_unsigned_int, scale_lod_tensor)
return graph return graph
...@@ -239,24 +242,21 @@ class Quant2Int8MkldnnPass(object): ...@@ -239,24 +242,21 @@ class Quant2Int8MkldnnPass(object):
output_name = op.output("Out")[0] output_name = op.output("Out")[0]
tensor_names = [input_name, output_name] tensor_names = [input_name, output_name]
# Scale is not quantized, so if it doesn't have any scales if all(name not in self._var_quant_scales
# to propagate, its tensors won't be added to the waiting list. for name in tensor_names):
if all(name not in self._var_quant_scales for name in tensor_names) \
and op.name() != 'scale':
waiting_for_scale.update(tensor_names) waiting_for_scale.update(tensor_names)
continue continue
elif input_name in self._var_quant_scales:
if input_name in self._var_quant_scales:
self._var_quant_scales[ self._var_quant_scales[
output_name] = self._var_quant_scales[input_name] output_name] = self._var_quant_scales[input_name]
elif output_name in self._var_quant_scales: elif output_name in self._var_quant_scales:
if op.name() == 'scale': self._var_quant_scales[
_update_scale_op_in_scale(op, input_name, input_name] = self._var_quant_scales[output_name]
output_name) elif op.name() in self._scale_ops:
else: input_name = op.input("X")[0]
self._var_quant_scales[ output_name = op.output("Out")[0]
input_name] = self._var_quant_scales[ if output_name in self._var_quant_scales:
output_name] _update_scale_op_in_scale(op, input_name, output_name)
return waiting_for_scale return waiting_for_scale
waiting_for_scale = _update_scales(graph) waiting_for_scale = _update_scales(graph)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册