提交 feab2624 编写于 作者: S slf12

update quant_post

上级 df8e1779
......@@ -23,8 +23,13 @@ from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.fluid import core
WEIGHT_QUANTIZATION_TYPES=['abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max']
ACTIVATION_QUANTIZATION_TYPES=['abs_max','range_abs_max', 'moving_average_abs_max']
WEIGHT_QUANTIZATION_TYPES = [
'abs_max', 'channel_wise_abs_max', 'range_abs_max',
'moving_average_abs_max'
]
ACTIVATION_QUANTIZATION_TYPES = [
'abs_max', 'range_abs_max', 'moving_average_abs_max'
]
VALID_DTYPES = ['int8']
_quant_config_default = {
......@@ -154,19 +159,19 @@ def quant_aware(program, place, config, scope=None, for_test=False):
return quant_program
def quant_post(executor,
model_path,
def quant_post(executor,
model_dir,
quantize_model_path,
data_reader,
batch_size=10,
sample_generator,
model_filename=None,
params_filename=None,
batch_size=16,
batch_nums=None,
scope=None,
algo='KL',
quantizable_op_type=[
"conv2d", "depthwise_conv2d",
"mul", "pool2d", "elementwise_add"]):
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"]):
"""
The class utilizes post training quantization methon to quantize the
The function utilizes post training quantization method to quantize the
fp32 model. It uses calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
......@@ -174,33 +179,47 @@ def quant_post(executor,
Args:
executor(fluid.Executor): The executor to load, run and save the
quantized model.
model_path(str): The path of fp32 model that will be quantized(
load_inference_model).
quantize_model_path(str): The path to save quantized model.
data_reader(Reader): The data reader generates a sample every time,
and it provides calibrate data for DataLoader.
batch_size(int, optional): The batch size of DataLoader, default is 10.
model_dir(str): The path of fp32 model that will be quantized, and
the model and params that saved by fluid.io.save_inference_model
are under the path.
quantize_model_path(str): The path to save quantized model using api
fluid.io.save_inference_model.
sample_generator(Python Generator): The sample generator provides
calibrate data for DataLoader, and it only returns a sample every time.
model_filename(str, optional): The name of model file to load the inference
program. If parameters were saved in separate files,
set it as 'None'. Default is 'None'.
params_filename(str, optional): The name of params file to load all parameters.
When all parameters were saved in a single file, set it
as filename. If parameters were saved in separate files,
set it as 'None'. Default is 'None'.
batch_size(int, optional): The batch size of DataLoader, default is 16.
batch_nums(int, optional): If set batch_nums, the number of calibrate
data is batch_size*batch_nums. If batch_nums=None, use all data
provided by data_reader as calibrate data.
scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
data is 'batch_size*batch_nums'. If batch_nums is None, use all data
generated by sample_generator as calibrate data.
scope(fluid.Scope, optional): The scope to run program, use it to load
and save variables. If scope is None, will use fluid.global_scope().
algo(str, optional): If algo=KL, use KL-divergenc method to
get the more precise scale factor. If algo='direct', use
abs_max methon to get the scale factor. Default is KL.
quantizable_op_type(list[str], optional): List the type of ops
abs_max method to get the scale factor. Default is 'KL'.
quantizable_op_type(list[str], optional): The list of op types
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul", "pool2d", "elementwise_add"].
"mul"].
Returns:
None
"""
post_training_quantization = PostTrainingQuantization(
executor=executor,
model_path=model_path,
data_reader=data_reader,
batch_size=batch_size,
batch_nums=batch_nums,
scope=scope,
algo=algo,
quantizable_op_type=quantizable_op_type)
executor=executor,
sample_generator=sample_generator,
model_dir=model_dir,
model_filename=model_filename,
params_filename=params_filename,
batch_size=batch_size,
batch_nums=batch_nums,
scope=scope,
algo=algo,
quantizable_op_type=quantizable_op_type,
is_full_quantize=False)
post_training_quantization.quantize()
post_training_quantization.save_quantized_model(quantize_model_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册