From 393a91f16b295a7fa1ab5fb8cbb8431253b19d1c Mon Sep 17 00:00:00 2001 From: guofei <52460041+gfwm2013@users.noreply.github.com> Date: Mon, 11 Jan 2021 14:44:43 +0800 Subject: [PATCH] Quantization supports 2.0 APIs (#30036) (#30257) * Quantization supports 2.0 APIs * Fix the error of save_quantized_model --- .../slim/quantization/imperative/qat.py | 100 ++++++++++++------ .../slim/tests/test_imperative_out_scale.py | 35 +++--- 2 files changed, 85 insertions(+), 50 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 58bfc58dcc..b543a91372 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import logging import numpy as np import sys @@ -20,8 +21,8 @@ import paddle from paddle.fluid import dygraph, core, framework from paddle.fluid.executor import Executor from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX -from paddle.nn import Linear, Conv2D -from paddle.fluid.dygraph.nn import BatchNorm, Pool2D, Conv2DTranspose +from paddle.nn import Linear, Conv2D, Conv2DTranspose, MaxPool2D, MaxPool1D, BatchNorm1D, BatchNorm2D, BatchNorm3D +from paddle.fluid.dygraph.nn import BatchNorm, Pool2D from paddle.fluid.io import load_inference_model, save_inference_model from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6, Tanh, Softmax, PReLU, Swish from paddle.fluid.log_helper import get_logger @@ -263,6 +264,7 @@ class ImperativeQuantAware(object): parent = obj quant_layer = self._get_quantized_counterpart(layer) + setattr(quant_layer, "layer_name", layer.full_name()) setattr(obj, target, quant_layer) self._out_scale.calc_out_scale(model) @@ -306,10 +308,11 @@ class ImperativeCalcOutScale(object): super(ImperativeCalcOutScale, self).__init__() self._moving_rate = moving_rate self._out_scale_layer_type_list = ( - BatchNorm, Conv2D, Conv2DTranspose, LeakyReLU, Linear, PReLU, - Pool2D, ReLU, ReLU6, Sigmoid, Softmax, Tanh, Swish) + BatchNorm, BatchNorm1D, BatchNorm2D, BatchNorm3D, Conv2D, + Conv2DTranspose, LeakyReLU, Linear, PReLU, Pool2D, MaxPool1D, + MaxPool2D, ReLU, ReLU6, Sigmoid, Softmax, Tanh, Swish) self._register_hook_handle_list = [] - self._out_scale_dict = {} + self._out_scale_dict = collections.OrderedDict() def calc_out_scale(self, model): """ @@ -325,7 +328,8 @@ class ImperativeCalcOutScale(object): model, dygraph.Layer), "model must be the instance of dygraph.Layer" for _, layer in model.named_sublayers(): if not isinstance(layer, self._out_scale_layer_type_list): - continue + if 'quantized_' not in layer.full_name(): + continue forward_post_hook_handle = layer.register_forward_post_hook( self._forward_post_hook) self._register_hook_handle_list.append(forward_post_hook_handle) @@ -364,12 +368,12 @@ class ImperativeCalcOutScale(object): self._out_scale_dict[key] = float(self._out_scale_dict[key] .numpy()) - paddle.jit.save(layer=layer, path=path, input_spec=input_spec, **config) - if paddle.in_dynamic_mode(): is_dynamic_mode = True paddle.enable_static() + paddle.jit.save(layer=layer, path=path, input_spec=input_spec, **config) + if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) else: @@ -391,40 +395,54 @@ class ImperativeCalcOutScale(object): # Traverse all ops in the program and find out the op matching # the Layer in the dynamic graph. layer_var_dict = {} + ops_list = [key for key, _ in self._out_scale_dict.items()] + op_count = 0 for block in inference_program.blocks: for op in block.ops: if op.type in _op_real_in_out_name: - output_var_names = quantization_pass._get_op_output_var_names( - op) - for output_var_name in output_var_names: - output_var_tensor = block.var(output_var_name) - if output_var_tensor.dtype not in [ - core.VarDesc.VarType.FP64, - core.VarDesc.VarType.FP32 - ]: + if op.type in ["batch_norm", "pool2d"]: + if op.type == "pool2d" and op.attr( + "pooling_type") != "max": continue - # Because the Layer in dygraph may correspond to multiple ops - # in static program after being saved. To ensure correctness, - # the outscale collected for output of dygraph Layer can only - # be set to the last op in the corresponding ops in static program. - # - # We can judge the execution order of the ops which corresponding - # to dygraph Layer by the name of output. And use dict to save - # the corresponding relationship between the dygraph Layer and the - # static graph op that needs to set the outscale attribute. - if '.' not in output_var_name: + op_count = self.op_match(op, ops_list, op_count) + if op_count >= len(ops_list): continue - dynamic_layer_name, var_name_suffix = output_var_name.split( - ".") - if dynamic_layer_name in layer_var_dict: - if layer_var_dict[dynamic_layer_name][ - 0] < var_name_suffix: + op._set_attr('out_threshold', + self._out_scale_dict[ops_list[op_count]]) + op_count += 1 + else: + output_var_names = quantization_pass._get_op_output_var_names( + op) + for output_var_name in output_var_names: + output_var_tensor = block.var(output_var_name) + if output_var_tensor.dtype not in [ + core.VarDesc.VarType.FP64, + core.VarDesc.VarType.FP32 + ]: + continue + # Because the Layer in dygraph may correspond to multiple ops + # in static program after being saved. To ensure correctness, + # the outscale collected for output of dygraph Layer can only + # be set to the last op in the corresponding ops in static program. + # + # We can judge the execution order of the ops which corresponding + # to dygraph Layer by the name of output. And use dict to save + # the corresponding relationship between the dygraph Layer and the + # static graph op that needs to set the outscale attribute. + if '.' not in output_var_name: + continue + dynamic_layer_name, var_name_suffix = output_var_name.split( + ".") + if dynamic_layer_name in layer_var_dict: + if layer_var_dict[dynamic_layer_name][ + 0] < var_name_suffix: + layer_var_dict[dynamic_layer_name] = [ + var_name_suffix, op + ] + else: layer_var_dict[dynamic_layer_name] = [ var_name_suffix, op ] - else: - layer_var_dict[ - dynamic_layer_name] = [var_name_suffix, op] # Because the naming styles of static and dynamic graph are different, # in order to avoid mistakes, we unify the name here. @@ -451,6 +469,14 @@ class ImperativeCalcOutScale(object): if is_dynamic_mode: paddle.disable_static() + def op_match(self, op, ops_list, op_count): + while op_count < len(ops_list) and op.type not in ops_list[op_count]: + op_count += 1 + while op_count < len(ops_list) and op.type is "pool2d" and op.attr( + "pooling_type") != "max": + op_count += 1 + return op_count + def _forward_post_hook(self, layer, input, output): assert isinstance( output, (core.VarBase, framework.Variable) @@ -463,4 +489,8 @@ class ImperativeCalcOutScale(object): layer._out_scale = quant_nn.MovingAverageAbsMaxScale( output.name, self._moving_rate, output.dtype) scale_out = layer._out_scale(output) - self._out_scale_dict[layer.full_name()] = scale_out + if hasattr(layer, 'layer_name'): + layer_name = layer.layer_name + else: + layer_name = layer.full_name() + self._out_scale_dict[layer_name] = scale_out diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_out_scale.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_out_scale.py index a900096a99..47e21910b4 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_out_scale.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_out_scale.py @@ -30,9 +30,10 @@ from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass, OutS from paddle.fluid.dygraph.container import Sequential from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.nn.layer import ReLU, LeakyReLU, Sigmoid, Softmax, ReLU6 -from paddle.nn import Linear, Conv2D, Softmax, BatchNorm +from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D from paddle.fluid.dygraph.nn import Pool2D from paddle.fluid.log_helper import get_logger +from paddle.fluid.dygraph import nn paddle.enable_static() @@ -50,7 +51,6 @@ def StaticLenet(data, num_classes=10, classifier_activation='softmax'): fc_w1_attr = fluid.ParamAttr(name="fc_w_1") fc_w2_attr = fluid.ParamAttr(name="fc_w_2") fc_w3_attr = fluid.ParamAttr(name="fc_w_3") - conv2d_b1_attr = fluid.ParamAttr(name="conv2d_b_1") conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2") fc_b1_attr = fluid.ParamAttr(name="fc_b_1") fc_b2_attr = fluid.ParamAttr(name="fc_b_2") @@ -62,7 +62,7 @@ def StaticLenet(data, num_classes=10, classifier_activation='softmax'): stride=1, padding=1, param_attr=conv2d_w1_attr, - bias_attr=conv2d_b1_attr) + bias_attr=False) batch_norm1 = layers.batch_norm(conv1) relu1 = layers.relu(batch_norm1) pool1 = fluid.layers.pool2d( @@ -99,14 +99,13 @@ def StaticLenet(data, num_classes=10, classifier_activation='softmax'): class ImperativeLenet(fluid.dygraph.Layer): - def __init__(self, num_classes=10, classifier_activation='softmax'): + def __init__(self, num_classes=10): super(ImperativeLenet, self).__init__() conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1") conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2") fc_w1_attr = fluid.ParamAttr(name="fc_w_1") fc_w2_attr = fluid.ParamAttr(name="fc_w_2") fc_w3_attr = fluid.ParamAttr(name="fc_w_3") - conv2d_b1_attr = fluid.ParamAttr(name="conv2d_b_1") conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2") fc_b1_attr = fluid.ParamAttr(name="fc_b_1") fc_b2_attr = fluid.ParamAttr(name="fc_b_2") @@ -119,8 +118,8 @@ class ImperativeLenet(fluid.dygraph.Layer): stride=1, padding=1, weight_attr=conv2d_w1_attr, - bias_attr=conv2d_b1_attr), - BatchNorm(6), + bias_attr=False), + BatchNorm2D(6), ReLU(), Pool2D( pool_size=2, pool_type='max', pool_stride=2), @@ -132,10 +131,10 @@ class ImperativeLenet(fluid.dygraph.Layer): padding=0, weight_attr=conv2d_w2_attr, bias_attr=conv2d_b2_attr), - BatchNorm(16), + BatchNorm2D(16), ReLU6(), - Pool2D( - pool_size=2, pool_type='max', pool_stride=2)) + MaxPool2D( + kernel_size=2, stride=2)) self.fc = Sequential( Linear( @@ -188,10 +187,10 @@ class TestImperativeOutSclae(unittest.TestCase): reader = paddle.batch( paddle.dataset.mnist.test(), batch_size=32, drop_last=True) weight_quantize_type = 'abs_max' - activation_quant_type = 'moving_average_abs_max' + activation_quantize_type = 'moving_average_abs_max' param_init_map = {} seed = 1000 - lr = 0.1 + lr = 0.001 dynamic_out_scale_list = [] static_out_scale_list = [] @@ -199,7 +198,9 @@ class TestImperativeOutSclae(unittest.TestCase): _logger.info( "--------------------------dynamic graph qat--------------------------" ) - imperative_out_scale = ImperativeQuantAware() + imperative_out_scale = ImperativeQuantAware( + weight_quantize_type=weight_quantize_type, + activation_quantize_type=activation_quantize_type) with fluid.dygraph.guard(): np.random.seed(seed) @@ -282,14 +283,18 @@ class TestImperativeOutSclae(unittest.TestCase): with fluid.scope_guard(scope): exe.run(startup) for param in main.all_parameters(): + if "batch_norm" in param.name: + param_name = param.name.replace("norm", "norm2d") + else: + param_name = param.name param_tensor = scope.var(param.name).get_tensor() - param_tensor.set(param_init_map[param.name], place) + param_tensor.set(param_init_map[param_name], place) main_graph = IrGraph(core.Graph(main.desc), for_test=False) infer_graph = IrGraph(core.Graph(infer.desc), for_test=True) transform_pass = QuantizationTransformPass( scope=scope, place=place, - activation_quantize_type=activation_quant_type, + activation_quantize_type=activation_quantize_type, weight_quantize_type=weight_quantize_type, quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']) transform_pass.apply(main_graph) -- GitLab