diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 14f1e7f912cc0ae3f0ddde49630dc44c0964ab21..f544154a22073d78751b6b5c594c76ae7a45532b 100755 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -253,6 +253,8 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_quantize_transpiler_v2) list(REMOVE_ITEM TEST_OPS test_imperative_qat_amp) list(REMOVE_ITEM TEST_OPS test_imperative_qat_lsq) + list(REMOVE_ITEM TEST_OPS test_imperative_qat_matmul) + endif() if(LINUX AND WITH_MKLDNN) @@ -507,6 +509,7 @@ if(WIN32) test_imperative_qat_channelwise test_imperative_qat test_imperative_qat_lsq + test_imperative_qat_matmul test_imperative_out_scale test_graph) list(REMOVE_ITEM TEST_OPS ${SINGLE_CARD_TEST_OPS}) @@ -547,6 +550,7 @@ set_tests_properties(test_imperative_qat_fuse PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_out_scale PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_qat_user_defined PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_qat_lsq PROPERTIES TIMEOUT 300) +set_tests_properties(test_imperative_qat_matmul PROPERTIES TIMEOUT 300) if(LINUX AND WITH_MKLDNN) set_tests_properties(test_quant2_int8_mobilenetv1_mkldnn PROPERTIES TIMEOUT diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_matmul.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..482f4a49efbe06f16688e4fd28af9bf724eb0ff8 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_matmul.py @@ -0,0 +1,234 @@ +# copyright (c) 2022 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import os +import numpy as np +import random +import time +import tempfile +import unittest +import logging + +import paddle +import paddle.fluid as fluid +from paddle.fluid import core +from paddle.fluid.optimizer import ( + SGDOptimizer, + AdamOptimizer, + MomentumOptimizer, +) +from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware +from paddle.nn import Sequential +from paddle.nn import ReLU, ReLU6, LeakyReLU, Sigmoid, Softmax, PReLU +from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D +from paddle.fluid.log_helper import get_logger +from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX +from paddle.nn.quant.quant_layers import ( + QuantizedConv2D, + QuantizedMatmul, +) +from paddle.fluid.framework import _test_eager_guard +from imperative_test_utils import fix_model_dict + +paddle.enable_static() + +os.environ["CPU_NUM"] = "1" +if core.is_compiled_with_cuda(): + fluid.set_flags({"FLAGS_cudnn_deterministic": True}) + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' +) + + +class ImperativeLenet(fluid.dygraph.Layer): + def __init__(self, num_classes=10): + super().__init__() + conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1") + conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2") + fc_w1_attr = fluid.ParamAttr(name="fc_w_1") + fc_w2_attr = fluid.ParamAttr(name="fc_w_2") + fc_w3_attr = fluid.ParamAttr(name="fc_w_3") + conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2") + fc_b1_attr = fluid.ParamAttr(name="fc_b_1") + fc_b2_attr = fluid.ParamAttr(name="fc_b_2") + fc_b3_attr = fluid.ParamAttr(name="fc_b_3") + self.features = Sequential( + Conv2D( + in_channels=1, + out_channels=6, + kernel_size=3, + stride=1, + padding=1, + weight_attr=conv2d_w1_attr, + bias_attr=False, + ), + BatchNorm2D(6), + ReLU(), + MaxPool2D(kernel_size=2, stride=2), + Conv2D( + in_channels=6, + out_channels=16, + kernel_size=5, + stride=1, + padding=0, + weight_attr=conv2d_w2_attr, + bias_attr=conv2d_b2_attr, + ), + BatchNorm2D(16), + PReLU(), + MaxPool2D(kernel_size=2, stride=2), + ) + self.matmul = QuantizedMatmul() + self.fc = Sequential( + Linear( + in_features=400, + out_features=120, + weight_attr=fc_w1_attr, + bias_attr=fc_b1_attr, + ), + LeakyReLU(), + Linear( + in_features=120, + out_features=84, + weight_attr=fc_w2_attr, + bias_attr=fc_b2_attr, + ), + Sigmoid(), + Linear( + in_features=84, + out_features=num_classes, + weight_attr=fc_w3_attr, + bias_attr=fc_b3_attr, + ), + Softmax(), + ) + + def forward(self, inputs): + inputs = self.features(inputs) + inputs = self.matmul(inputs, inputs, transpose_y=True) + inputs = paddle.flatten(inputs, 1) + x = self.fc(inputs) + return x + + +class TestImperativeQatMatmul(unittest.TestCase): + def set_vars(self): + self.weight_quantize_type = 'abs_max' + self.activation_quantize_type = 'moving_average_abs_max' + self.onnx_format = True + self.fuse_conv_bn = False + + def func_qat(self): + self.set_vars() + + imperative_qat = ImperativeQuantAware( + weight_quantize_type=self.weight_quantize_type, + activation_quantize_type=self.activation_quantize_type, + fuse_conv_bn=self.fuse_conv_bn, + ) + + seed = 100 + np.random.seed(seed) + fluid.default_main_program().random_seed = seed + fluid.default_startup_program().random_seed = seed + paddle.disable_static() + lenet = ImperativeLenet() + lenet = fix_model_dict(lenet) + imperative_qat.quantize(lenet) + + optimizer = MomentumOptimizer( + learning_rate=0.1, parameter_list=lenet.parameters(), momentum=0.9 + ) + + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=64, drop_last=True + ) + test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=32) + epoch_num = 1 + for epoch in range(epoch_num): + lenet.train() + for batch_id, data in enumerate(train_reader()): + x_data = np.array( + [x[0].reshape(1, 28, 28) for x in data] + ).astype('float32') + y_data = ( + np.array([x[1] for x in data]) + .astype('int64') + .reshape(-1, 1) + ) + + img = fluid.dygraph.to_variable(x_data) + label = fluid.dygraph.to_variable(y_data) + out = lenet(img) + acc = paddle.static.accuracy(out, label) + loss = fluid.layers.cross_entropy(out, label) + avg_loss = paddle.mean(loss) + + avg_loss.backward() + optimizer.minimize(avg_loss) + lenet.clear_gradients() + + if batch_id % 100 == 0: + _logger.info( + "Train | At epoch {} step {}: loss = {:}, acc= {:}".format( + epoch, batch_id, avg_loss.numpy(), acc.numpy() + ) + ) + + lenet.eval() + eval_acc_top1_list = [] + with paddle.no_grad(): + for batch_id, data in enumerate(test_reader()): + + x_data = np.array( + [x[0].reshape(1, 28, 28) for x in data] + ).astype('float32') + y_data = ( + np.array([x[1] for x in data]) + .astype('int64') + .reshape(-1, 1) + ) + img = fluid.dygraph.to_variable(x_data) + label = fluid.dygraph.to_variable(y_data) + + out = lenet(img) + acc_top1 = paddle.static.accuracy( + input=out, label=label, k=1 + ) + acc_top5 = paddle.static.accuracy( + input=out, label=label, k=5 + ) + + if batch_id % 100 == 0: + eval_acc_top1_list.append(float(acc_top1.numpy())) + _logger.info( + "Test | At epoch {} step {}: acc1 = {:}, acc5 = {:}".format( + epoch, + batch_id, + acc_top1.numpy(), + acc_top5.numpy(), + ) + ) + + # check eval acc + eval_acc_top1 = sum(eval_acc_top1_list) / len(eval_acc_top1_list) + print('eval_acc_top1', eval_acc_top1) + + def test_qat(self): + self.func_qat() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/quant/quant_layers.py b/python/paddle/nn/quant/quant_layers.py index 9cb2db000d5311edf9bbc840d9ed221bbbaa6ffa..257009a8ff1fbf886bed3b08bef0998728ccd3e3 100644 --- a/python/paddle/nn/quant/quant_layers.py +++ b/python/paddle/nn/quant/quant_layers.py @@ -39,6 +39,7 @@ __all__ = [ 'QuantStub', 'QuantizedRowParallelLinear', 'QuantizedColumnParallelLinear', + 'QuantizedMatmul', ] _logger = get_logger( @@ -999,6 +1000,65 @@ class QuantizedRowParallelLinear(Layer): return output +class QuantizedMatmul(Layer): + """ + The computational logic of QuantizedMatmul is the same with Matmul. + The only difference is that its inputs are all fake quantized. + """ + + def __init__( + self, + layer=None, + weight_bits=8, + activation_bits=8, + moving_rate=0.9, + weight_quantize_type='abs_max', + activation_quantize_type='abs_max', + weight_pre_layer=None, + act_pre_layer=None, + weight_quant_layer=None, + act_quant_layer=None, + ): + super().__init__() + + # For FakeQuant + if act_quant_layer is not None: + self._fake_quant_x = act_quant_layer() + self._fake_quant_y = act_quant_layer() + else: + self._fake_quant_x = _get_fake_quant_type( + activation_quantize_type, + moving_rate=moving_rate, + quant_bits=activation_bits, + quant_on_weight=False, + ) + self._fake_quant_y = _get_fake_quant_type( + activation_quantize_type, + moving_rate=moving_rate, + quant_bits=activation_bits, + quant_on_weight=False, + ) + + self._act_preprocess_x = ( + act_pre_layer() if act_pre_layer is not None else None + ) + self._act_preprocess_y = ( + act_pre_layer() if act_pre_layer is not None else None + ) + + def forward(self, x, y, transpose_x=False, transpose_y=False, name=None): + if self._act_preprocess_x is not None: + x = self._act_preprocess_x(x) + quant_x = self._fake_quant_x(x) + + if self._act_preprocess_y is not None: + y = self._act_preprocess_y(y) + quant_y = self._fake_quant_y(y) + + out = paddle.matmul(quant_x, quant_y, transpose_x, transpose_y, name) + return out + + class MAOutputScaleLayer(Layer): """ Add MovingAverageMaxScale layer to the behind of the input layer.