未验证 提交 83d5d128 编写于 作者: L Liufang Sang 提交者: GitHub

update quant_aware and quant_post for paddle version 2.0 (#244)

上级 a521a961
......@@ -46,8 +46,8 @@ ACTIVATION_QUANTIZATION_TYPES_TENSORRT = [
VALID_DTYPES = ['int8']
TRANSFORM_PASS_OP_TYPES = QuantizationTransformPass._supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = AddQuantDequantPass._supported_quantizable_op_type + \
AddQuantDequantPass._activation_type
QUANT_DEQUANT_PASS_OP_TYPES = AddQuantDequantPass._supported_quantizable_op_type
TENSORRT_OP_TYPES = [
'mul', 'conv2d', 'pool2d', 'depthwise_conv2d', 'elementwise_add',
'leaky_relu'
......@@ -230,9 +230,12 @@ def quant_aware(program, place, config=None, scope=None, for_test=False):
def quant_post(executor,
model_dir,
quantize_model_path,
sample_generator,
batch_generator=None,
sample_generator=None,
model_filename=None,
params_filename=None,
save_model_filename='__model__',
save_params_filename='__params__',
batch_size=16,
batch_nums=None,
scope=None,
......@@ -241,6 +244,8 @@ def quant_post(executor,
is_full_quantize=False,
weight_bits=8,
activation_bits=8,
activation_quantize_type='range_abs_max',
weight_quantize_type='channel_wise_abs_max',
is_use_cache_file=False,
cache_dir="./temp_post_training"):
"""
......@@ -257,6 +262,10 @@ def quant_post(executor,
are under the path.
quantize_model_path(str): The path to save quantized model using api
``fluid.io.save_inference_model``.
batch_generator(Python Generator): The batch generator provides
calibrate data for DataLoader, and it returns a batch every
time. For sample_generator and batch_generator, only one
can be set. Beisdes, batch_generator supports lod tensor.
sample_generator(Python Generator): The sample generator provides
calibrate data for DataLoader, and it only returns a sample every time.
model_filename(str, optional): The name of model file. If parameters
......@@ -265,6 +274,9 @@ def quant_post(executor,
When all parameters are saved in a single file, set it
as filename. If parameters are saved in separate files,
set it as 'None'. Default : 'None'.
save_model_filename(str): The name of model file to save the quantized inference program. Default: '__model__'.
save_params_filename(str): The name of file to save all related parameters.
If it is set None, parameters will be saved in separate files. Default: '__params__'.
batch_size(int, optional): The batch size of DataLoader, default is 16.
batch_nums(int, optional): If batch_nums is not None, the number of calibrate
data is 'batch_size*batch_nums'. If batch_nums is None, use all data
......@@ -279,6 +291,15 @@ def quant_post(executor,
"mul"].
weight_bits(int, optional): quantization bit number for weights.
activation_bits(int): quantization bit number for activation.
activation_quantize_type(str): quantization type for activation,
now support 'range_abs_max', 'moving_average_abs_max' and 'abs_max'.
This parameter only specifies the fake ops in quantized model.
If it is 'range_abs_max' or 'moving_average_abs_max', we save the scale
obtained by post training quantization in fake ops. If it
is 'abs_max', the scale will not be saved in fake ops.
weight_quantize_type(str): quantization type for weights,
support 'abs_max' and 'channel_wise_abs_max'. Compared to 'abs_max',
the model accuracy is usually higher when using 'channel_wise_abs_max'.
is_full_quantize(bool): if True, apply quantization to all supported quantizable op type.
If False, only apply quantization to the input quantizable_op_type. Default is False.
is_use_cache_file(bool): If False, all temp data will be saved in memory. If True,
......@@ -291,6 +312,7 @@ def quant_post(executor,
post_training_quantization = PostTrainingQuantization(
executor=executor,
sample_generator=sample_generator,
batch_generator=batch_generator,
model_dir=model_dir,
model_filename=model_filename,
params_filename=params_filename,
......@@ -302,10 +324,15 @@ def quant_post(executor,
is_full_quantize=is_full_quantize,
weight_bits=weight_bits,
activation_bits=activation_bits,
activation_quantize_type=activation_quantize_type,
weight_quantize_type=weight_quantize_type,
is_use_cache_file=is_use_cache_file,
cache_dir=cache_dir)
post_training_quantization.quantize()
post_training_quantization.save_quantized_model(quantize_model_path)
post_training_quantization.save_quantized_model(
quantize_model_path,
model_filename=save_model_filename,
params_filename=save_params_filename)
def convert(program, place, config=None, scope=None, save_int8=False):
......@@ -338,10 +365,6 @@ def convert(program, place, config=None, scope=None, save_int8=False):
_logger.info("convert config {}".format(config))
test_graph = IrGraph(core.Graph(program.desc), for_test=True)
support_op_types = []
for op in config['quantize_op_types']:
if op in QuantizationFreezePass._supported_quantizable_op_type:
support_op_types.append(op)
# Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
......@@ -350,8 +373,8 @@ def convert(program, place, config=None, scope=None, save_int8=False):
place=place,
weight_bits=config['weight_bits'],
activation_bits=config['activation_bits'],
weight_quantize_type=config['weight_quantize_type'],
quantizable_op_type=support_op_types)
weight_quantize_type=config['weight_quantize_type'])
freeze_pass.apply(test_graph)
freezed_program = test_graph.to_program()
......
......@@ -101,12 +101,15 @@ class TestQuantAwareCase1(unittest.TestCase):
exe,
'./test_quant_post',
'./test_quant_post_inference',
paddle.dataset.mnist.test(),
sample_generator=paddle.dataset.mnist.test(),
model_filename='model',
params_filename='params',
batch_nums=10)
quant_post_prog, feed_target_names, fetch_targets = fluid.io.load_inference_model(
dirname='./test_quant_post_inference', executor=exe)
dirname='./test_quant_post_inference',
executor=exe,
model_filename='__model__',
params_filename='__params__')
top1_2, top5_2 = test(quant_post_prog, fetch_targets)
print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
print("after quantization: top1: {}, top5: {}".format(top1_2, top5_2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册