diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 7fc177e7ad7654e03374b7cb6a190b1cbc2627db..cae241772326759b04bd1c0b5b38663d134f1e14 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -59,7 +59,11 @@ class ImperativeQuantAware(object): weight_quantize_type='abs_max', activation_quantize_type='moving_average_abs_max', moving_rate=0.9, - quantizable_layer_type=['Conv2D', 'Linear']): + quantizable_layer_type=['Conv2D', 'Linear'], + weight_preprocess_layer=None, + act_preprocess_layer=None, + weight_quantize_layer=None, + act_quantize_layer=None): """ The constructor for ImperativeQuantAware. @@ -81,7 +85,28 @@ class ImperativeQuantAware(object): quantizable_op_type(list[str]): List the type of layers that will be quantized. Default is ['Conv2D', 'Linear']. The quantizable_op_type in QuantizationFreezePass and ConvertToInt8Pass must be the same as this. - + weight_preprocess_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to preprocess + weight before quantization. Using this can quickly test if user's + preprocess method works or not. The input is non-quantized + weight and function returns processed weight to be quantized. + If None, the weight will be quantized directly. Default is None. + act_preprocess_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to preprocess + activation before quantization. Using this can quickly test if user's + preprocess method works or not. The input is non-quantized + activation and function returns processed activation to be quantized. + If None, the activation will be quantized directly. Default is None. + weight_quantize_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to quantize weight. + Using this can quickly test if user's quantization method works or not. + In this layer, user should both define quantization method and + dequantization method, that is, the function's input is non-quantized + weight and returns dequantized weight. If None, will use + quantization op defined by 'weight_quantize_type'. Default is None. + act_quantize_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to quantize activation. + Using this can quickly test if user's quantization method works or not. + In this layer, user should both define quantization method and + dequantization method, that is, the function's input is non-quantized + activation and returns dequantized activation. If None, will use + quantization op defined by 'activation_quantize_type'. Default is None. Examples: .. code-block:: python @@ -118,6 +143,19 @@ class ImperativeQuantAware(object): self._activation_bits = activation_bits self._moving_rate = moving_rate + self._weight_pre_layer = weight_preprocess_layer + self._act_pre_layer = act_preprocess_layer + self._weight_quant_layer = weight_quantize_layer + self._act_quant_layer = act_quantize_layer + + t_check = lambda method: method is None or issubclass(method, dygraph.layers.Layer) + assert t_check( + self._weight_pre_layer), "weight_preprocess should be nn.Layer" + assert t_check(self._act_pre_layer), "act_preprocess should be nn.Layer" + assert t_check( + self._weight_quant_layer), "weight_quantize should be nn.Layer" + assert t_check(self._act_quant_layer), "act_quantize should be nn.Layer" + quant_type = { 'abs_max', 'moving_average_abs_max', 'channel_wise_abs_max' } @@ -189,7 +227,9 @@ class ImperativeQuantAware(object): quantized_layer = quant_nn.__dict__[quantized_counterpart[index]]( layer, self._weight_bits, self._activation_bits, self._moving_rate, - self._weight_quantize_type, self._activation_quantize_type) + self._weight_quantize_type, self._activation_quantize_type, + self._weight_pre_layer, self._act_pre_layer, + self._weight_quant_layer, self._act_quant_layer) return quantized_layer diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py b/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py index bbaae56439eb662fefa64b5fd258c37e0c6fc664..79138febd0ce87d7d006700c9494c30f53691742 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/quant_nn.py @@ -332,7 +332,11 @@ class QuantizedConv2D(layers.Layer): activation_bits=8, moving_rate=0.9, weight_quantize_type='abs_max', - activation_quantize_type='abs_max'): + activation_quantize_type='abs_max', + weight_pre_layer=None, + act_pre_layer=None, + weight_quant_layer=None, + act_quant_layer=None): super(QuantizedConv2D, self).__init__() # For Conv2D self._groups = getattr(layer, '_groups') @@ -347,26 +351,44 @@ class QuantizedConv2D(layers.Layer): self.bias = getattr(layer, 'bias') # For FakeQuant self._conv2d_quant_axis = 0 - self._fake_quant_weight = _get_fake_quant_type( - weight_quantize_type, - name=self.weight.name, - moving_rate=moving_rate, - quant_bits=weight_bits, - dtype=self._dtype, - quant_on_weight=True, - channel_num=self.weight.shape[self._conv2d_quant_axis], - quant_axis=self._conv2d_quant_axis) - self._fake_quant_input = _get_fake_quant_type( - activation_quantize_type, - name=layer.full_name(), - moving_rate=moving_rate, - quant_bits=activation_bits, - dtype=self._dtype, - quant_on_weight=False) + + if weight_quant_layer is not None: + self._fake_quant_weight = weight_quant_layer() + else: + self._fake_quant_weight = _get_fake_quant_type( + weight_quantize_type, + name=self.weight.name, + moving_rate=moving_rate, + quant_bits=weight_bits, + dtype=self._dtype, + quant_on_weight=True, + channel_num=self.weight.shape[self._conv2d_quant_axis], + quant_axis=self._conv2d_quant_axis) + if act_quant_layer is not None: + self._fake_quant_input = act_quant_layer() + else: + self._fake_quant_input = _get_fake_quant_type( + activation_quantize_type, + name=layer.full_name(), + moving_rate=moving_rate, + quant_bits=activation_bits, + dtype=self._dtype, + quant_on_weight=False) + + self._act_preprocess = act_pre_layer( + ) if act_pre_layer is not None else None + self._weight_preprocess = weight_pre_layer( + ) if weight_pre_layer is not None else None def forward(self, input): + if self._act_preprocess is not None: + input = self._act_preprocess(input) quant_input = self._fake_quant_input(input) - quant_weight = self._fake_quant_weight(self.weight) + + weight = self.weight + if self._weight_preprocess is not None: + weight = self._weight_preprocess(self.weight) + quant_weight = self._fake_quant_weight(weight) if in_dygraph_mode() and self._l_type == 'conv2d': attrs = ('strides', self._stride, 'paddings', self._padding, @@ -428,7 +450,11 @@ class QuantizedLinear(layers.Layer): activation_bits=8, moving_rate=0.9, weight_quantize_type='abs_max', - activation_quantize_type='abs_max'): + activation_quantize_type='abs_max', + weight_pre_layer=None, + act_pre_layer=None, + weight_quant_layer=None, + act_quant_layer=None): super(QuantizedLinear, self).__init__() # For Linear self._act = getattr(layer, '_act') @@ -437,26 +463,46 @@ class QuantizedLinear(layers.Layer): self.bias = getattr(layer, 'bias') # For FakeQuant self._linear_quant_axis = 1 - self._fake_quant_weight = _get_fake_quant_type( - weight_quantize_type, - name=self.weight.name, - moving_rate=moving_rate, - quant_bits=weight_bits, - dtype=self._dtype, - quant_on_weight=True, - channel_num=self.weight.shape[self._linear_quant_axis], - quant_axis=self._linear_quant_axis) - self._fake_quant_input = _get_fake_quant_type( - activation_quantize_type, - name=layer.full_name(), - moving_rate=moving_rate, - quant_bits=activation_bits, - dtype=self._dtype, - quant_on_weight=False) + + if weight_quant_layer is not None: + self._fake_quant_weight = weight_quant_layer() + else: + self._fake_quant_weight = _get_fake_quant_type( + weight_quantize_type, + name=self.weight.name, + moving_rate=moving_rate, + quant_bits=weight_bits, + dtype=self._dtype, + quant_on_weight=True, + channel_num=self.weight.shape[self._linear_quant_axis], + quant_axis=self._linear_quant_axis) + + if act_quant_layer is not None: + self._fake_quant_input = act_quant_layer() + else: + self._fake_quant_input = _get_fake_quant_type( + activation_quantize_type, + name=layer.full_name(), + moving_rate=moving_rate, + quant_bits=activation_bits, + dtype=self._dtype, + quant_on_weight=False) + + self._act_preprocess = act_pre_layer( + ) if act_pre_layer is not None else None + self._weight_preprocess = weight_pre_layer( + ) if weight_pre_layer is not None else None def forward(self, input): + if self._act_preprocess is not None: + input = self._act_preprocess(input) quant_input = self._fake_quant_input(input) - quant_weight = self._fake_quant_weight(self.weight) + + weight = self.weight + if self._weight_preprocess is not None: + weight = self._weight_preprocess(self.weight) + quant_weight = self._fake_quant_weight(weight) + if in_dygraph_mode(): pre_bias = _varbase_creator(dtype=input.dtype) core.ops.matmul(quant_input, quant_weight, pre_bias, 'transpose_X', diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py new file mode 100644 index 0000000000000000000000000000000000000000..29b69bbe0f8ea27344f4af32ac1437c97433f8cb --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py @@ -0,0 +1,248 @@ +# copyright (c) 2020 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +from __future__ import print_function + +import os +import numpy as np +import random +import unittest +import logging +import paddle +import paddle.nn as nn +from paddle.optimizer import Adam +from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware +from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass +from paddle.nn import Sequential +from paddle.fluid.dygraph import Conv2D +from paddle.nn import Pool2D +from paddle.fluid.dygraph import Linear +from paddle.fluid.log_helper import get_logger + +os.environ["CPU_NUM"] = "1" + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + + +class PACT(nn.Layer): + def __init__(self, init_value=20): + super(PACT, self).__init__() + alpha_attr = paddle.ParamAttr( + name=self.full_name() + ".pact", + initializer=paddle.nn.initializer.Constant(value=init_value)) + self.alpha = self.create_parameter( + shape=[1], attr=alpha_attr, dtype='float32') + + def forward(self, x): + out_left = paddle.nn.functional.relu(x - self.alpha) + out_right = paddle.nn.functional.relu(-self.alpha - x) + x = x - out_left + out_right + return x + + +class CustomQAT(nn.Layer): + def __init__(self): + super(CustomQAT, self).__init__() + attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=1.0)) + self.u_param = self.create_parameter( + shape=[1], attr=attr, dtype='float32') + self.l_param = self.create_parameter( + shape=[1], attr=attr, dtype='float32') + self.alpha_param = self.create_parameter( + shape=[1], attr=attr, dtype='float32') + self.upper = self.create_parameter( + shape=[1], attr=attr, dtype='float32') + self.upper.stop_gradient = True + self.lower = self.create_parameter( + shape=[1], attr=attr, dtype='float32') + self.lower.stop_gradient = True + + def forward(self, x): + def clip(x, upper, lower): + x = x + paddle.nn.functional.relu(lower - x) + x = x - paddle.nn.functional.relu(x - upper) + return x + + def phi_function(x, mi, alpha, delta): + s = 1 / (1 - alpha) + k = paddle.log(2 / alpha - 1) * (1 / delta) + x = (paddle.tanh((x - mi) * k)) * s + return x + + def dequantize(x, lower_bound, delta, interval): + x = ((x + 1) / 2 + interval) * delta + lower_bound + return x + + bit = 8 + bit_range = 2**bit - 1 + + paddle.assign(self.upper * 0.9 + self.u_param * 0.1, self.upper) + paddle.assign(self.lower * 0.9 + self.l_param * 0.1, self.lower) + x = clip(x, self.upper, self.lower) + delta = (self.upper - self.lower) / bit_range + interval = (x - self.lower) / delta + mi = (interval + 0.5) * delta + self.l_param + x = phi_function(x, mi, self.alpha_param, delta) + x = dequantize(x, self.l_param, delta, interval) + return x + + +class ImperativeLenet(paddle.nn.Layer): + def __init__(self, num_classes=10, classifier_activation='softmax'): + super(ImperativeLenet, self).__init__() + self.features = Sequential( + Conv2D( + num_channels=1, + num_filters=6, + filter_size=3, + stride=1, + padding=1), + Pool2D( + pool_size=2, pool_type='max', pool_stride=2), + Conv2D( + num_channels=6, + num_filters=16, + filter_size=5, + stride=1, + padding=0), + Pool2D( + pool_size=2, pool_type='max', pool_stride=2)) + + self.fc = Sequential( + Linear( + input_dim=400, output_dim=120), + Linear( + input_dim=120, output_dim=84), + Linear( + input_dim=84, output_dim=num_classes, + act=classifier_activation)) + + def forward(self, inputs): + x = self.features(inputs) + + x = paddle.flatten(x, 1) + x = self.fc(x) + return x + + +class TestUserDefinedActPreprocess(unittest.TestCase): + def setUp(self): + _logger.info("test act_preprocess") + self.imperative_qat = ImperativeQuantAware(act_preprocess_layer=PACT) + + def test_quant_aware_training(self): + imperative_qat = self.imperative_qat + seed = 1 + np.random.seed(seed) + paddle.static.default_main_program().random_seed = seed + paddle.static.default_startup_program().random_seed = seed + lenet = ImperativeLenet() + fixed_state = {} + param_init_map = {} + for name, param in lenet.named_parameters(): + p_shape = param.numpy().shape + p_value = param.numpy() + if name.endswith("bias"): + value = np.zeros_like(p_value).astype('float32') + else: + value = np.random.normal( + loc=0.0, scale=0.01, + size=np.product(p_shape)).reshape(p_shape).astype('float32') + fixed_state[name] = value + param_init_map[param.name] = value + lenet.set_dict(fixed_state) + + imperative_qat.quantize(lenet) + adam = Adam(learning_rate=0.001, parameters=lenet.parameters()) + dynamic_loss_rec = [] + + def train(model): + adam = Adam(learning_rate=0.001, parameters=model.parameters()) + epoch_num = 1 + for epoch in range(epoch_num): + model.train() + for batch_id, data in enumerate(train_reader()): + x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = paddle.to_tensor(x_data) + label = paddle.to_tensor(y_data) + out = model(img) + acc = paddle.metric.accuracy(out, label, k=1) + loss = nn.functional.loss.cross_entropy(out, label) + avg_loss = paddle.mean(loss) + avg_loss.backward() + adam.minimize(avg_loss) + model.clear_gradients() + if batch_id % 50 == 0: + _logger.info( + "Train | At epoch {} step {}: loss = {:}, acc= {:}". + format(epoch, batch_id, + avg_loss.numpy(), acc.numpy())) + break + + def test(model): + model.eval() + avg_acc = [[], []] + for batch_id, data in enumerate(test_reader()): + x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = paddle.to_tensor(x_data) + label = paddle.to_tensor(y_data) + + out = model(img) + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) + avg_acc[0].append(acc_top1.numpy()) + avg_acc[1].append(acc_top5.numpy()) + if batch_id % 100 == 0: + _logger.info( + "Test | step {}: acc1 = {:}, acc5 = {:}".format( + batch_id, acc_top1.numpy(), acc_top5.numpy())) + + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=512, drop_last=True) + test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=512) + train(lenet) + test(lenet) + + +class TestUserDefinedWeightPreprocess(TestUserDefinedActPreprocess): + def setUp(self): + _logger.info("test weight_preprocess") + self.imperative_qat = ImperativeQuantAware(weight_preprocess_layer=PACT) + + +class TestUserDefinedActQuantize(TestUserDefinedActPreprocess): + def setUp(self): + _logger.info("test act_quantize") + self.imperative_qat = ImperativeQuantAware(act_quantize_layer=CustomQAT) + + +class TestUserDefinedWeightQuantize(TestUserDefinedActPreprocess): + def setUp(self): + _logger.info("test weight_quantize") + self.imperative_qat = ImperativeQuantAware( + weight_quantize_layer=CustomQAT) + + +if __name__ == '__main__': + unittest.main()