diff --git a/python/paddle/static/quantization/__init__.py b/python/paddle/static/quantization/__init__.py index 28d76d8cd6772337f9e98ac3074ad7def72e557b..0e94831516afc97998251c33520b3da18ca81aa4 100644 --- a/python/paddle/static/quantization/__init__.py +++ b/python/paddle/static/quantization/__init__.py @@ -64,3 +64,7 @@ from .post_training_quantization import ( from .post_training_quantization import ( WeightQuantization, ) +from .quanter import ( + quant_aware, + convert, +) diff --git a/python/paddle/static/quantization/quanter.py b/python/paddle/static/quantization/quanter.py new file mode 100644 index 0000000000000000000000000000000000000000..a5baf899060c4c91ad3568e88c406f88626ca604 --- /dev/null +++ b/python/paddle/static/quantization/quanter.py @@ -0,0 +1,523 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import copy +import json +import logging +import os + +import paddle + +from ...fluid.framework import IrGraph, core +from ..log_helper import get_logger +from .quantization_pass import ( + AddQuantDequantPass, + ConvertToInt8Pass, + OutScaleForInferencePass, + OutScaleForTrainingPass, + QuantizationFreezePass, + QuantizationTransformPass, +) + +_logger = get_logger(__name__, level=logging.INFO) + +from . import quant_config +from .post_training_quantization import PostTrainingQuantizationProgram +from .quantization_pass import ( + AddQuantDequantForInferencePass, + AddQuantDequantPassV2, + QuantizationTransformPassV2, + QuantWeightPass, +) + +WEIGHT_QUANTIZATION_TYPES = [ + 'abs_max', + 'channel_wise_abs_max', + 'range_abs_max', + 'moving_average_abs_max', +] +WEIGHT_QUANTIZATION_TYPES_TENSORRT = ['channel_wise_abs_max'] + +ACTIVATION_QUANTIZATION_TYPES = [ + 'abs_max', + 'range_abs_max', + 'moving_average_abs_max', +] + +ACTIVATION_QUANTIZATION_TYPES_TENSORRT = [ + 'range_abs_max', + 'moving_average_abs_max', +] + +VALID_DTYPES = ['int8'] + +TRANSFORM_PASS_OP_TYPES = list( + quant_config.SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys() +) +QUANT_DEQUANT_PASS_OP_TYPES = list( + quant_config.SUPPORT_ACT_QUANTIZATION_OP_DICT.keys() +) + +TENSORRT_OP_TYPES = [ + 'mul', + 'conv2d', + 'pool2d', + 'depthwise_conv2d', + 'elementwise_add', + 'leaky_relu', +] + +VARS_MAPPING_TABLE = './mapping_table_for_saving_inference_model' + +_quant_config_default = { + # weight quantize type, default is 'channel_wise_abs_max' + 'weight_quantize_type': 'channel_wise_abs_max', + # activation quantize type, default is 'moving_average_abs_max' + 'activation_quantize_type': 'moving_average_abs_max', + # weight quantize bit num, default is 8 + 'weight_bits': 8, + # activation quantize bit num, default is 8 + 'activation_bits': 8, + # ops of name_scope in not_quant_pattern list, will not be quantized + 'not_quant_pattern': ['skip_quant'], + # ops of type in quantize_op_types, will be quantized + 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], + # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' + 'dtype': 'int8', + # window size for 'range_abs_max' quantization. defaulf is 10000 + 'window_size': 10000, + # The decay coefficient of moving average, default is 0.9 + 'moving_rate': 0.9, + # if True, 'quantize_op_types' will be TENSORRT_OP_TYPES + 'for_tensorrt': False, + # if True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES + 'is_full_quantize': False, + # if True, use onnx format to quant. + 'onnx_format': True, + # quant post to get initial scale for quant_aware + 'quant_post_first': False, + # whether scale can be train + 'scale_trainable': True, +} + + +def load_dict(): + with open(VARS_MAPPING_TABLE, 'r') as file: + data = file.read() + data = json.loads(data) + return data + + +def save_dict(table): + with open(VARS_MAPPING_TABLE, 'w') as file: + file.write(json.dumps(table)) + + +def _parse_configs(user_config): + """ + check if user's configs are valid. + Args: + user_config(dict): user's config. + Return: + configs(dict): final configs will be used. + """ + + configs = copy.deepcopy(_quant_config_default) + configs.update(user_config) + + assert isinstance(configs['for_tensorrt'], bool) and isinstance( + configs['is_full_quantize'], bool + ), "'for_tensorrt' and 'is_full_quantize' must both be bool'" + + # check if configs is valid + if configs['for_tensorrt']: + weight_types = WEIGHT_QUANTIZATION_TYPES_TENSORRT + activation_types = ACTIVATION_QUANTIZATION_TYPES_TENSORRT + platform = 'TensorRT' + else: + weight_types = WEIGHT_QUANTIZATION_TYPES + activation_types = WEIGHT_QUANTIZATION_TYPES + platform = 'PaddleLite' + assert ( + configs['weight_quantize_type'] in weight_types + ), "Unknown weight_quantize_type: {}. {} only supports {} ".format( + configs['weight_quantize_type'], platform, weight_types + ) + + assert ( + configs['activation_quantize_type'] in activation_types + ), "Unknown activation_quantize_type: {}. {} only supports {}".format( + configs['activation_quantize_type'], platform, activation_types + ) + + assert isinstance( + configs['weight_bits'], int + ), "weight_bits must be int value." + + assert ( + configs['weight_bits'] >= 1 and configs['weight_bits'] <= 16 + ), "weight_bits should be between 1 and 16." + + assert isinstance( + configs['activation_bits'], int + ), "activation_bits must be int value." + + assert ( + configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16 + ), "activation_bits should be between 1 and 16." + + assert isinstance( + configs['not_quant_pattern'], (list, str) + ), "not_quant_pattern must be list or str" + + assert isinstance( + configs['quantize_op_types'], list + ), "quantize_op_types must be a list" + + if configs['for_tensorrt']: + configs['quantize_op_types'] = TENSORRT_OP_TYPES + elif configs['is_full_quantize']: + configs['quantize_op_types'] = ( + TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES + ) + else: + for op_type in configs['quantize_op_types']: + assert (op_type in QUANT_DEQUANT_PASS_OP_TYPES) or ( + op_type in TRANSFORM_PASS_OP_TYPES + ), "{} is not support, \ + now support op types are {}".format( + op_type, TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES + ) + + assert isinstance(configs['dtype'], str), "dtype must be a str." + + assert configs['dtype'] in VALID_DTYPES, "dtype can only be " + " ".join( + VALID_DTYPES + ) + + assert isinstance( + configs['window_size'], int + ), "window_size must be int value, window size for 'range_abs_max' quantization, default is 10000." + + assert isinstance( + configs['moving_rate'], float + ), "moving_rate must be float value, The decay coefficient of moving average, default is 0.9." + + return configs + + +def quant_aware( + program, + place, + config=None, + scope=None, + for_test=False, + weight_quantize_func=None, + act_quantize_func=None, + weight_preprocess_func=None, + act_preprocess_func=None, + optimizer_func=None, + executor=None, + return_program=False, + calib_config={}, + draw_graph=False, + return_scale_dict=False, + scale_dict=None, + model_type=None, + pattern_ops=None, +): + """Add quantization and dequantization operators to "program" + for quantization training or testing. + Args: + program(paddle.static.Program): training or testing ``program``. + place(paddle.CPUPlace or paddle.CUDAPlace): This parameter represents + the executor run on which device. + config(dict, optional): configs for quantization. if None, will use default config. + Default: None. + scope(paddle.static.Scope): Scope records the mapping between variable names and variables, + similar to brackets in programming languages. Usually users can use + `paddle.static.global_scope `_. + When ``None`` will use `paddle.static.global_scope() `_ . + Default: ``None``. + for_test(bool): If the 'program' parameter is a test program, this parameter should be set to ``True``. + Otherwise, set to ``False``.Default: False + weight_quantize_func(function): Function that defines how to quantize weight. Using this + can quickly test if user's quantization method works or not. In this function, user should + both define quantization function and dequantization function, that is, the function's input + is non-quantized weight and function returns dequantized weight. If None, will use + quantization op defined by 'weight_quantize_type'. + Default is None. + act_quantize_func(function): Function that defines how to quantize activation. Using this + can quickly test if user's quantization method works or not. In this function, user should + both define quantization and dequantization process, that is, the function's input + is non-quantized activation and function returns dequantized activation. If None, will use + quantization op defined by 'activation_quantize_type'. + Default is None. + weight_preprocess_func(function): Function that defines how to preprocess weight before quantization. Using this + can quickly test if user's preprocess method works or not. The function's input + is non-quantized weight and function returns processed weight to be quantized. If None, the weight will + be quantized directly. + Default is None. + act_preprocess_func(function): Function that defines how to preprocess activation before quantization. Using this + can quickly test if user's preprocess method works or not. The function's input + is non-quantized activation and function returns processed activation to be quantized. If None, the activation will + be quantized directly. + Default is None. + optimizer_func(function): Fuction return a optimizer. When 'is_test' is False and user want to use self-defined + quantization function and preprocess function, this function must be set. Default is None. + exe(paddle.static.Executor): If user want to use self-defined quantization function and preprocess function, exe must be set for + initialization. Default is None. + return_program(bool): If user want return value is a Program rather than Compiled Program, This argument should be set True. + Default is False. + draw_graph(bool): whether to draw graph when quantization is initialized. In order to prevent cycle, + the ERNIE model needs to be set to True. Default is False. + return_scale_dict(bool): If user want to return scale dict, model_type and pattern_ops, this argument should be set True. + Default is False. + scale_dict(dict): Use scale dict to initialize scales in program. Default is None. + model_type(str): Model type can be 'transformer' or 'non-transformer'. If model type is transformer, patterns will be analyzed. + Default is None. + pattern_ops(dict): Pattern_ops contain pattern name and corresponding ops. Default is None. + Returns: + paddle.static.CompiledProgram | paddle.static.Program: Program with quantization and dequantization ``operators`` + """ + + scope = paddle.static.global_scope() if not scope else scope + if config is None: + config = _quant_config_default + else: + assert isinstance(config, dict), "config must be dict" + config = _parse_configs(config) + _logger.info(f"quant_aware config {config}") + + skip_tensor_list = [] + same_scale_tensor_list = [] + + is_test = True if for_test else not config['scale_trainable'] + if config['quant_post_first'] and for_test: + if 'quantizable_op_type' not in calib_config: + calib_config['quantizable_op_type'] = config['quantize_op_types'] + exe = paddle.static.Executor() if executor is None else executor + post_training_quantization = PostTrainingQuantizationProgram( + exe, + program, + freeze_model=False, + skip_tensor_list=skip_tensor_list, + same_scale_tensor_list=same_scale_tensor_list, + batch_nums=10, + scale_dict=scale_dict, + return_graph=True, + **calib_config, + ) + main_graph = post_training_quantization.quantize() + scale_dict = post_training_quantization._scale_dict + sub_graphs = list(main_graph.all_sub_graphs()) + else: + main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) + sub_graphs = list(main_graph.all_sub_graphs()) + transform_pass_ops = [] + quant_dequant_ops = [] + if 'quant_config' in config and config['quant_config']: + transform_pass_ops = config[ + 'quant_config' + ].weight_quant_operation_types + quant_dequant_ops = config[ + 'quant_config' + ].activation_quant_operation_types + else: + for op_type in config['quantize_op_types']: + if op_type in TRANSFORM_PASS_OP_TYPES: + transform_pass_ops.append(op_type) + elif op_type in QUANT_DEQUANT_PASS_OP_TYPES: + quant_dequant_ops.append(op_type) + if len(transform_pass_ops) > 0: + transform_func = ( + QuantizationTransformPassV2 + if config['onnx_format'] + else QuantizationTransformPass + ) + transform_pass = transform_func( + scope=scope, + place=place, + weight_bits=config['weight_bits'], + activation_bits=config['activation_bits'], + activation_quantize_type=config['activation_quantize_type'], + weight_quantize_type=config['weight_quantize_type'], + window_size=config['window_size'], + moving_rate=config['moving_rate'], + quantizable_op_type=transform_pass_ops, + skip_pattern=config['not_quant_pattern'], + weight_quantize_func=weight_quantize_func, + act_quantize_func=act_quantize_func, + weight_preprocess_func=weight_preprocess_func, + act_preprocess_func=act_preprocess_func, + optimizer_func=optimizer_func, + executor=executor, + is_test=is_test, + ) + + for sub_graph in sub_graphs: + transform_pass.apply(sub_graph) + + if len(quant_dequant_ops) > 0: + qdq_func = ( + AddQuantDequantPassV2 + if config['onnx_format'] + else AddQuantDequantPass + ) + quant_dequant_pass = qdq_func( + scope=scope, + place=place, + moving_rate=config['moving_rate'], + quant_bits=config['activation_bits'], + skip_pattern=config['not_quant_pattern'], + quantizable_op_type=quant_dequant_ops, + is_test=is_test, + ) + + for sub_graph in sub_graphs: + quant_dequant_pass.apply(sub_graph) + + out_scale_training_pass = OutScaleForTrainingPass( + scope=scope, + place=place, + moving_rate=config['moving_rate'], + is_test=is_test, + scale_dict=scale_dict, + ) + + for sub_graph in sub_graphs: + out_scale_training_pass.apply(sub_graph) + + if ( + (weight_preprocess_func is not None or act_preprocess_func is not None) + and not for_test + and not config['onnx_format'] + ): + _logger.info( + "When a preprocess_func is used in quant_aware, Need to save a mapping table to match variable names in the convert phase." + ) + _logger.info(f"The mapping table is saved as '{VARS_MAPPING_TABLE}'.") + for sub_graph in sub_graphs: + save_dict(sub_graph.out_node_mapping_table) + + # TDOD: remove it. + if draw_graph: + main_graph.draw('./', 'graph.pdf') + + if for_test or return_program: + quant_program = main_graph.to_program() + else: + quant_program = paddle.static.CompiledProgram(main_graph.graph) + + if return_scale_dict: + return quant_program, scale_dict, model_type, pattern_ops + else: + return quant_program + + +def convert(program, place, config=None, scope=None, save_int8=False): + """ + convert quantized and well-trained ``program`` to final quantized + ``program``that can be used to save ``inference model``. + + Args: + program(paddle.static.Program): quantized and well-trained ``test program``. + place(paddle.CPUPlace or paddle.CUDAPlace): This parameter represents + the executor run on which device. + config(dict, optional): configs for convert. if set None, will use + default config. It must be same with config that used in + 'quant_aware'. Default is None. + scope(paddle.static.Scope, optional): Scope records the mapping between + variable names and variables, similar to brackets in + programming languages. Usually users can use + `paddle.static.global_scope `_. + When ``None`` will use + `paddle.static.global_scope() `_ + . Default: ``None``. + save_int8: Whether to return ``program`` which model parameters' + dtype is ``int8``. This parameter can only be used to + get model size. Default: ``False``. + Returns: + Tuple : freezed program which can be used for inference. + when ``save_int8`` is False, return ``freezed_program(paddle.static.Program)``. + when ``save_int8`` is True, return ``freezed_program(paddle.static.Program)`` + and ``freezed_program_int8(paddle.static.Program)`` + """ + scope = paddle.static.global_scope() if not scope else scope + + if config is None: + config = _quant_config_default + else: + assert isinstance(config, dict), "config must be dict" + config = _parse_configs(config) + _logger.info(f"convert config {config}") + test_graph = IrGraph(core.Graph(program.desc), for_test=True) + + if config['onnx_format']: + quant_weight_pass = QuantWeightPass(scope, place) + for sub_graph in test_graph.all_sub_graphs(): + quant_weight_pass.apply(sub_graph) + out_scale_infer_pass = AddQuantDequantForInferencePass( + scope=scope, place=place, quant_bits=config['activation_bits'] + ) + for sub_graph in test_graph.all_sub_graphs(): + out_scale_infer_pass.apply(sub_graph) + else: + out_scale_infer_pass = OutScaleForInferencePass(scope=scope) + for sub_graph in test_graph.all_sub_graphs(): + out_scale_infer_pass.apply(sub_graph) + # Freeze the graph after training by adjusting the quantize + # operators' order for the inference. + freeze_pass = QuantizationFreezePass( + scope=scope, + place=place, + weight_bits=config['weight_bits'], + activation_bits=config['activation_bits'], + weight_quantize_type=config['weight_quantize_type'], + ) + if os.path.exists(VARS_MAPPING_TABLE): + test_graph.out_node_mapping_table = load_dict() + for sub_graph in test_graph.all_sub_graphs(): + freeze_pass.apply(sub_graph) + + freezed_program = test_graph.to_program() + + # Move sub blocks persistable var to global block + global_block = freezed_program.global_block() + for _op in global_block.ops: + if _op.type == "while": + _block_id = _op.attr("sub_block").id + _block = freezed_program.block(_block_id) + persistables = [] + for _name, _var in _block.vars.items(): + if _var.persistable: + global_block._clone_variable(_var) + persistables.append(_name) + for _name in persistables: + _block._remove_var(_name) + persistables.extend(_op.input('X')) + _op.desc.set_input("X", persistables) + + assert not ( + save_int8 and config['onnx_format'] + ), "When onnx_format=True, already saved int8 weight,so you can't set save_int8=True." + if save_int8: + convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place) + for sub_graph in test_graph.all_sub_graphs(): + convert_int8_pass.apply(sub_graph) + freezed_program_int8 = test_graph.to_program() + return freezed_program, freezed_program_int8 + else: + return freezed_program diff --git a/test/quantization/CMakeLists.txt b/test/quantization/CMakeLists.txt index eb17fac27e069307475c88bdac84b327e03fec39..57a0d7c20007251fcb817f1da35594a2e83fa49f 100644 --- a/test/quantization/CMakeLists.txt +++ b/test/quantization/CMakeLists.txt @@ -227,6 +227,10 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_imperative_qat_amp) list(REMOVE_ITEM TEST_OPS test_imperative_qat_lsq) list(REMOVE_ITEM TEST_OPS test_imperative_qat_matmul) + list(REMOVE_ITEM TEST_OPS test_quant_aware) + list(REMOVE_ITEM TEST_OPS test_quant_post_quant_aware) + list(REMOVE_ITEM TEST_OPS test_quant_aware_user_defined) + list(REMOVE_ITEM TEST_OPS test_quant_aware_config) endif() @@ -484,6 +488,10 @@ if(NOT WIN32) set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 120) set_tests_properties(test_weight_quantization_mobilenetv1 PROPERTIES TIMEOUT 120) + set_tests_properties(test_quant_aware PROPERTIES TIMEOUT 900) + set_tests_properties(test_quant_post_quant_aware PROPERTIES TIMEOUT 900) + set_tests_properties(test_quant_aware_user_defined PROPERTIES TIMEOUT 900) + set_tests_properties(test_quant_aware_config PROPERTIES TIMEOUT 900) endif() set_tests_properties(test_graph PROPERTIES TIMEOUT 120) diff --git a/test/quantization/test_quant_aware.py b/test/quantization/test_quant_aware.py new file mode 100644 index 0000000000000000000000000000000000000000..775ab656b49e7dd789cdf38e42f41fc8dec668fd --- /dev/null +++ b/test/quantization/test_quant_aware.py @@ -0,0 +1,413 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import numpy as np + +import paddle +from paddle.nn.initializer import KaimingUniform +from paddle.static.quantization.quanter import convert, quant_aware + +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [10, 16, 30], + "steps": [0.1, 0.01, 0.001, 0.0001], + }, +} + + +class MobileNet: + def __init__(self): + self.params = train_parameters + + def net(self, input, class_dim=1000, scale=1.0): + # conv1: 112x112 + input = self.conv_bn_layer( + input, + filter_size=3, + channels=3, + num_filters=int(32 * scale), + stride=2, + padding=1, + name="conv1", + ) + + # 56x56 + input = self.depthwise_separable( + input, + num_filters1=32, + num_filters2=64, + num_groups=32, + stride=1, + scale=scale, + name="conv2_1", + ) + + input = self.depthwise_separable( + input, + num_filters1=64, + num_filters2=128, + num_groups=64, + stride=2, + scale=scale, + name="conv2_2", + ) + + # 28x28 + input = self.depthwise_separable( + input, + num_filters1=128, + num_filters2=128, + num_groups=128, + stride=1, + scale=scale, + name="conv3_1", + ) + + input = self.depthwise_separable( + input, + num_filters1=128, + num_filters2=256, + num_groups=128, + stride=2, + scale=scale, + name="conv3_2", + ) + + # 14x14 + input = self.depthwise_separable( + input, + num_filters1=256, + num_filters2=256, + num_groups=256, + stride=1, + scale=scale, + name="conv4_1", + ) + + input = self.depthwise_separable( + input, + num_filters1=256, + num_filters2=512, + num_groups=256, + stride=2, + scale=scale, + name="conv4_2", + ) + + # 14x14 + for i in range(5): + input = self.depthwise_separable( + input, + num_filters1=512, + num_filters2=512, + num_groups=512, + stride=1, + scale=scale, + name="conv5" + "_" + str(i + 1), + ) + # 7x7 + input = self.depthwise_separable( + input, + num_filters1=512, + num_filters2=1024, + num_groups=512, + stride=2, + scale=scale, + name="conv5_6", + ) + + input = self.depthwise_separable( + input, + num_filters1=1024, + num_filters2=1024, + num_groups=1024, + stride=1, + scale=scale, + name="conv6", + ) + + input = paddle.nn.functional.adaptive_avg_pool2d(input, 1) + with paddle.static.name_scope('last_fc'): + output = paddle.static.nn.fc( + input, + class_dim, + weight_attr=paddle.ParamAttr( + initializer=KaimingUniform(), name="fc7_weights" + ), + bias_attr=paddle.ParamAttr(name="fc7_offset"), + ) + + return output + + def conv_bn_layer( + self, + input, + filter_size, + num_filters, + stride, + padding, + channels=None, + num_groups=1, + act='relu', + use_cudnn=True, + name=None, + ): + conv = paddle.static.nn.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=paddle.ParamAttr( + initializer=KaimingUniform(), name=name + "_weights" + ), + bias_attr=False, + ) + bn_name = name + "_bn" + return paddle.static.nn.batch_norm( + input=conv, + act=act, + param_attr=paddle.ParamAttr(name=bn_name + "_scale"), + bias_attr=paddle.ParamAttr(name=bn_name + "_offset"), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', + ) + + def depthwise_separable( + self, + input, + num_filters1, + num_filters2, + num_groups, + stride, + scale, + name=None, + ): + depthwise_conv = self.conv_bn_layer( + input=input, + filter_size=3, + num_filters=int(num_filters1 * scale), + stride=stride, + padding=1, + num_groups=int(num_groups * scale), + use_cudnn=False, + name=name + "_dw", + ) + + pointwise_conv = self.conv_bn_layer( + input=depthwise_conv, + filter_size=1, + num_filters=int(num_filters2 * scale), + stride=1, + padding=0, + name=name + "_sep", + ) + return pointwise_conv + + +class StaticCase(unittest.TestCase): + def setUp(self): + # switch mode + paddle.enable_static() + + +class TestQuantAwareCase(StaticCase): + def test_accuracy(self): + image = paddle.static.data( + name='image', shape=[None, 1, 28, 28], dtype='float32' + ) + label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + model = MobileNet() + out = model.net(input=image, class_dim=10) + cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label) + avg_cost = paddle.mean(x=cost) + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) + optimizer = paddle.optimizer.Momentum( + momentum=0.9, + learning_rate=0.01, + weight_decay=paddle.regularizer.L2Decay(4e-5), + ) + optimizer.minimize(avg_cost) + main_prog = paddle.static.default_main_program() + val_prog = paddle.static.default_main_program().clone(for_test=True) + + place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + + def transform(x): + return np.reshape(x, [1, 28, 28]) + + train_dataset = paddle.vision.datasets.MNIST( + mode='train', backend='cv2', transform=transform + ) + test_dataset = paddle.vision.datasets.MNIST( + mode='test', backend='cv2', transform=transform + ) + batch_size = 64 if os.environ.get('DATASET') == 'full' else 8 + train_loader = paddle.io.DataLoader( + train_dataset, + places=place, + feed_list=[image, label], + drop_last=True, + return_list=False, + batch_size=batch_size, + ) + valid_loader = paddle.io.DataLoader( + test_dataset, + places=place, + feed_list=[image, label], + batch_size=batch_size, + return_list=False, + ) + + def train(program): + iter = 0 + stop_iter = None if os.environ.get('DATASET') == 'full' else 10 + for data in train_loader(): + cost, top1, top5 = exe.run( + program, + feed=data, + fetch_list=[avg_cost, acc_top1, acc_top5], + ) + iter += 1 + if iter % 100 == 0: + print( + 'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format( + iter, cost, top1, top5 + ) + ) + if stop_iter is not None and iter == stop_iter: + break + + def test(program): + iter = 0 + stop_iter = None if os.environ.get('DATASET') == 'full' else 10 + result = [[], [], []] + for data in valid_loader(): + cost, top1, top5 = exe.run( + program, + feed=data, + fetch_list=[avg_cost, acc_top1, acc_top5], + ) + iter += 1 + if iter % 100 == 0: + print( + 'eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format( + iter, cost, top1, top5 + ) + ) + result[0].append(cost) + result[1].append(top1) + result[2].append(top5) + if stop_iter is not None and iter == stop_iter: + break + print( + ' avg loss {}, acc_top1 {}, acc_top5 {}'.format( + np.mean(result[0]), np.mean(result[1]), np.mean(result[2]) + ) + ) + return np.mean(result[1]), np.mean(result[2]) + + train(main_prog) + top1_1, top5_1 = test(main_prog) + + config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'], + } + quant_train_prog = quant_aware(main_prog, place, config, for_test=False) + quant_eval_prog = quant_aware(val_prog, place, config, for_test=True) + op_nums_1, quant_op_nums_1 = self.get_op_number(quant_eval_prog) + # test quant_aware op numbers + self.assertEqual(op_nums_1 * 2, quant_op_nums_1) + + train(quant_train_prog) + convert_eval_prog = convert(quant_eval_prog, place, config) + + top1_2, top5_2 = test(convert_eval_prog) + # values before quantization and after quantization should be close + print(f"before quantization: top1: {top1_1}, top5: {top5_1}") + print(f"after quantization: top1: {top1_2}, top5: {top5_2}") + + convert_op_nums_1, convert_quant_op_nums_1 = self.get_convert_op_number( + convert_eval_prog + ) + # test convert op numbers + self.assertEqual(convert_op_nums_1 + 25, convert_quant_op_nums_1) + + config['not_quant_pattern'] = ['last_fc'] + quant_prog_2 = quant_aware( + main_prog, place, config=config, for_test=True + ) + op_nums_2, quant_op_nums_2 = self.get_op_number(quant_prog_2) + convert_prog_2 = convert(quant_prog_2, place, config=config) + convert_op_nums_2, convert_quant_op_nums_2 = self.get_convert_op_number( + convert_prog_2 + ) + + self.assertEqual(op_nums_1, op_nums_2) + # test skip_quant + self.assertEqual(quant_op_nums_1 - 2, quant_op_nums_2) + + # The following assert will fail and is waiting for investigation. + # self.assertEqual(convert_quant_op_nums_1, convert_quant_op_nums_2) + + def get_op_number(self, prog): + graph = paddle.fluid.framework.IrGraph( + paddle.framework.core.Graph(prog.desc), for_test=False + ) + quant_op_nums = 0 + op_nums = 0 + for op in graph.all_op_nodes(): + if op.name() in ['conv2d', 'depthwise_conv2d', 'mul']: + op_nums += 1 + elif op.name() == 'quantize_linear': + quant_op_nums += 1 + return op_nums, quant_op_nums + + def get_convert_op_number(self, prog): + graph = paddle.fluid.framework.IrGraph( + paddle.framework.core.Graph(prog.desc), for_test=True + ) + quant_op_nums = 0 + op_nums = 0 + dequant_num = 0 + for op in graph.all_op_nodes(): + if op.name() not in ['quantize_linear', 'dequantize_linear']: + op_nums += 1 + elif op.name() == 'quantize_linear': + quant_op_nums += 1 + return op_nums, quant_op_nums + + +if __name__ == '__main__': + unittest.main() diff --git a/test/quantization/test_quant_aware_config.py b/test/quantization/test_quant_aware_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0b73571a977bc81d5b4ff48e7351d5e18bee54bc --- /dev/null +++ b/test/quantization/test_quant_aware_config.py @@ -0,0 +1,216 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest + +import numpy as np +from test_quant_aware import MobileNet + +import paddle +from paddle.static.quantization.quanter import convert, quant_aware + + +class TestQuantAwareBase(unittest.TestCase): + def setUp(self): + paddle.enable_static() + + def get_save_int8(self): + return False + + def generate_config(self): + config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'], + 'onnx_format': False, + } + return config + + def test_accuracy(self): + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog): + image = paddle.static.data( + name='image', shape=[None, 1, 28, 28], dtype='float32' + ) + label = paddle.static.data( + name='label', shape=[None, 1], dtype='int64' + ) + model = MobileNet() + out = model.net(input=image, class_dim=10) + cost = paddle.nn.functional.loss.cross_entropy( + input=out, label=label + ) + avg_cost = paddle.mean(x=cost) + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) + optimizer = paddle.optimizer.Momentum( + momentum=0.9, + learning_rate=0.01, + weight_decay=paddle.regularizer.L2Decay(4e-5), + ) + optimizer.minimize(avg_cost) + val_prog = main_prog.clone(for_test=True) + + place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + + def transform(x): + return np.reshape(x, [1, 28, 28]) + + train_dataset = paddle.vision.datasets.MNIST( + mode='train', backend='cv2', transform=transform + ) + test_dataset = paddle.vision.datasets.MNIST( + mode='test', backend='cv2', transform=transform + ) + batch_size = 64 if os.environ.get('DATASET') == 'full' else 8 + train_loader = paddle.io.DataLoader( + train_dataset, + places=place, + feed_list=[image, label], + drop_last=True, + return_list=False, + batch_size=batch_size, + ) + valid_loader = paddle.io.DataLoader( + test_dataset, + places=place, + feed_list=[image, label], + batch_size=batch_size, + return_list=False, + ) + + def train(program): + iter = 0 + stop_iter = None if os.environ.get('DATASET') == 'full' else 10 + for data in train_loader(): + cost, top1, top5 = exe.run( + program, + feed=data, + fetch_list=[avg_cost, acc_top1, acc_top5], + ) + iter += 1 + if iter % 100 == 0: + print( + 'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format( + iter, cost, top1, top5 + ) + ) + if stop_iter is not None and iter == stop_iter: + break + + def test(program): + iter = 0 + stop_iter = None if os.environ.get('DATASET') == 'full' else 10 + result = [[], [], []] + for data in valid_loader(): + cost, top1, top5 = exe.run( + program, + feed=data, + fetch_list=[avg_cost, acc_top1, acc_top5], + ) + iter += 1 + if iter % 100 == 0: + print( + 'eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format( + iter, cost, top1, top5 + ) + ) + result[0].append(cost) + result[1].append(top1) + result[2].append(top5) + if stop_iter is not None and iter == stop_iter: + break + print( + ' avg loss {}, acc_top1 {}, acc_top5 {}'.format( + np.mean(result[0]), np.mean(result[1]), np.mean(result[2]) + ) + ) + return np.mean(result[1]), np.mean(result[2]) + + train(main_prog) + top1_1, top5_1 = test(main_prog) + + config = self.generate_config() + quant_train_prog = quant_aware(main_prog, place, config, for_test=False) + quant_eval_prog = quant_aware(val_prog, place, config, for_test=True) + + train(quant_train_prog) + save_int8 = self.get_save_int8() + if save_int8: + convert_eval_prog, _ = convert( + quant_eval_prog, place, config, save_int8=save_int8 + ) + else: + convert_eval_prog = convert( + quant_eval_prog, place, config, save_int8=save_int8 + ) + + top1_2, top5_2 = test(convert_eval_prog) + # values before quantization and after quantization should be close + print(f"before quantization: top1: {top1_1}, top5: {top5_1}") + print(f"after quantization: top1: {top1_2}, top5: {top5_2}") + + +class TestQuantAwareNone(TestQuantAwareBase): + def generate_config(self): + config = None + return config + + +class TestQuantAwareTRT(TestQuantAwareBase): + def generate_config(self): + config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'], + 'onnx_format': False, + 'for_tensorrt': True, + } + return config + + +class TestQuantAwareFullQuantize(TestQuantAwareBase): + def generate_config(self): + config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'], + 'onnx_format': False, + 'is_full_quantize': True, + } + return config + + +class TestQuantAwareSaveInt8(TestQuantAwareBase): + def generate_config(self): + config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'], + 'onnx_format': False, + } + return config + + def get_save_int8(self): + return True + + +if __name__ == '__main__': + unittest.main() diff --git a/test/quantization/test_quant_aware_user_defined.py b/test/quantization/test_quant_aware_user_defined.py new file mode 100644 index 0000000000000000000000000000000000000000..cb1c25a5fed3ec59b085b99da5c9061bb94db530 --- /dev/null +++ b/test/quantization/test_quant_aware_user_defined.py @@ -0,0 +1,192 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest + +import numpy as np +from test_quant_aware import MobileNet, StaticCase + +import paddle +from paddle.static.quantization.quanter import convert, quant_aware + + +def pact(x): + helper = paddle.fluid.layer_helper.LayerHelper("pact", **locals()) + dtype = 'float32' + init_thres = 20 + u_param_attr = paddle.ParamAttr( + name=x.name + '_pact', + initializer=paddle.nn.initializer.Constant(value=init_thres), + regularizer=paddle.regularizer.L2Decay(0.0001), + learning_rate=1, + ) + u_param = helper.create_parameter(attr=u_param_attr, shape=[1], dtype=dtype) + + part_a = paddle.nn.functional.relu(x - u_param) + part_b = paddle.nn.functional.relu(-u_param - x) + x = x - part_a + part_b + return x + + +def get_optimizer(): + return paddle.optimizer.Momentum(0.0001, 0.9) + + +class TestQuantAwareCase1(StaticCase): + def get_model(self): + image = paddle.static.data( + name='image', shape=[None, 1, 28, 28], dtype='float32' + ) + label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + model = MobileNet() + out = model.net(input=image, class_dim=10) + cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label) + avg_cost = paddle.mean(x=cost) + startup_prog = paddle.static.default_startup_program() + train_prog = paddle.static.default_main_program() + return startup_prog, train_prog + + def test_accuracy(self): + image = paddle.static.data( + name='image', shape=[None, 1, 28, 28], dtype='float32' + ) + image.stop_gradient = False + label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + model = MobileNet() + out = model.net(input=image, class_dim=10) + cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label) + avg_cost = paddle.mean(x=cost) + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) + optimizer = paddle.optimizer.Momentum( + momentum=0.9, + learning_rate=0.01, + weight_decay=paddle.regularizer.L2Decay(4e-5), + ) + optimizer.minimize(avg_cost) + main_prog = paddle.static.default_main_program() + val_prog = main_prog.clone(for_test=True) + + place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + + def transform(x): + return np.reshape(x, [1, 28, 28]) + + train_dataset = paddle.vision.datasets.MNIST( + mode='train', backend='cv2', transform=transform + ) + test_dataset = paddle.vision.datasets.MNIST( + mode='test', backend='cv2', transform=transform + ) + batch_size = 64 if os.environ.get('DATASET') == 'full' else 8 + train_loader = paddle.io.DataLoader( + train_dataset, + places=place, + feed_list=[image, label], + drop_last=True, + return_list=False, + batch_size=batch_size, + ) + valid_loader = paddle.io.DataLoader( + test_dataset, + places=place, + feed_list=[image, label], + batch_size=batch_size, + return_list=False, + ) + + def train(program): + iter = 0 + stop_iter = None if os.environ.get('DATASET') == 'full' else 10 + for data in train_loader(): + cost, top1, top5 = exe.run( + program, + feed=data, + fetch_list=[avg_cost, acc_top1, acc_top5], + ) + iter += 1 + if iter % 100 == 0: + print( + 'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format( + iter, cost, top1, top5 + ) + ) + if stop_iter is not None and iter == stop_iter: + break + + def test(program): + iter = 0 + stop_iter = None if os.environ.get('DATASET') == 'full' else 10 + result = [[], [], []] + for data in valid_loader(): + cost, top1, top5 = exe.run( + program, + feed=data, + fetch_list=[avg_cost, acc_top1, acc_top5], + ) + iter += 1 + if iter % 100 == 0: + print( + 'eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'.format( + iter, cost, top1, top5 + ) + ) + result[0].append(cost) + result[1].append(top1) + result[2].append(top5) + if stop_iter is not None and iter == stop_iter: + break + print( + ' avg loss {}, acc_top1 {}, acc_top5 {}'.format( + np.mean(result[0]), np.mean(result[1]), np.mean(result[2]) + ) + ) + return np.mean(result[1]), np.mean(result[2]) + + train(main_prog) + top1_1, top5_1 = test(main_prog) + + config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'], + 'onnx_format': False, + } + quant_train_prog_pact = quant_aware( + main_prog, + place, + config, + for_test=False, + act_preprocess_func=pact, + optimizer_func=get_optimizer, + executor=exe, + ) + + quant_eval_prog = quant_aware(val_prog, place, config, for_test=True) + train(quant_train_prog_pact) + quant_eval_prog = convert(quant_eval_prog, place, config) + top1_2, top5_2 = test(quant_eval_prog) + # values before quantization and after quantization should be close + print(f"before quantization: top1: {top1_1}, top5: {top5_1}") + print(f"after quantization: top1: {top1_2}, top5: {top5_2}") + + +if __name__ == '__main__': + unittest.main() diff --git a/test/quantization/test_quant_post_quant_aware.py b/test/quantization/test_quant_post_quant_aware.py new file mode 100644 index 0000000000000000000000000000000000000000..0983e86732cf2ab60a3c8fca851814adde7853d2 --- /dev/null +++ b/test/quantization/test_quant_post_quant_aware.py @@ -0,0 +1,188 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np +from test_quant_aware import StaticCase + +import paddle +from paddle.static.quantization.quanter import convert, quant_aware + +np.random.seed(0) +random.seed(0) +paddle.seed(0) + + +class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + enc_input = np.random.random([4, 128]).astype('float32') + attn_mask = np.random.random([2, 4, 4]).astype('float32') + label = np.random.randint(0, 2, (1,)).astype('int64') + return enc_input, attn_mask, label + + def __len__(self): + return self.num_samples + + +class TestQuantPostQuantAwareCase1(StaticCase): + def test_accuracy(self): + def simple_transformer(enc_input, attn_mask): + encoder_layer = paddle.nn.TransformerEncoderLayer(128, 2, 512) + encoder = paddle.nn.TransformerEncoder(encoder_layer, 2) + encoder_output = encoder(enc_input, attn_mask) + first_token = encoder_output[:, 0] + bias = paddle.full(shape=[1, 128], fill_value=1e-6) + linear = paddle.nn.Linear(128, 2) + logits = linear(first_token + bias) + return logits + + enc_input = paddle.static.data( + name='enc_input', shape=[None, 4, 128], dtype='float32' + ) + attn_mask = paddle.static.data( + name='attn_mask', shape=[None, 2, 4, 4], dtype='float32' + ) + label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + out = simple_transformer(enc_input, attn_mask) + cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label) + avg_cost = paddle.mean(x=cost) + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + optimizer = paddle.optimizer.Momentum( + momentum=0.9, + learning_rate=0.01, + weight_decay=paddle.regularizer.L2Decay(4e-5), + ) + optimizer.minimize(avg_cost) + main_prog = paddle.static.default_main_program() + val_prog = main_prog.clone(for_test=True) + + place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + + train_dataset = RandomDataset(100) + test_dataset = RandomDataset(50) + train_loader = paddle.io.DataLoader( + train_dataset, + places=place, + feed_list=[enc_input, attn_mask, label], + drop_last=True, + return_list=False, + batch_size=10, + ) + valid_loader = paddle.io.DataLoader( + test_dataset, + places=place, + feed_list=[enc_input, attn_mask, label], + batch_size=10, + return_list=False, + ) + + def train(program): + iter = 0 + for data in train_loader(): + cost, top1 = exe.run( + program, feed=data, fetch_list=[avg_cost, acc_top1] + ) + iter += 1 + if iter % 100 == 0: + print( + 'train iter={}, avg loss {}, acc_top1 {}'.format( + iter, cost, top1 + ) + ) + + def test(program): + iter = 0 + result = [[], []] + for data in valid_loader(): + cost, top1 = exe.run( + program, feed=data, fetch_list=[avg_cost, acc_top1] + ) + iter += 1 + if iter % 100 == 0: + print( + 'eval iter={}, avg loss {}, acc_top1 {}'.format( + iter, cost, top1 + ) + ) + result[0].append(cost) + result[1].append(top1) + print( + ' avg loss {}, acc_top1 {}'.format( + np.mean(result[0]), np.mean(result[1]) + ) + ) + return np.mean(result[1]) + + train(main_prog) + top1_1 = test(main_prog) + + config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantize_op_types': [ + 'conv2d', + 'depthwise_conv2d', + 'mul', + 'matmul', + 'elementwise_add', + ], + 'quant_post_first': True, + 'scale_trainable': True, + } + calib_config = { + 'data_loader': valid_loader, + 'algo': 'abs_max', + 'feed_list': ['enc_input', 'attn_mask', 'label'], + 'fetch_list': [avg_cost, acc_top1], + } + quant_eval_prog, scale_dict, _, _ = quant_aware( + val_prog, + place, + config, + for_test=True, + calib_config=calib_config, + model_type='transformer', + return_scale_dict=True, + ) + quant_train_prog = quant_aware( + main_prog, + place, + config, + for_test=False, + calib_config=calib_config, + return_program=True, + scale_dict=scale_dict, + model_type='transformer', + ) + train(quant_train_prog) + quant_eval_prog = convert(quant_eval_prog, place, config) + top1_2 = test(quant_eval_prog) + # values before quantization and after quantization should be close + print(f"before quantization: top1: {top1_1}") + print(f"after quantization: top1: {top1_2}") + + +if __name__ == '__main__': + unittest.main()