From 745118183f4933aa91d2feddb71981cd5475beb4 Mon Sep 17 00:00:00 2001 From: huangxu96 <46740794+huangxu96@users.noreply.github.com> Date: Mon, 9 Nov 2020 20:15:33 +0800 Subject: [PATCH] Added dygraph quantization. (#435) * Added dygraph quantization. * Added save_quant_model function * Overload quant_aware, make dynamic and static model use the same interface. * Use singledispatch to overload the quant aware for dygraph quantization * Add unnitest for dygraph quant_aware * Use 2.0 api for supporting paddle 2.0 * Using pre-commit to modify the coding style. * Add unittest for user inputted quant_config. * Add singledispath into requirements.txt Co-authored-by: Bai Yifan --- paddleslim/quant/quanter.py | 66 +++++++- requirements.txt | 1 + tests/test_dygraph_quant_aware.py | 253 ++++++++++++++++++++++++++++++ 3 files changed, 319 insertions(+), 1 deletion(-) create mode 100644 tests/test_dygraph_quant_aware.py diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py index 9324276b..bcf606a7 100755 --- a/paddleslim/quant/quanter.py +++ b/paddleslim/quant/quanter.py @@ -16,6 +16,7 @@ import os import copy import json import logging +from singledispatch import singledispatch import paddle from paddle.fluid.framework import IrGraph @@ -29,6 +30,8 @@ from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass from paddle.fluid import core from paddle.fluid.contrib.slim.quantization import WeightQuantization +# For Imperative graph quantization +from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware from ..common import get_logger _logger = get_logger(__name__, level=logging.INFO) @@ -79,7 +82,9 @@ _quant_config_default = { # if True, 'quantize_op_types' will be TENSORRT_OP_TYPES 'for_tensorrt': False, # if True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES - 'is_full_quantize': False + 'is_full_quantize': False, + # for dygraph quantization, layers of type in quantizable_layer_type will be quantized + 'quantizable_layer_type': ['Conv2D', 'Linear'] } @@ -171,9 +176,12 @@ def _parse_configs(user_config): assert isinstance(configs['moving_rate'], float), \ "moving_rate must be float value, The decay coefficient of moving average, default is 0.9." + assert isinstance(configs['quantizable_layer_type'], list), \ + "quantizable_layer_type must be a list" return configs +@singledispatch def quant_aware(program, place, config=None, @@ -300,6 +308,62 @@ def quant_aware(program, return quant_program +@quant_aware.register(paddle.nn.Layer) +def _(model: paddle.nn.Layer, + config=None, + weight_quantize_func=None, + act_quantize_func=None, + weight_preprocess_func=None, + act_preprocess_func=None): + """ + This is function overload for dygraph model quant aware training. + Args: + model(nn.Layer) + config(dict, optional): configs for quantization. if None, will use default config. + Default: None. + weight_quantize_func(function): Function that defines how to quantize weight. Using this + can quickly test if user's quantization method works or not. In this function, user should + both define quantization function and dequantization function, that is, the function's input + is non-quantized weight and function returns dequantized weight. If None, will use + quantization op defined by 'weight_quantize_type'. + Default is None. + act_quantize_func(function): Function that defines how to quantize activation. Using this + can quickly test if user's quantization method works or not. In this function, user should + both define quantization and dequantization process, that is, the function's input + is non-quantized activation and function returns dequantized activation. If None, will use + quantization op defined by 'activation_quantize_type'. + Default is None. + weight_preprocess_func(function): Function that defines how to preprocess weight before quantization. Using this + can quickly test if user's preprocess method works or not. The function's input + is non-quantized weight and function returns processed weight to be quantized. If None, the weight will + be quantized directly. + Default is None. + act_preprocess_func(function): Function that defines how to preprocess activation before quantization. Using this + can quickly test if user's preprocess method works or not. The function's input + is non-quantized activation and function returns processed activation to be quantized. If None, the activation will + be quantized directly. + Default is None. + + Returns: + model(nn.Layer) | nn.layer: model with fake quantized layers + """ + + if config is None: + config = _quant_config_default + else: + assert isinstance(config, dict), "config must be dict" + config = _parse_configs(config) + + imperative_qat = ImperativeQuantAware( + weight_quantize_type=config['weight_quantize_type'], + activation_quantize_type=config['activation_quantize_type'], + quantizable_layer_type=config['quantizable_layer_type']) + + imperative_qat.quantize(model) + + return model + + def quant_post_static( executor, model_dir, diff --git a/requirements.txt b/requirements.txt index 8b4fa549..2f5bd7a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ #paddlepaddle == 1.6.0rc0 tqdm pyzmq +singledispatch diff --git a/tests/test_dygraph_quant_aware.py b/tests/test_dygraph_quant_aware.py new file mode 100644 index 00000000..1705cfea --- /dev/null +++ b/tests/test_dygraph_quant_aware.py @@ -0,0 +1,253 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import sys +sys.path.append("../") +import unittest +import logging +import paddle +import paddle.nn as nn +import paddle.fluid as fluid +from paddle.fluid.optimizer import AdamOptimizer +from paddle.fluid.dygraph.container import Sequential +from paddle.fluid.dygraph.nn import Conv2D +from paddle.fluid.dygraph.nn import Pool2D +from paddle.fluid.dygraph.nn import Linear +from paddle.fluid.log_helper import get_logger + +from paddleslim.quant import quant_aware + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + + +class ImperativeLenet(nn.Layer): + def __init__(self, num_classes=10, classifier_activation='softmax'): + super(ImperativeLenet, self).__init__() + self.features = Sequential( + Conv2D( + num_channels=1, + num_filters=6, + filter_size=3, + stride=1, + padding=1), + Pool2D( + pool_size=2, pool_type='max', pool_stride=2), + Conv2D( + num_channels=6, + num_filters=16, + filter_size=5, + stride=1, + padding=0), + Pool2D( + pool_size=2, pool_type='max', pool_stride=2)) + + self.fc = Sequential( + Linear( + input_dim=400, output_dim=120), + Linear( + input_dim=120, output_dim=84), + Linear( + input_dim=84, output_dim=num_classes, + act=classifier_activation)) + + def forward(self, inputs): + x = self.features(inputs) + + x = fluid.layers.flatten(x, 1) + x = self.fc(x) + return x + + +class TestImperativeQatDefaultConfig(unittest.TestCase): + """ + QAT = quantization-aware training + This test case uses defualt quantization config, weight_quantize_type + is channel_wise_abs_max + """ + + def test_qat_acc(self): + with fluid.dygraph.guard(): + lenet = ImperativeLenet() + quant_lenet = quant_aware(lenet) + + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=32, drop_last=True) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=32) + + def train(model): + adam = AdamOptimizer( + learning_rate=0.001, parameter_list=model.parameters()) + epoch_num = 1 + for epoch in range(epoch_num): + model.train() + for batch_id, data in enumerate(train_reader()): + x_data = np.array( + [x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = fluid.dygraph.to_variable(x_data) + label = fluid.dygraph.to_variable(y_data) + out = model(img) + acc = fluid.layers.accuracy(out, label) + loss = fluid.layers.cross_entropy(out, label) + avg_loss = fluid.layers.mean(loss) + avg_loss.backward() + adam.minimize(avg_loss) + model.clear_gradients() + if batch_id % 100 == 0: + _logger.info( + "Train | At epoch {} step {}: loss = {:}, acc= {:}". + format(epoch, batch_id, + avg_loss.numpy(), acc.numpy())) + + def test(model): + model.eval() + avg_acc = [[], []] + for batch_id, data in enumerate(test_reader()): + x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = fluid.dygraph.to_variable(x_data) + label = fluid.dygraph.to_variable(y_data) + + out = model(img) + acc_top1 = fluid.layers.accuracy( + input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy( + input=out, label=label, k=5) + avg_acc[0].append(acc_top1.numpy()) + avg_acc[1].append(acc_top5.numpy()) + if batch_id % 100 == 0: + _logger.info( + "Test | step {}: acc1 = {:}, acc5 = {:}".format( + batch_id, acc_top1.numpy(), acc_top5.numpy())) + + _logger.info("Test |Average: acc_top1 {}, acc_top5 {}".format( + np.mean(avg_acc[0]), np.mean(avg_acc[1]))) + return np.mean(avg_acc[0]), np.mean(avg_acc[1]) + + train(lenet) + top1_1, top5_1 = test(lenet) + + quant_lenet.__init__() + train(quant_lenet) + top1_2, top5_2 = test(quant_lenet) + + # values before quantization and after quantization should be close + _logger.info("Before quantization: top1: {}, top5: {}".format( + top1_1, top5_1)) + _logger.info("After quantization: top1: {}, top5: {}".format( + top1_2, top5_2)) + + +class TestImperativeQatUserDefineConfig(unittest.TestCase): + """ + QAT = quantization-aware training + This test case is for testing user defined quantization config. + """ + + def test_qat_acc(self): + with fluid.dygraph.guard(): + lenet = ImperativeLenet() + quant_config = { + 'weight_quantize_type': 'abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantizable_layer_type': ['Conv2D', 'Linear'] + } + quant_lenet = quant_aware(lenet, quant_config) + + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=32, drop_last=True) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=32) + + def train(model): + adam = AdamOptimizer( + learning_rate=0.001, parameter_list=model.parameters()) + epoch_num = 1 + for epoch in range(epoch_num): + model.train() + for batch_id, data in enumerate(train_reader()): + x_data = np.array( + [x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = fluid.dygraph.to_variable(x_data) + label = fluid.dygraph.to_variable(y_data) + out = model(img) + acc = fluid.layers.accuracy(out, label) + loss = fluid.layers.cross_entropy(out, label) + avg_loss = fluid.layers.mean(loss) + avg_loss.backward() + adam.minimize(avg_loss) + model.clear_gradients() + if batch_id % 100 == 0: + _logger.info( + "Train | At epoch {} step {}: loss = {:}, acc= {:}". + format(epoch, batch_id, + avg_loss.numpy(), acc.numpy())) + + def test(model): + model.eval() + avg_acc = [[], []] + for batch_id, data in enumerate(test_reader()): + x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = fluid.dygraph.to_variable(x_data) + label = fluid.dygraph.to_variable(y_data) + + out = model(img) + acc_top1 = fluid.layers.accuracy( + input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy( + input=out, label=label, k=5) + avg_acc[0].append(acc_top1.numpy()) + avg_acc[1].append(acc_top5.numpy()) + if batch_id % 100 == 0: + _logger.info( + "Test | step {}: acc1 = {:}, acc5 = {:}".format( + batch_id, acc_top1.numpy(), acc_top5.numpy())) + + _logger.info("Test |Average: acc_top1 {}, acc_top5 {}".format( + np.mean(avg_acc[0]), np.mean(avg_acc[1]))) + return np.mean(avg_acc[0]), np.mean(avg_acc[1]) + + train(lenet) + top1_1, top5_1 = test(lenet) + + quant_lenet.__init__() + train(quant_lenet) + top1_2, top5_2 = test(quant_lenet) + + # values before quantization and after quantization should be close + _logger.info("Before quantization: top1: {}, top5: {}".format( + top1_1, top5_1)) + _logger.info("After quantization: top1: {}, top5: {}".format( + top1_2, top5_2)) + + +if __name__ == '__main__': + unittest.main() -- GitLab