diff --git a/paddleslim/dygraph/quant/__init__.py b/paddleslim/dygraph/quant/__init__.py index f95bb683e7a9525508954b66c86d198b50ad0f6a..b9436bd2a6a6ec1774597d032bcf0b4841ffc7a5 100644 --- a/paddleslim/dygraph/quant/__init__.py +++ b/paddleslim/dygraph/quant/__init__.py @@ -12,5 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .quanter import QAT -__all__ = ['QAT'] +from . import qat +from . import ptq + +from .qat import * +from .ptq import * + +__all__ = [] +__all__ += qat.__all__ +__all__ += ptq.__all__ diff --git a/paddleslim/dygraph/quant/ptq.py b/paddleslim/dygraph/quant/ptq.py new file mode 100644 index 0000000000000000000000000000000000000000..49c0f70fb20f11c9a30676f23d002d7489294f2f --- /dev/null +++ b/paddleslim/dygraph/quant/ptq.py @@ -0,0 +1,105 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging + +import paddle +import paddle.fluid.contrib.slim.quantization as Q +from paddle.fluid.contrib.slim.quantization import AbsmaxQuantizer +from paddle.fluid.contrib.slim.quantization import HistQuantizer +from paddle.fluid.contrib.slim.quantization import KLQuantizer +from paddle.fluid.contrib.slim.quantization import PerChannelAbsmaxQuantizer +from ...common import get_logger + +_logger = get_logger(__name__, level=logging.INFO) + +__all__ = [ + 'PTQ', + 'AbsmaxQuantizer', + 'HistQuantizer', + 'KLQuantizer', + 'PerChannelAbsmaxQuantizer', +] + + +class PTQ(object): + """ + Static post training quantization. + """ + + def __init__(self, + activation_quantizer=Q.KLQuantizer(), + weight_quantizer=Q.PerChannelAbsmaxQuantizer()): + """ + Args: + activation_quantizer(Quantizer): The quantizer method for activation. + Default: KLQuantizer. + weight_quantizer(Quantizer): The quantizer method for weight. + Default: PerChannelAbsmaxQuantizer. + """ + assert isinstance(activation_quantizer, tuple(Q.SUPPORT_ACT_QUANTIZERS)) + assert isinstance(weight_quantizer, tuple(Q.SUPPORT_WT_QUANTIZERS)) + + quant_config = Q.PTQConfig( + activation_quantizer=activation_quantizer, + weight_quantizer=weight_quantizer) + + self.ptq = Q.ImperativePTQ(quant_config=quant_config) + + def quantize(self, model, inplace=False): + """ + Quantize the input model. + + Args: + model(paddle.nn.Layer): The model to be quantized. + inplace(bool): Whether apply quantization to the input model. + Default: False. + Returns: + quantized_model(paddle.nn.Layer): The quantized model. + """ + assert isinstance(model, paddle.nn.Layer), \ + "The model must be the instance of paddle.nn.Layer." + + return self.ptq.quantize(model=model, inplace=inplace) + + def save_quantized_model(self, model, path, input_spec=None): + """ + Save the quantized inference model. + + Args: + model (Layer): The model to be saved. + path (str): The path prefix to save model. The format is + ``dirname/file_prefix`` or ``file_prefix``. + input_spec (list[InputSpec|Tensor], optional): Describes the input + of the saved model's forward method, which can be described by + InputSpec or example Tensor. If None, all input variables of + the original Layer's forward method would be the inputs of + the saved model. Default: None. + + Returns: + None + """ + assert isinstance(model, paddle.nn.Layer), \ + "The model must be the instance of paddle.nn.Layer." + + training = model.training + if training: + model.eval() + + self.ptq.save_quantized_model( + model=model, path=path, input_spec=input_spec) + + if training: + model.train() diff --git a/paddleslim/dygraph/quant/quanter.py b/paddleslim/dygraph/quant/qat.py similarity index 87% rename from paddleslim/dygraph/quant/quanter.py rename to paddleslim/dygraph/quant/qat.py index 9a47565772fcf2593483495ca5685067261fa4c7..cf418e78051c3f385ec9b5c5a7bdbc16d1b449f1 100644 --- a/paddleslim/dygraph/quant/quanter.py +++ b/paddleslim/dygraph/quant/qat.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. @@ -126,7 +126,7 @@ class PACT(paddle.nn.Layer): super(PACT, self).__init__() alpha_attr = paddle.ParamAttr( name=self.full_name() + ".pact", - initializer=paddle.nn.initializer.Constant(value=20), + initializer=paddle.nn.initializer.Constant(value=100), learning_rate=1000.0) self.alpha = self.create_parameter( @@ -207,11 +207,48 @@ class QAT(object): weight_quantize_layer=self.weight_quantize, act_quantize_layer=self.act_quantize) - def quantize(self, model): + def quantize(self, model, inplace=False): + """ + Quantize the input model. + + Args: + model(paddle.nn.Layer): The model to be quantized. + inplace(bool): Whether apply quantization to the input model. + Default: False. + Returns: + quantized_model(paddle.nn.Layer): The quantized model. + """ + assert isinstance(model, paddle.nn.Layer), \ + "The model must be the instance of paddle.nn.Layer." + self._model = copy.deepcopy(model) - self.imperative_qat.quantize(model) + + if inplace: + self.imperative_qat.quantize(model) + quant_model = model + else: + quant_model = copy.deepcopy(model) + self.imperative_qat.quantize(quant_model) + + return quant_model def save_quantized_model(self, model, path, input_spec=None): + """ + Save the quantized inference model. + + Args: + model (Layer): The model to be saved. + path (str): The path prefix to save model. The format is + ``dirname/file_prefix`` or ``file_prefix``. + input_spec (list[InputSpec|Tensor], optional): Describes the input + of the saved model's forward method, which can be described by + InputSpec or example Tensor. If None, all input variables of + the original Layer's forward method would be the inputs of + the saved model. Default: None. + + Returns: + None + """ if self.weight_preprocess is not None or self.act_preprocess is not None: training = model.training model = self._remove_preprocess(model) diff --git a/tests/dygraph/test_ptq.py b/tests/dygraph/test_ptq.py new file mode 100644 index 0000000000000000000000000000000000000000..3ca38d71e6e185b05e6d62992d3ddcf380ffcf04 --- /dev/null +++ b/tests/dygraph/test_ptq.py @@ -0,0 +1,194 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import sys +sys.path.append("../../") +import unittest +import logging +import paddle +import paddle.nn as nn +import paddle.fluid as fluid +from paddle.fluid.log_helper import get_logger +import paddle.vision.transforms as T + +from paddleslim import PTQ + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + + +class ImperativeLenet(nn.Layer): + def __init__(self, num_classes=10, classifier_activation='softmax'): + super(ImperativeLenet, self).__init__() + self.features = paddle.nn.Sequential( + paddle.nn.Conv2D( + in_channels=1, + out_channels=6, + kernel_size=3, + stride=1, + padding=1), + paddle.nn.AvgPool2D( + kernel_size=2, stride=2), + paddle.nn.Conv2D( + in_channels=6, + out_channels=16, + kernel_size=5, + stride=1, + padding=0), + paddle.nn.AvgPool2D( + kernel_size=2, stride=2)) + + self.fc = paddle.nn.Sequential( + paddle.nn.Linear( + in_features=400, out_features=120), + paddle.nn.Linear( + in_features=120, out_features=84), + paddle.nn.Linear( + in_features=84, out_features=num_classes), ) + + def forward(self, inputs): + x = self.features(inputs) + + x = paddle.flatten(x, 1) + x = self.fc(x) + return x + + +class TestPTQ(unittest.TestCase): + """ + Test dygraph post training quantization. + """ + + def calibrate(self, model, test_reader, batch_num=10): + model.eval() + for batch_id, data in enumerate(test_reader): + img = paddle.to_tensor(data[0]) + img = paddle.reshape(img, [-1, 1, 28, 28]) + + out = model(img) + + if batch_num + 1 >= batch_num: + break + + def model_test(self, model, test_reader): + model.eval() + avg_acc = [[], []] + for batch_id, data in enumerate(test_reader): + img = paddle.to_tensor(data[0]) + img = paddle.reshape(img, [-1, 1, 28, 28]) + label = paddle.to_tensor(data[1]) + label = paddle.reshape(label, [-1, 1]) + + out = model(img) + + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) + avg_acc[0].append(acc_top1.numpy()) + avg_acc[1].append(acc_top5.numpy()) + + if batch_id % 100 == 0: + _logger.info("Test | step {}: acc1 = {:}, acc5 = {:}".format( + batch_id, acc_top1.numpy(), acc_top5.numpy())) + + _logger.info("Test |Average: acc_top1 {}, acc_top5 {}".format( + np.mean(avg_acc[0]), np.mean(avg_acc[1]))) + return np.mean(avg_acc[0]), np.mean(avg_acc[1]) + + def model_train(self, model, train_reader): + adam = paddle.optimizer.Adam( + learning_rate=0.0001, parameters=model.parameters()) + epoch_num = 1 + for epoch in range(epoch_num): + model.train() + for batch_id, data in enumerate(train_reader): + img = paddle.to_tensor(data[0]) + label = paddle.to_tensor(data[1]) + img = paddle.reshape(img, [-1, 1, 28, 28]) + label = paddle.reshape(label, [-1, 1]) + + out = model(img) + acc = paddle.metric.accuracy(out, label) + loss = paddle.nn.functional.loss.cross_entropy(out, label) + avg_loss = paddle.mean(loss) + avg_loss.backward() + adam.minimize(avg_loss) + model.clear_gradients() + if batch_id % 100 == 0: + _logger.info( + "Train | At epoch {} step {}: loss = {:}, acc= {:}". + format(epoch, batch_id, avg_loss.numpy(), acc.numpy())) + + def test_ptq(self): + seed = 1 + np.random.seed(seed) + paddle.static.default_main_program().random_seed = seed + paddle.static.default_startup_program().random_seed = seed + + _logger.info("create the fp32 model") + fp32_lenet = ImperativeLenet() + + _logger.info("prepare data") + batch_size = 64 + transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) + train_dataset = paddle.vision.datasets.MNIST( + mode='train', backend='cv2', transform=transform) + val_dataset = paddle.vision.datasets.MNIST( + mode='test', backend='cv2', transform=transform) + + place = paddle.CUDAPlace(0) \ + if paddle.is_compiled_with_cuda() else paddle.CPUPlace() + train_reader = paddle.io.DataLoader( + train_dataset, + drop_last=True, + places=place, + batch_size=batch_size, + return_list=True) + test_reader = paddle.io.DataLoader( + val_dataset, places=place, batch_size=batch_size, return_list=True) + + _logger.info("train the fp32 model") + self.model_train(fp32_lenet, train_reader) + + _logger.info("test fp32 model") + fp32_top1, fp32_top5 = self.model_test(fp32_lenet, test_reader) + + _logger.info("quantize the fp32 model") + quanter = PTQ() + quant_lenet = quanter.quantize(fp32_lenet) + + _logger.info("calibrate") + self.calibrate(quant_lenet, test_reader) + + _logger.info("save and test the quantized model") + save_path = "./tmp/model" + input_spec = paddle.static.InputSpec( + shape=[None, 1, 28, 28], dtype='float32') + quanter.save_quantized_model( + quant_lenet, save_path, input_spec=[input_spec]) + quant_top1, quant_top5 = self.model_test(quant_lenet, test_reader) + + _logger.info("FP32 acc: top1: {}, top5: {}".format(fp32_top1, + fp32_top5)) + _logger.info("Int acc: top1: {}, top5: {}".format(quant_top1, + quant_top5)) + + diff = 0.002 + self.assertTrue( + fp32_top1 - quant_top1 < diff, + msg="The acc of quant model is too lower than fp32 model") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/dygraph/test_dygraph_quant_aware.py b/tests/dygraph/test_qat.py similarity index 55% rename from tests/dygraph/test_dygraph_quant_aware.py rename to tests/dygraph/test_qat.py index 57823537abe7ded5a58c1ed567fb5ba3f182b500..9e2284a793986afad20ba9af09a025269b32c7f0 100644 --- a/tests/dygraph/test_dygraph_quant_aware.py +++ b/tests/dygraph/test_qat.py @@ -66,20 +66,30 @@ class ImperativeLenet(nn.Layer): return x -class TestImperativeQatDefaultConfig(unittest.TestCase): +class TestQAT(unittest.TestCase): """ QAT = quantization-aware training This test case uses defualt quantization config, weight_quantize_type is channel_wise_abs_max """ + def set_seed(self): + seed = 1 + np.random.seed(seed) + paddle.static.default_main_program().random_seed = seed + paddle.static.default_startup_program().random_seed = seed + + def prepare(self): + self.quanter = QAT() + def test_qat_acc(self): - lenet = ImperativeLenet() - quanter = QAT() - quanter.quantize(lenet) + self.prepare() + self.set_seed() + + fp32_lenet = ImperativeLenet() - place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( - ) else paddle.CPUPlace() + place = paddle.CUDAPlace(0) \ + if paddle.is_compiled_with_cuda() else paddle.CPUPlace() transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) @@ -141,130 +151,46 @@ class TestImperativeQatDefaultConfig(unittest.TestCase): "Test | step {}: acc1 = {:}, acc5 = {:}".format( batch_id, acc_top1.numpy(), acc_top5.numpy())) - _logger.info("Test |Average: acc_top1 {}, acc_top5 {}".format( + _logger.info("Test | Average: acc_top1 {}, acc_top5 {}".format( np.mean(avg_acc[0]), np.mean(avg_acc[1]))) return np.mean(avg_acc[0]), np.mean(avg_acc[1]) - train(lenet) - top1_1, top5_1 = test(lenet) + train(fp32_lenet) + top1_1, top5_1 = test(fp32_lenet) - lenet.__init__() - train(lenet) - top1_2, top5_2 = test(lenet) + fp32_lenet.__init__() + quant_lenet = self.quanter.quantize(fp32_lenet) + train(quant_lenet) + top1_2, top5_2 = test(quant_lenet) + self.quanter.save_quantized_model( + quant_lenet, + './tmp/qat', + input_spec=[ + paddle.static.InputSpec( + shape=[None, 1, 28, 28], dtype='float32') + ]) # values before quantization and after quantization should be close - _logger.info("Before quantization: top1: {}, top5: {}".format(top1_2, - top5_2)) - _logger.info("After quantization: top1: {}, top5: {}".format(top1_1, - top5_1)) + _logger.info("Before quantization: top1: {}, top5: {}".format(top1_1, + top5_1)) + _logger.info("After quantization: top1: {}, top5: {}".format(top1_2, + top5_2)) + _logger.info("\n") + + diff = 0.002 + self.assertTrue( + top1_1 - top1_2 < diff, + msg="The acc of quant model is too lower than fp32 model") -class TestImperativeQatPACT(unittest.TestCase): +class TestQATWithPACT(TestQAT): """ - QAT = quantization-aware training This test case is for testing user defined quantization. """ - def test_qat_acc(self): - lenet = ImperativeLenet() - quant_config = { - 'activation_preprocess_type': 'PACT', - 'quantizable_layer_type': ['Conv2D', 'Linear'], - } - quanter = QAT(config=quant_config) - quanter.quantize(lenet) - - place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( - ) else paddle.CPUPlace() - - transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) - - train_dataset = paddle.vision.datasets.MNIST( - mode='train', backend='cv2', transform=transform) - val_dataset = paddle.vision.datasets.MNIST( - mode='test', backend='cv2', transform=transform) - train_reader = paddle.io.DataLoader( - train_dataset, drop_last=True, places=place, batch_size=64) - test_reader = paddle.io.DataLoader( - val_dataset, places=place, batch_size=64) - - def train(model): - adam = paddle.optimizer.Adam( - learning_rate=0.001, parameters=model.parameters()) - epoch_num = 1 - for epoch in range(epoch_num): - model.train() - for batch_id, data in enumerate(train_reader): - img = paddle.to_tensor(data[0]) - label = paddle.to_tensor(data[1]) - img = paddle.reshape(img, [-1, 1, 28, 28]) - label = paddle.reshape(label, [-1, 1]) - - out = model(img) - acc = paddle.metric.accuracy(out, label) - loss = paddle.nn.functional.loss.cross_entropy(out, label) - avg_loss = paddle.mean(loss) - avg_loss.backward() - adam.minimize(avg_loss) - model.clear_gradients() - if batch_id % 100 == 0: - _logger.info( - "Train | At epoch {} step {}: loss = {:}, acc= {:}". - format(epoch, batch_id, - avg_loss.numpy(), acc.numpy())) - - def test(model): - model.eval() - avg_acc = [[], []] - for batch_id, data in enumerate(test_reader): - img = paddle.to_tensor(data[0]) - label = paddle.to_tensor(data[1]) - img = paddle.reshape(img, [-1, 1, 28, 28]) - label = paddle.reshape(label, [-1, 1]) - - out = model(img) - acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) - acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) - avg_acc[0].append(acc_top1.numpy()) - avg_acc[1].append(acc_top5.numpy()) - if batch_id % 100 == 0: - _logger.info( - "Test | step {}: acc1 = {:}, acc5 = {:}".format( - batch_id, acc_top1.numpy(), acc_top5.numpy())) - - _logger.info("Test |Average: acc_top1 {}, acc_top5 {}".format( - np.mean(avg_acc[0]), np.mean(avg_acc[1]))) - return np.mean(avg_acc[0]), np.mean(avg_acc[1]) - - train(lenet) - top1_1, top5_1 = test(lenet) - quanter.save_quantized_model( - lenet, - './dygraph_qat', - input_spec=[ - paddle.static.InputSpec( - shape=[None, 1, 28, 28], dtype='float32') - ]) - - lenet.__init__() - train(lenet) - top1_2, top5_2 = test(lenet) - - # values before quantization and after quantization should be close - _logger.info("Before quantization: top1: {}, top5: {}".format(top1_2, - top5_2)) - _logger.info("After quantization: top1: {}, top5: {}".format(top1_1, - top5_1)) - - # test for saving model in train mode - lenet.train() - quanter.save_quantized_model( - lenet, - './dygraph_qat', - input_spec=[ - paddle.static.InputSpec( - shape=[None, 1, 28, 28], dtype='float32') - ]) + def prepare(self): + quant_config = {'activation_preprocess_type': 'PACT', } + self.quanter = QAT(config=quant_config) if __name__ == '__main__':