From 62f455e023eb8dcbbcf288a8f31c6f1ecb20444d Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Wed, 30 Dec 2020 14:17:13 +0800 Subject: [PATCH] Support quantizing program_desc (#29526) * Support quantizing program_desc, test=develop --- .../quantization/quantize_transpiler_v2.py | 177 ++++++++++++++++++ .../fluid/contrib/slim/tests/CMakeLists.txt | 5 +- .../slim/tests/test_quantize_transpiler_v2.py | 163 ++++++++++++++++ 3 files changed, 343 insertions(+), 2 deletions(-) create mode 100644 python/paddle/fluid/contrib/slim/quantization/quantize_transpiler_v2.py create mode 100644 python/paddle/fluid/contrib/slim/tests/test_quantize_transpiler_v2.py diff --git a/python/paddle/fluid/contrib/slim/quantization/quantize_transpiler_v2.py b/python/paddle/fluid/contrib/slim/quantization/quantize_transpiler_v2.py new file mode 100644 index 00000000000..cde3d991a7f --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/quantize_transpiler_v2.py @@ -0,0 +1,177 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import logging +import numpy as np +from .... import core +from ....framework import Program, Operator, Variable, program_guard +from .... import unique_name +from ....layer_helper import LayerHelper +from ....param_attr import ParamAttr +from ....initializer import Constant +from ....log_helper import get_logger + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + + +class QuantizeTranspilerV2(object): + def __init__(self, + weight_bits=8, + activation_bits=8, + weight_quantize_type='abs_max', + activation_quantize_type='abs_max', + quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'], + skip_pattern=['skip_quant']): + """ + Add quant_dequant op before the quantized op to quantize the fluid Program. + It is a patch for distributed quantization, we will support others module for + distributed quantization. + + Args: + weight_bits(int): the bit of quantized weight. + activation_bits(int): the bit of quantized activation. + weight_quantize_type(str): the quantization type for weight. + Only support to be 'abs_max' for now. + activation_quantize_type(str): the quantization type for activation. + Only support to be 'abs_max' for now. + quantizable_op_type(str): set the op type for quantization. + skip_pattern(str|list): The user-defined quantization skip pattern, which + will be presented in the name scope of an op. When the skip pattern is + detected in an op's name scope, the corresponding op will not be quantized. + """ + self._weight_bits = weight_bits + self._activation_bits = activation_bits + + assert activation_quantize_type == "abs_max", \ + "activation_quantize_type should be abs_max for now." + assert weight_quantize_type == "abs_max", \ + "weight_quantize_type should be abs_max for now." + self._activation_quantize_type = activation_quantize_type + self._weight_quantize_type = weight_quantize_type + + self._quantizable_ops = quantizable_op_type + self._quantizable_grad_ops = [ + '%s_grad' % (op) for op in self._quantizable_ops + ] + + self._skip_pattern = skip_pattern + self.helper = LayerHelper(self.__class__.__name__) + + def apply(self, program, startup_program): + """ + Apply quantization to fluid Program. + + Args: + program(Program): the train or test program to be quantized. + startup_program(Program): the corresponding startup_program. + Returns: + None + """ + assert isinstance(program, Program), \ + "program must be the instance of Program" + assert isinstance(startup_program, Program), \ + "startup_program must be the instance of Program" + + quant_dequant_vars = [ + collections.OrderedDict() for _ in range(len(program.blocks)) + ] + with program_guard(program, startup_program): + for block in program.blocks: + ops = list(block.ops) + for op in ops: + if op.type in self._quantizable_ops and \ + (not self._is_skip_quant(op)): + self._transform_forward(block, op, quant_dequant_vars) + for block in program.blocks: + ops = list(block.ops) + for op in ops: + if op.type in self._quantizable_grad_ops and \ + (not self._is_skip_quant(op)): + self._transform_backward(block, op, quant_dequant_vars) + + def _is_skip_quant(self, op): + """ + Analyse whether the op should skip quantization or not. + """ + user_skipped = False + if isinstance(self._skip_pattern, list): + user_skipped = op.has_attr("op_namescope") and \ + any(pattern in op.attr("op_namescope") \ + for pattern in self._skip_pattern) + elif isinstance(self._skip_pattern, str): + user_skipped = op.has_attr("op_namescope") and \ + op.attr("op_namescope").find( + self._skip_pattern) != -1 + return user_skipped + + def _transform_forward(self, block, op, quant_dequant_vars): + op._set_attr("quantization_type", "qat_with_weight") + idx = block.ops.index(op) + block_id = block.idx + for in_name in op.input_arg_names: + if in_name in quant_dequant_vars[block_id]: + quant_dequant_var = quant_dequant_vars[block_id][in_name] + else: + in_var = block.var(in_name) + quant_bits = self._weight_bits if in_var.persistable \ + else self._activation_bits + quant_type = self._weight_quantize_type if in_var.persistable \ + else self._activation_quantize_type + if quant_type == "abs_max": + quant_dequant_var = self._insert_quant_dequant_abs_max_op( + block, idx, in_var, quant_bits) + else: + _logger.error("Quant_type only supported to be abs_max") + quant_dequant_vars[block_id][in_name] = quant_dequant_var + op._rename_input(in_name, quant_dequant_var.name) + + def _transform_backward(self, block, op, quant_dequant_vars): + block_id = block.idx + no_dequanted_input_vars = True + for name in op.input_arg_names: + if name in quant_dequant_vars[block_id]: + dequant_var = quant_dequant_vars[block_id][name] + op._rename_input(name, dequant_var.name) + no_dequanted_input_vars = False + if no_dequanted_input_vars: + raise ValueError("There is no dequanted inputs for op %s." % + (op.type)) + + def _insert_quant_dequant_abs_max_op(self, block, idx, in_var, quant_bits): + quant_dequant_var = block.create_var( + type=in_var.type, + name="{}.quant_dequant".format(in_var.name), + shape=in_var.shape, + dtype=in_var.dtype) + scale_var = self.helper.create_parameter( + attr=ParamAttr( + name="{}.quant_dequant.scale".format(in_var.name), + initializer=Constant(0.001), + trainable=False), + shape=[1], + dtype=in_var.dtype) + scale_var.stop_gradient = True + + inputs = {'X': in_var} + outputs = {'Out': quant_dequant_var, 'OutScale': scale_var} + attrs = {'bit_length': quant_bits} + block._insert_op( + idx, + type='fake_quantize_dequantize_abs_max', + attrs=attrs, + inputs=inputs, + outputs=outputs) + return quant_dequant_var diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index f24a82f4fd9..25141de63f5 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -123,8 +123,9 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_light_nas) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mnist) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1) - list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50) - list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1) + list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50) + list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1) + list(REMOVE_ITEM TEST_OPS test_quantize_transpiler_v2) endif() if(LINUX AND WITH_MKLDNN) diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantize_transpiler_v2.py b/python/paddle/fluid/contrib/slim/tests/test_quantize_transpiler_v2.py new file mode 100644 index 00000000000..00f2b597d93 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_quantize_transpiler_v2.py @@ -0,0 +1,163 @@ +# copyright (c) 2018 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import os +import unittest +import random +import numpy as np +import six +import paddle.fluid as fluid +import paddle +from paddle.fluid.framework import IrGraph +from paddle.fluid.contrib.slim.quantization.quantize_transpiler_v2 import QuantizeTranspilerV2 +from paddle.fluid import core + +paddle.enable_static() + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +os.environ["CPU_NUM"] = "1" + + +def conv_net(img, label): + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=img, + filter_size=5, + num_filters=20, + pool_size=2, + pool_stride=2, + pool_type='max', + act="relu") + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + filter_size=5, + num_filters=50, + pool_size=2, + pool_stride=2, + pool_type='avg', + act="relu") + with fluid.name_scope("skip_quant"): + hidden = fluid.layers.fc(input=conv_pool_1, size=100, act='relu') + prediction = fluid.layers.fc(input=hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + avg_loss = fluid.layers.mean(loss) + return avg_loss + + +class TestQuantizeProgramPass(unittest.TestCase): + def quantize_program(self, + use_cuda, + seed, + activation_quant_type='abs_max', + weight_quant_type='abs_max', + for_ci=False): + def build_program(main, startup, is_test): + main.random_seed = seed + startup.random_seed = seed + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + img = fluid.layers.data( + name='image', shape=[1, 28, 28], dtype='float32') + label = fluid.layers.data( + name='label', shape=[1], dtype='int64') + loss = conv_net(img, label) + if not is_test: + opt = fluid.optimizer.Adam(learning_rate=0.0001) + opt.minimize(loss) + return [img, label], loss + + random.seed(0) + np.random.seed(0) + + train_program = fluid.Program() + startup_program = fluid.Program() + test_program = fluid.Program() + feeds, loss = build_program(train_program, startup_program, False) + build_program(test_program, startup_program, True) + test_program = test_program.clone(for_test=True) + + if not for_ci: + train_graph = IrGraph( + core.Graph(train_program.desc), for_test=False) + train_graph.draw('.', 'train_program_1') + test_graph = IrGraph(core.Graph(test_program.desc), for_test=True) + test_graph.draw('.', 'test_program_1') + + qt = QuantizeTranspilerV2( + activation_quantize_type=activation_quant_type, + weight_quantize_type=weight_quant_type, + quantizable_op_type=[ + 'conv2d', 'depthwise_conv2d', 'mul', 'pool2d' + ]) + qt.apply(train_program, startup_program) + qt.apply(test_program, startup_program) + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.Scope() + with fluid.scope_guard(scope): + exe.run(startup_program) + if not for_ci: + train_graph = IrGraph( + core.Graph(train_program.desc), for_test=False) + train_graph.draw('.', 'train_program_2') + test_graph = IrGraph(core.Graph(test_program.desc), for_test=True) + test_graph.draw('.', 'test_program_2') + + build_strategy = fluid.BuildStrategy() + build_strategy.memory_optimize = False + build_strategy.enable_inplace = False + build_strategy.fuse_all_reduce_ops = False + binary = fluid.CompiledProgram(train_program).with_data_parallel( + loss_name=loss.name, build_strategy=build_strategy) + iters = 2 + batch_size = 8 + + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=batch_size) + feeder = fluid.DataFeeder(feed_list=feeds, place=place) + with fluid.scope_guard(scope): + for _ in range(iters): + data = next(train_reader()) + loss_v = exe.run(binary, + feed=feeder.feed(data), + fetch_list=[loss]) + if not for_ci: + print('{}: {}'.format('loss', loss_v)) + + if not for_ci: + with fluid.scope_guard(scope): + fluid.io.save_inference_model('./infer_model', + ['image', 'label'], [loss], exe, + test_program) + + def test_quantize_program_gpu(self): + if fluid.core.is_compiled_with_cuda(): + self.quantize_program( + use_cuda=True, + seed=1, + activation_quant_type='abs_max', + weight_quant_type='abs_max', + for_ci=True) + + def test_quantize_program_cpu(self): + self.quantize_program( + use_cuda=False, + seed=2, + activation_quant_type='abs_max', + weight_quant_type='abs_max', + for_ci=True) + + +if __name__ == '__main__': + unittest.main() -- GitLab