From 548cdbc544149c221d3cf522b16180afd78750a9 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Sat, 11 Jul 2020 23:26:24 +0800 Subject: [PATCH] Quantization-aware training for dygraph (#24634) * Add the imperative quantization aware training. * This is the python part of Imperative QAT. test=develop --- paddle/fluid/operators/fake_quantize_op.cc | 17 +- .../contrib/slim/quantization/__init__.py | 3 + .../slim/quantization/imperative/__init__.py | 25 ++ .../slim/quantization/imperative/qat.py | 229 +++++++++++ .../slim/quantization/imperative/quant_nn.py | 375 ++++++++++++++++++ .../contrib/slim/tests/test_imperative_qat.py | 259 ++++++++++++ python/setup.py.in | 1 + 7 files changed, 899 insertions(+), 10 deletions(-) create mode 100644 python/paddle/fluid/contrib/slim/quantization/imperative/__init__.py create mode 100644 python/paddle/fluid/contrib/slim/quantization/imperative/qat.py create mode 100644 python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py create mode 100644 python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 401cc448ac9..16a32a3f6cf 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -219,14 +219,14 @@ class FakeQuantOrWithDequantAbsMaxOpMaker bit_length)); }); AddComment(R"DOC( -This is a Base Op which support FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker. +This is a Base Op which supports FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker. FakeQuantAbsMaxOp operator is used in the dynamic quantization. $$scale = max(abs(X))$$ $$range = 2^{bit_length - 1} - 1$$ $$Out = round(X/scale * range)$$ -FakeQuantDequantAbsMaxOp operator do the abs_max quant and then dequant. +FakeQuantDequantAbsMaxOp operator does the abs_max quantization and then dequantization. $$scale = max(abs(X))$$ $$range = 2^{bit\_length - 1} - 1$$ @@ -423,14 +423,14 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker "for training. Some layers may run faster when this is true.") .SetDefault(false); AddComment(R"DOC( -This is a Base Op which support FakeQuantMovingAverageAbsMaxOp and FakeQuantDequantMovingAverageAbsMaxOp. +This is a Base Op which supports FakeQuantMovingAverageAbsMaxOp and FakeQuantDequantMovingAverageAbsMaxOp. FakeQuantMovingAverageAbsMaxOp operator is used in the static quantization. $$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$ $$range = 2^{bit\_length - 1} - 1$$ $$Out = round(X/scale * range)$$ -FakeQuantDequantMovingAverageAbsMaxOp operator do the moving_average_abs_max quant and then dequant. +FakeQuantDequantMovingAverageAbsMaxOp operator does the moving_average_abs_max quant and then dequant. $$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$ $$range = 2^{bit\_length - 1} - 1$$ @@ -505,15 +505,12 @@ class FakeQuantDequantGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { auto out_grad_name = framework::GradVarName("Out"); + auto x_grad_name = framework::GradVarName("X"); OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name, "FakeQuantDequantGradOp"); + OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name, + "FakeQuantDequantGradOp"); - auto x_grad_name = framework::GradVarName("X"); - PADDLE_ENFORCE_EQ( - ctx->HasOutput(x_grad_name), true, - platform::errors::PreconditionNotMet( - "FakeQuantDequantGradOp doesn't have the output named %s.", - x_grad_name)); ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name)); } diff --git a/python/paddle/fluid/contrib/slim/quantization/__init__.py b/python/paddle/fluid/contrib/slim/quantization/__init__.py index 328983c70ec..ee7e6536f2e 100644 --- a/python/paddle/fluid/contrib/slim/quantization/__init__.py +++ b/python/paddle/fluid/contrib/slim/quantization/__init__.py @@ -26,9 +26,12 @@ from . import quant2_int8_mkldnn_pass from .quant2_int8_mkldnn_pass import * from . import post_training_quantization from .post_training_quantization import * +from . import imperative +from .imperative import * __all__ = quantization_pass.__all__ + quantization_strategy.__all__ __all__ += mkldnn_post_training_strategy.__all__ __all__ += quant_int8_mkldnn_pass.__all__ __all__ += quant2_int8_mkldnn_pass.__all__ __all__ += post_training_quantization.__all__ +__all__ += imperative.__all__ diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/__init__.py b/python/paddle/fluid/contrib/slim/quantization/imperative/__init__.py new file mode 100644 index 00000000000..7ea62b5f324 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from . import quant_nn +from .quant_nn import * + +from . import qat +from .qat import * + +__all__ = [] +__all__ += quant_nn.__all__ +__all__ += qat.__all__ diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py new file mode 100644 index 00000000000..c77648ac7b5 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -0,0 +1,229 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import numpy as np +import sys +from paddle.fluid import dygraph +from paddle.fluid.dygraph.nn import Conv2D +from paddle.fluid.dygraph.nn import Linear +from paddle.fluid.log_helper import get_logger +from . import quant_nn + +__all__ = ['ImperativeQuantAware'] + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + + +class ImperativeQuantAware(object): + """ + Add the fake quant logic for given quantizable layers, namely add the quant_dequant + computational logic both for activation inputs and weight inputs. + """ + + def __init__(self, + weight_bits=8, + activation_bits=8, + weight_quantize_type='abs_max', + activation_quantize_type='moving_average_abs_max', + moving_rate=0.9, + quantizable_layer_type=['Conv2D', 'Linear']): + """ + The constructor for ImperativeQuantAware. + + Args: + weight_bits(int): quantization bit number for weights, + whereas the bias is not quantized. + activation_bits(int): quantization bit number for activations. + weight_quantize_type(str): quantization type for weights, + which supports 'abs_max' now. The 'moving_average_abs_max' + usually is not used for weights, since weights are fixed once the + model is well trained. + activation_quantize_type(str): quantization type for activations, + which supports 'abs_max' and 'moving_average_abs_max' now. + If using 'abs_max' mode, the quantization scale will be calculated + dynamically each step in both training and testing period. If using + 'moving_average_abs_max', the static quantization scale will be calculated + during training and used in inference. + moving_rate(float): the parameter for 'moving_average_abs_max' quantization. + quantizable_op_type(list[str]): List the type of layers that will be quantized. + Default is ['Conv2D', 'Linear']. The quantizable_op_type in + QuantizationFreezePass and ConvertToInt8Pass must be the same as this. + + + Examples: + .. code-block:: python + + from paddle.fluid.contrib.slim.quantization \ + import ImperativeQuantAware + from paddle.incubate.hapi.vision.models \ + import resnet + + model = resnet.resnet50(pretrained=True) + + imperative_qat = ImperativeQuantAware( + weight_quantize_type='abs_max', + activation_quantize_type='moving_average_abs_max') + + # Add the fake quant logical. + # The original model will be rewrite. + imperative_qat.quantize(model) + + # Fine-tune the quantized model + # ... + + # Save quant model for the inference. + imperative_qat.save_quantized_model( + dirname="./resnet50_qat", + model=model, + input_shape=[(3, 224, 224)], + input_dtype=['float32'], + feed=[0], + fetch=[0]) + """ + super(ImperativeQuantAware, self).__init__() + self._weight_bits = weight_bits + self._activation_bits = activation_bits + self._moving_rate = moving_rate + + quant_type = {'abs_max', 'moving_average_abs_max'} + if activation_quantize_type not in quant_type: + raise ValueError( + "Unknown activation_quantize_type : '%s'. It can only be " + "'abs_max' or 'moving_average_abs_max' now." % + (str(activation_quantize_type))) + if weight_quantize_type not in quant_type: + raise ValueError( + "Unknown weight_quantize_type: '%s'. It can only be " + "'abs_max' or 'moving_average_abs_max' now." % + (str(weight_quantize_type))) + self._activation_quantize_type = activation_quantize_type + self._weight_quantize_type = weight_quantize_type + + self._quant_layers_map = {'Conv2D': Conv2D, 'Linear': Linear} + self._quantizable_layer_type = tuple( + self._quant_layers_map[layer] + if layer in self._quant_layers_map else layer + for layer in quantizable_layer_type) + for layer in self._quantizable_layer_type: + assert not isinstance( + layer, str), "{} is unspported to be quantized.".format(layer) + + def quantize(self, model): + """ + According to weights' and activations' quantization types, the model will be added some fake + quant ops, such as fake_quantize_dequantize_moving_average_abs_max, fake_quantize_dequantize_abs_max + and so on. + + Args: + model(fluid.dygraph.Layer): the model to be quantized. + Returns: + None + """ + for name, layer in model.named_sublayers(): + if not isinstance(layer, self._quantizable_layer_type): + continue + + scopes = name.split('.') + target = scopes[-1] + obj = model + parent = model + for i in range(len(scopes) - 1): + obj = getattr(parent, scopes[i]) + parent = obj + + quant_layer = self._get_quantized_counterpart(layer) + setattr(obj, target, quant_layer) + + def save_quantized_model(self, + dirname, + model, + input_shape, + input_dtype, + feed, + fetch, + append_batch_size=True): + """ + Save the quantized model for the inference. + + Args: + dirname (str): the directory to save the quantized model. + model(fluid.dygraph.Layer): the quantized model to be saved. + input_shape(list[tuple(int)]): The shape value for each input, + e.g. [(3, 224, 224)]. + input_dtype(list[str]): The dtype value for each input, + e.g. ['float32']. + feed(list[int]): the indices of the input variables of the + imperative functions which will be saved as input variables in + inference model. + fetch(list[int]): the indices of the returned variable of the + imperative functions which will be saved as output variables in + inference model. + append_batch_size(bool, optional): + If true, it prepends an extra axis to the input_shape, meanwhile, + the input_shape shouldn't contain the batch size dimension. + Otherwise, it just uses the input_shape. Default True. + Returns: + None + """ + assert isinstance( + input_shape, list), "The parameter `input_shape` shoubld be a list." + assert isinstance( + input_dtype, list), "The parameter `input_dtype` shoubld be a list." + assert isinstance(feed, list), "The parameter `feed` shoubld be a list." + assert isinstance(fetch, + list), "The parameter `fetch` shoubld be a list." + assert len(input_shape) == len( + input_dtype + ), "The length of input_shape should be equal to input_dtype's." + assert len(input_dtype) == len( + feed), "The length of input_shape should be equal to feed's." + + def _convert(model, *args): + return model(*args) + + prog_trans = dygraph.ProgramTranslator() + with dygraph.guard(): + model.eval() + input_vars = [] + for shape, dtype in zip(input_shape, input_dtype): + raw_data = np.random.random(shape) + input_data = raw_data[np.newaxis, :].astype( + dtype) if append_batch_size else raw_data.astype(dtype) + input_var = dygraph.to_variable(input_data) + input_vars.append(input_var) + prog_trans.get_output(_convert, model, *input_vars) + prog_trans.save_inference_model(dirname, feed, fetch) + + def _get_quantized_counterpart(self, layer): + quant_layers = tuple(self._quant_layers_map.values()) + quantized_counterpart = tuple('Quantized' + k + for k in self._quant_layers_map.keys()) + + predicate = lambda value: isinstance(layer, value) + index_generator = (i for i, v in enumerate(quant_layers) + if predicate(v)) + + try: + index = next(index_generator) + except StopIteration: + _logger.fatal("The layer {} is unsupported to be quantized.".format( + layer.full_name())) + sys.exit(-1) + + quantized_layer = quant_nn.__dict__[quantized_counterpart[index]]( + layer, self._weight_bits, self._activation_bits, self._moving_rate, + self._weight_quantize_type, self._activation_quantize_type) + return quantized_layer diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py b/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py new file mode 100644 index 00000000000..59dd9867abb --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py @@ -0,0 +1,375 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.dygraph import layers +from paddle.fluid import core +from paddle.fluid import dygraph_utils +from paddle.fluid import unique_name +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.framework import _varbase_creator +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.initializer import Constant +from paddle.fluid.data_feeder import check_variable_and_dtype + +__all__ = [ + 'FakeQuantMovingAverage', 'FakeQuantAbsMax', 'QuantizedConv2D', + 'QuantizedLinear' +] + + +class FakeQuantMovingAverage(layers.Layer): + """ + FakeQuantMovingAverage layer does the moving_average_abs_max quant and then dequant. + Its computational formula is described as below: + + :math:`scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)` + :math:`range = 2^{bit\_length - 1} - 1` + :math:`Out = round(X / scale * range) * scale / range` + """ + + def __init__(self, + name=None, + moving_rate=0.9, + quant_bits=8, + dtype='float32'): + super(FakeQuantMovingAverage, self).__init__() + self._moving_rate = moving_rate + self._quant_bits = quant_bits + + scale_prefix = "{}.scale".format( + name) if name else 'quant_dequant.scale' + scale_attr = ParamAttr( + name=unique_name.generate(scale_prefix), + initializer=Constant(0.001), + trainable=False) + self._scale = self.create_parameter( + shape=[1], attr=scale_attr, dtype=dtype) + self._scale.stop_gradient = True + + state_prefix = "{}.state".format( + name) if name else 'quant_dequant.state' + state_attr = ParamAttr( + name=unique_name.generate(state_prefix), + initializer=Constant(1), + trainable=False) + self._state = self.create_parameter( + shape=[1], attr=state_attr, dtype=dtype) + self._state.stop_gradient = True + + accum_prefix = "{}.accum".format( + name) if name else 'quant_dequant.accum' + accum_attr = ParamAttr( + name=unique_name.generate(accum_prefix), + initializer=Constant(1), + trainable=False) + self._accum = self.create_parameter( + shape=[1], attr=accum_attr, dtype=dtype) + self._accum.stop_gradient = True + + def forward(self, input): + if in_dygraph_mode(): + attrs = ('moving_rate', self._moving_rate, 'bit_length', + self._quant_bits, 'is_test', not self.training) + quant_out = _varbase_creator( + type=input.type, + name="{}.quantized.dequantized".format(input.name), + shape=input.shape, + dtype=input.dtype, + persistable=False) + state = self._state if self.training else None + accum = self._accum if self.training else None + + out, _, _, _ = core.ops.fake_quantize_dequantize_moving_average_abs_max( + input, self._scale, accum, state, quant_out, self._scale, state, + accum, *attrs) + return out + + check_variable_and_dtype(input, 'input', ['float32'], + "FakeQuantMovingAverage") + attrs = { + 'moving_rate': self._moving_rate, + 'bit_length': self._quant_bits, + 'is_test': not self.training + } + inputs = {"X": [input], "InScale": [self._scale]} + quant_out = self._helper.create_variable( + name="{}.quantized.dequantized".format(input.name), + dtype=input.dtype, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + outputs = {"Out": [quant_out], "OutScale": [self._scale]} + + if self.training: + inputs['InState'] = [self._state] + inputs['InAccum'] = [self._accum] + outputs['OutState'] = [self._state] + outputs['OutAccum'] = [self._accum] + + self._helper.append_op( + type="fake_quantize_dequantize_moving_average_abs_max", + inputs=inputs, + outputs=outputs, + attrs=attrs) + + return quant_out + + +class FakeQuantAbsMax(layers.Layer): + """ + FakeQuantAbsMax layer does the abs_max quant and then dequant. + Its computational formula is described as below: + + :math:`scale = max(abs(X))` + :math:`range = 2^{bit\_length - 1} - 1` + :math:`Out = round(X / scale * range) * scale / range` + """ + + def __init__(self, + name=None, + quant_bits=8, + dtype='float32', + quant_on_weight=False): + super(FakeQuantAbsMax, self).__init__() + self._quant_bits = quant_bits + self._dtype = dtype + self._name = name + scale_prefix = "{}.scale".format( + name) if name else 'quant_dequant.scale' + self._scale_name = unique_name.generate(scale_prefix) + if quant_on_weight: + scale_attr = ParamAttr( + name=self._scale_name, + initializer=Constant(0.0), + trainable=False) + self._scale = self.create_parameter( + shape=[1], attr=scale_attr, dtype=self._dtype) + self._scale.stop_gradient = True + else: + self._scale = None + + def forward(self, input): + if in_dygraph_mode(): + attrs = ('bit_length', self._quant_bits) + quant_out = _varbase_creator( + type=input.type, + name="{}.quantized.dequantized".format(input.name), + shape=input.shape, + dtype=input.dtype, + persistable=False) + out_scale = self._scale + if not out_scale: + out_scale = _varbase_creator( + type=core.VarDesc.VarType.LOD_TENSOR, + name=self._scale_name, + shape=[1], + dtype=self._dtype, + persistable=False) + out_scale.stop_gradient = True + out, _, = core.ops.fake_quantize_dequantize_abs_max( + input, quant_out, out_scale, *attrs) + return out + + check_variable_and_dtype(input, 'input', ['float32'], "FakeQuantAbsMax") + attrs = {'bit_length': self._quant_bits} + inputs = {"X": [input]} + quant_out = self._helper.create_variable( + name="{}.quantized.dequantized".format(input.name), + dtype=input.dtype, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + out_scale = self._scale + if not out_scale: + out_scale = self._helper.create_variable( + name=self._scale_name, + dtype=self._dtype, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=True) + outputs = {"Out": [quant_out], "OutScale": [out_scale]} + + self._helper.append_op( + type="fake_quantize_dequantize_abs_max", + inputs=inputs, + outputs=outputs, + attrs=attrs) + + return quant_out + + +def _get_fake_quant_type(quant_type, name, moving_rate, quant_bits, dtype, + quant_on_weight): + fake_quant_map = { + 'abs_max': + lambda: FakeQuantAbsMax(name, quant_bits, dtype, quant_on_weight), + 'moving_average_abs_max': + lambda: FakeQuantMovingAverage(name, moving_rate, quant_bits, dtype) + } + return fake_quant_map[quant_type]() + + +class QuantizedConv2D(layers.Layer): + """ + The computational logic of QuantizedConv2D is the same with Conv2D. + The only difference is that its inputs are all fake quantized. + """ + + def __init__(self, + layer, + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + weight_quantize_type='abs_max', + activation_quantize_type='abs_max'): + super(QuantizedConv2D, self).__init__() + # For Conv2D + self._groups = getattr(layer, '_groups') + self._stride = getattr(layer, '_stride') + self._padding = getattr(layer, '_padding') + self._dilation = getattr(layer, '_dilation') + self._act = getattr(layer, '_act') + self._use_cudnn = getattr(layer, '_use_cudnn') + self._dtype = getattr(layer, '_dtype') + self._l_type = getattr(layer, '_l_type') + self.weight = getattr(layer, 'weight') + self.bias = getattr(layer, 'bias') + # For FakeQuant + self._fake_quant_weight = _get_fake_quant_type( + weight_quantize_type, self.weight.name, moving_rate, weight_bits, + self._dtype, True) + self._fake_quant_input = _get_fake_quant_type( + activation_quantize_type, + layer.full_name(), moving_rate, activation_bits, self._dtype, False) + + def forward(self, input): + quant_input = self._fake_quant_input(input) + quant_weight = self._fake_quant_weight(self.weight) + + if in_dygraph_mode() and self._l_type == 'conv2d': + attrs = ('strides', self._stride, 'paddings', self._padding, + 'dilations', self._dilation, 'groups', self._groups + if self._groups else 1, 'use_cudnn', self._use_cudnn) + pre_bias = core.ops.conv2d(quant_input, quant_weight, *attrs) + + pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, self.bias, + 1) + return dygraph_utils._append_activation_in_dygraph(pre_act, + self._act) + check_variable_and_dtype(quant_input, 'input', + ['float16', 'float32', 'float64'], + 'QuantizedConv2D') + attrs = { + 'strides': self._stride, + 'paddings': self._padding, + 'dilations': self._dilation, + 'groups': self._groups if self._groups else 1, + 'use_cudnn': self._use_cudnn, + 'use_mkldnn': False, + } + pre_bias = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + + self._helper.append_op( + type=self._l_type, + inputs={ + 'Input': quant_input, + 'Filter': quant_weight, + }, + outputs={"Output": pre_bias}, + attrs=attrs) + + if self.bias is not None: + pre_act = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + self._helper.append_op( + type='elementwise_add', + inputs={'X': [pre_bias], + 'Y': [self.bias]}, + outputs={'Out': [pre_act]}, + attrs={'axis': 1}) + else: + pre_act = pre_bias + + return self._helper.append_activation(pre_act, act=self._act) + + +class QuantizedLinear(layers.Layer): + """ + The computational logic of QuantizedLinear is the same with Linear. + The only difference is that its inputs are all fake quantized. + """ + + def __init__(self, + layer, + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + weight_quantize_type='abs_max', + activation_quantize_type='abs_max'): + super(QuantizedLinear, self).__init__() + # For Linear + self._act = getattr(layer, '_act') + self._dtype = getattr(layer, '_dtype') + self.weight = getattr(layer, 'weight') + self.bias = getattr(layer, 'bias') + # For FakeQuant + self._fake_quant_weight = _get_fake_quant_type( + weight_quantize_type, self.weight.name, moving_rate, weight_bits, + self._dtype, True) + self._fake_quant_input = _get_fake_quant_type( + activation_quantize_type, + layer.full_name(), moving_rate, activation_bits, self._dtype, False) + + def forward(self, input): + quant_input = self._fake_quant_input(input) + quant_weight = self._fake_quant_weight(self.weight) + if in_dygraph_mode(): + pre_bias = _varbase_creator(dtype=input.dtype) + core.ops.matmul(quant_input, quant_weight, pre_bias, 'transpose_X', + False, 'transpose_Y', False, "alpha", 1) + pre_act = dygraph_utils._append_bias_in_dygraph( + pre_bias, self.bias, axis=len(input.shape) - 1) + + return dygraph_utils._append_activation_in_dygraph(pre_act, + self._act) + + check_variable_and_dtype(input, 'input', + ['float16', 'float32', 'float64'], + "QuantizedLinear") + attrs = { + "transpose_X": False, + "transpose_Y": False, + "alpha": 1, + } + inputs = {"X": [quant_input], "Y": [quant_weight]} + mul_out = self._helper.create_variable_for_type_inference(self._dtype) + + self._helper.append_op( + type="matmul", + inputs=inputs, + outputs={"Out": [mul_out]}, + attrs=attrs) + if self.bias is not None: + pre_activation = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + self._helper.append_op( + type='elementwise_add', + inputs={'X': [mul_out], + 'Y': [self.bias]}, + outputs={'Out': [pre_activation]}, + attrs={'axis': len(input.shape) - 1}) + else: + pre_activation = mul_out + return self._helper.append_activation(pre_activation, act=self._act) diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py new file mode 100644 index 00000000000..997e9ff3698 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py @@ -0,0 +1,259 @@ +# copyright (c) 2018 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +from __future__ import print_function + +import os +import numpy as np +import random +import unittest +import logging +import paddle +import paddle.fluid as fluid +from paddle.fluid import core +from paddle.fluid.optimizer import AdamOptimizer +from paddle.fluid.framework import IrGraph +from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware +from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass +from paddle.fluid.dygraph.container import Sequential +from paddle.fluid.dygraph.nn import Conv2D +from paddle.fluid.dygraph.nn import Pool2D +from paddle.fluid.dygraph.nn import Linear +from paddle.fluid.log_helper import get_logger + +os.environ["CPU_NUM"] = "1" +if core.is_compiled_with_cuda(): + fluid.set_flags({"FLAGS_cudnn_deterministic": True}) + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + + +def StaticLenet(data, num_classes=10, classifier_activation='softmax'): + conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1") + conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2") + fc_w1_attr = fluid.ParamAttr(name="fc_w_1") + fc_w2_attr = fluid.ParamAttr(name="fc_w_2") + fc_w3_attr = fluid.ParamAttr(name="fc_w_3") + conv2d_b1_attr = fluid.ParamAttr(name="conv2d_b_1") + conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2") + fc_b1_attr = fluid.ParamAttr(name="fc_b_1") + fc_b2_attr = fluid.ParamAttr(name="fc_b_2") + fc_b3_attr = fluid.ParamAttr(name="fc_b_3") + conv1 = fluid.layers.conv2d( + data, + num_filters=6, + filter_size=3, + stride=1, + padding=1, + param_attr=conv2d_w1_attr, + bias_attr=conv2d_b1_attr) + pool1 = fluid.layers.pool2d( + conv1, pool_size=2, pool_type='max', pool_stride=2) + conv2 = fluid.layers.conv2d( + pool1, + num_filters=16, + filter_size=5, + stride=1, + padding=0, + param_attr=conv2d_w2_attr, + bias_attr=conv2d_b2_attr) + pool2 = fluid.layers.pool2d( + conv2, pool_size=2, pool_type='max', pool_stride=2) + + fc1 = fluid.layers.fc(input=pool2, + size=120, + param_attr=fc_w1_attr, + bias_attr=fc_b1_attr) + fc2 = fluid.layers.fc(input=fc1, + size=84, + param_attr=fc_w2_attr, + bias_attr=fc_b2_attr) + fc3 = fluid.layers.fc(input=fc2, + size=num_classes, + act=classifier_activation, + param_attr=fc_w3_attr, + bias_attr=fc_b3_attr) + + return fc3 + + +class ImperativeLenet(fluid.dygraph.Layer): + def __init__(self, num_classes=10, classifier_activation='softmax'): + super(ImperativeLenet, self).__init__() + conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1") + conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2") + fc_w1_attr = fluid.ParamAttr(name="fc_w_1") + fc_w2_attr = fluid.ParamAttr(name="fc_w_2") + fc_w3_attr = fluid.ParamAttr(name="fc_w_3") + conv2d_b1_attr = fluid.ParamAttr(name="conv2d_b_1") + conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2") + fc_b1_attr = fluid.ParamAttr(name="fc_b_1") + fc_b2_attr = fluid.ParamAttr(name="fc_b_2") + fc_b3_attr = fluid.ParamAttr(name="fc_b_3") + self.features = Sequential( + Conv2D( + num_channels=1, + num_filters=6, + filter_size=3, + stride=1, + padding=1, + param_attr=conv2d_w1_attr, + bias_attr=conv2d_b1_attr), + Pool2D( + pool_size=2, pool_type='max', pool_stride=2), + Conv2D( + num_channels=6, + num_filters=16, + filter_size=5, + stride=1, + padding=0, + param_attr=conv2d_w2_attr, + bias_attr=conv2d_b2_attr), + Pool2D( + pool_size=2, pool_type='max', pool_stride=2)) + + self.fc = Sequential( + Linear( + input_dim=400, + output_dim=120, + param_attr=fc_w1_attr, + bias_attr=fc_b1_attr), + Linear( + input_dim=120, + output_dim=84, + param_attr=fc_w2_attr, + bias_attr=fc_b2_attr), + Linear( + input_dim=84, + output_dim=num_classes, + act=classifier_activation, + param_attr=fc_w3_attr, + bias_attr=fc_b3_attr)) + + def forward(self, inputs): + x = self.features(inputs) + + x = fluid.layers.flatten(x, 1) + x = self.fc(x) + return x + + +class TestImperativeQat(unittest.TestCase): + """ + QAT = quantization-aware training + """ + + def test_qat_save(self): + imperative_qat = ImperativeQuantAware( + weight_quantize_type='abs_max', + activation_quantize_type='moving_average_abs_max') + + with fluid.dygraph.guard(): + lenet = ImperativeLenet() + imperative_qat.quantize(lenet) + adam = AdamOptimizer( + learning_rate=0.001, parameter_list=lenet.parameters()) + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=32, drop_last=True) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=32) + + epoch_num = 1 + for epoch in range(epoch_num): + lenet.train() + for batch_id, data in enumerate(train_reader()): + x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = fluid.dygraph.to_variable(x_data) + label = fluid.dygraph.to_variable(y_data) + + out = lenet(img) + acc = fluid.layers.accuracy(out, label) + loss = fluid.layers.cross_entropy(out, label) + avg_loss = fluid.layers.mean(loss) + avg_loss.backward() + adam.minimize(avg_loss) + lenet.clear_gradients() + if batch_id % 100 == 0: + _logger.info( + "Train | At epoch {} step {}: loss = {:}, acc= {:}". + format(epoch, batch_id, + avg_loss.numpy(), acc.numpy())) + + lenet.eval() + for batch_id, data in enumerate(test_reader()): + x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = fluid.dygraph.to_variable(x_data) + label = fluid.dygraph.to_variable(y_data) + + out = lenet(img) + acc_top1 = fluid.layers.accuracy( + input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy( + input=out, label=label, k=5) + + if batch_id % 100 == 0: + _logger.info( + "Test | At epoch {} step {}: acc1 = {:}, acc5 = {:}". + format(epoch, batch_id, + acc_top1.numpy(), acc_top5.numpy())) + + # save weights + model_dict = lenet.state_dict() + fluid.save_dygraph(model_dict, "save_temp") + + # test the correctness of `save_quantized_model` + data = next(test_reader()) + test_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + test_img = fluid.dygraph.to_variable(test_data) + lenet.eval() + before_save = lenet(test_img) + + # save inference quantized model + path = "./mnist_infer_model" + imperative_qat.save_quantized_model( + dirname=path, + model=lenet, + input_shape=[(1, 28, 28)], + input_dtype=['float32'], + feed=[0], + fetch=[0]) + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + exe = fluid.Executor(place) + [inference_program, feed_target_names, fetch_targets] = ( + fluid.io.load_inference_model( + dirname=path, executor=exe)) + after_save, = exe.run(inference_program, + feed={feed_target_names[0]: test_data}, + fetch_list=fetch_targets) + + self.assertTrue( + np.allclose(after_save, before_save.numpy()), + msg='Failed to save the inference quantized model.') + + +if __name__ == '__main__': + unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index 613a0bf2b74..67db20ce14b 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -168,6 +168,7 @@ packages=['paddle', 'paddle.fluid.contrib.slim.graph', 'paddle.fluid.contrib.slim.prune', 'paddle.fluid.contrib.slim.quantization', + 'paddle.fluid.contrib.slim.quantization.imperative', 'paddle.fluid.contrib.slim.distillation', 'paddle.fluid.contrib.slim.nas', 'paddle.fluid.contrib.slim.searcher', -- GitLab