From 8bbae468388411c3f9d9e925e421090d7c83e14a Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Fri, 6 Jan 2023 17:52:56 +0800 Subject: [PATCH] Add observer attribute in qdq node & Add quant config for different backends. (#46887) --- paddle/fluid/operators/quantize_linear_op.cc | 6 + paddle/fluid/operators/quantize_linear_op.h | 39 ++- .../passes/auto_parallel_quantization.py | 10 +- .../post_training_quantization.py | 145 +++++--- .../static/quantization/quant_config.py | 327 ++++++++++++++++++ .../static/quantization/quantization_pass.py | 107 ++++-- ..._post_training_quantization_mobilenetv1.py | 130 +++++++ .../test_post_training_quantization_while.py | 1 - python/paddle/static/quantization/utils.py | 224 +----------- 9 files changed, 681 insertions(+), 308 deletions(-) create mode 100644 python/paddle/static/quantization/quant_config.py diff --git a/paddle/fluid/operators/quantize_linear_op.cc b/paddle/fluid/operators/quantize_linear_op.cc index f143bc3a50..cc8e0bccf6 100644 --- a/paddle/fluid/operators/quantize_linear_op.cc +++ b/paddle/fluid/operators/quantize_linear_op.cc @@ -200,6 +200,12 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") .SetDefault(true); + AddAttr( + "only_observer", + "(bool, default false) Whether to only observer or not. If " + "only_observer=false, it will calculate fake quant or dequant output. " + "If only_observer=true, it will only calibrate scale information.") + .SetDefault(false); AddComment(R"DOC( The scale of QuantizeLinear operator is a vector. In detail, each channel of the input X has a scale value. diff --git a/paddle/fluid/operators/quantize_linear_op.h b/paddle/fluid/operators/quantize_linear_op.h index 3461b9de0a..4769a96ca0 100644 --- a/paddle/fluid/operators/quantize_linear_op.h +++ b/paddle/fluid/operators/quantize_linear_op.h @@ -61,6 +61,7 @@ class QuantizeLinearKernel : public framework::OpKernel { int bin_cnt = std::pow(2, bit_length - 1) - 1; int quant_axis = context.Attr("quant_axis"); bool is_test = context.Attr("is_test"); + bool only_observer = context.Attr("only_observer"); auto& dev_ctx = context.template device_context(); if (quant_axis < 0) { @@ -91,11 +92,19 @@ class QuantizeLinearKernel : public framework::OpKernel { out_state, out_accum, out_scale); - ClipAndFakeQuantFunctor()( - dev_ctx, *in, *out_scale, bin_cnt, round_type, out); + if (only_observer) { + framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out); + } else { + ClipAndFakeQuantFunctor()( + dev_ctx, *in, *out_scale, bin_cnt, round_type, out); + } } else { - ClipAndFakeQuantFunctor()( - dev_ctx, *in, *in_scale, bin_cnt, round_type, out); + if (only_observer) { + framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out); + } else { + ClipAndFakeQuantFunctor()( + dev_ctx, *in, *in_scale, bin_cnt, round_type, out); + } } } else { if (!is_test) { @@ -103,11 +112,19 @@ class QuantizeLinearKernel : public framework::OpKernel { T* out_scale_data = out_scale->mutable_data(context.GetPlace()); FindChannelAbsMaxFunctor()( dev_ctx, *in, quant_axis, out_scale_data); - ChannelClipAndFakeQuantFunctor()( - dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out); + if (only_observer) { + framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out); + } else { + ChannelClipAndFakeQuantFunctor()( + dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out); + } } else { - ChannelClipAndFakeQuantFunctor()( - dev_ctx, *in, *in_scale, bin_cnt, round_type, quant_axis, out); + if (only_observer) { + framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out); + } else { + ChannelClipAndFakeQuantFunctor()( + dev_ctx, *in, *in_scale, bin_cnt, round_type, quant_axis, out); + } } } } @@ -132,6 +149,12 @@ class DeQuantizeLinearKernel : public framework::OpKernel { int bit_length = context.Attr("bit_length"); auto quant_axis = context.Attr("quant_axis"); dev_ctx.template Alloc(out, out->numel() * sizeof(D)); + bool only_observer = context.Attr("only_observer"); + + if (only_observer) { + framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out); + return; + } if (quant_axis < 0) { float max_range = (std::pow(2, bit_length - 1) - 1); diff --git a/python/paddle/distributed/passes/auto_parallel_quantization.py b/python/paddle/distributed/passes/auto_parallel_quantization.py index 8f75c90880..9019f3b0cc 100644 --- a/python/paddle/distributed/passes/auto_parallel_quantization.py +++ b/python/paddle/distributed/passes/auto_parallel_quantization.py @@ -24,15 +24,19 @@ from paddle.static.quantization import ( AddQuantDequantPassV2, OutScaleForTrainingPass, QuantizationTransformPassV2, - utils, + quant_config, ) from ..auto_parallel.converter import Converter from ..auto_parallel.dist_attribute import OperatorDistAttr, TensorDistAttr from .pass_base import PassBase, register_pass -TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type -QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type +TRANSFORM_PASS_OP_TYPES = list( + quant_config.SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys() +) +QUANT_DEQUANT_PASS_OP_TYPES = list( + quant_config.SUPPORT_ACT_QUANTIZATION_OP_DICT.keys() +) def _node_id(node): diff --git a/python/paddle/static/quantization/post_training_quantization.py b/python/paddle/static/quantization/post_training_quantization.py index f11ae10948..024c227bca 100644 --- a/python/paddle/static/quantization/post_training_quantization.py +++ b/python/paddle/static/quantization/post_training_quantization.py @@ -35,7 +35,15 @@ from ..log_helper import get_logger from . import utils from .adaround import run_adaround from .cal_kl_threshold import cal_kl_threshold +from .quant_config import ( + SUPPORT_QUANTIZATION_OP_DICT, + ARMCPUQuantizer, + BaseQuantizer, + MKLDNNQuantizer, + TensorRTQuantizer, +) from .quantization_pass import ( + AddQuantDequantForInferencePass, AddQuantDequantPass, AddQuantDequantPassV2, QuantizationFreezePass, @@ -127,7 +135,7 @@ class PostTrainingQuantization: batch_nums=None, algo="KL", hist_percent=0.99999, - quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], + quantizable_op_type=[], round_type='round', learning_rate=0.001, is_full_quantize=False, @@ -145,6 +153,7 @@ class PostTrainingQuantization: cache_dir=None, scale_dict=None, return_graph=False, + deploy_backend=None, ): ''' Constructor. @@ -190,8 +199,9 @@ class PostTrainingQuantization: hist_percent(float, optional): The threshold of algo 'hist' for activations. Default is 0.99999. quantizable_op_type(list[str], optional): List the type of ops - that will be quantized. Default is ["conv2d", "depthwise_conv2d", - "mul"]. + that will be quantized. Default is []. If quantizable_op_type is [], + it will use the default quantization op type of the qunat config in + the current deploy_backend. round_type(str, optional): The method of converting the quantized weights value float->int. Currently supports ['round', 'adaround'] methods. Default is `round`, which is rounding nearest to the integer. @@ -199,8 +209,8 @@ class PostTrainingQuantization: learning_rate(float, optional): The learning rate of adaround method. is_full_quantized(bool, optional): If set is_full_quantized as True, apply quantization to all supported quantizable op type. If set - is_full_quantized as False, only apply quantization to the op type - according to the input quantizable_op_type. + is_full_quantized as False, it will apply quantization to the op type + according to the input quantizable_op_type or quant config of deploy_backend. bias_correction(bool, optional): If set as True, use the bias correction method of https://arxiv.org/abs/1810.05723. Default is False. activation_bits(int): quantization bit number for activation. @@ -234,6 +244,9 @@ class PostTrainingQuantization: quantization. Default False. is_use_cache_file(bool, optional): This param is deprecated. cache_dir(str, optional): This param is deprecated. + deploy_backend(str, optional): Deploy backend, it can be None, `TensorRT`, + `MKLDNN`, `ARM`. And it will extend the new backend. Default is None, + which means to use the default general quantization configuration. Returns: None @@ -294,13 +307,6 @@ class PostTrainingQuantization: self._round_type = round_type self._learning_rate = learning_rate self._dynamic_quantize_op_type = ['lstm'] - self._support_quantize_op_type = list( - set( - utils._weight_supported_quantizable_op_type - + utils._act_supported_quantizable_op_type - + self._dynamic_quantize_op_type - ) - ) # Check inputs assert executor is not None, "The executor cannot be None." @@ -355,15 +361,6 @@ class PostTrainingQuantization: self._onnx_format = onnx_format self._clip_extra = True if self._onnx_format else False self._skip_tensor_list = skip_tensor_list - self._is_full_quantize = is_full_quantize - if is_full_quantize: - self._quantizable_op_type = self._support_quantize_op_type - else: - self._quantizable_op_type = quantizable_op_type - for op_type in self._quantizable_op_type: - assert op_type in self._support_quantize_op_type, ( - op_type + " is not supported for quantization." - ) self._optimize_model = optimize_model # Define variables @@ -373,7 +370,6 @@ class PostTrainingQuantization: self._fetch_list = None self._data_loader = data_loader - self._out_scale_op_list = utils.QUANT_SUPPORTED_OP_TYPE_LIST self._quantized_weight_var_name = set() self._quantized_act_var_name = set() self._weight_op_pairs = {} @@ -403,6 +399,43 @@ class PostTrainingQuantization: if self._program is not None: self.FLAG = True + self._is_full_quantize = is_full_quantize + if is_full_quantize: + quantizable_op_type = list(SUPPORT_QUANTIZATION_OP_DICT.keys()) + elif quantizable_op_type: + for op_type in quantizable_op_type: + assert op_type in list(SUPPORT_QUANTIZATION_OP_DICT.keys()), ( + op_type + " is not supported for quantization." + ) + assert ( + activation_bits == weight_bits + ), "activation_bits and weight_bits must be the same, other cases are not supported." + support_deploy_backend = [None, "tensorrt", "mkldnn", "arm"] + if not deploy_backend: + self.quant_config = BaseQuantizer( + quantizable_op_type=quantizable_op_type, + quant_bits=weight_bits, + ) + elif deploy_backend.lower() == "tensorrt": + self.quant_config = TensorRTQuantizer( + quantizable_op_type=quantizable_op_type, + quant_bits=weight_bits, + ) + elif deploy_backend.lower() == "mkldnn": + self.quant_config = MKLDNNQuantizer( + quantizable_op_type=quantizable_op_type, + quant_bits=weight_bits, + ) + elif deploy_backend.lower() == "arm": + self.quant_config = ARMCPUQuantizer( + quantizable_op_type=quantizable_op_type, + quant_bits=weight_bits, + ) + else: + assert "Deploy Backend {} not support, please choose one of {}.".format( + deploy_backend, support_deploy_backend + ) + def quantize(self): ''' Load the FP32 model, and use the calibrate data to calculate the forward-stage. @@ -486,7 +519,7 @@ class PostTrainingQuantization: self._save_output_threshold() if any( - op_type in self._quantizable_op_type + op_type in self.quant_config.activation_quant_operation_types for op_type in self._dynamic_quantize_op_type ): self._collect_dynamic_quantize_op_threshold( @@ -652,9 +685,8 @@ class PostTrainingQuantization: op._set_attr("op_namescope", "skip_quant") op_type = op.type - if ( - self._is_full_quantize - and op_type not in self._quantizable_op_type + if self._is_full_quantize and op_type not in list( + SUPPORT_QUANTIZATION_OP_DICT.keys() ): _logger.warning( op_type + " is not supported for quantization." @@ -664,7 +696,12 @@ class PostTrainingQuantization: in persistable_var_names ) # For quantized ops, sample inputs and outputs - if op_type in self._quantizable_op_type or is_conv1d_quant: + if ( + op_type in self.quant_config.weight_quant_operation_types + or op_type + in self.quant_config.activation_quant_operation_types + or is_conv1d_quant + ): collect_var_name( utils._get_op_input_var_names(op), persistable_var_names, @@ -683,7 +720,7 @@ class PostTrainingQuantization: in_var_name ] = out_var_name # For other op, only sample output scale - elif op_type in self._out_scale_op_list: + elif op_type in self.quant_config.observer_operation_types: collect_var_name( utils._get_op_output_var_names(op), persistable_var_names, @@ -1034,7 +1071,11 @@ class PostTrainingQuantization: ), "The algo should be min_max to save input threshold." for block_id in range(len(self._program.blocks)): for op in self._program.blocks[block_id].ops: - if op.type in self._quantizable_op_type: + if ( + op.type in self.quant_config.weight_quant_operation_types + or op.type + in self.quant_config.activation_quant_operation_types + ): for var_name in utils._get_op_input_var_names(op): assert var_name in self._quantized_var_min assert var_name in self._quantized_var_max @@ -1142,10 +1183,6 @@ class PostTrainingQuantization: graph = IrGraph(core.Graph(self._program.desc), for_test=True) # use QuantizationTransformPass to insert fake_quant/fake_dequantize op - major_quantizable_op_types = [] - for op_type in utils._weight_supported_quantizable_op_type: - if op_type in self._quantizable_op_type: - major_quantizable_op_types.append(op_type) if not self._onnx_format: transform_pass = QuantizationTransformPass( scope=self._scope, @@ -1154,7 +1191,7 @@ class PostTrainingQuantization: activation_bits=self._activation_bits, activation_quantize_type=self._activation_quantize_type, weight_quantize_type=self._weight_quantize_type, - quantizable_op_type=major_quantizable_op_types, + quantizable_op_type=self.quant_config.weight_quant_operation_types, ) else: transform_pass = QuantizationTransformPassV2( @@ -1164,7 +1201,7 @@ class PostTrainingQuantization: activation_bits=self._activation_bits, activation_quantize_type=self._activation_quantize_type, weight_quantize_type=self._weight_quantize_type, - quantizable_op_type=major_quantizable_op_types, + quantizable_op_type=self.quant_config.weight_quant_operation_types, ) for sub_graph in graph.all_sub_graphs(): @@ -1174,22 +1211,17 @@ class PostTrainingQuantization: transform_pass.apply(sub_graph) # use AddQuantDequantPass to insert fake_quant_dequant op - minor_quantizable_op_types = [] - for op_type in utils._act_supported_quantizable_op_type: - if op_type in self._quantizable_op_type: - minor_quantizable_op_types.append(op_type) if not self._onnx_format: add_quant_dequant_pass = AddQuantDequantPass( scope=self._scope, place=self._place, - quantizable_op_type=minor_quantizable_op_types, + quantizable_op_type=self.quant_config.activation_quant_operation_types, ) else: add_quant_dequant_pass = AddQuantDequantPassV2( scope=self._scope, place=self._place, - quantizable_op_type=minor_quantizable_op_types, - is_full_quantized=True, + quantizable_op_type=self.quant_config.activation_quant_operation_types, ) for sub_graph in graph.all_sub_graphs(): @@ -1283,7 +1315,7 @@ class PostTrainingQuantization: round_type=self._round_type, activation_bits=self._activation_bits, weight_quantize_type=self._weight_quantize_type, - quantizable_op_type=major_quantizable_op_types, + quantizable_op_type=self.quant_config.weight_quant_operation_types, ) for sub_graph in graph.all_sub_graphs(): @@ -1295,6 +1327,22 @@ class PostTrainingQuantization: sub_graph._for_test = True quant_weight_pass.apply(sub_graph) + infer_pass_quant_op_types = ( + self.quant_config.weight_quant_operation_types + + self.quant_config.activation_quant_operation_types + + self.quant_config.observer_operation_types + ) + out_scale_infer_pass = AddQuantDequantForInferencePass( + scope=self._scope, + place=self._place, + quant_bits=self._activation_bits, + quantizable_op_type=infer_pass_quant_op_types, + calibration_range_dict=self._scale_dict, + ) + for sub_graph in graph.all_sub_graphs(): + sub_graph._for_test = True + out_scale_infer_pass.apply(sub_graph) + self._program = graph.to_program() def _save_output_threshold(self): @@ -1339,7 +1387,12 @@ class PostTrainingQuantization: threshold_map[out_var_name], ) op_node._set_attr("with_quant_attr", True) - if op_node.type in self._quantizable_op_type: + if ( + op_node.type + in self.quant_config.weight_quant_operation_types + or op_node.type + in self.quant_config.activation_quant_operation_types + ): op._set_attr("quantization_type", quantized_type) def analysis_and_save_info(op_node, out_var_name): @@ -1387,7 +1440,9 @@ class PostTrainingQuantization: for block_id in range(len(self._program.blocks)): for op in self._program.blocks[block_id].ops: if op.type in ( - self._quantizable_op_type + self._out_scale_op_list + self.quant_config.weight_quant_operation_types + + self.quant_config.activation_quant_operation_types + + self.quant_config.observer_operation_types ): out_var_names = utils._get_op_output_var_names(op) for var_name in out_var_names: diff --git a/python/paddle/static/quantization/quant_config.py b/python/paddle/static/quantization/quant_config.py new file mode 100644 index 0000000000..5ddb9b9b2d --- /dev/null +++ b/python/paddle/static/quantization/quant_config.py @@ -0,0 +1,327 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# A dict of operators that contain weights and support quantization, +# including operator names, actual input and output names. +SUPPORT_WEIGHT_QUANTIZATION_OP_DICT = { + "conv2d": [["Input", "Filter"], ["Output"]], + "depthwise_conv2d": [["Input", "Filter"], ["Output"]], + "conv2d_transpose": [["Input", "Filter"], ["Output"]], + "mul": [["X", "Y"], ["Out"]], + "matmul": [["X", "Y"], ["Out"]], + "matmul_v2": [["X", "Y"], ["Out"]], +} + +# A dict of operators that supports quantization and has only activation inputs, +# including operator names, actual input and output names. +SUPPORT_ACT_QUANTIZATION_OP_DICT = { + "mul": [["X", "Y"], ["Out"]], + "matmul": [["X", "Y"], ["Out"]], + "matmul_v2": [["X", "Y"], ["Out"]], + "pool2d": [["X"], ["Out"]], + "elementwise_add": [["X", "Y"], ["Out"]], + "concat": [["X"], ["Out"]], + "softmax": [["X"], ["Out"]], + "argmax": [["X"], ["Out"]], + "transpose": [["X"], ["Out"]], + "equal": [["X", "Y"], ["Out"]], + "gather": [["X"], ["Out"]], + "greater_equal": [["X", "Y"], ["Out"]], + "greater_than": [["X", "Y"], ["Out"]], + "less_equal": [["X", "Y"], ["Out"]], + "less_than": [["X", "Y"], ["Out"]], + "mean": [["X"], ["Out"]], + "not_equal": [["X", "Y"], ["Out"]], + "reshape": [["X"], ["Out"]], + "reshape2": [["X"], ["Out"]], + "transpose2": [["X"], ["Out"]], + "nearest_interp": [["X"], ["Out"]], + "trilinear_interp": [["X"], ["Out"]], + "slice": [["Input"], ["Out"]], + "squeeze": [["X"], ["Out"]], + "elementwise_sub": [["X", "Y"], ["Out"]], + "relu": [["X"], ["Out"]], + "relu6": [["X"], ["Out"]], + "leaky_relu": [["X"], ["Out"]], + "prelu": [["X", "Alpha"], ["Out"]], + "tanh": [["X"], ["Out"]], + "swish": [["X"], ["Out"]], + "dropout": [["X"], ["Out"]], + "batch_norm": [["X"], ["Y"]], + "layer_norm": [["X"], ["Y"]], + "sigmoid": [["X"], ["Out"]], + "elementwise_mul": [["X", "Y"], ["Out"]], + "elementwise_pow": [["X", "Y"], ["Out"]], + "hard_swish": [["X"], ["Out"]], + "hard_sigmoid": [["X"], ["Out"]], + "gru": [["Input", "Weight"], ["Hidden"]], + "lstm": [["Input", "Weight"], ["Hidden"]], + "pad2d": [["X"], ["Out"]], + "pad3d": [["X"], ["Out"]], + "flatten": [["X"], ["Out"]], + "flatten2": [["X"], ["Out"]], + "unsqueeze2": [["X"], ["Out"]], + "flatten_contiguous_range": [["X"], ["Out"]], + "split": [["X"], ["Out"]], + "squeeze2": [["X"], ["Out"]], + "nearest_interp_v2": [["X"], ["Out"]], + "bilinear_interp": [["X"], ["Out"]], + "bilinear_interp_v2": [["X"], ["Out"]], + "fill_constant_batch_size_like": [["Input"], ["Out"]], + "arg_max": [["X"], ["Out"]], + "abs": [["X"], ["Out"]], + "assign": [["X"], ["Out"]], + "cast": [["X"], ["Out"]], + "clip": [["X"], ["Out"]], + "box_coder": [["PriorBox"], ["OutputBox"]], + "crop": [["X"], ["Out"]], + "cumsum": [["X"], ["Out"]], + "expand_v2": [["X"], ["Out"]], + "fill_any_like": [["X"], ["Out"]], + "fill_constant": [[], ["Out"]], + "gelu": [["X"], ["Out"]], + "instance_norm": [["X"], ["Y"]], + "lookup_table": [["W", "Ids"], ["Out"]], + "lookup_table_v2": [["W", "Ids"], ["Out"]], + "norm": [["X"], ["Norm"]], + "p_norm": [["X"], ["Out"]], + "pow": [["X"], ["Out"]], + "reduce_mean": [["X"], ["Out"]], + "stack": [["X"], ["Y"]], + "top_k_v2": [["X"], ["Out", "Indices"]], + "logical_and": [["X", "Y"], ["Out"]], + "logical_not": [["X"], ["Out"]], + "meshgrid": [["X"], ["Out"]], + "roi_align": [["X", "ROIs"], ["Out"]], + "strided_slice": [["Input"], ["Out"]], + "where": [["Condition", "X", "Y"], ["Out"]], + "grid_sampler": [["X", "Grid"], ["Output"]], + "tile": [["X"], ["Out"]], + "group_norm": [["X"], ["Y", "Mean", "Variance"]], + "reduce_sum": [["X"], ["Out"]], + "square": [["X"], ["Out"]], + "softplus": [["X"], ["Out"]], + "shuffle_channel": [["X"], ["Out"]], + "reduce_max": [["X"], ["Out"]], + "scale": [["X"], ["Out"]], +} + +# A full dict of operators that supports quantization, +# including operator names, actual input and output names. +SUPPORT_QUANTIZATION_OP_DICT = SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.copy() +SUPPORT_QUANTIZATION_OP_DICT.update(SUPPORT_ACT_QUANTIZATION_OP_DICT) + + +class BaseQuantizer: + """ + Basic quantization configuration class, which configures some hyperparameters + required for quantization, including the list of op types to be quantized, + quantization bit number for weight and activation and the range of quantization values. + Args: + quantizable_op_type(list[str], optional): List the type of ops + that will be quantized. Default is []. If quantizable_op_type is [], + it will use the default quantization op type of the qunat config in + the current Quantizer. + quant_bits(int, optional): Quantization bit number for weight and activation. + Default is 8. + """ + + def __init__( + self, + quantizable_op_type=[], + quant_bits=8, + ): + self._quantizable_op_type = quantizable_op_type + self._quant_bits = quant_bits + self._quant_min = -128 + self._quant_max = 127 + + @property + def weight_quant_operation_types(self): + """ + Operation type list which should support weight quantization. + And before these ops, quant dequant nodes will be inserted. + """ + base_weight_op_type_list = list( + SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys() + ) + if self._quantizable_op_type: + weight_list = [] + for _op_type in self._quantizable_op_type: + if _op_type in base_weight_op_type_list: + weight_list.append(_op_type) + return weight_list + else: + return base_weight_op_type_list + + @property + def activation_quant_operation_types(self): + """ + Operation type list which should support activation quantization. + And before these ops, quant dequant nodes will be inserted. + """ + base_act_op_type_list = list(SUPPORT_ACT_QUANTIZATION_OP_DICT.keys()) + act_quant_op_list = [] + if self._quantizable_op_type: + for _op_type in self._quantizable_op_type: + if _op_type in base_act_op_type_list: + act_quant_op_list.append(_op_type) + else: + act_quant_op_list = [ + 'mul', + 'matmul', + 'matmul_v2', + ] + return act_quant_op_list + + @property + def observer_operation_types(self): + """ + Operation type list for observer in quantization. These nodes only count the + calibration boundary scale and do not participate in the fake quantization. + In order to facilitate the deployment of the prediction engine, quant + and dequant nodes will be inserted after these ops when exporting the model. + """ + return list(SUPPORT_ACT_QUANTIZATION_OP_DICT.keys()) + + +class TensorRTQuantizer(BaseQuantizer): + """ + TensorRT quantization configuration class. + Args: + quantizable_op_type(list[str], optional): List the type of ops + that will be quantized. Default is []. If quantizable_op_type is [], + it will use the default quantization op type of the qunat config in + the current Quantizer. + quant_bits(int, optional): Quantization bit number for weight and activation. + Default is 8. + """ + + def __init__( + self, + quantizable_op_type=[], + quant_bits=8, + ): + super().__init__() + self._quantizable_op_type = quantizable_op_type + self._quant_bits = quant_bits + self._quant_min = -128 + self._quant_max = 127 + + @property + def activation_quant_operation_types(self): + """ + Operation type list which should support activation quantization. + And before these ops, quant dequant nodes will be inserted. + """ + return [ + "pool2d", + "elementwise_add", + "elementwise_sub", + "elementwise_mul", + "elementwise_pow", + "concat", + "softmax", + "argmax", + "mean", + "relu", + "relu6", + "leaky_relu", + "tanh", + "swish", + "softplus", + "gelu", + "hard_sigmoid", + "hard_swish", + "sigmoid", + "layer_norm", + "matmul_v2", + "split", + "bilinear_interp", + "nearest_interp", + "trilinear_interp", + "nearest_interp_v2", + "bilinear_interp", + "bilinear_interp_v2", + "clip", + "pow", + "reduce_mean", + "reduce_sum", + "reduce_max", + ] + + +class MKLDNNQuantizer(BaseQuantizer): + """ + MKLDNN quantization configuration class. + Args: + quantizable_op_type(list[str], optional): List the type of ops + that will be quantized. Default is []. If quantizable_op_type is [], + it will use the default quantization op type of the qunat config in + the current Quantizer. + quant_bits(int, optional): Quantization bit number for weight and activation. + Default is 8. + """ + + def __init__( + self, + quantizable_op_type=[], + quant_bits=8, + ): + super().__init__() + self._quantizable_op_type = quantizable_op_type + self._quant_bits = quant_bits + self._quant_min = -128 + self._quant_max = 127 + + @property + def activation_quant_operation_types(self): + """ + Operation type list which should support activation quantization. + And before these ops, quant dequant nodes will be inserted. + """ + return [ + "pool2d", + "elementwise_add", + "elementwise_mul", + "concat", + "nearest_interp", + "nearest_interp_v2", + "split", + ] + + +class ARMCPUQuantizer(BaseQuantizer): + """ + ARM CPU with Paddle Lite quantization configuration class. + Args: + quantizable_op_type(list[str], optional): List the type of ops + that will be quantized. Default is []. If quantizable_op_type is [], + it will use the default quantization op type of the qunat config in + the current Quantizer. + quant_bits(int, optional): Quantization bit number for weight and activation. + Default is 8. + """ + + def __init__( + self, + quantizable_op_type=[], + quant_bits=8, + ): + super().__init__() + self._quantizable_op_type = quantizable_op_type + self._quant_bits = quant_bits + self._quant_min = -127 + self._quant_max = 127 diff --git a/python/paddle/static/quantization/quantization_pass.py b/python/paddle/static/quantization/quantization_pass.py index 1198d1c2cf..fc7ab7689e 100644 --- a/python/paddle/static/quantization/quantization_pass.py +++ b/python/paddle/static/quantization/quantization_pass.py @@ -28,6 +28,11 @@ from ...framework import _get_paddle_place, core from ...static import Program, data, program_guard, scope_guard from ...utils import unique_name from . import utils +from .quant_config import ( + SUPPORT_ACT_QUANTIZATION_OP_DICT, + SUPPORT_QUANTIZATION_OP_DICT, + SUPPORT_WEIGHT_QUANTIZATION_OP_DICT, +) _fake_quant_op_list = [ 'fake_quantize_abs_max', @@ -231,7 +236,7 @@ class QuantizationTransformPass: self._quantizable_ops = quantizable_op_type for op in self._quantizable_ops: - assert op in utils._weight_supported_quantizable_op_type, ( + assert op in list(SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys()), ( op + " is not supported for quantization." ) self._quantizable_grad_ops = [ @@ -1594,7 +1599,7 @@ class OutScaleForTrainingPass: self._place = _get_paddle_place(place) self._moving_rate = moving_rate self._is_test = is_test - self._teller_set = utils.QUANT_SUPPORTED_OP_TYPE_LIST + self._teller_set = list(SUPPORT_QUANTIZATION_OP_DICT.keys()) self._scale_dict = scale_dict def apply(self, graph): @@ -1749,7 +1754,7 @@ class OutScaleForInferencePass: scope(static.Scope): The scope is used to initialize these new parameters. """ self._scope = scope - self._teller_set = utils.QUANT_SUPPORTED_OP_TYPE_LIST + self._teller_set = list(SUPPORT_QUANTIZATION_OP_DICT.keys()) def apply(self, graph): """ @@ -1830,7 +1835,6 @@ class AddQuantDequantPass: quant_bits=8, skip_pattern=["skip_quant"], quantizable_op_type=["elementwise_add", "pool2d"], - is_full_quantized=False, is_test=None, scale_dict=None, ): @@ -1851,10 +1855,6 @@ class AddQuantDequantPass: Default is 'skip_quant'. quantizable_op_type(list[str], optional): List the type of ops that will be quantized. Default is ["elementwise_add", "pool2d"]. - is_full_quantized(bool, optional): If set is_full_quantized as True, apply - quantization to all supported quantizable op type. If set is_full_quantized - as False, only apply quantization to the op type according to the input - quantizable_op_type. """ self._scope = scope self._place = _get_paddle_place(place) @@ -1864,14 +1864,11 @@ class AddQuantDequantPass: self._skip_pattern = skip_pattern self._scale_dict = scale_dict - if is_full_quantized: - self._quantizable_op_type = utils._act_supported_quantizable_op_type - else: - self._quantizable_op_type = quantizable_op_type - for op_type in quantizable_op_type: - assert op_type in utils._act_supported_quantizable_op_type, ( - op_type + " is not supported for quantization." - ) + self._quantizable_op_type = quantizable_op_type + for op_type in self._quantizable_op_type: + assert op_type in list(SUPPORT_ACT_QUANTIZATION_OP_DICT.keys()), ( + op_type + " is not supported for quantization." + ) self._quantizable_grad_op_type = [ '%s_grad' % (op) for op in self._quantizable_op_type ] @@ -2485,7 +2482,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass): self._quantizable_ops = quantizable_op_type for op in self._quantizable_ops: - assert op in utils._weight_supported_quantizable_op_type, ( + assert op in list(SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys()), ( op + " is not supported for quantization." ) self._quantizable_grad_ops = [ @@ -2763,7 +2760,6 @@ class AddQuantDequantPassV2: quant_bits=8, skip_pattern=["skip_quant"], quantizable_op_type=["elementwise_add", "pool2d"], - is_full_quantized=False, is_test=None, scale_dict=None, ): @@ -2782,10 +2778,6 @@ class AddQuantDequantPassV2: Default is 'skip_quant'. quantizable_op_type(list[str], optional): List the type of ops that will be quantized. Default is ["elementwise_add", "pool2d"]. - is_full_quantized(bool, optional): If set is_full_quantized as True, apply - quantization to all supported quantizable op type. If set is_full_quantized - as False, only apply quantization to the op type according to the input - quantizable_op_type. scale_dict(dict, optional): calibration ranges of tensors output. Examples: @@ -2811,14 +2803,11 @@ class AddQuantDequantPassV2: self._skip_pattern = skip_pattern self._scale_dict = scale_dict - if is_full_quantized: - self._quantizable_op_type = utils._act_supported_quantizable_op_type - else: - self._quantizable_op_type = quantizable_op_type - for op_type in quantizable_op_type: - assert op_type in utils._act_supported_quantizable_op_type, ( - op_type + " is not supported for quantization." - ) + self._quantizable_op_type = quantizable_op_type + for op_type in self._quantizable_op_type: + assert op_type in list(SUPPORT_ACT_QUANTIZATION_OP_DICT.keys()), ( + op_type + " is not supported for quantization." + ) self._quantizable_grad_op_type = [ '%s_grad' % (op) for op in self._quantizable_op_type ] @@ -3243,7 +3232,15 @@ class AddQuantDequantForInferencePass: When export quant model, it will traverse to find the output of each op, and then insert the quant/dequant op after it. """ - def __init__(self, scope, place, quant_bits=8): + def __init__( + self, + scope, + place, + quant_bits=8, + quantizable_op_type=[], + calibration_range_dict=None, + only_observer=True, + ): """ Args: scope(static.Scope): The scope is used to initialize these new parameters. @@ -3254,7 +3251,13 @@ class AddQuantDequantForInferencePass: self._scope = scope self._place = place self._quant_bits = quant_bits - self._teller_set = utils.QUANT_SUPPORTED_OP_TYPE_LIST + self._only_observer = only_observer + self._teller_set = ( + quantizable_op_type + if quantizable_op_type + else list(SUPPORT_QUANTIZATION_OP_DICT.keys()) + ) + self._calibration_range_dict = calibration_range_dict def apply(self, graph): """ @@ -3321,9 +3324,31 @@ class AddQuantDequantForInferencePass: shape=var_node.shape(), var_dtype=var_node.dtype(), ) - scale_var_node = graph._find_node_by_name( - graph.all_persistable_nodes(), self._scale_name(var_name) - ) + if not self._calibration_range_dict: + scale_var_node = graph._find_node_by_name( + graph.all_persistable_nodes(), self._scale_name(var_name) + ) + elif var_name in self._calibration_range_dict: + scale_value = self._calibration_range_dict[var_name] + scale_var_node = graph.create_persistable_node( + name=self._scale_name(var_name), + var_type=var_node.type(), + shape=[1], + var_dtype=var_node.dtype(), + ) + data_type = ( + 'float64' + if var_node.dtype() == core.VarDesc.VarType.FP64 + else 'float32' + ) + _init_var_node( + scale_var_node, + np.array(scale_value, dtype=data_type), + self._scope, + self._place, + ) + else: + return None try: zero_point_node = graph._find_node_by_name( graph.all_persistable_nodes(), @@ -3347,7 +3372,11 @@ class AddQuantDequantForInferencePass: if zero_point_node is not None: inputs["ZeroPoint"] = zero_point_node - attrs = {"quant_axis": quant_axis, "bit_length": self._quant_bits} + attrs = { + "quant_axis": quant_axis, + "bit_length": self._quant_bits, + "only_observer": self._only_observer, + } attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward outputs = {"Y": quant_var_node} @@ -3376,7 +3405,11 @@ class AddQuantDequantForInferencePass: if zero_point_node is not None: inputs["ZeroPoint"] = zero_point_node - attrs = {"quant_axis": -1, "bit_length": self._quant_bits} + attrs = { + "quant_axis": -1, + "bit_length": self._quant_bits, + "only_observer": self._only_observer, + } attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward dequant_op_node = graph.create_op_node( diff --git a/python/paddle/static/quantization/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/static/quantization/tests/test_post_training_quantization_mobilenetv1.py index bdb80cd3d3..cfef61d51a 100644 --- a/python/paddle/static/quantization/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/static/quantization/tests/test_post_training_quantization_mobilenetv1.py @@ -277,6 +277,7 @@ class TestPostTrainingQuantization(unittest.TestCase): is_optimize_model=False, batch_nums=10, onnx_format=False, + deploy_backend=None, ): try: os.system("mkdir " + self.int8_model) @@ -305,6 +306,7 @@ class TestPostTrainingQuantization(unittest.TestCase): optimize_model=is_optimize_model, onnx_format=onnx_format, is_use_cache_file=is_use_cache_file, + deploy_backend=deploy_backend, ) ptq.quantize() ptq.save_quantized_model( @@ -329,6 +331,7 @@ class TestPostTrainingQuantization(unittest.TestCase): diff_threshold, onnx_format=False, batch_nums=10, + deploy_backend=None, ): infer_iterations = self.infer_iterations batch_size = self.batch_size @@ -361,6 +364,7 @@ class TestPostTrainingQuantization(unittest.TestCase): is_optimize_model, batch_nums, onnx_format, + deploy_backend, ) print( @@ -571,5 +575,131 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): ) +class TestPostTrainingAvgONNXFormatForMobilenetv1TensorRT( + TestPostTrainingQuantization +): + def test_post_training_onnx_format_mobilenetv1_tensorrt(self): + model = "MobileNet-V1" + algo = "avg" + round_type = "round" + data_urls = [ + 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar' + ] + data_md5s = ['5ee2b1775b11dc233079236cdc216c2e'] + quantizable_op_type = [ + "conv2d", + "depthwise_conv2d", + "mul", + ] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = False + onnx_format = True + diff_threshold = 0.05 + batch_nums = 10 + deploy_backend = "tensorrt" + self.run_test( + model, + 'inference.pdmodel', + 'inference.pdiparams', + algo, + round_type, + data_urls, + data_md5s, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + onnx_format=onnx_format, + batch_nums=batch_nums, + deploy_backend=deploy_backend, + ) + + +class TestPostTrainingKLONNXFormatForMobilenetv1MKLDNN( + TestPostTrainingQuantization +): + def test_post_training_onnx_format_mobilenetv1_mkldnn(self): + model = "MobileNet-V1" + algo = "ptf" + round_type = "round" + data_urls = [ + 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar' + ] + data_md5s = ['5ee2b1775b11dc233079236cdc216c2e'] + quantizable_op_type = [ + "conv2d", + "depthwise_conv2d", + "mul", + ] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = False + onnx_format = True + diff_threshold = 0.05 + batch_nums = 2 + deploy_backend = "mkldnn" + self.run_test( + model, + 'inference.pdmodel', + 'inference.pdiparams', + algo, + round_type, + data_urls, + data_md5s, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + onnx_format=onnx_format, + batch_nums=batch_nums, + deploy_backend=deploy_backend, + ) + + +class TestPostTrainingAvgONNXFormatForMobilenetv1ARMCPU( + TestPostTrainingQuantization +): + def test_post_training_onnx_format_mobilenetv1_armcpu(self): + model = "MobileNet-V1" + algo = "avg" + round_type = "round" + data_urls = [ + 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar' + ] + data_md5s = ['5ee2b1775b11dc233079236cdc216c2e'] + quantizable_op_type = [ + "conv2d", + "depthwise_conv2d", + "mul", + ] + is_full_quantize = False + is_use_cache_file = False + is_optimize_model = True + onnx_format = True + diff_threshold = 0.05 + batch_nums = 3 + deploy_backend = "arm" + self.run_test( + model, + 'inference.pdmodel', + 'inference.pdiparams', + algo, + round_type, + data_urls, + data_md5s, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + onnx_format=onnx_format, + batch_nums=batch_nums, + deploy_backend=deploy_backend, + ) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/static/quantization/tests/test_post_training_quantization_while.py b/python/paddle/static/quantization/tests/test_post_training_quantization_while.py index 71482209b3..389465df89 100644 --- a/python/paddle/static/quantization/tests/test_post_training_quantization_while.py +++ b/python/paddle/static/quantization/tests/test_post_training_quantization_while.py @@ -188,7 +188,6 @@ class TestPostTrainingQuantization(unittest.TestCase): ): origin_model_path = self.download_model(data_url, data_md5, model_name) - # origin_model_path = os.path.join(origin_model_path, model_name) print( "Start FP32 inference for {0} on {1} images ...".format( diff --git a/python/paddle/static/quantization/utils.py b/python/paddle/static/quantization/utils.py index 05e197bb75..39c4db33aa 100644 --- a/python/paddle/static/quantization/utils.py +++ b/python/paddle/static/quantization/utils.py @@ -17,115 +17,7 @@ import sys import numpy as np from ...fluid.framework import IrNode, Operator - -_weight_supported_quantizable_op_type = [ - 'conv2d', - 'depthwise_conv2d', - 'conv2d_transpose', - 'mul', - 'matmul', - 'matmul_v2', -] - -_act_supported_quantizable_op_type = [ - "pool2d", - "elementwise_add", - "concat", - "softmax", - "argmax", - "transpose", - "equal", - "gather", - "greater_equal", - "greater_than", - "less_equal", - "less_than", - "mean", - "not_equal", - "reshape", - "reshape2", - "dropout", - "bilinear_interp", - "nearest_interp", - "trilinear_interp", - "slice", - "squeeze", - "elementwise_sub", - "mul", - "matmul", - "relu", - "relu6", - "leaky_relu", - "tanh", - "swish", - "transpose", - "transpose2", - "sigmoid", - "pad2d", - "flatten", - "flatten2", - "batch_norm", - "layer_norm", - "matmul_v2", - "split", - "flatten_contiguous_range", - "squeeze2", - "nearest_interp_v2", - "bilinear_interp", - "bilinear_interp_v2", - "fill_constant_batch_size_like", - "arg_max", - "abs", - "assign", - "cast", - "clip", - "box_coder", - "crop", - "cumsum", - "elementwise_mul", - "elementwise_pow", - "expand_v2", - "fill_any_like", - "fill_constant", - "gelu", - "hard_sigmoid", - "hard_swish", - "instance_norm", - "lookup_table", - "lookup_table_v2", - "norm", - "p_norm", - "pad3d", - "pow", - "prelu", - "reduce_mean", - "unsqueeze", - "unsqueeze2", - "logical_and", - "logical_not", - "meshgrid", - "roi_align", - "strided_slice", - "where", - "grid_sampler", - "tile", - "group_norm", - "reduce_sum", - "square", - "softplus", - "shuffle_channel", - "reduce_max", - "scale", -] - -QUANT_SUPPORTED_OP_TYPE_LIST = list( - set( - _weight_supported_quantizable_op_type - + _act_supported_quantizable_op_type - ) -) - -_out_scale_op_list = QUANT_SUPPORTED_OP_TYPE_LIST +from .quant_config import SUPPORT_QUANTIZATION_OP_DICT _channelwise_quant_axis1_ops = [ 'conv2d_transpose', @@ -134,102 +26,6 @@ _channelwise_quant_axis1_ops = [ 'matmul_v2', ] -# list op real input and output names, to avoid processing input such as AxisTensor. -_op_real_in_out_name = { - "conv2d": [["Input", "Filter"], ["Output"]], - "depthwise_conv2d": [["Input", "Filter"], ["Output"]], - "conv2d_transpose": [["Input", "Filter"], ["Output"]], - "mul": [["X", "Y"], ["Out"]], - "matmul": [["X", "Y"], ["Out"]], - "matmul_v2": [["X", "Y"], ["Out"]], - "pool2d": [["X"], ["Out"]], - "elementwise_add": [["X", "Y"], ["Out"]], - "concat": [["X"], ["Out"]], - "softmax": [["X"], ["Out"]], - "argmax": [["X"], ["Out"]], - "transpose": [["X"], ["Out"]], - "equal": [["X", "Y"], ["Out"]], - "gather": [["X"], ["Out"]], - "greater_equal": [["X", "Y"], ["Out"]], - "greater_than": [["X", "Y"], ["Out"]], - "less_equal": [["X", "Y"], ["Out"]], - "less_than": [["X", "Y"], ["Out"]], - "mean": [["X"], ["Out"]], - "not_equal": [["X", "Y"], ["Out"]], - "reshape": [["X"], ["Out"]], - "reshape2": [["X"], ["Out"]], - "transpose2": [["X"], ["Out"]], - "nearest_interp": [["X"], ["Out"]], - "trilinear_interp": [["X"], ["Out"]], - "slice": [["Input"], ["Out"]], - "squeeze": [["X"], ["Out"]], - "elementwise_sub": [["X", "Y"], ["Out"]], - "relu": [["X"], ["Out"]], - "relu6": [["X"], ["Out"]], - "leaky_relu": [["X"], ["Out"]], - "prelu": [["X", "Alpha"], ["Out"]], - "tanh": [["X"], ["Out"]], - "swish": [["X"], ["Out"]], - "dropout": [["X"], ["Out"]], - "batch_norm": [["X"], ["Y"]], - "layer_norm": [["X"], ["Y"]], - "sigmoid": [["X"], ["Out"]], - "elementwise_mul": [["X", "Y"], ["Out"]], - "elementwise_pow": [["X", "Y"], ["Out"]], - "hard_swish": [["X"], ["Out"]], - "hard_sigmoid": [["X"], ["Out"]], - "gru": [["Input", "Weight"], ["Hidden"]], - "lstm": [["Input", "Weight"], ["Hidden"]], - "pad2d": [["X"], ["Out"]], - "pad3d": [["X"], ["Out"]], - "flatten": [["X"], ["Out"]], - "flatten2": [["X"], ["Out"]], - "unsqueeze2": [["X"], ["Out"]], - "flatten_contiguous_range": [["X"], ["Out"]], - "split": [["X"], ["Out"]], - "squeeze2": [["X"], ["Out"]], - "nearest_interp_v2": [["X"], ["Out"]], - "bilinear_interp": [["X"], ["Out"]], - "bilinear_interp_v2": [["X"], ["Out"]], - "fill_constant_batch_size_like": [["Input"], ["Out"]], - "arg_max": [["X"], ["Out"]], - "abs": [["X"], ["Out"]], - "assign": [["X"], ["Out"]], - "cast": [["X"], ["Out"]], - "clip": [["X"], ["Out"]], - "box_coder": [["PriorBox"], ["OutputBox"]], - "crop": [["X"], ["Out"]], - "cumsum": [["X"], ["Out"]], - "expand_v2": [["X"], ["Out"]], - "fill_any_like": [["X"], ["Out"]], - "fill_constant": [[], ["Out"]], - "gelu": [["X"], ["Out"]], - "instance_norm": [["X"], ["Y"]], - "lookup_table": [["W", "Ids"], ["Out"]], - "lookup_table_v2": [["W", "Ids"], ["Out"]], - "norm": [["X"], ["Norm"]], - "p_norm": [["X"], ["Out"]], - "pow": [["X"], ["Out"]], - "reduce_mean": [["X"], ["Out"]], - "stack": [["X"], ["Y"]], - "top_k_v2": [["X"], ["Out", "Indices"]], - "logical_and": [["X", "Y"], ["Out"]], - "logical_not": [["X"], ["Out"]], - "meshgrid": [["X"], ["Out"]], - "roi_align": [["X", "ROIs"], ["Out"]], - "strided_slice": [["Input"], ["Out"]], - "where": [["Condition", "X", "Y"], ["Out"]], - "grid_sampler": [["X", "Grid"], ["Output"]], - "tile": [["X"], ["Out"]], - "group_norm": [["X"], ["Y", "Mean", "Variance"]], - "reduce_sum": [["X"], ["Out"]], - "square": [["X"], ["Out"]], - "softplus": [["X"], ["Out"]], - "shuffle_channel": [["X"], ["Out"]], - "reduce_max": [["X"], ["Out"]], - "scale": [["X"], ["Out"]], -} - def _get_op_input_var_names(op): """ @@ -244,10 +40,10 @@ def _get_op_input_var_names(op): ), "The input op should be IrNode or Operator." var_names = [] op_name = op.name() if isinstance(op, IrNode) else op.type - if op_name not in _op_real_in_out_name: + if op_name not in SUPPORT_QUANTIZATION_OP_DICT: return [] - name_list = _op_real_in_out_name[op_name][0] + name_list = SUPPORT_QUANTIZATION_OP_DICT[op_name][0] for name in name_list: var_name = op.input(name) if isinstance(var_name, list): @@ -264,10 +60,10 @@ def _get_op_output_var_names(op): ), "The input op should be IrNode or Operator." var_names = [] op_name = op.name() if isinstance(op, IrNode) else op.type - if op_name not in _op_real_in_out_name: + if op_name not in SUPPORT_QUANTIZATION_OP_DICT: return [] - name_list = _op_real_in_out_name[op_name][1] + name_list = SUPPORT_QUANTIZATION_OP_DICT[op_name][1] for name in name_list: var_name = op.output(name) if isinstance(var_name, list): @@ -283,11 +79,11 @@ def _get_input_name_index(op, input_var_name): op, (IrNode, Operator) ), "The input op should be IrNode or Operator." op_name = op.name() if isinstance(op, IrNode) else op.type - if op_name not in _op_real_in_out_name: + if op_name not in SUPPORT_QUANTIZATION_OP_DICT: return None res = None - for argname in _op_real_in_out_name[op_name][0]: + for argname in SUPPORT_QUANTIZATION_OP_DICT[op_name][0]: var_names = op.input(argname) for index, name in enumerate(var_names): if name == input_var_name: @@ -301,10 +97,10 @@ def _get_output_name_index(op, output_var_name): op, (IrNode, Operator) ), "The input op should be IrNode or Operator." op_name = op.name() if isinstance(op, IrNode) else op.type - if op_name not in _op_real_in_out_name: + if op_name not in SUPPORT_QUANTIZATION_OP_DICT: return None - name_list = _op_real_in_out_name[op_name][1] + name_list = SUPPORT_QUANTIZATION_OP_DICT[op_name][1] res = None for name in name_list: var_name = op.output(name) @@ -347,7 +143,7 @@ def quant_tensor(x, scale, quant_axis=0, weight_bits=8, onnx_format=False): if isinstance(scale, list) and len(scale) == 1: scale = scale[0] if isinstance(scale, list): - assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.' + assert quant_axis in [-1, 0, 1], 'quant_axis should be 0 or 1 for now.' for i, s in enumerate(scale): if s == 0.0: s = 1e-8 -- GitLab