未验证 提交 be558e8d 编写于 作者: G Guanghua Yu 提交者: GitHub

Add deploy_backend in PTQ and QAT & update to the default onnx new format (#1601)

上级 e682f51b
...@@ -890,8 +890,7 @@ class AutoCompression: ...@@ -890,8 +890,7 @@ class AutoCompression:
test_program, test_program,
self._places, self._places,
self._quant_config, self._quant_config,
scope=paddle.static.global_scope(), scope=paddle.static.global_scope())
save_clip_ranges_path=self.final_dir)
feed_vars = [ feed_vars = [
test_program.global_block().var(name) test_program.global_block().var(name)
......
...@@ -3,14 +3,17 @@ import paddle ...@@ -3,14 +3,17 @@ import paddle
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.framework import core from paddle.framework import core
from paddle.static.quantization import QuantizationTransformPass, QuantizationTransformPassV2, AddQuantDequantPass, AddQuantDequantPassV2, QuantizationFreezePass, QuantWeightPass from paddle.static.quantization import QuantizationTransformPass, QuantizationTransformPassV2, AddQuantDequantPass, AddQuantDequantPassV2, QuantizationFreezePass, QuantWeightPass
from paddle.static.quantization import utils
try: try:
from paddle.static.quantization import utils from paddle.static.quantization import quant_config
TRANSFORM_PASS_OP_TYPES = list(
quant_config.SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys())
QUANT_DEQUANT_PASS_OP_TYPES = list(
quant_config.SUPPORT_ACT_QUANTIZATION_OP_DICT.keys())
except:
TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type
except:
TRANSFORM_PASS_OP_TYPES = QuantizationTransformPass._supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = AddQuantDequantPass._supported_quantizable_op_type
from ...common.load_model import load_inference_model from ...common.load_model import load_inference_model
...@@ -155,7 +158,9 @@ def post_quant_fake(executor, ...@@ -155,7 +158,9 @@ def post_quant_fake(executor,
for block_id in range(len(_program.blocks)): for block_id in range(len(_program.blocks)):
for op in _program.blocks[block_id].ops: for op in _program.blocks[block_id].ops:
if op.type in (_quantizable_op_type + utils._out_scale_op_list): if op.type in (
_quantizable_op_type +
list(quant_config.SUPPORT_QUANTIZATION_OP_DICT.keys())):
out_var_names = utils._get_op_output_var_names(op) out_var_names = utils._get_op_output_var_names(op)
for var_name in out_var_names: for var_name in out_var_names:
analysis_and_save_info(op, var_name) analysis_and_save_info(op, var_name)
......
...@@ -57,8 +57,8 @@ _quant_config_default = { ...@@ -57,8 +57,8 @@ _quant_config_default = {
'quantizable_layer_type': ['Conv2D', 'Linear'], 'quantizable_layer_type': ['Conv2D', 'Linear'],
# whether fuse conv and bn before QAT # whether fuse conv and bn before QAT
'fuse_conv_bn': False, 'fuse_conv_bn': False,
# Whether to export the quantized model with format of ONNX. Default is False. # Whether to export the quantized model with format of ONNX. Default is True.
'onnx_format': False, 'onnx_format': True,
} }
......
...@@ -311,9 +311,7 @@ def export_quant_infermodel( ...@@ -311,9 +311,7 @@ def export_quant_infermodel(
# operators' order for the inference. # operators' order for the inference.
# The dtype of float_program's weights is float32, but in int8 range. # The dtype of float_program's weights is float32, but in int8 range.
############################################################################################################ ############################################################################################################
float_program, int8_program = convert(test_program, place, quant_config, \ float_program = convert(test_program, place, quant_config, scope=scope)
scope=scope, \
save_int8=True)
############################################################################################################ ############################################################################################################
# 4. Save inference model # 4. Save inference model
############################################################################################################ ############################################################################################################
......
...@@ -41,9 +41,10 @@ try: ...@@ -41,9 +41,10 @@ try:
from paddle.static.quantization import AddQuantDequantPassV2 from paddle.static.quantization import AddQuantDequantPassV2
from paddle.static.quantization import PostTrainingQuantizationProgram from paddle.static.quantization import PostTrainingQuantizationProgram
from paddle.static.quantization import AddQuantDequantForInferencePass from paddle.static.quantization import AddQuantDequantForInferencePass
from paddle.static.quantization import quant_config
except: except:
_logger.warning( _logger.warning(
"Some functions fail to import, please update PaddlePaddle version to 2.4+" "Some functions failed to import, better to update PaddlePaddle to the latest develop version."
) )
WEIGHT_QUANTIZATION_TYPES = [ WEIGHT_QUANTIZATION_TYPES = [
...@@ -61,12 +62,14 @@ ACTIVATION_QUANTIZATION_TYPES_TENSORRT = [ ...@@ -61,12 +62,14 @@ ACTIVATION_QUANTIZATION_TYPES_TENSORRT = [
VALID_DTYPES = ['int8'] VALID_DTYPES = ['int8']
try: try:
TRANSFORM_PASS_OP_TYPES = list(
quant_config.SUPPORT_WEIGHT_QUANTIZATION_OP_DICT.keys())
QUANT_DEQUANT_PASS_OP_TYPES = list(
quant_config.SUPPORT_ACT_QUANTIZATION_OP_DICT.keys())
except:
from paddle.static.quantization import utils from paddle.static.quantization import utils
TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type
except:
TRANSFORM_PASS_OP_TYPES = QuantizationTransformPass._supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = AddQuantDequantPass._supported_quantizable_op_type
TENSORRT_OP_TYPES = [ TENSORRT_OP_TYPES = [
'mul', 'conv2d', 'pool2d', 'depthwise_conv2d', 'elementwise_add', 'mul', 'conv2d', 'pool2d', 'depthwise_conv2d', 'elementwise_add',
...@@ -99,11 +102,13 @@ _quant_config_default = { ...@@ -99,11 +102,13 @@ _quant_config_default = {
# if True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES # if True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES
'is_full_quantize': False, 'is_full_quantize': False,
# if True, use onnx format to quant. # if True, use onnx format to quant.
'onnx_format': False, 'onnx_format': True,
# quant post to get initial scale for quant_aware # quant post to get initial scale for quant_aware
'quant_post_first': False, 'quant_post_first': False,
# whether scale can be train # whether scale can be train
'scale_trainable': True 'scale_trainable': True,
# Deploy backend, it could be: None, TensorRT, MKLDNN, ARM
'deploy_backend': None
} }
...@@ -195,6 +200,32 @@ def _parse_configs(user_config): ...@@ -195,6 +200,32 @@ def _parse_configs(user_config):
assert isinstance(configs['moving_rate'], float), \ assert isinstance(configs['moving_rate'], float), \
"moving_rate must be float value, The decay coefficient of moving average, default is 0.9." "moving_rate must be float value, The decay coefficient of moving average, default is 0.9."
deploy_backend = configs['deploy_backend']
assert not deploy_backend or deploy_backend.lower() in [
'tensorrt', 'mkldnn', 'arm'
], "Deploy Backend {} not support, please choose None, tensorrt or mkldnn.".format(
deploy_backend)
try:
if not deploy_backend:
configs['quant_config'] = quant_config.BaseQuantizer(
quantizable_op_type=configs['quantize_op_types'],
quant_bits=configs['weight_bits'], )
elif deploy_backend.lower() == "tensorrt":
configs['quant_config'] = quant_config.TensorRTQuantizer(
quantizable_op_type=configs['quantize_op_types'],
quant_bits=configs['weight_bits'], )
elif deploy_backend.lower() == "mkldnn":
configs['quant_config'] = quant_config.MKLDNNQuantizer(
quantizable_op_type=configs['quantize_op_types'],
quant_bits=configs['weight_bits'], )
elif deploy_backend.lower() == "arm":
configs['quant_config'] = quant_config.ARMCPUQuantizer(
quantizable_op_type=configs['quantize_op_types'],
quant_bits=configs['weight_bits'], )
except:
_logger.warning(
"Set deploy_backend failed, Please update to PaddlePaddle Develop.")
return configs return configs
...@@ -388,11 +419,17 @@ def quant_aware(program, ...@@ -388,11 +419,17 @@ def quant_aware(program,
sub_graphs = [sub_graph for sub_graph in main_graph.all_sub_graphs()] sub_graphs = [sub_graph for sub_graph in main_graph.all_sub_graphs()]
transform_pass_ops = [] transform_pass_ops = []
quant_dequant_ops = [] quant_dequant_ops = []
for op_type in config['quantize_op_types']: if 'quant_config' in config and config['quant_config']:
if op_type in TRANSFORM_PASS_OP_TYPES: transform_pass_ops = config[
transform_pass_ops.append(op_type) 'quant_config'].weight_quant_operation_types
elif op_type in QUANT_DEQUANT_PASS_OP_TYPES: quant_dequant_ops = config[
quant_dequant_ops.append(op_type) 'quant_config'].activation_quant_operation_types
else:
for op_type in config['quantize_op_types']:
if op_type in TRANSFORM_PASS_OP_TYPES:
transform_pass_ops.append(op_type)
elif op_type in QUANT_DEQUANT_PASS_OP_TYPES:
quant_dequant_ops.append(op_type)
if len(transform_pass_ops) > 0: if len(transform_pass_ops) > 0:
trannsform_func = 'QuantizationTransformPassV2' if config[ trannsform_func = 'QuantizationTransformPassV2' if config[
'onnx_format'] else 'QuantizationTransformPass' 'onnx_format'] else 'QuantizationTransformPass'
...@@ -486,8 +523,8 @@ def quant_post_static(executor, ...@@ -486,8 +523,8 @@ def quant_post_static(executor,
hist_percent=0.9999, hist_percent=0.9999,
bias_correction=False, bias_correction=False,
quantizable_op_type=[ quantizable_op_type=[
"conv2d", "depthwise_conv2d", "mul", "matmul", "conv2d", "depthwise_conv2d", "conv2d_transpose",
"matmul_v2" "mul", "matmul", "matmul_v2"
], ],
is_full_quantize=False, is_full_quantize=False,
weight_bits=8, weight_bits=8,
...@@ -495,10 +532,9 @@ def quant_post_static(executor, ...@@ -495,10 +532,9 @@ def quant_post_static(executor,
activation_quantize_type='range_abs_max', activation_quantize_type='range_abs_max',
weight_quantize_type='channel_wise_abs_max', weight_quantize_type='channel_wise_abs_max',
optimize_model=False, optimize_model=False,
onnx_format=False, onnx_format=True,
skip_tensor_list=None, skip_tensor_list=None,
is_use_cache_file=False, deploy_backend=None):
cache_dir="./temp_post_training"):
""" """
The function utilizes static post training quantization method to The function utilizes static post training quantization method to
quantize the fp32 model. It uses calibrate data to calculate the quantize the fp32 model. It uses calibrate data to calculate the
...@@ -568,10 +604,11 @@ def quant_post_static(executor, ...@@ -568,10 +604,11 @@ def quant_post_static(executor,
optimize_model(bool, optional): If set optimize_model as True, it applies some optimize_model(bool, optional): If set optimize_model as True, it applies some
passes to optimize the model before quantization. So far, the place of passes to optimize the model before quantization. So far, the place of
executor must be cpu it supports fusing batch_norm into convs. executor must be cpu it supports fusing batch_norm into convs.
onnx_format(bool): Whether to export the quantized model with format of ONNX. Default is False. onnx_format(bool): Whether to export the quantized model with format of ONNX. Default is True.
skip_tensor_list(list): List of skip quant tensor name. skip_tensor_list(list): List of skip quant tensor name.
is_use_cache_file(bool): This param is deprecated. deploy_backend(str): Deploy backend, it could be None, TensorRT, MKLDNN, ARM.
cache_dir(str): This param is deprecated. Other backends will continue to expand, the default is None, which means to
use the default general quantization configuration.
Returns: Returns:
None None
...@@ -599,8 +636,10 @@ def quant_post_static(executor, ...@@ -599,8 +636,10 @@ def quant_post_static(executor,
activation_quantize_type=activation_quantize_type, activation_quantize_type=activation_quantize_type,
weight_quantize_type=weight_quantize_type, weight_quantize_type=weight_quantize_type,
onnx_format=onnx_format, onnx_format=onnx_format,
skip_tensor_list=skip_tensor_list, # support in Paddle >= 2.3.1 skip_tensor_list=skip_tensor_list,
optimize_model=optimize_model) optimize_model=optimize_model,
deploy_backend=deploy_backend, # support at Paddle develop
)
except: except:
post_training_quantization = PostTrainingQuantization( post_training_quantization = PostTrainingQuantization(
executor=executor, executor=executor,
...@@ -624,6 +663,7 @@ def quant_post_static(executor, ...@@ -624,6 +663,7 @@ def quant_post_static(executor,
activation_quantize_type=activation_quantize_type, activation_quantize_type=activation_quantize_type,
weight_quantize_type=weight_quantize_type, weight_quantize_type=weight_quantize_type,
onnx_format=onnx_format, onnx_format=onnx_format,
skip_tensor_list=skip_tensor_list,
optimize_model=optimize_model) optimize_model=optimize_model)
post_training_quantization.quantize() post_training_quantization.quantize()
...@@ -639,12 +679,7 @@ def quant_post_static(executor, ...@@ -639,12 +679,7 @@ def quant_post_static(executor,
quant_post = quant_post_static quant_post = quant_post_static
def convert(program, def convert(program, place, config=None, scope=None, save_int8=False):
place,
config=None,
scope=None,
save_int8=False,
save_clip_ranges_path='./'):
""" """
convert quantized and well-trained ``program`` to final quantized convert quantized and well-trained ``program`` to final quantized
``program``that can be used to save ``inference model``. ``program``that can be used to save ``inference model``.
...@@ -666,7 +701,6 @@ def convert(program, ...@@ -666,7 +701,6 @@ def convert(program,
save_int8: Whether to return ``program`` which model parameters' save_int8: Whether to return ``program`` which model parameters'
dtype is ``int8``. This parameter can only be used to dtype is ``int8``. This parameter can only be used to
get model size. Default: ``False``. get model size. Default: ``False``.
save_clip_ranges_path: If config.onnx_format=True, quantization clip ranges will be saved locally.
Returns: Returns:
Tuple : freezed program which can be used for inference. Tuple : freezed program which can be used for inference.
...@@ -732,6 +766,9 @@ def convert(program, ...@@ -732,6 +766,9 @@ def convert(program,
persistables.extend(_op.input('X')) persistables.extend(_op.input('X'))
_op.desc.set_input("X", persistables) _op.desc.set_input("X", persistables)
assert not (
save_int8 and config['onnx_format']
), "When onnx_format=True, already saved int8 weight,so you can't set save_int8=True."
if save_int8: if save_int8:
convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place) convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
for sub_graph in test_graph.all_sub_graphs(): for sub_graph in test_graph.all_sub_graphs():
......
...@@ -180,10 +180,10 @@ class ReconstructionQuantization(PostTrainingQuantization): ...@@ -180,10 +180,10 @@ class ReconstructionQuantization(PostTrainingQuantization):
# save out_threshold for quantized ops. # save out_threshold for quantized ops.
self._save_output_threshold() self._save_output_threshold()
if any(op_type in self._quantizable_op_type if any(op_type in self.quant_config.activation_quant_operation_types
for op_type in self._dynamic_quantize_op_type): for op_type in self._dynamic_quantize_op_type):
self._collect_dynamic_quantize_op_threshold( self._collect_dynamic_quantize_op_threshold(
self._dynamic_quantize_op_type, ) self._dynamic_quantize_op_type)
# Move sub blocks persistable var to global block # Move sub blocks persistable var to global block
global_block = self._program.global_block() global_block = self._program.global_block()
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import sys
sys.path.append("../") sys.path.append("../")
import unittest import unittest
...@@ -23,70 +24,7 @@ from layers import conv_bn_layer ...@@ -23,70 +24,7 @@ from layers import conv_bn_layer
import numpy as np import numpy as np
class TestQuantAwareCase1(StaticCase): class TestQuantAwareCase(StaticCase):
def get_model(self):
image = paddle.static.data(
name='image', shape=[None, 1, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
model = MobileNet()
out = model.net(input=image, class_dim=10)
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost)
startup_prog = paddle.static.default_startup_program()
train_prog = paddle.static.default_main_program()
return startup_prog, train_prog
def get_op_number(self, prog):
graph = paddle.fluid.framework.IrGraph(
paddle.framework.core.Graph(prog.desc), for_test=False)
quant_op_nums = 0
op_nums = 0
for op in graph.all_op_nodes():
if op.name() in ['conv2d', 'depthwise_conv2d', 'mul']:
op_nums += 1
elif 'fake_' in op.name():
quant_op_nums += 1
return op_nums, quant_op_nums
def test_quant_op(self):
startup_prog, train_prog = self.get_model()
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
config_1 = {
'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'],
}
quant_prog_1 = quant_aware(
train_prog, place, config=config_1, for_test=True)
op_nums_1, quant_op_nums_1 = self.get_op_number(quant_prog_1)
convert_prog_1 = convert(quant_prog_1, place, config=config_1)
convert_op_nums_1, convert_quant_op_nums_1 = self.get_op_number(
convert_prog_1)
config_1['not_quant_pattern'] = ['last_fc']
quant_prog_2 = quant_aware(
train_prog, place, config=config_1, for_test=True)
op_nums_2, quant_op_nums_2 = self.get_op_number(quant_prog_2)
convert_prog_2 = convert(quant_prog_2, place, config=config_1)
convert_op_nums_2, convert_quant_op_nums_2 = self.get_op_number(
convert_prog_2)
self.assertTrue(op_nums_1 == op_nums_2)
# test quant_aware op numbers
self.assertTrue(op_nums_1 * 4 == quant_op_nums_1)
# test convert op numbers
self.assertTrue(convert_op_nums_1 * 2 == convert_quant_op_nums_1)
# test skip_quant
self.assertTrue(quant_op_nums_1 - 4 == quant_op_nums_2)
self.assertTrue(convert_quant_op_nums_1 - 2 == convert_quant_op_nums_2)
class TestQuantAwareCase2(StaticCase):
def test_accuracy(self): def test_accuracy(self):
image = paddle.static.data( image = paddle.static.data(
name='image', shape=[None, 1, 28, 28], dtype='float32') name='image', shape=[None, 1, 28, 28], dtype='float32')
...@@ -103,7 +41,7 @@ class TestQuantAwareCase2(StaticCase): ...@@ -103,7 +41,7 @@ class TestQuantAwareCase2(StaticCase):
weight_decay=paddle.regularizer.L2Decay(4e-5)) weight_decay=paddle.regularizer.L2Decay(4e-5))
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
main_prog = paddle.static.default_main_program() main_prog = paddle.static.default_main_program()
val_prog = main_prog.clone(for_test=True) val_prog = paddle.static.default_main_program().clone(for_test=True)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace() ) else paddle.CPUPlace()
...@@ -173,14 +111,61 @@ class TestQuantAwareCase2(StaticCase): ...@@ -173,14 +111,61 @@ class TestQuantAwareCase2(StaticCase):
} }
quant_train_prog = quant_aware(main_prog, place, config, for_test=False) quant_train_prog = quant_aware(main_prog, place, config, for_test=False)
quant_eval_prog = quant_aware(val_prog, place, config, for_test=True) quant_eval_prog = quant_aware(val_prog, place, config, for_test=True)
op_nums_1, quant_op_nums_1 = self.get_op_number(quant_eval_prog)
# test quant_aware op numbers
self.assertTrue(op_nums_1 * 2 == quant_op_nums_1)
train(quant_train_prog) train(quant_train_prog)
quant_eval_prog, int8_prog = convert( convert_eval_prog = convert(quant_eval_prog, place, config)
quant_eval_prog, place, config, save_int8=True)
top1_2, top5_2 = test(quant_eval_prog) top1_2, top5_2 = test(convert_eval_prog)
# values before quantization and after quantization should be close # values before quantization and after quantization should be close
print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1)) print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
print("after quantization: top1: {}, top5: {}".format(top1_2, top5_2)) print("after quantization: top1: {}, top5: {}".format(top1_2, top5_2))
convert_op_nums_1, convert_quant_op_nums_1 = self.get_convert_op_number(
convert_eval_prog)
# test convert op numbers
self.assertTrue(convert_op_nums_1 + 25 == convert_quant_op_nums_1)
config['not_quant_pattern'] = ['last_fc']
quant_prog_2 = quant_aware(
main_prog, place, config=config, for_test=True)
op_nums_2, quant_op_nums_2 = self.get_op_number(quant_prog_2)
convert_prog_2 = convert(quant_prog_2, place, config=config)
convert_op_nums_2, convert_quant_op_nums_2 = self.get_convert_op_number(
convert_prog_2)
self.assertTrue(op_nums_1 == op_nums_2)
# test skip_quant
self.assertTrue(quant_op_nums_1 - 2 == quant_op_nums_2)
self.assertTrue(convert_quant_op_nums_1 == convert_quant_op_nums_2)
def get_op_number(self, prog):
graph = paddle.fluid.framework.IrGraph(
paddle.framework.core.Graph(prog.desc), for_test=False)
quant_op_nums = 0
op_nums = 0
for op in graph.all_op_nodes():
if op.name() in ['conv2d', 'depthwise_conv2d', 'mul']:
op_nums += 1
elif op.name() == 'quantize_linear':
quant_op_nums += 1
return op_nums, quant_op_nums
def get_convert_op_number(self, prog):
graph = paddle.fluid.framework.IrGraph(
paddle.framework.core.Graph(prog.desc), for_test=True)
quant_op_nums = 0
op_nums = 0
dequant_num = 0
for op in graph.all_op_nodes():
if op.name() not in ['quantize_linear', 'dequantize_linear']:
op_nums += 1
elif op.name() == 'quantize_linear':
quant_op_nums += 1
return op_nums, quant_op_nums
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -153,8 +153,7 @@ class TestQuantAwareCase1(StaticCase): ...@@ -153,8 +153,7 @@ class TestQuantAwareCase1(StaticCase):
quant_eval_prog = quant_aware(val_prog, place, config, for_test=True) quant_eval_prog = quant_aware(val_prog, place, config, for_test=True)
train(quant_train_prog_pact) train(quant_train_prog_pact)
quant_eval_prog, int8_prog = convert( quant_eval_prog = convert(quant_eval_prog, place, config)
quant_eval_prog, place, config, save_int8=True)
top1_2, top5_2 = test(quant_eval_prog) top1_2, top5_2 = test(quant_eval_prog)
# values before quantization and after quantization should be close # values before quantization and after quantization should be close
print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1)) print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
......
...@@ -130,7 +130,8 @@ class TestQuantAwareWithInferModelCase1(StaticCase): ...@@ -130,7 +130,8 @@ class TestQuantAwareWithInferModelCase1(StaticCase):
'weight_quantize_type': 'channel_wise_abs_max', 'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max', 'activation_quantize_type': 'moving_average_abs_max',
'not_quant_pattern': ['skip_quant'], 'not_quant_pattern': ['skip_quant'],
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'] 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
'onnx_format': False
} }
train_config = { train_config = {
"num_epoch": 1, # training epoch num "num_epoch": 1, # training epoch num
......
...@@ -143,8 +143,7 @@ class TestQuantPostQuantAwareCase1(StaticCase): ...@@ -143,8 +143,7 @@ class TestQuantPostQuantAwareCase1(StaticCase):
scale_dict=scale_dict, scale_dict=scale_dict,
model_type='transformer') model_type='transformer')
train(quant_train_prog) train(quant_train_prog)
quant_eval_prog, int8_prog = convert( quant_eval_prog = convert(quant_eval_prog, place, config)
quant_eval_prog, place, config, save_int8=True)
top1_2 = test(quant_eval_prog) top1_2 = test(quant_eval_prog)
# values before quantization and after quantization should be close # values before quantization and after quantization should be close
print("before quantization: top1: {}".format(top1_1)) print("before quantization: top1: {}".format(top1_1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册