未验证 提交 430f8449 编写于 作者: G guofei 提交者: GitHub

Fix the error of save_quantized_model (#30583)

* Fix the error of save_quantized_model
上级 10271ddf
...@@ -36,7 +36,7 @@ _logger = get_logger( ...@@ -36,7 +36,7 @@ _logger = get_logger(
_op_real_in_out_name = { _op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]], "conv2d": [["Input", "Filter"], ["Output"]],
"conv2d_transpose": [["Input", "Filter"], ["Output"]], "depthwise_conv2d": [["Input", "Filter"], ["Output"]],
"pool2d": [["X"], ["Out"]], "pool2d": [["X"], ["Out"]],
"elementwise_add": [["X", "Y"], ["Out"]], "elementwise_add": [["X", "Y"], ["Out"]],
"softmax": [["X"], ["Out"]], "softmax": [["X"], ["Out"]],
...@@ -329,9 +329,9 @@ class ImperativeCalcOutScale(object): ...@@ -329,9 +329,9 @@ class ImperativeCalcOutScale(object):
super(ImperativeCalcOutScale, self).__init__() super(ImperativeCalcOutScale, self).__init__()
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._out_scale_layer_type_list = ( self._out_scale_layer_type_list = (
BatchNorm, BatchNorm1D, BatchNorm2D, BatchNorm3D, Conv2D, BatchNorm, BatchNorm1D, BatchNorm2D, BatchNorm3D, Conv2D, LeakyReLU,
Conv2DTranspose, LeakyReLU, Linear, PReLU, Pool2D, MaxPool1D, Linear, PReLU, Pool2D, MaxPool1D, MaxPool2D, ReLU, ReLU6, Sigmoid,
MaxPool2D, ReLU, ReLU6, Sigmoid, Softmax, Tanh, Swish) Softmax, Tanh, Swish)
self._register_hook_handle_list = [] self._register_hook_handle_list = []
self._out_scale_dict = collections.OrderedDict() self._out_scale_dict = collections.OrderedDict()
...@@ -415,9 +415,11 @@ class ImperativeCalcOutScale(object): ...@@ -415,9 +415,11 @@ class ImperativeCalcOutScale(object):
# Traverse all ops in the program and find out the op matching # Traverse all ops in the program and find out the op matching
# the Layer in the dynamic graph. # the Layer in the dynamic graph.
layer_var_dict = {} layer_var_dict = collections.OrderedDict()
ops_list = [key for key, _ in self._out_scale_dict.items()] ops_list = [key for key, _ in self._out_scale_dict.items()]
op_count = 0 op_count = 0
conv_count = 0
for block in inference_program.blocks: for block in inference_program.blocks:
for op in block.ops: for op in block.ops:
if op.type in _op_real_in_out_name: if op.type in _op_real_in_out_name:
...@@ -472,6 +474,9 @@ class ImperativeCalcOutScale(object): ...@@ -472,6 +474,9 @@ class ImperativeCalcOutScale(object):
layer_name = layer_name.replace('prelu', 'p_re_lu') layer_name = layer_name.replace('prelu', 'p_re_lu')
if 'relu' in layer_name: if 'relu' in layer_name:
layer_name = layer_name.replace('relu', 're_lu') layer_name = layer_name.replace('relu', 're_lu')
if 'conv2d' in layer_name:
layer_name = 'conv2d_' + str(conv_count)
conv_count = conv_count + 1
if layer_name not in self._out_scale_dict: if layer_name not in self._out_scale_dict:
continue continue
var_name_op_list[1]._set_attr('out_threshold', var_name_op_list[1]._set_attr('out_threshold',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册