diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index c5ee9ea6751003c19ef5b43f1af0f09093bded89..afe8a3de6673f0869e2a1cb588bffe0167b69b8d 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -17,11 +17,15 @@ import logging import numpy as np import sys import os +import warnings + import paddle -from paddle.fluid import dygraph, core, framework +from paddle.fluid import dygraph, core, framework, unique_name from paddle.fluid.executor import Executor +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Constant from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX -from paddle.nn import Linear, Conv2D, Conv2DTranspose, MaxPool2D, MaxPool1D, BatchNorm1D, BatchNorm2D, BatchNorm3D +from paddle.nn import Linear, Conv2D, Conv2DTranspose, MaxPool2D, MaxPool1D, BatchNorm1D, BatchNorm2D, BatchNorm3D, SyncBatchNorm from paddle.fluid.dygraph.nn import BatchNorm, Pool2D from paddle.fluid.io import load_inference_model, save_inference_model from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6, Tanh, Softmax, PReLU, Swish @@ -331,10 +335,73 @@ class ImperativeCalcOutScale(object): self._out_scale_layer_type_list = ( BatchNorm, BatchNorm1D, BatchNorm2D, BatchNorm3D, Conv2D, LeakyReLU, Linear, PReLU, Pool2D, MaxPool1D, MaxPool2D, ReLU, ReLU6, Sigmoid, - Softmax, Tanh, Swish) + Softmax, SyncBatchNorm, Tanh, Swish) self._register_hook_handle_list = [] self._out_scale_dict = collections.OrderedDict() + # Determine whether layer supports calculation out_scale + def _is_matched_layer(self, layer): + if not isinstance(layer, self._out_scale_layer_type_list): + if 'quantized_' not in layer.full_name(): + return False + return True + + # When inferenc model is saved, the logic in hook would not be executed + # in program translation, so that some parameters can not created in + # __init__, which would cause the model to fail to save. Therefore, the + # parameters creation in the hook is advanced to be exected outside the hook. + def _add_new_parameters(self, layer, name=None): + dtype = layer._dtype if layer._dtype is not None else "float32" + if dtype not in ["float32", "float64"]: + return + scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale' + scale_name = unique_name.generate(scale_prefix) + scale_attr = ParamAttr( + name=scale_name, initializer=Constant(1), trainable=False) + layer._quant_out_scale = layer.create_parameter( + shape=[1], attr=scale_attr, dtype=dtype) + layer._quant_out_scale.stop_gradient = True + + state_prefix = "{}.state".format(name) if name else 'outscale.state' + state_attr = ParamAttr( + name=unique_name.generate(state_prefix), + initializer=Constant(1), + trainable=False) + layer._quant_out_state = layer.create_parameter( + shape=[1], attr=state_attr, dtype=dtype) + layer._quant_out_state.stop_gradient = True + + accum_prefix = "{}.accum".format(name) if name else 'outscale.accum' + accum_attr = ParamAttr( + name=unique_name.generate(accum_prefix), + initializer=Constant(1), + trainable=False) + layer._quant_out_accum = layer.create_parameter( + shape=[1], attr=accum_attr, dtype=dtype) + layer._quant_out_accum.stop_gradient = True + + # Judge whether the op in program matches the Layer in dynamic model + def _is_op_matched(self, layer_name, op, block): + output_var_names = quantization_pass._get_op_output_var_names(op) + for output_var_name in output_var_names: + output_var_tensor = block.var(output_var_name) + if output_var_tensor.dtype not in [ + core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32 + ]: + return False + + # Because the naming styles of static and dynamic graph are different, + # in order to avoid mistakes, we unify the name here. + op_type = output_var_names[0].split(".")[0] + op_type = op_type.rsplit("_", 1)[0] + if op_type == 'depthwise_conv2d': + op_type = 'conv2d' + if 'prelu' in op_type: + op_type = op_type.replace('prelu', 'p_re_lu') + if 'relu' in op_type: + op_type = op_type.replace('relu', 're_lu') + return op_type in layer_name + def calc_out_scale(self, model): """ Insert the `moving_average_abs_max_scale` op to calculate output scale of Specific layers in model. @@ -348,12 +415,11 @@ class ImperativeCalcOutScale(object): assert isinstance( model, dygraph.Layer), "model must be the instance of dygraph.Layer" for _, layer in model.named_sublayers(): - if not isinstance(layer, self._out_scale_layer_type_list): - if 'quantized_' not in layer.full_name(): - continue - forward_post_hook_handle = layer.register_forward_post_hook( - self._forward_post_hook) - self._register_hook_handle_list.append(forward_post_hook_handle) + if self._is_matched_layer(layer): + self._add_new_parameters(layer) + forward_post_hook_handle = layer.register_forward_post_hook( + self._forward_post_hook) + self._register_hook_handle_list.append(forward_post_hook_handle) def save_quantized_model(self, layer, path, input_spec=None, **config): """ @@ -380,14 +446,26 @@ class ImperativeCalcOutScale(object): assert isinstance( layer, dygraph.Layer), "model must be the instance of dygraph.Layer" + self._layer = layer is_dynamic_mode = False with dygraph.guard(): - layer.eval() - for handle in self._register_hook_handle_list: - handle.remove() - for key in self._out_scale_dict: - self._out_scale_dict[key] = float(self._out_scale_dict[key] - .numpy()) + self._layer.eval() + if self._register_hook_handle_list is not None: + for handle in self._register_hook_handle_list: + handle.remove() + if self._out_scale_dict: + for key in self._out_scale_dict: + self._out_scale_dict[key] = float(self._out_scale_dict[key] + .numpy()) + else: + for _, sub_layer in self._layer.named_sublayers(): + if self._is_matched_layer(sub_layer): + layer_name = sub_layer.full_name() + if hasattr(sub_layer, "layer_name"): + layer_name = sub_layer.layer_name + if hasattr(sub_layer, "_quant_out_scale"): + self._out_scale_dict[layer_name] = float( + sub_layer._quant_out_scale) if paddle.in_dynamic_mode(): is_dynamic_mode = True @@ -413,74 +491,68 @@ class ImperativeCalcOutScale(object): model_filename=model_filename, params_filename=params_filename)) - # Traverse all ops in the program and find out the op matching - # the Layer in the dynamic graph. - layer_var_dict = collections.OrderedDict() - ops_list = [key for key, _ in self._out_scale_dict.items()] + check_behind_op = False op_count = 0 - conv_count = 0 - - for block in inference_program.blocks: - for op in block.ops: - if op.type in _op_real_in_out_name: - if op.type in ["batch_norm", "pool2d"]: - if op.type == "pool2d" and op.attr( - "pooling_type") != "max": - continue - op_count = self.op_match(op, ops_list, op_count) - if op_count >= len(ops_list): - continue - op._set_attr('out_threshold', - self._out_scale_dict[ops_list[op_count]]) - op_count += 1 - else: - output_var_names = quantization_pass._get_op_output_var_names( - op) - for output_var_name in output_var_names: - output_var_tensor = block.var(output_var_name) - if output_var_tensor.dtype not in [ - core.VarDesc.VarType.FP64, - core.VarDesc.VarType.FP32 - ]: - continue - # Because the Layer in dygraph may correspond to multiple ops - # in static program after being saved. To ensure correctness, - # the outscale collected for output of dygraph Layer can only - # be set to the last op in the corresponding ops in static program. - # - # We can judge the execution order of the ops which corresponding - # to dygraph Layer by the name of output. And use dict to save - # the corresponding relationship between the dygraph Layer and the - # static graph op that needs to set the outscale attribute. - if '.' not in output_var_name: + ops_list = [key for key, _ in self._out_scale_dict.items()] + if len(ops_list) == 0: + warnings.warn( + "Warning: No Layer of the model while to be saved contains the out_threshold attribute, " + "so the generated inference model would not contain the out_threshold." + ) + else: + # Because the Layer in dygraph may correspond to multiple ops + # in static program after being saved. To ensure correctness, + # the outscale collected for output of dygraph Layer can only + # be set to the last op in the corresponding ops in static program. + # + # We can judge the execution order of the ops which corresponding + # to dygraph Layer by check_behind_op + forward_op = None + for block in inference_program.blocks: + for op in block.ops: + if op.type in _op_real_in_out_name: + if op_count > len(ops_list): + warnings.warn( + "The number of Layer which has out_threshold attribute should be bigger than the op in inference model" + ) + break + if check_behind_op: + check_behind_op = False + if op.type == "elementwise_add": + if self._is_op_matched(ops_list[op_count], op, + block): + op._set_attr("out_threshold", + self._out_scale_dict[ops_list[ + op_count]]) + op_count += 1 + forward_op = None continue - dynamic_layer_name, var_name_suffix = output_var_name.split( - ".") - if dynamic_layer_name in layer_var_dict: - if layer_var_dict[dynamic_layer_name][ - 0] < var_name_suffix: - layer_var_dict[dynamic_layer_name] = [ - var_name_suffix, op - ] else: - layer_var_dict[dynamic_layer_name] = [ - var_name_suffix, op - ] - - # Because the naming styles of static and dynamic graph are different, - # in order to avoid mistakes, we unify the name here. - for (layer_name, var_name_op_list) in layer_var_dict.items(): - if 'prelu' in layer_name: - layer_name = layer_name.replace('prelu', 'p_re_lu') - if 'relu' in layer_name: - layer_name = layer_name.replace('relu', 're_lu') - if 'conv2d' in layer_name: - layer_name = 'conv2d_' + str(conv_count) - conv_count = conv_count + 1 - if layer_name not in self._out_scale_dict: - continue - var_name_op_list[1]._set_attr('out_threshold', - self._out_scale_dict[layer_name]) + if forward_op is None: + raise ValueError( + "forward_op should not be None") + if self._is_op_matched(ops_list[op_count], + forward_op, block): + forward_op._set_attr( + "out_threshold", self._out_scale_dict[ + ops_list[op_count]]) + op_count += 1 + forward_op = None + + if op.type in ["conv2d", "depthwise_conv2d", "matmul"]: + check_behind_op = True + forward_op = op + continue + if op_count >= len(ops_list): + warnings.warn( + "The number of Layer which has out_threshold attribute should be bigger than the op in inference model" + ) + break + if self._is_op_matched(ops_list[op_count], op, block): + op._set_attr( + "out_threshold", + self._out_scale_dict[ops_list[op_count]]) + op_count += 1 # Save the processed program. save_inference_model( @@ -495,14 +567,6 @@ class ImperativeCalcOutScale(object): if is_dynamic_mode: paddle.disable_static() - def op_match(self, op, ops_list, op_count): - while op_count < len(ops_list) and op.type not in ops_list[op_count]: - op_count += 1 - while op_count < len(ops_list) and op.type is "pool2d" and op.attr( - "pooling_type") != "max": - op_count += 1 - return op_count - def _forward_post_hook(self, layer, input, output): assert isinstance( output, (core.VarBase, framework.Variable) @@ -512,9 +576,9 @@ class ImperativeCalcOutScale(object): ]: return if not hasattr(layer, "_out_scale"): - layer._out_scale = quant_nn.MovingAverageAbsMaxScale( - output.name, self._moving_rate, output.dtype) - scale_out = layer._out_scale(output) + self._out_scale = quant_nn.MovingAverageAbsMaxScale( + layer, output.name, self._moving_rate, output.dtype) + scale_out = self._out_scale(output) if hasattr(layer, 'layer_name'): layer_name = layer.layer_name else: diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py b/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py index 0469de7aef20704682e62e6d2af0f5f471113942..0b052d5dd0da62b2c746fc61938c952b2b7de5d1 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py @@ -503,7 +503,7 @@ class QuantizedNoweightLayer(layers.Layer): class MovingAverageAbsMaxScale(layers.Layer): - def __init__(self, name=None, moving_rate=0.9, dtype='float32'): + def __init__(self, layer=None, name=None, moving_rate=0.9, dtype='float32'): r""" MovingAverageMaxScale layer is used to calculating the output quantization scale of Layer. Its computational formula is described as below: @@ -514,33 +514,48 @@ class MovingAverageAbsMaxScale(layers.Layer): super(MovingAverageAbsMaxScale, self).__init__() self._moving_rate = moving_rate self._dtype = dtype + self._layer = layer - scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale' - name = unique_name.generate(scale_prefix) - scale_attr = ParamAttr( - name=name, initializer=Constant(1), trainable=False) - self._scale = self.create_parameter( - shape=[1], attr=scale_attr, dtype=self._dtype) - self._scale.stop_gradient = True + if self._layer is None or not hasattr(self._layer, "_quant_out_scale"): + scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale' + scale_name = unique_name.generate(scale_prefix) + scale_attr = ParamAttr( + name=scale_name, initializer=Constant(1), trainable=False) + self._scale = self.create_parameter( + shape=[1], attr=scale_attr, dtype=self._dtype) + self._scale.stop_gradient = True + if self._layer is not None: + setattr(self._layer, "_quant_out_scale", self._scale) + else: + self._scale = self._layer._quant_out_scale - state_prefix = "{}.state".format(name) if name else 'outscale.state' - state_attr = ParamAttr( - name=unique_name.generate(state_prefix), - initializer=Constant(1), - trainable=False) - self._state = self.create_parameter( - shape=[1], attr=state_attr, dtype=self._dtype) - self._state.stop_gradient = True + if self._layer is None or not hasattr(self._layer, "_quant_out_state"): + state_prefix = "{}.state".format(name) if name else 'outscale.state' + state_attr = ParamAttr( + name=unique_name.generate(state_prefix), + initializer=Constant(1), + trainable=False) + self._state = self.create_parameter( + shape=[1], attr=state_attr, dtype=self._dtype) + self._state.stop_gradient = True + if self._layer is not None: + setattr(self._layer, "_quant_out_state", self._state) + else: + self._state = self._layer._quant_out_state - accum_prefix = "{}.accum".format(name) if name else 'outscale.accum' - accum_attr = ParamAttr( - name=unique_name.generate(accum_prefix), - initializer=Constant(1), - trainable=False) - self._accum = self.create_parameter( - shape=[1], attr=accum_attr, dtype=self._dtype) - self._accum.stop_gradient = True - MovingAverageAbsMaxScale._has_create = True + if self._layer is None or not hasattr(self._layer, "_quant_out_accum"): + accum_prefix = "{}.accum".format(name) if name else 'outscale.accum' + accum_attr = ParamAttr( + name=unique_name.generate(accum_prefix), + initializer=Constant(1), + trainable=False) + self._accum = self.create_parameter( + shape=[1], attr=accum_attr, dtype=self._dtype) + self._accum.stop_gradient = True + if self._layer is not None: + setattr(self._layer, "_quant_out_accum", self._accum) + else: + self._accum = self._layer._quant_out_accum def forward(self, input): if in_dygraph_mode(): @@ -549,18 +564,17 @@ class MovingAverageAbsMaxScale(layers.Layer): state = self._state if self.training else None accum = self._accum if self.training else None - out_scale, _, _ = core.ops.moving_average_abs_max_scale( + self._scale, _, _ = core.ops.moving_average_abs_max_scale( input, accum, state, self._scale, state, accum, *attrs) - return out_scale + return self._scale check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'MovingAverageAbsMaxScale') - scale_out = self._scale attrs = {'moving_rate': self._moving_rate, 'is_test': not self.training} inputs = {"X": [input]} - outputs = {"OutScale": [scale_out]} + outputs = {"OutScale": [self._scale]} if self.training: inputs['InState'] = [self._state] @@ -574,4 +588,4 @@ class MovingAverageAbsMaxScale(layers.Layer): outputs=outputs, attrs=attrs) - return scale_out + return self._scale diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_out_scale.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_out_scale.py index 47e21910b48dfd4a367ea744de2ffbfaf07b9df2..83ddac41965c516043d4e8074570c7c78b79d89f 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_out_scale.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_out_scale.py @@ -19,6 +19,8 @@ import numpy as np import random import unittest import logging +import warnings + import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers @@ -29,7 +31,7 @@ from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass, OutScaleForInferencePass, QuantizationTransformPass from paddle.fluid.dygraph.container import Sequential from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX -from paddle.nn.layer import ReLU, LeakyReLU, Sigmoid, Softmax, ReLU6 +from paddle.nn.layer import ReLU, LeakyReLU, Sigmoid, Softmax, PReLU from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D from paddle.fluid.dygraph.nn import Pool2D from paddle.fluid.log_helper import get_logger @@ -45,6 +47,14 @@ _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') +def get_vaild_warning_num(warning, w): + num = 0 + for i in range(len(w)): + if warning in str(w[i].message): + num += 1 + return num + + def StaticLenet(data, num_classes=10, classifier_activation='softmax'): conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1") conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2") @@ -76,9 +86,9 @@ def StaticLenet(data, num_classes=10, classifier_activation='softmax'): param_attr=conv2d_w2_attr, bias_attr=conv2d_b2_attr) batch_norm2 = layers.batch_norm(conv2) - relu6_1 = layers.relu6(batch_norm2) + prelu1 = layers.prelu(batch_norm2, mode='all') pool2 = fluid.layers.pool2d( - relu6_1, pool_size=2, pool_type='max', pool_stride=2) + prelu1, pool_size=2, pool_type='max', pool_stride=2) fc1 = fluid.layers.fc(input=pool2, size=120, @@ -132,7 +142,7 @@ class ImperativeLenet(fluid.dygraph.Layer): weight_attr=conv2d_w2_attr, bias_attr=conv2d_b2_attr), BatchNorm2D(16), - ReLU6(), + PReLU(), MaxPool2D( kernel_size=2, stride=2)) @@ -246,6 +256,10 @@ class TestImperativeOutSclae(unittest.TestCase): lenet.eval() + param_save_path = "test_save_quantized_model/lenet.pdparams" + save_dict = lenet.state_dict() + paddle.save(save_dict, param_save_path) + path = "./dynamic_outscale_infer_model/lenet" dynamic_save_dir = "./dynamic_outscale_infer_model" @@ -285,6 +299,8 @@ class TestImperativeOutSclae(unittest.TestCase): for param in main.all_parameters(): if "batch_norm" in param.name: param_name = param.name.replace("norm", "norm2d") + elif 'prelu' in param.name: + param_name = param.name.replace("prelu", 'p_re_lu') else: param_name = param.name param_tensor = scope.var(param.name).get_tensor() @@ -384,5 +400,94 @@ class TestImperativeOutSclae(unittest.TestCase): static_ops[i].attr("out_threshold")) +class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase): + def test_save_quantized_model(self): + weight_quantize_type = 'abs_max' + activation_quantize_type = 'moving_average_abs_max' + load_param_path = "test_save_quantized_model/lenet.pdparams" + path = "./dynamic_outscale_infer_model_from_checkpoint/lenet" + dynamic_model_save_dir = "./dynamic_outscale_infer_model_from_checkpoint" + static_model_save_dir = "./static_outscale_infer_model" + + imperative_out_scale = ImperativeQuantAware( + weight_quantize_type=weight_quantize_type, + activation_quantize_type=activation_quantize_type) + + with fluid.dygraph.guard(): + lenet = ImperativeLenet() + load_dict = paddle.load(load_param_path) + imperative_out_scale.quantize(lenet) + lenet.set_dict(load_dict) + + imperative_out_scale.save_quantized_model( + layer=lenet, + path=path, + input_spec=[ + paddle.static.InputSpec( + shape=[None, 1, 28, 28], dtype='float32') + ]) + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + exe = fluid.Executor(place) + + # load dynamic model + [dynamic_inference_program, feed_target_names, fetch_targets] = ( + fluid.io.load_inference_model( + dirname=dynamic_model_save_dir, + executor=exe, + model_filename="lenet" + INFER_MODEL_SUFFIX, + params_filename="lenet" + INFER_PARAMS_SUFFIX)) + # load static model + [static_inference_program, feed_target_names, fetch_targets] = ( + fluid.io.load_inference_model( + dirname=static_model_save_dir, + executor=exe, + model_filename="lenet" + INFER_MODEL_SUFFIX, + params_filename="lenet" + INFER_PARAMS_SUFFIX)) + + dynamic_ops = dynamic_inference_program.global_block().ops + static_ops = static_inference_program.global_block().ops + + for op in dynamic_ops[:]: + if op.type == "flatten2" or 'fake' in op.type: + dynamic_ops.remove(op) + + for op in static_ops[:]: + if 'fake' in op.type: + static_ops.remove(op) + + for i in range(len(dynamic_ops)): + if dynamic_ops[i].has_attr("out_threshold"): + self.assertTrue(dynamic_ops[i].type == static_ops[i].type) + self.assertTrue(dynamic_ops[i].attr("out_threshold") == + static_ops[i].attr("out_threshold")) + + +class TestSaveQuantizedModel_Warning(unittest.TestCase): + def test_warning(self): + path = "./dynamic_outscale_infer_model_with_warnings/lenet" + imperative_out_scale = ImperativeQuantAware() + with fluid.dygraph.guard(): + lenet = ImperativeLenet() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + imperative_out_scale.save_quantized_model( + layer=lenet, + path=path, + input_spec=[ + paddle.static.InputSpec( + shape=[None, 1, 28, 28], dtype='float32') + ]) + + warning_message = "Warning: No Layer of the model while to be saved contains the out_threshold attribute, " \ + "so the generated inference model would not contain the out_threshold." + num = get_vaild_warning_num(warning_message, w) + assert num == 1 + + if __name__ == '__main__': unittest.main()