diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index afe8a3de6673f0869e2a1cb588bffe0167b69b8d..04aec158eace6b1951b0d55045ad930983899cdc 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -25,101 +25,99 @@ from paddle.fluid.executor import Executor from paddle.fluid.param_attr import ParamAttr from paddle.fluid.initializer import Constant from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX -from paddle.nn import Linear, Conv2D, Conv2DTranspose, MaxPool2D, MaxPool1D, BatchNorm1D, BatchNorm2D, BatchNorm3D, SyncBatchNorm +from paddle.nn import Linear, Conv2D, Conv2DTranspose, MaxPool2D, MaxPool1D +from paddle.nn import BatchNorm1D, BatchNorm2D, BatchNorm3D, SyncBatchNorm from paddle.fluid.dygraph.nn import BatchNorm, Pool2D from paddle.fluid.io import load_inference_model, save_inference_model -from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6, Tanh, Softmax, PReLU, Swish +from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6 +from paddle.nn.layer.activation import Tanh, Softmax, PReLU, Swish from paddle.fluid.log_helper import get_logger from . import quant_nn from .. import quantization_pass +from . import utils -__all__ = ['ImperativeQuantAware', 'ImperativeCalcOutScale'] +__all__ = ['ImperativeQuantAware'] _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') -_op_real_in_out_name = { - "conv2d": [["Input", "Filter"], ["Output"]], - "depthwise_conv2d": [["Input", "Filter"], ["Output"]], - "pool2d": [["X"], ["Out"]], - "elementwise_add": [["X", "Y"], ["Out"]], - "softmax": [["X"], ["Out"]], - "relu": [["X"], ["Out"]], - "relu6": [["X"], ["Out"]], - "leaky_relu": [["X"], ["Out"]], - "prelu": [["X"], ["Out"]], - "tanh": [["X"], ["Out"]], - "batch_norm": [["X"], ["Y"]], - "sigmoid": [["X"], ["Out"]], - "swish": [["X"], ["Out"]], -} - class ImperativeQuantAware(object): """ - Add the fake quant logic for given quantizable layers, namely add the quant_dequant - computational logic both for activation inputs and weight inputs. + Applying quantization aware training (QAT) to dgraph model. """ def __init__(self, - weight_bits=8, - activation_bits=8, + quantizable_layer_type=['Conv2D', 'Linear'], weight_quantize_type='abs_max', activation_quantize_type='moving_average_abs_max', + weight_bits=8, + activation_bits=8, moving_rate=0.9, - quantizable_layer_type=['Conv2D', 'Linear'], weight_preprocess_layer=None, act_preprocess_layer=None, weight_quantize_layer=None, act_quantize_layer=None): - r""" + """ The constructor for ImperativeQuantAware. Args: - weight_bits(int): quantization bit number for weights, - whereas the bias is not quantized. - activation_bits(int): quantization bit number for activations. + quantizable_layer_type(list[str]): List the type of layers that + will be quantized. Default is ['Conv2D', 'Linear']. + The quantizable_op_type in QuantizationFreezePass and + ConvertToInt8Pass must be the same as this. weight_quantize_type(str): quantization type for weights, which supports 'abs_max' now. The 'moving_average_abs_max' - usually is not used for weights, since weights are fixed once the - model is well trained. + usually is not used for weights, since weights are fixed + once the model is well trained. activation_quantize_type(str): quantization type for activations, which supports 'abs_max' and 'moving_average_abs_max' now. - If using 'abs_max' mode, the quantization scale will be calculated - dynamically each step in both training and testing period. If using - 'moving_average_abs_max', the static quantization scale will be calculated - during training and used in inference. - moving_rate(float): the parameter for 'moving_average_abs_max' quantization. - quantizable_layer_type(list[str]): List the type of layers that will be quantized. - Default is ['Conv2D', 'Linear']. The quantizable_op_type in - QuantizationFreezePass and ConvertToInt8Pass must be the same as this. - weight_preprocess_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to preprocess - weight before quantization. Using this can quickly test if user's - preprocess method works or not. The input is non-quantized - weight and function returns processed weight to be quantized. - If None, the weight will be quantized directly. Default is None. - act_preprocess_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to preprocess - activation before quantization. Using this can quickly test if user's - preprocess method works or not. The input is non-quantized - activation and function returns processed activation to be quantized. - If None, the activation will be quantized directly. Default is None. - weight_quantize_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to quantize weight. + If using 'abs_max' mode, the quantization scale will be + calculated dynamically each step in both training and testing + period. If using 'moving_average_abs_max', the static + quantization scale will be calculated during training and + used in inference. + weight_bits(int): quantization bit number for weights, + whereas the bias is not quantized. + activation_bits(int): quantization bit number for activations. + moving_rate(float): the parameter for 'moving_average_abs_max' + quantization. + weight_preprocess_layer(paddle.nn.Layer, optional): A paddle + Layer that defines how to preprocess weight before quantization. + Using this can quickly test if user's preprocess method works + or not. The input is non-quantized weight and function returns + processed weight to be quantized. + If None, the weight will be quantized directly. + Default is None. + act_preprocess_layer(paddle.nn.Layer, optional): A paddle Layer + that defines how to preprocess activation before quantization. + Using this can quickly test if user's preprocess method works + or not. The input is non-quantized activation and function returns + processed activation to be quantized. + If None, the activation will be quantized directly. + Default is None. + weight_quantize_layer(paddle.nn.Layer, optional): A paddle Layer that + defines how to quantize weight. Using this can quickly test if user's quantization method works or not. In this layer, user should both define quantization method and dequantization method, that is, the function's input is non-quantized - weight and returns dequantized weight. If None, will use - quantization op defined by 'weight_quantize_type'. Default is None. - act_quantize_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to quantize activation. + weight and returns dequantized weight. + If None, will use uantization op defined by 'weight_quantize_type'. + Default is None. + act_quantize_layer(paddle.nn.Layer, optional): A paddle Layer that defines + how to quantize activation. Using this can quickly test if user's quantization method works or not. In this layer, user should both define quantization method and dequantization method, that is, the function's input is non-quantized - activation and returns dequantized activation. If None, will use - quantization op defined by 'activation_quantize_type'. Default is None. + activation and returns dequantized activation. + If None, will use quantization op defined by 'activation_quantize_type'. + Default is None. Note: - If user sets attribute 'skip_quant' to a Layer that support dynamic quantization and sets - it to true, the layer would not be quantized during training. If this attribute is not sets - or the attribute is false, the Layer would be qunatized in training. + If user sets attribute 'skip_quant' to a Layer that support dynamic + quantization and sets it to true, the layer would not be quantized + during training. If this attribute is not sets or the attribute is + false, the Layer would be qunatized in training. Examples 1: .. code-block:: python @@ -196,141 +194,175 @@ class ImperativeQuantAware(object): model_path="./imperative_model_qat") """ super(ImperativeQuantAware, self).__init__() - self._weight_bits = weight_bits - self._activation_bits = activation_bits - self._moving_rate = moving_rate - self._activation_quantize_type = activation_quantize_type - self._weight_quantize_type = weight_quantize_type - - self._weight_pre_layer = weight_preprocess_layer - self._act_pre_layer = act_preprocess_layer - self._weight_quant_layer = weight_quantize_layer - self._act_quant_layer = act_quantize_layer - self._out_scale = ImperativeCalcOutScale() - - t_check = lambda method: method is None or issubclass(method, dygraph.layers.Layer) - assert t_check( - self._weight_pre_layer), "weight_preprocess should be nn.Layer" - assert t_check(self._act_pre_layer), "act_preprocess should be nn.Layer" - assert t_check( - self._weight_quant_layer), "weight_quantize should be nn.Layer" - assert t_check(self._act_quant_layer), "act_quantize should be nn.Layer" - - quant_type = { - 'abs_max', 'moving_average_abs_max', 'channel_wise_abs_max' - } - assert activation_quantize_type != 'channel_wise_abs_max', \ - "The activation quantization type does not support 'channel_wise_abs_max'." - if activation_quantize_type not in quant_type: - raise ValueError( - "Unknown activation_quantize_type : '%s'. It can only be " - "'abs_max' or 'moving_average_abs_max' now." % - (str(activation_quantize_type))) - if weight_quantize_type not in quant_type: - raise ValueError( - "Unknown weight_quantize_type: '%s'. It can only be " - "'abs_max' or 'moving_average_abs_max' or 'channel_wise_abs_max' now." - % (str(weight_quantize_type))) - - self._quant_layers_map = { - 'Conv2D': Conv2D, - 'Linear': Linear, - 'Pool2D': Pool2D, - 'ReLU': ReLU, - 'LeakyReLU': LeakyReLU, - 'ReLU6': ReLU6, - 'Softmax': Softmax, - 'Tanh': Tanh, - 'Swish': Swish + kwargs = { + "quantizable_layer_type": quantizable_layer_type, + "weight_quantize_type": weight_quantize_type, + "activation_quantize_type": activation_quantize_type, + "weight_bits": weight_bits, + "activation_bits": activation_bits, + "moving_rate": moving_rate, + "weight_preprocess_layer": weight_preprocess_layer, + "act_preprocess_layer": act_preprocess_layer, + "weight_quantize_layer": weight_quantize_layer, + "act_quantize_layer": act_quantize_layer } - self._quantizable_layer_type = tuple( - self._quant_layers_map[layer] - if layer in self._quant_layers_map else layer - for layer in quantizable_layer_type) - for layer in self._quantizable_layer_type: - assert not isinstance( - layer, str), "{} is unspported to be quantized.".format(layer) + + self._quantize_inputs = ImperativeQuantizeInputs(**kwargs) + + self._calc_output_scale = ImperativeCalcOutputScale() def quantize(self, model): """ - According to weights' and activations' quantization types, the model will be added some fake - quant ops, such as fake_quantize_dequantize_moving_average_abs_max, fake_quantize_dequantize_abs_max - and so on. At the same time, the out_scale value of outputs would be calculated. + According to weights' and activations' quantization types, + the model will be added some fake quant ops, such as + fake_quantize_dequantize_moving_average_abs_max, + fake_quantize_dequantize_abs_max and so on. At the same time, + the out_scale value of outputs would be calculated. Args: model(fluid.dygraph.Layer): the model to be quantized. Returns: None """ + assert isinstance(model, dygraph.Layer), \ + "The model must be the instance of dygraph.Layer." + self._quantize_inputs.apply(model) + self._calc_output_scale.apply(model) + + def save_quantized_model(self, layer, path, input_spec=None, **config): + self._calc_output_scale.save_quantized_model(layer, path, input_spec, + **config) + + +class ImperativeQuantizeInputs(object): + """ + Based on the input params, add the quant_dequant computational + logic both for activation inputs and weight inputs. + """ + + def __init__(self, + quantizable_layer_type=['Conv2D', 'Linear'], + weight_quantize_type='abs_max', + activation_quantize_type='moving_average_abs_max', + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + weight_preprocess_layer=None, + act_preprocess_layer=None, + weight_quantize_layer=None, + act_quantize_layer=None): + """ + The constructor for ImperativeQuantizeInputs. + + Please refer to the args of ImperativeQuantAware. + """ + super(ImperativeQuantizeInputs, self).__init__() + + self._quantizable_layer_type = tuple( + utils._quant_layers_map[layer] + if layer in utils._quant_layers_map else layer + for layer in quantizable_layer_type) + for layer in self._quantizable_layer_type: + assert not isinstance(layer, str), \ + "%s is unspported to be quantized." % layer + + quantize_type = { + 'abs_max', 'moving_average_abs_max', 'channel_wise_abs_max' + } + assert weight_quantize_type in quantize_type, \ + "Unsupported weight_quantize_type: %s. It can only " \ + "be abs_max or moving_average_abs_max or " \ + "channel_wise_abs_max." % weight_quantize_type + assert activation_quantize_type != 'channel_wise_abs_max' \ + and activation_quantize_type in quantize_type, \ + "Unsupported activation_quantize_type: %s. It can " \ + "only be abs_max or moving_average_abs_max now." \ + % activation_quantize_type + + bits_check = lambda bits: isinstance(bits, int) \ + and bits >= 0 and bits <= 16 + assert bits_check(weight_bits), \ + "weight_bits should be 1, 2,... or 16." + assert bits_check(activation_bits), \ + "activation_bits should be 1, 2,... or 16." + + layer_check = lambda method: method is None or \ + issubclass(method, dygraph.layers.Layer) + assert layer_check(weight_preprocess_layer), \ + "weight_preprocess should be nn.Layer." + assert layer_check(act_preprocess_layer), \ + "act_preprocess should be nn.Layer." + assert layer_check(weight_quantize_layer), \ + "weight_quantize should be nn.Layer." + assert layer_check(act_quantize_layer), \ + "act_quantize should be nn.Layer." + + self._kwargs = { + "weight_quantize_type": weight_quantize_type, + "activation_quantize_type": activation_quantize_type, + "weight_bits": weight_bits, + "activation_bits": activation_bits, + "moving_rate": moving_rate, + "weight_pre_layer": weight_preprocess_layer, + "act_pre_layer": act_preprocess_layer, + "weight_quant_layer": weight_quantize_layer, + "act_quant_layer": act_quantize_layer + } + + def apply(self, model): + assert isinstance(model, dygraph.Layer), \ + "The model must be the instance of dygraph.Layer." + for name, layer in model.named_sublayers(): - if not isinstance(layer, self._quantizable_layer_type): - continue - if hasattr(layer, "skip_quant") and layer.skip_quant == True: + if not isinstance(layer, self._quantizable_layer_type) \ + or (hasattr(layer, "skip_quant") \ + and layer.skip_quant == True): continue + # TODO(jc): optimize this module last_idx = 0 idx = 0 obj = model - parent = model - while idx < len(name): if (name[idx] == '.'): - if hasattr(parent, name[last_idx:idx]): + if hasattr(obj, name[last_idx:idx]): obj = getattr(obj, name[last_idx:idx]) - parent = obj last_idx = idx + 1 idx += 1 target = name[last_idx:idx] - quant_layer = self._get_quantized_counterpart(layer) + quant_layer = self._get_quantized_layer(layer) setattr(quant_layer, "layer_name", layer.full_name()) setattr(obj, target, quant_layer) - self._out_scale.calc_out_scale(model) - - def _get_quantized_counterpart(self, layer): - quant_layers = tuple(self._quant_layers_map.values()) - quantized_counterpart = tuple('Quantized' + k - for k in self._quant_layers_map.keys()) - - predicate = lambda value: isinstance(layer, value) - index_generator = (i for i, v in enumerate(quant_layers) - if predicate(v)) - - try: - index = next(index_generator) - except StopIteration: - _logger.fatal("The layer {} is unsupported to be quantized.".format( - layer.full_name())) - sys.exit(-1) + def _get_quantized_layer(self, layer): + quant_layer_name = None + for key, value in utils._quant_layers_map.items(): + if isinstance(layer, value): + quant_layer_name = 'Quantized' + key + break + assert quant_layer_name is not None, \ + "The layer %s is unsupported to be quantized." \ + % layer.full_name() layer_with_weight = ['QuantizedConv2D', 'QuantizedLinear'] - if quantized_counterpart[index] not in layer_with_weight: - quant_layer_class_name = 'QuantizedNoweightLayer' - else: - quant_layer_class_name = quantized_counterpart[index] - quantized_layer = quant_nn.__dict__[quant_layer_class_name]( - layer, self._weight_bits, self._activation_bits, self._moving_rate, - self._weight_quantize_type, self._activation_quantize_type, - self._weight_pre_layer, self._act_pre_layer, - self._weight_quant_layer, self._act_quant_layer) - return quantized_layer + if quant_layer_name not in layer_with_weight: + quant_layer_name = 'QuantizedNoweightLayer' - def save_quantized_model(self, layer, path, input_spec=None, **config): - self._out_scale.save_quantized_model(layer, path, input_spec, **config) + return quant_nn.__dict__[quant_layer_name](layer, **self._kwargs) -class ImperativeCalcOutScale(object): +class ImperativeCalcOutputScale(object): def __init__(self, moving_rate=0.9): """ - Add the logic of calculating and setting output quantization scales of some layers. - These output quantization scales may be used by tensorRT or some other inference engines. + Add the logic of calculating and setting output scales of some layers. Args: - moving_rate(float): The decay coefficient of moving average. The default value is 0.9. + moving_rate(float): The decay coefficient of moving average. + The default value is 0.9. """ - super(ImperativeCalcOutScale, self).__init__() + super(ImperativeCalcOutputScale, self).__init__() self._moving_rate = moving_rate self._out_scale_layer_type_list = ( BatchNorm, BatchNorm1D, BatchNorm2D, BatchNorm3D, Conv2D, LeakyReLU, @@ -339,83 +371,22 @@ class ImperativeCalcOutScale(object): self._register_hook_handle_list = [] self._out_scale_dict = collections.OrderedDict() - # Determine whether layer supports calculation out_scale - def _is_matched_layer(self, layer): - if not isinstance(layer, self._out_scale_layer_type_list): - if 'quantized_' not in layer.full_name(): - return False - return True - - # When inferenc model is saved, the logic in hook would not be executed - # in program translation, so that some parameters can not created in - # __init__, which would cause the model to fail to save. Therefore, the - # parameters creation in the hook is advanced to be exected outside the hook. - def _add_new_parameters(self, layer, name=None): - dtype = layer._dtype if layer._dtype is not None else "float32" - if dtype not in ["float32", "float64"]: - return - scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale' - scale_name = unique_name.generate(scale_prefix) - scale_attr = ParamAttr( - name=scale_name, initializer=Constant(1), trainable=False) - layer._quant_out_scale = layer.create_parameter( - shape=[1], attr=scale_attr, dtype=dtype) - layer._quant_out_scale.stop_gradient = True - - state_prefix = "{}.state".format(name) if name else 'outscale.state' - state_attr = ParamAttr( - name=unique_name.generate(state_prefix), - initializer=Constant(1), - trainable=False) - layer._quant_out_state = layer.create_parameter( - shape=[1], attr=state_attr, dtype=dtype) - layer._quant_out_state.stop_gradient = True - - accum_prefix = "{}.accum".format(name) if name else 'outscale.accum' - accum_attr = ParamAttr( - name=unique_name.generate(accum_prefix), - initializer=Constant(1), - trainable=False) - layer._quant_out_accum = layer.create_parameter( - shape=[1], attr=accum_attr, dtype=dtype) - layer._quant_out_accum.stop_gradient = True - - # Judge whether the op in program matches the Layer in dynamic model - def _is_op_matched(self, layer_name, op, block): - output_var_names = quantization_pass._get_op_output_var_names(op) - for output_var_name in output_var_names: - output_var_tensor = block.var(output_var_name) - if output_var_tensor.dtype not in [ - core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32 - ]: - return False - - # Because the naming styles of static and dynamic graph are different, - # in order to avoid mistakes, we unify the name here. - op_type = output_var_names[0].split(".")[0] - op_type = op_type.rsplit("_", 1)[0] - if op_type == 'depthwise_conv2d': - op_type = 'conv2d' - if 'prelu' in op_type: - op_type = op_type.replace('prelu', 'p_re_lu') - if 'relu' in op_type: - op_type = op_type.replace('relu', 're_lu') - return op_type in layer_name - - def calc_out_scale(self, model): + def apply(self, model): """ - Insert the `moving_average_abs_max_scale` op to calculate output scale of Specific layers in model. + Insert the `moving_average_abs_max_scale` op to calculate output + scale of specific layers in model. Args: - model(fluid.dygraph.Layer): The target model which would be calculate the output quantization scale. + model(fluid.dygraph.Layer): The target model which would be + calculate the output quantization scale. Returns: None """ - assert isinstance( - model, dygraph.Layer), "model must be the instance of dygraph.Layer" + assert isinstance(model, dygraph.Layer), \ + "The model must be the instance of dygraph.Layer." for _, layer in model.named_sublayers(): - if self._is_matched_layer(layer): + if self._is_target_layer(layer): self._add_new_parameters(layer) forward_post_hook_handle = layer.register_forward_post_hook( self._forward_post_hook) @@ -459,7 +430,7 @@ class ImperativeCalcOutScale(object): .numpy()) else: for _, sub_layer in self._layer.named_sublayers(): - if self._is_matched_layer(sub_layer): + if self._is_target_layer(sub_layer): layer_name = sub_layer.full_name() if hasattr(sub_layer, "layer_name"): layer_name = sub_layer.layer_name @@ -510,7 +481,7 @@ class ImperativeCalcOutScale(object): forward_op = None for block in inference_program.blocks: for op in block.ops: - if op.type in _op_real_in_out_name: + if op.type in utils._op_real_in_out_name: if op_count > len(ops_list): warnings.warn( "The number of Layer which has out_threshold attribute should be bigger than the op in inference model" @@ -567,6 +538,66 @@ class ImperativeCalcOutScale(object): if is_dynamic_mode: paddle.disable_static() + def _is_target_layer(self, layer): + return isinstance(layer, self._out_scale_layer_type_list) \ + or 'quantized_' in layer.full_name() + + # When inferenc model is saved, the logic in hook would not be executed + # in program translation, so that some parameters can not created in + # __init__, which would cause the model to fail to save. Therefore, the + # parameters creation in the hook is advanced to be exected outside the hook. + def _add_new_parameters(self, layer, name=None): + dtype = layer._dtype if layer._dtype is not None else "float32" + if dtype not in ["float32", "float64"]: + return + scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale' + scale_name = unique_name.generate(scale_prefix) + scale_attr = ParamAttr( + name=scale_name, initializer=Constant(1), trainable=False) + layer._quant_out_scale = layer.create_parameter( + shape=[1], attr=scale_attr, dtype=dtype) + layer._quant_out_scale.stop_gradient = True + + state_prefix = "{}.state".format(name) if name else 'outscale.state' + state_attr = ParamAttr( + name=unique_name.generate(state_prefix), + initializer=Constant(1), + trainable=False) + layer._quant_out_state = layer.create_parameter( + shape=[1], attr=state_attr, dtype=dtype) + layer._quant_out_state.stop_gradient = True + + accum_prefix = "{}.accum".format(name) if name else 'outscale.accum' + accum_attr = ParamAttr( + name=unique_name.generate(accum_prefix), + initializer=Constant(1), + trainable=False) + layer._quant_out_accum = layer.create_parameter( + shape=[1], attr=accum_attr, dtype=dtype) + layer._quant_out_accum.stop_gradient = True + + # Judge whether the op in program matches the Layer in dynamic model + def _is_op_matched(self, layer_name, op, block): + output_var_names = quantization_pass._get_op_output_var_names(op) + for output_var_name in output_var_names: + output_var_tensor = block.var(output_var_name) + if output_var_tensor.dtype not in [ + core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32 + ]: + return False + + # Because the naming styles of static and dynamic graph are different, + # in order to avoid mistakes, we unify the name here. + op_type = output_var_names[0].split(".")[0] + op_type = op_type.rsplit("_", 1)[0] + if op_type == 'depthwise_conv2d': + op_type = 'conv2d' + if 'prelu' in op_type: + op_type = op_type.replace('prelu', 'p_re_lu') + if 'relu' in op_type: + op_type = op_type.replace('relu', 're_lu') + return op_type in layer_name + def _forward_post_hook(self, layer, input, output): assert isinstance( output, (core.VarBase, framework.Variable) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a732181db7d64e6ed39be2f45a6805d2bd4ab02a --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py @@ -0,0 +1,46 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.nn import Linear, Conv2D +from paddle.fluid.dygraph.nn import Pool2D +from paddle.nn.layer.activation import ReLU, LeakyReLU, Sigmoid, ReLU6 +from paddle.nn.layer.activation import Tanh, Softmax, PReLU, Swish + +_op_real_in_out_name = { + "conv2d": [["Input", "Filter"], ["Output"]], + "depthwise_conv2d": [["Input", "Filter"], ["Output"]], + "pool2d": [["X"], ["Out"]], + "elementwise_add": [["X", "Y"], ["Out"]], + "softmax": [["X"], ["Out"]], + "relu": [["X"], ["Out"]], + "relu6": [["X"], ["Out"]], + "leaky_relu": [["X"], ["Out"]], + "prelu": [["X"], ["Out"]], + "tanh": [["X"], ["Out"]], + "batch_norm": [["X"], ["Y"]], + "sigmoid": [["X"], ["Out"]], + "swish": [["X"], ["Out"]], +} + +_quant_layers_map = { + 'Conv2D': Conv2D, + 'Linear': Linear, + 'Pool2D': Pool2D, + 'ReLU': ReLU, + 'LeakyReLU': LeakyReLU, + 'ReLU6': ReLU6, + 'Softmax': Softmax, + 'Tanh': Tanh, + 'Swish': Swish +}