未验证 提交 e7a02b5c 编写于 作者: X XGZhang 提交者: GitHub

changed post-quant methods (#713)

上级 8fad8d41
...@@ -43,7 +43,7 @@ python quant_post_static.py --model_path ./inference_model/MobileNet --save_path ...@@ -43,7 +43,7 @@ python quant_post_static.py --model_path ./inference_model/MobileNet --save_path
运行以上命令后,可在``${save_path}``下看到量化后的模型文件和参数文件。 运行以上命令后,可在``${save_path}``下看到量化后的模型文件和参数文件。
> 使用的量化算法为``'KL'``, 使用训练集中的160张图片进行量化参数的校正。 > 使用的量化算法为``'hist'``, 使用训练集中的32张图片进行量化参数的校正。
### 测试精度 ### 测试精度
...@@ -67,6 +67,6 @@ python eval.py --model_path ./quant_model_train/MobileNet --model_name __model__ ...@@ -67,6 +67,6 @@ python eval.py --model_path ./quant_model_train/MobileNet --model_name __model__
精度输出为 精度输出为
``` ```
top1_acc/top5_acc= [0.70141864 0.89086477] top1_acc/top5_acc= [0.70328485 0.89183184]
``` ```
从以上精度对比可以看出,对``mobilenet````imagenet``上的分类模型进行离线量化后 ``top1``精度损失为``0.77%````top5``精度损失为``0.46%``. 从以上精度对比可以看出,对``mobilenet````imagenet``上的分类模型进行离线量化后 ``top1``精度损失为``0.59%````top5``精度损失为``0.36%``.
...@@ -19,13 +19,15 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -19,13 +19,15 @@ _logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('batch_size', int, 16, "Minibatch size.") add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('batch_num', int, 10, "Batch number") add_arg('batch_num', int, 1, "Batch number")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.") add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model_path', str, "./inference_model/MobileNet/", "model dir") add_arg('model_path', str, "./inference_model/MobileNet/", "model dir")
add_arg('save_path', str, "./quant_model/MobileNet/", "model dir to save quanted model") add_arg('save_path', str, "./quant_model/MobileNet/", "model dir to save quanted model")
add_arg('model_filename', str, None, "model file name") add_arg('model_filename', str, None, "model file name")
add_arg('params_filename', str, None, "params file name") add_arg('params_filename', str, None, "params file name")
add_arg('algo', str, 'hist', "calibration algorithm")
add_arg('hist_percent', float, 0.9999, "The percentile of algo:hist")
# yapf: enable # yapf: enable
...@@ -46,7 +48,9 @@ def quantize(args): ...@@ -46,7 +48,9 @@ def quantize(args):
model_filename=args.model_filename, model_filename=args.model_filename,
params_filename=args.params_filename, params_filename=args.params_filename,
batch_size=args.batch_size, batch_size=args.batch_size,
batch_nums=args.batch_num) batch_nums=args.batch_num,
algo=args.algo,
hist_percent=args.hist_percent)
def main(): def main():
......
...@@ -313,7 +313,9 @@ def quant_post_static( ...@@ -313,7 +313,9 @@ def quant_post_static(
batch_size=16, batch_size=16,
batch_nums=None, batch_nums=None,
scope=None, scope=None,
algo='KL', algo='hist',
hist_percent=0.9999,
bias_correction=False,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False, is_full_quantize=False,
weight_bits=8, weight_bits=8,
...@@ -358,9 +360,15 @@ def quant_post_static( ...@@ -358,9 +360,15 @@ def quant_post_static(
generated by sample_generator as calibrate data. generated by sample_generator as calibrate data.
scope(paddle.static.Scope, optional): The scope to run program, use it to load scope(paddle.static.Scope, optional): The scope to run program, use it to load
and save variables. If scope is None, will use paddle.static.global_scope(). and save variables. If scope is None, will use paddle.static.global_scope().
algo(str, optional): If algo=KL, use KL-divergenc method to algo(str, optional): If algo='KL', use KL-divergenc method to
get the more precise scale factor. If algo='direct', use get the scale factor. If algo='hist', use the hist_percent of histogram
abs_max method to get the scale factor. Default: 'KL'. to get the scale factor. If algo='mse', search for the best scale factor which
makes the mse loss minimal. Use one batch of data for mse is enough. If
algo='avg', use the average of abs_max values to get the scale factor. If
algo='abs_max', use abs_max method to get the scale factor. Default: 'hist'.
hist_percent(float, optional): The percentile of histogram for algo hist.Default:0.9999.
bias_correction(bool, optional): Bias correction method of https://arxiv.org/abs/1810.05723.
Default: False.
quantizable_op_type(list[str], optional): The list of op types quantizable_op_type(list[str], optional): The list of op types
that will be quantized. Default: ["conv2d", "depthwise_conv2d", that will be quantized. Default: ["conv2d", "depthwise_conv2d",
"mul"]. "mul"].
...@@ -397,6 +405,8 @@ def quant_post_static( ...@@ -397,6 +405,8 @@ def quant_post_static(
batch_nums=batch_nums, batch_nums=batch_nums,
scope=scope, scope=scope,
algo=algo, algo=algo,
hist_percent=hist_percent,
bias_correction=bias_correction,
quantizable_op_type=quantizable_op_type, quantizable_op_type=quantizable_op_type,
is_full_quantize=is_full_quantize, is_full_quantize=is_full_quantize,
weight_bits=weight_bits, weight_bits=weight_bits,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册