未验证 提交 430f8449 编写于 作者: G guofei 提交者: GitHub

Fix the error of save_quantized_model (#30583)

* Fix the error of save_quantized_model
上级 10271ddf
......@@ -36,7 +36,7 @@ _logger = get_logger(
_op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]],
"conv2d_transpose": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input", "Filter"], ["Output"]],
"pool2d": [["X"], ["Out"]],
"elementwise_add": [["X", "Y"], ["Out"]],
"softmax": [["X"], ["Out"]],
......@@ -329,9 +329,9 @@ class ImperativeCalcOutScale(object):
super(ImperativeCalcOutScale, self).__init__()
self._moving_rate = moving_rate
self._out_scale_layer_type_list = (
BatchNorm, BatchNorm1D, BatchNorm2D, BatchNorm3D, Conv2D,
Conv2DTranspose, LeakyReLU, Linear, PReLU, Pool2D, MaxPool1D,
MaxPool2D, ReLU, ReLU6, Sigmoid, Softmax, Tanh, Swish)
BatchNorm, BatchNorm1D, BatchNorm2D, BatchNorm3D, Conv2D, LeakyReLU,
Linear, PReLU, Pool2D, MaxPool1D, MaxPool2D, ReLU, ReLU6, Sigmoid,
Softmax, Tanh, Swish)
self._register_hook_handle_list = []
self._out_scale_dict = collections.OrderedDict()
......@@ -415,9 +415,11 @@ class ImperativeCalcOutScale(object):
# Traverse all ops in the program and find out the op matching
# the Layer in the dynamic graph.
layer_var_dict = {}
layer_var_dict = collections.OrderedDict()
ops_list = [key for key, _ in self._out_scale_dict.items()]
op_count = 0
conv_count = 0
for block in inference_program.blocks:
for op in block.ops:
if op.type in _op_real_in_out_name:
......@@ -472,6 +474,9 @@ class ImperativeCalcOutScale(object):
layer_name = layer_name.replace('prelu', 'p_re_lu')
if 'relu' in layer_name:
layer_name = layer_name.replace('relu', 're_lu')
if 'conv2d' in layer_name:
layer_name = 'conv2d_' + str(conv_count)
conv_count = conv_count + 1
if layer_name not in self._out_scale_dict:
continue
var_name_op_list[1]._set_attr('out_threshold',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册