From d31a202a4ac5c571d4ecbe09d14e060b3b1c4ea4 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Tue, 29 Mar 2022 16:07:01 +0800 Subject: [PATCH] add adaround post quant method (#1023) --- demo/quant/quant_post/quant_post.py | 2 ++ paddleslim/quant/quanter.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/demo/quant/quant_post/quant_post.py b/demo/quant/quant_post/quant_post.py index aeb337fc..c7a682df 100755 --- a/demo/quant/quant_post/quant_post.py +++ b/demo/quant/quant_post/quant_post.py @@ -29,6 +29,7 @@ add_arg('save_path', str, "./quant_model/MobileNet/", "model dir to save add_arg('model_filename', str, None, "model file name") add_arg('params_filename', str, None, "params file name") add_arg('algo', str, 'hist', "calibration algorithm") +add_arg('round_type', str, 'round', "The method of converting the quantized weights.") add_arg('hist_percent', float, 0.9999, "The percentile of algo:hist") add_arg('bias_correction', bool, False, "Whether to use bias correction") add_arg('ce_test', bool, False, "Whether to CE test.") @@ -74,6 +75,7 @@ def quantize(args): batch_size=args.batch_size, batch_nums=args.batch_num, algo=args.algo, + round_type=args.round_type, hist_percent=args.hist_percent, bias_correction=args.bias_correction) diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py index a5ebbcb6..b77efa15 100755 --- a/paddleslim/quant/quanter.py +++ b/paddleslim/quant/quanter.py @@ -325,6 +325,7 @@ def quant_post_static( batch_nums=None, scope=None, algo='hist', + round_type='round', hist_percent=0.9999, bias_correction=False, quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], @@ -380,6 +381,9 @@ def quant_post_static( makes the mse loss minimal. Use one batch of data for mse is enough. If algo='avg', use the average of abs_max values to get the scale factor. If algo='abs_max', use abs_max method to get the scale factor. Default: 'hist'. + round_type(str, optional): The method of converting the quantized weights value + from float to int. Currently supports ['round', 'adaround'] methods. + Default is `round`, which is rounding nearest to the nearest whole number. hist_percent(float, optional): The percentile of histogram for algo hist.Default:0.9999. bias_correction(bool, optional): Bias correction method of https://arxiv.org/abs/1810.05723. Default: False. @@ -420,6 +424,7 @@ def quant_post_static( batch_nums=batch_nums, scope=scope, algo=algo, + round_type=round_type, hist_percent=hist_percent, bias_correction=bias_correction, quantizable_op_type=quantizable_op_type, -- GitLab