diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 26fa0f0d484058b1e7b898a1365e67c387de021a..696168251e70d7d1e538c94df3e0212efefe0d57 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -36,7 +36,7 @@ _logger = get_logger( _op_real_in_out_name = { "conv2d": [["Input", "Filter"], ["Output"]], - "conv2d_transpose": [["Input", "Filter"], ["Output"]], + "depthwise_conv2d": [["Input", "Filter"], ["Output"]], "pool2d": [["X"], ["Out"]], "elementwise_add": [["X", "Y"], ["Out"]], "softmax": [["X"], ["Out"]], @@ -329,9 +329,9 @@ class ImperativeCalcOutScale(object): super(ImperativeCalcOutScale, self).__init__() self._moving_rate = moving_rate self._out_scale_layer_type_list = ( - BatchNorm, BatchNorm1D, BatchNorm2D, BatchNorm3D, Conv2D, - Conv2DTranspose, LeakyReLU, Linear, PReLU, Pool2D, MaxPool1D, - MaxPool2D, ReLU, ReLU6, Sigmoid, Softmax, Tanh, Swish) + BatchNorm, BatchNorm1D, BatchNorm2D, BatchNorm3D, Conv2D, LeakyReLU, + Linear, PReLU, Pool2D, MaxPool1D, MaxPool2D, ReLU, ReLU6, Sigmoid, + Softmax, Tanh, Swish) self._register_hook_handle_list = [] self._out_scale_dict = collections.OrderedDict() @@ -415,9 +415,10 @@ class ImperativeCalcOutScale(object): # Traverse all ops in the program and find out the op matching # the Layer in the dynamic graph. - layer_var_dict = {} + layer_var_dict = collections.OrderedDict() ops_list = [key for key, _ in self._out_scale_dict.items()] op_count = 0 + conv_count = 0 for block in inference_program.blocks: for op in block.ops: if op.type in _op_real_in_out_name: @@ -472,6 +473,9 @@ class ImperativeCalcOutScale(object): layer_name = layer_name.replace('prelu', 'p_re_lu') if 'relu' in layer_name: layer_name = layer_name.replace('relu', 're_lu') + if 'conv2d' in layer_name: + layer_name = 'conv2d_' + str(conv_count) + conv_count = conv_count + 1 if layer_name not in self._out_scale_dict: continue var_name_op_list[1]._set_attr('out_threshold',