From fd85be80cfc589b75575d48119ddc7490a1ede28 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Fri, 9 Jul 2021 11:39:59 +0800 Subject: [PATCH] [PTQ ] wrap simulated layers and save the quantized model (#33962) * PTQ save quantized model * Wrap simulated layer * post process the inference model --- .../slim/quantization/imperative/ptq.py | 339 +++++++++++++++++- .../quantization/imperative/ptq_config.py | 9 +- .../slim/quantization/imperative/ptq_hooks.py | 8 +- .../quantization/imperative/ptq_quantizer.py | 23 +- .../quantization/imperative/ptq_registry.py | 56 ++- .../slim/quantization/imperative/qat.py | 8 +- .../slim/quantization/imperative/utils.py | 13 +- .../slim/tests/imperative_test_utils.py | 4 +- .../contrib/slim/tests/test_imperative_ptq.py | 163 ++++----- 9 files changed, 490 insertions(+), 133 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py index 13ca44d7f2..b85a4b6637 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py @@ -14,14 +14,18 @@ import logging import copy +import os import numpy as np import paddle +import paddle.nn.quant.quant_layers as quant_layers from paddle.fluid.log_helper import get_logger +from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from . import utils from . import ptq_hooks from . import ptq_config +from . import ptq_quantizer from .ptq_registry import PTQRegistry __all__ = ['ImperativePTQ'] @@ -53,7 +57,7 @@ class ImperativePTQ(object): def quantize(self, model, inplace=False): """ - Add hook to the leaf layer to calculate the threshold of inputs and outputs. + Add quant config and hook to the target layer. Args: model(paddle.nn.Layer): The model to be quantized. @@ -70,10 +74,16 @@ class ImperativePTQ(object): for name, layer in new_model.named_sublayers(): if PTQRegistry.is_supported_layer(layer) \ - and utils.is_leaf_layer(layer): + and utils.is_leaf_layer(layer) \ + and not self._is_skip_layer(layer): + + # Add quant config quant_config = copy.deepcopy(self._quant_config) + if PTQRegistry.is_simulated_quant_layer(layer): + quant_config.enable_in_act_quantizer = True layer._quant_config = quant_config + # register hook hook = ptq_hooks.quant_forward_post_hook quant_hook_handle = layer.register_forward_post_hook(hook) quant_config.quant_hook_handle = quant_hook_handle @@ -82,35 +92,330 @@ class ImperativePTQ(object): return new_model - def convert(self, model): + def save_quantized_model(self, model, path, input_spec=None, **config): """ - Process the scales and remove the hooks. + 1. Convert the quantized model + 2. Call jit.save to save the inference model + 3. Load and postprocess the inference model. Args: - model(paddle.nn.Layer): The model to be quantized. + model (Layer): The model to be saved. + path (str): The path prefix to save model. The format is + ``dirname/file_prefix`` or ``file_prefix``. + input_spec (list[InputSpec|Tensor], optional): Describes the input + of the saved model's forward method, which can be described by + InputSpec or example Tensor. If None, all input variables of + the original Layer's forward method would be the inputs of + the saved model. Default None. + **configs (dict, optional): Other save configuration options for + compatibility. We do not recommend using these configurations, + they may be removed in the future. If not necessary, DO NOT use + them. Default None. + The following options are currently supported: + (1) output_spec (list[Tensor]): Selects the output targets of + the saved model. By default, all return variables of original + Layer's forward method are kept as the output of the saved model. + If the provided ``output_spec`` list is not all output variables, + the saved model will be pruned according to the given + ``output_spec`` list. + Returns: - converted_model(paddle.nn.Layer): The converted model. + None """ + assert isinstance(model, paddle.nn.Layer), \ - "The input model must be the instance of paddle.nn.Layer." + "The model must be the instance of paddle.nn.Layer." + + # Convert and save dygraph quantized model + self._convert(model) + + paddle.jit.save(layer=model, path=path, input_spec=input_spec, **config) + + # Load inference program + is_dynamic_mode = False + if paddle.in_dynamic_mode(): + is_dynamic_mode = True + paddle.enable_static() + + place = paddle.CPUPlace() + scope = paddle.static.global_scope() + exe = paddle.static.Executor(place) + + dirname = os.path.dirname(path) + basename = os.path.basename(path) + model_filename = basename + INFER_MODEL_SUFFIX + params_filename = basename + INFER_PARAMS_SUFFIX + + [infer_program, feed_target_names, fetch_targets] = ( + paddle.fluid.io.load_inference_model( + dirname=dirname, + executor=exe, + model_filename=model_filename, + params_filename=params_filename)) + + # Process inference program + self._clean_up(infer_program) + self._gather_input_thresholds(infer_program, scope) + self._remove_scale_op(infer_program) + + # Save final program + paddle.fluid.io.save_inference_model( + dirname=dirname, + feeded_var_names=feed_target_names, + target_vars=fetch_targets, + executor=exe, + main_program=infer_program.clone(), + model_filename=model_filename, + params_filename=params_filename) + + if is_dynamic_mode: + paddle.disable_static() + + def _convert(self, model): + """ + Convert the quantized model. + + Args: + model(paddle.nn.Layer): The quantized model. + inplace(bool): Whether apply conversion to the input model. + Default: False. + Returns: + None + """ for name, sub_layer in model.named_sublayers(): - if PTQRegistry.is_supported_layer(sub_layer) \ - and utils.is_leaf_layer(sub_layer): + if self._is_quant_layer(sub_layer): + sub_layer._quant_config.quant_hook_handle.remove() - assert hasattr(sub_layer, "_quant_config") + self._cal_thresholds(model) + + for name, sub_layer in model.named_sublayers(): + if self._is_quant_layer(sub_layer): + self._save_output_thresholds(sub_layer, sub_layer._quant_config) + + self._wrap_simulated_layers(model) + + def _cal_thresholds(self, model): + """ + Calculate the thresholds of inputs and outputs. + + Args: + model(paddle.nn.Layer): The quantized model. + Returns: + None + """ + assert isinstance(model, paddle.nn.Layer), \ + "The input model must be the instance of paddle.nn.Layer." + + for name, sub_layer in model.named_sublayers(): + if self._is_quant_layer(sub_layer): quant_config = sub_layer._quant_config - quant_config.quant_hook_handle.remove() - quant_config.in_act_quantizer.cal_thresholds() + if quant_config.enable_in_act_quantizer: + quant_config.in_act_quantizer.cal_thresholds() quant_config.out_act_quantizer.cal_thresholds() - # get weight thresholds - if isinstance(sub_layer, tuple(utils.fake_quant_input_layers)): + if PTQRegistry.is_simulated_quant_layer(sub_layer): weights = (sub_layer.weight, ) quant_config.wt_quantizer.sample_data(sub_layer, weights) + quant_config.wt_quantizer.cal_thresholds() + + def _save_output_thresholds(self, sub_layer, quant_config): + """ + Save the output thresholds to the layer. + + Args: + sub_layer(paddle.nn.Layer): The quantized layer. + quant_config(PTQConfig): the quant config for the layer. + Returns: + None + """ + assert isinstance(sub_layer, paddle.nn.Layer), \ + "The input model must be the instance of paddle.nn.Layer." + + layer_info = PTQRegistry.layer_info(sub_layer) + + output_names = layer_info.output_names + output_thresholds = quant_config.out_act_quantizer.thresholds + assert len(output_names) == 1 + assert len(output_thresholds) == 1 + save_name = output_names[0] + str(0) + "_threshold" + sub_layer._set_op_attrs({save_name: output_thresholds[0]}) + sub_layer._set_op_attrs({"out_threshold": output_thresholds[0]}) + + def _wrap_simulated_layers(self, model): + """ + Replace conv2d and linear with the quantized layers, and save + thresholds into the fake layers. + Args: + model(paddle.nn.Layer): The model to be quantized. + Returns: + None + """ + assert isinstance(model, paddle.nn.Layer), \ + "The input model must be the instance of paddle.nn.Layer." + + for name, sub_layer in model.named_sublayers(): + if self._is_quant_layer(sub_layer) \ + and PTQRegistry.is_simulated_quant_layer(sub_layer): + + quant_config = sub_layer._quant_config + assert quant_config.enable_in_act_quantizer == True + wt_quantizer = quant_config.wt_quantizer + in_act_quantizer = quant_config.in_act_quantizer + + # create layer + quant_layer_name = None + for key, value in utils.layer_name_map.items(): + if isinstance(sub_layer, value): + quant_layer_name = 'Quantized' + key + break + assert quant_layer_name is not None + + if isinstance(wt_quantizer, ptq_quantizer.AbsmaxQuantizer): + weight_quantize_type = "abs_max" + else: + weight_quantize_type = "channel_wise_abs_max" + kwargs = { + "weight_quantize_type": weight_quantize_type, + "activation_quantize_type": "moving_average_abs_max", + "weight_bits": wt_quantizer.quant_bits, + "activation_bits": in_act_quantizer.quant_bits, + } + + quant_layer = quant_layers.__dict__[quant_layer_name](sub_layer, + **kwargs) + + # save the input thresholds + assert hasattr(quant_layer, "_fake_quant_input") + assert hasattr(quant_layer._fake_quant_input, "_scale") + assert len(in_act_quantizer.thresholds) == 1 + input_threshold = np.array( + [in_act_quantizer.thresholds[0]], dtype=np.float32) + quant_layer._fake_quant_input._scale.set_value(input_threshold) + + assert hasattr(quant_layer, "_fake_quant_weight") + assert hasattr(quant_layer._fake_quant_weight, "_scale") + assert len(wt_quantizer.thresholds) == 1 + weight_threshold = wt_quantizer.thresholds[0] + if isinstance(weight_threshold, list): + weight_threshold = np.array( + weight_threshold, dtype=np.float32) + else: + weight_threshold = np.array( + [weight_threshold], dtype=np.float32) + quant_layer._fake_quant_weight._scale.set_value( + weight_threshold) + + # save the output thresholds + self._save_output_thresholds(quant_layer, quant_config) + + # replace the layer + parent_layer, sub_name = \ + utils.find_parent_layer_and_sub_name(model, name) + setattr(parent_layer, sub_name, quant_layer) + + def _gather_input_thresholds(self, program, scope): + """ + Get and save input thresholds from the front ops. + + Args: + program(Program): the input infer program. + scope(Scope): the corresponding scope for the program. + Returns: + None + """ + for op in utils.program_all_ops(program): + for in_var_name in utils._get_op_input_var_names(op): + previous_op = utils.find_previous_op(op.block, in_var_name) + if previous_op is None: + continue + + if "quantize_dequantize" in previous_op.type or \ + previous_op.type == "moving_average_abs_max_scale": + attr_name = previous_op.output('OutScale')[0] + in_threshold = utils.load_variable_data(scope, attr_name) + in_threshold = utils.fp_numpy_to_naive(in_threshold) + argname, index = utils._get_input_name_index(op, + in_var_name) + op._set_attr(argname + str(index) + "_threshold", + in_threshold) + else: + for out_var_name in utils._get_op_output_var_names( + previous_op): + if out_var_name != in_var_name: + continue + argname, index = utils._get_output_name_index( + previous_op, out_var_name) + attr_name = argname + str(index) + "_threshold" + if not previous_op.has_attr(attr_name): + continue + threshold = previous_op.attr(attr_name) + + argname, index = utils._get_input_name_index( + op, in_var_name) + attr_name = argname + str(index) + "_threshold" + op._set_attr(attr_name, threshold) + + def _clean_up(self, program): + """ + Remove useless thresholds which are added in jit.save. + + Args: + program(Program): the input infer program. + Returns: + None + """ + + def _helper(op, next_op, old_attr_name, new_attr_name): + if op.has_attr(old_attr_name) and next_op.has_attr(old_attr_name) \ + and op.attr(old_attr_name) == next_op.attr(old_attr_name): + threshold = op.attr(old_attr_name) + op._remove_attr(old_attr_name) + next_op._remove_attr(old_attr_name) + next_op._set_attr(new_attr_name, threshold) + + for op in utils.program_all_ops(program): + if "quantize_dequantize" in op.type: + # remove the thresholds in fake ops + for attr_name in op.attr_names: + if "_threshold" in attr_name: + op._remove_attr(attr_name) + elif op.type in ["conv2d", "matmul"]: + # change the thresholds in conv2d/matmul + eleadd + arg_name = "Output" if op.type == "conv2d" else "Out" + out_var_name = op.output(arg_name)[0] + next_ops = utils.find_next_ops(op.block, out_var_name) + if len(next_ops) > 1 or next_ops[0].type != "elementwise_add": + continue + next_op = next_ops[0] + + argname, index = utils._get_output_name_index(op, out_var_name) + old_attr_name = argname + str(index) + "_threshold" + + argname, index = utils._get_output_name_index( + next_op, next_op.output("Out")[0]) + new_attr_name = argname + str(index) + "_threshold" + + _helper(op, next_op, old_attr_name, new_attr_name) + _helper(op, next_op, "out_threshold", "out_threshold") + + def _remove_scale_op(self, program): + """ + Remove the moving_average_abs_max_scale op. + """ + for op in utils.program_all_ops(program): + if op.type == "moving_average_abs_max_scale": + in_var_name = op.input("X")[0] + out_var_name = op.output("Out")[0] + next_ops = utils.find_next_ops(op.block, out_var_name) + for next_op in next_ops: + next_op._rename_input(out_var_name, in_var_name) - # TODO (jc): - # save input activation threshold and quant bits + @staticmethod + def _is_skip_layer(layer): + return hasattr(layer, "skip_quant") and layer.skip_quant == True - return model + @staticmethod + def _is_quant_layer(layer): + return hasattr(layer, "_quant_config") diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_config.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_config.py index 4db311567a..1d089b3218 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_config.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_config.py @@ -39,9 +39,8 @@ class PTQConfig(object): It should be the instance of BaseQuantizer. """ super(PTQConfig, self).__init__() - - assert isinstance(activation_quantizer, BaseQuantizer) - assert isinstance(weight_quantizer, BaseQuantizer) + assert isinstance(activation_quantizer, tuple(SUPPORT_ACT_QUANTIZERS)) + assert isinstance(weight_quantizer, tuple(SUPPORT_WT_QUANTIZERS)) self.in_act_quantizer = copy.deepcopy(activation_quantizer) self.out_act_quantizer = copy.deepcopy(activation_quantizer) @@ -49,5 +48,9 @@ class PTQConfig(object): self.quant_hook_handle = None + # In order to wrap simulated layers, use in_act_quantizer + # to calculate the input thresholds for conv2d, linear and etc. + self.enable_in_act_quantizer = False + default_ptq_config = PTQConfig(AbsmaxQuantizer(), AbsmaxQuantizer()) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_hooks.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_hooks.py index 82a277ad28..41c9b07195 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_hooks.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_hooks.py @@ -16,6 +16,7 @@ import paddle import math import numpy as np from . import ptq_config +from .ptq_registry import PTQRegistry def quant_forward_post_hook(layer, inputs, outputs): @@ -24,5 +25,8 @@ def quant_forward_post_hook(layer, inputs, outputs): """ assert hasattr(layer, '_quant_config'), \ "The layer should have _quant_config attr" - layer._quant_config.in_act_quantizer.sample_data(layer, inputs) - layer._quant_config.out_act_quantizer.sample_data(layer, (outputs, )) + + qc = layer._quant_config + if qc.enable_in_act_quantizer: + qc.in_act_quantizer.sample_data(layer, inputs) + qc.out_act_quantizer.sample_data(layer, (outputs, )) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_quantizer.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_quantizer.py index 9999de6bd0..63b3578871 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_quantizer.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_quantizer.py @@ -24,11 +24,9 @@ from . import utils from ..cal_kl_threshold import cal_kl_threshold __all__ = [ - 'BaseQuantizer', - 'AbsmaxQuantizer', - 'PerChannelAbsmaxQuantizer', - 'KLQuantizer', - 'HistQuantizer', + 'BaseQuantizer', 'AbsmaxQuantizer', 'PerChannelAbsmaxQuantizer', + 'KLQuantizer', 'HistQuantizer', 'SUPPORT_ACT_QUANTIZERS', + 'SUPPORT_WT_QUANTIZERS' ] @@ -110,6 +108,7 @@ class BaseQuantizer(object): self.quant_bits = quant_bits + self.abs_max_vals = [] self.thresholds = [] @abc.abstractmethod @@ -133,10 +132,10 @@ class AbsmaxQuantizer(BaseQuantizer): assert isinstance(tensors, tuple) abs_max_vals = [abs_max_value(t) for t in tensors] - self.thresholds = merge_max_value(self.thresholds, abs_max_vals) + self.abs_max_vals = merge_max_value(self.abs_max_vals, abs_max_vals) def cal_thresholds(self): - pass + self.thresholds = self.abs_max_vals class PerChannelAbsmaxQuantizer(BaseQuantizer): @@ -164,10 +163,11 @@ class PerChannelAbsmaxQuantizer(BaseQuantizer): ] abs_max_vals_list.append(abs_max_vals) - self.thresholds = merge_max_value(self.thresholds, abs_max_vals_list) + self.abs_max_vals = merge_max_value(self.abs_max_vals, + abs_max_vals_list) def cal_thresholds(self): - pass + self.thresholds = self.abs_max_vals @six.add_metaclass(abc.ABCMeta) @@ -180,7 +180,6 @@ class BaseHistQuantizer(BaseQuantizer): self.bins = bins self.upsample_bins = upsample_bins - self.abs_max_vals = [] self.hists = [] def sample_data(self, layer, tensors): @@ -262,3 +261,7 @@ class KLQuantizer(BaseHistQuantizer): bin_width = abs_max_val / hist.shape[0] threshold = cal_kl_threshold(hist, bin_width, self.quant_bits) self.thresholds.append(threshold) + + +SUPPORT_ACT_QUANTIZERS = [AbsmaxQuantizer, HistQuantizer, KLQuantizer] +SUPPORT_WT_QUANTIZERS = [AbsmaxQuantizer, PerChannelAbsmaxQuantizer] diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_registry.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_registry.py index 973d66303e..a6b8033bc7 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_registry.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq_registry.py @@ -47,12 +47,22 @@ PTQ_LAYERS_INFO = [ LayerInfo(paddle.nn.quant.add, ['X', 'Y'], [], ['Out']), ] +QUANT_LAYERS_INFO = [ + LayerInfo(paddle.nn.quant.quant_layers.QuantizedConv2D, ['Input'], + ['Filter'], ['Output']), + LayerInfo(paddle.nn.quant.quant_layers.QuantizedLinear, ['X'], ['Y'], + ['Out']), +] + +SIMULATED_LAYERS = [paddle.nn.Conv2D, paddle.nn.Linear] + class PTQRegistry(object): """ Register the supported layers for PTQ and provide layers info. """ supported_layers_map = {} + registered_layers_map = {} is_inited = False def __init__(self): @@ -63,24 +73,62 @@ class PTQRegistry(object): if not cls.is_inited: for layer_info in PTQ_LAYERS_INFO: cls.supported_layers_map[layer_info.layer] = layer_info + + all_layers_info = PTQ_LAYERS_INFO + QUANT_LAYERS_INFO + for layer_info in all_layers_info: + cls.registered_layers_map[layer_info.layer] = layer_info cls.is_inited = True @classmethod def is_supported_layer(cls, layer): """ Analyze whether the layer supports quantization. + Args: + layer(Layer): The input layer can be a python class or an instance. + Returns: + flag(bool): Whther the layer is supported. """ cls._init() return layer in cls.supported_layers_map or \ isinstance(layer, tuple(cls.supported_layers_map.keys())) + @classmethod + def is_registered_layer(cls, layer): + """ + Analyze whether the layer is register layer_info. + Args: + layer(Layer): The input layer can be a python class or an instance. + Returns: + flag(bool): Wether the layer is register layer_info. + """ + cls._init() + return layer in cls.registered_layers_map or \ + isinstance(layer, tuple(cls.registered_layers_map.keys())) + + @classmethod + def is_simulated_quant_layer(cls, layer): + """ + Analyze whether the layer is simulated quant layer. + Args: + layer(Layer): The input layer can be a python class or an instance. + Returns: + flag(bool): Whther the layer is supported. + """ + return layer in SIMULATED_LAYERS or \ + isinstance(layer, tuple(SIMULATED_LAYERS)) + + @classmethod def layer_info(cls, layer): """ - Get the infomation for the supported layer. + Get the infomation for the layer. + Args: + layer(Layer): The input layer can be a python class or an instance. + Returns: + layer_info(LayerInfo): The layer info of the input layer. """ - assert cls.is_supported_layer( - layer), "The input layer is not supported." + assert cls.is_registered_layer(layer), \ + "The input layer is not register." - for layer_key, layer_info in cls.supported_layers_map.items(): + for layer_key, layer_info in cls.registered_layers_map.items(): if layer == layer_key or isinstance(layer, layer_key): return layer_info diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 3b4f9a7574..b8c0e47e9b 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -379,12 +379,12 @@ class ImperativeQuantizeOutputs(object): setattr(parent_layer, sub_name, cur_quant_layer) - def save_quantized_model(self, layer, path, input_spec=None, **config): + def save_quantized_model(self, model, path, input_spec=None, **config): """ Save the quantized model for the inference. Args: - layer (Layer): The Layer to be saved. + model (Layer): The model to be saved. path (str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``. input_spec (list[InputSpec|Tensor], optional): Describes the input @@ -407,10 +407,10 @@ class ImperativeQuantizeOutputs(object): Returns: None """ - assert isinstance(layer, dygraph.Layer), \ + assert isinstance(model, dygraph.Layer), \ "The model must be the instance of dygraph.Layer." - paddle.jit.save(layer=layer, path=path, input_spec=input_spec, **config) + paddle.jit.save(layer=model, path=path, input_spec=input_spec, **config) is_dynamic_mode = False if paddle.in_dynamic_mode(): diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py index cae26a6dbd..a9d52c5a87 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py @@ -69,7 +69,7 @@ fake_quant_wrap_layers = [ ] # The weight format of these layers is Cin * Cout * H * W -spec_channel_axis_layers = [paddle.nn.Conv2D, paddle.nn.Conv2DTranspose] +spec_channel_axis_layers = [paddle.nn.Conv2DTranspose, paddle.nn.Linear] weight_op_types = [ "conv2d", "depthwise_conv2d", "matmul", "conv2d_transpose", @@ -139,6 +139,17 @@ def find_parent_layer_and_sub_name(model, name): return parent_layer, sub_name +def program_all_ops(program): + """ + Return all ops for the input program. + """ + all_ops = [] + for block in program.blocks: + for op in block.ops: + all_ops.append(op) + return all_ops + + def is_leaf_layer(layer): """ Whether the layer is leaf layer. diff --git a/python/paddle/fluid/contrib/slim/tests/imperative_test_utils.py b/python/paddle/fluid/contrib/slim/tests/imperative_test_utils.py index cc26f6a88f..5c91f01d0b 100644 --- a/python/paddle/fluid/contrib/slim/tests/imperative_test_utils.py +++ b/python/paddle/fluid/contrib/slim/tests/imperative_test_utils.py @@ -128,9 +128,11 @@ class ImperativeLenet(fluid.dygraph.Layer): bias_attr=fc_b3_attr), Softmax()) self.add = paddle.nn.quant.add() + self.quant_stub = paddle.nn.quant.QuantStub() def forward(self, inputs): - x = self.features(inputs) + x = self.quant_stub(inputs) + x = self.features(x) x = fluid.layers.flatten(x, 1) x = self.add(x, paddle.to_tensor(0.0)) # For CI diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py index 236e4a823d..24ae75456a 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py @@ -20,6 +20,7 @@ import random import shutil import time import unittest +import copy import logging import paddle @@ -59,7 +60,8 @@ class TestImperativePTQ(unittest.TestCase): @classmethod def tearDownClass(cls): try: - shutil.rmtree(cls.root_path) + pass + # shutil.rmtree(cls.root_path) except Exception as e: print("Failed to delete {} due to {}".format(cls.root_path, str(e))) @@ -84,8 +86,9 @@ class TestImperativePTQ(unittest.TestCase): self.batch_num = 10 self.batch_size = 10 - self.eval_acc_top1 = 0.99 + self.eval_acc_top1 = 0.95 + # the input, output and weight thresholds of quantized op self.gt_thresholds = { 'conv2d_0': [[1.0], [0.37673383951187134], [0.10933732241392136]], 'batch_norm2d_0': [[0.37673383951187134], [0.44249194860458374]], @@ -96,36 +99,6 @@ class TestImperativePTQ(unittest.TestCase): 'add_0': [[1.7058950662612915, 0.0], [1.7058950662612915]], } - def model_train(self, model, train_reader, max_step=-1): - model.train() - adam = paddle.optimizer.Adam( - learning_rate=0.001, parameters=model.parameters()) - - for batch_id, data in enumerate(train_reader()): - x_data = np.array([x[0].reshape(1, 28, 28) - for x in data]).astype('float32') - y_data = np.array( - [x[1] for x in data]).astype('int64').reshape(-1, 1) - - img = paddle.to_tensor(x_data) - label = paddle.to_tensor(y_data) - - out = model(img) - acc = fluid.layers.accuracy(out, label) - loss = fluid.layers.cross_entropy(out, label) - avg_loss = fluid.layers.mean(loss) - avg_loss.backward() - - adam.minimize(avg_loss) - model.clear_gradients() - - if batch_id % 100 == 0: - _logger.info("Train | step {}: loss = {:}, acc= {:}".format( - batch_id, avg_loss.numpy(), acc.numpy())) - - if max_step > 0 and batch_id > max_step: # For shortening CI time - break - def model_test(self, model, batch_num=-1, batch_size=8): model.eval() @@ -145,9 +118,9 @@ class TestImperativePTQ(unittest.TestCase): out = model(img) acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + eval_acc_top1_list.append(float(acc_top1.numpy())) - if batch_id % 100 == 0: - eval_acc_top1_list.append(float(acc_top1.numpy())) + if batch_id % 50 == 0: _logger.info("Test | At step {}: acc1 = {:}, acc5 = {:}".format( batch_id, acc_top1.numpy(), acc_top5.numpy())) @@ -158,80 +131,88 @@ class TestImperativePTQ(unittest.TestCase): return eval_acc_top1 - def check_thresholds(self, model): - check_num = 0 - for name, layer in model.named_sublayers(): - layer_name = layer.full_name() - if layer_name in self.gt_thresholds: - ref_val = self.gt_thresholds[layer_name] - assert hasattr(layer, '_quant_config') - - quant_config = layer._quant_config - in_val = quant_config.in_act_quantizer.thresholds - out_val = quant_config.out_act_quantizer.thresholds - wt_val = quant_config.wt_quantizer.thresholds - check_num += 1 - - self.assertTrue( - np.allclose( - ref_val[0], in_val, atol=1e-3), - "%s | The thresholds(%s) is different " - "from the ground truth(%s)." % - (layer_name, str(in_val), str(ref_val[0]))) - self.assertTrue( - np.allclose( - ref_val[1], out_val, atol=1e-3), - "%s | The thresholds(%s) is different " - "from the ground truth(%s)." % - (layer_name, str(out_val), str(ref_val[1]))) - if len(ref_val) > 2 and ref_val[2] != []: - self.assertTrue( - np.allclose( - ref_val[2], wt_val, atol=1e-3), - "%s | The thresholds(%s) is different " - "from the ground truth(%s)." % - (layer_name, str(wt_val), str(ref_val[2]))) - - self.assertTrue(check_num == len(self.gt_thresholds)) + def program_test(self, program_path, batch_num=-1, batch_size=8): + exe = paddle.static.Executor(paddle.CPUPlace()) + [inference_program, feed_target_names, fetch_targets] = ( + paddle.static.load_inference_model(program_path, exe)) + + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + + top1_correct_num = 0. + total_num = 0. + for batch_id, data in enumerate(test_reader()): + img = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + label = np.array([x[1] for x in data]).astype('int64') + + feed = {feed_target_names[0]: img} + results = exe.run(inference_program, + feed=feed, + fetch_list=fetch_targets) + + pred = np.argmax(results[0], axis=1) + top1_correct_num += np.sum(np.equal(pred, label)) + total_num += len(img) + + if total_num % 50 == 49: + _logger.info("Test | Test num {}: acc1 = {:}".format( + total_num, top1_correct_num / total_num)) + + if batch_num > 0 and batch_id + 1 >= batch_num: + break + return top1_correct_num / total_num def test_ptq(self): start_time = time.time() self.set_vars() + # Load model params_path = self.download_model(self.lenet_url, self.lenet_md5, "lenet") params_path += "/lenet_pretrained/lenet.pdparams" - with fluid.dygraph.guard(): - model = ImperativeLenet() - model_state_dict = paddle.load(params_path) - model.set_state_dict(model_state_dict) - - quant_model = self.ptq.quantize(model) - - acc_top1 = self.model_test(quant_model, self.batch_num, - self.batch_size) - print('acc_top1: %s' % acc_top1) - self.assertTrue( - acc_top1 > self.eval_acc_top1, - msg="The test acc {%f} is less than {%f}." % - (acc_top1, self.eval_acc_top1)) - - final_model = self.ptq.convert(quant_model) + model = ImperativeLenet() + model_state_dict = paddle.load(params_path) + model.set_state_dict(model_state_dict) - self.check_thresholds(final_model) + # Quantize, calibrate and save + quant_model = self.ptq.quantize(model) + before_acc_top1 = self.model_test(quant_model, self.batch_num, + self.batch_size) input_spec = [ paddle.static.InputSpec( shape=[None, 1, 28, 28], dtype='float32') ] - paddle.jit.save( - layer=final_model, path=self.save_path, input_spec=input_spec) + self.ptq.save_quantized_model( + model=quant_model, path=self.save_path, input_spec=input_spec) print('Quantized model saved in {%s}' % self.save_path) + after_acc_top1 = self.model_test(quant_model, self.batch_num, + self.batch_size) + + paddle.enable_static() + infer_acc_top1 = self.program_test(self.save_path, self.batch_num, + self.batch_size) + paddle.disable_static() + + # Check + print('Before converted acc_top1: %s' % before_acc_top1) + print('After converted acc_top1: %s' % after_acc_top1) + print('Infer acc_top1: %s' % infer_acc_top1) + + self.assertTrue( + after_acc_top1 >= self.eval_acc_top1, + msg="The test acc {%f} is less than {%f}." % + (after_acc_top1, self.eval_acc_top1)) + self.assertTrue( + infer_acc_top1 >= after_acc_top1, + msg='The acc is lower after converting model.') + end_time = time.time() - print("total time: %ss" % (end_time - start_time)) + print("total time: %ss \n" % (end_time - start_time)) class TestImperativePTQHist(TestImperativePTQ): @@ -241,7 +222,7 @@ class TestImperativePTQHist(TestImperativePTQ): self.batch_num = 10 self.batch_size = 10 - self.eval_acc_top1 = 0.99 + self.eval_acc_top1 = 0.98 self.gt_thresholds = { 'conv2d_0': @@ -262,7 +243,7 @@ class TestImperativePTQKL(TestImperativePTQ): self.batch_num = 10 self.batch_size = 10 - self.eval_acc_top1 = 0.99 + self.eval_acc_top1 = 1.0 conv2d_1_wt_thresholds = [ 0.18116560578346252, 0.17079241573810577, 0.1702047884464264, -- GitLab