diff --git a/demo/quant/quant_post_hpo/quant_post_hpo.py b/demo/quant/quant_post_hpo/quant_post_hpo.py index f1c4089e1387595cbe6024abe42e6f8fa1dc8107..f96e869c312d4bc4358ca825cc2423aa8d92e4b2 100755 --- a/demo/quant/quant_post_hpo/quant_post_hpo.py +++ b/demo/quant/quant_post_hpo/quant_post_hpo.py @@ -55,7 +55,7 @@ def quantize(args): save_model_filename='__model__', save_params_filename='__params__', quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], - weight_quantize_type='channel_wise_abs_max', + weight_quantize_type=['channel_wise_abs_max'], runcount_limit=args.max_model_quant_count) def main(): diff --git a/paddleslim/quant/quant_post_hpo.py b/paddleslim/quant/quant_post_hpo.py index fb6cbe18edb44263856eb13d0a3d601f6a57b179..c12edbed31afc28c7d776cb914e1d21726017fee 100755 --- a/paddleslim/quant/quant_post_hpo.py +++ b/paddleslim/quant/quant_post_hpo.py @@ -189,10 +189,10 @@ def eval_quant_model(): valid_data_num = 0 max_eval_data_num = 200 if g_quant_config.eval_sample_generator is not None: - feed_dict=False + feed_dict = False eval_dataloader = g_quant_config.eval_sample_generator else: - feed_dict=True + feed_dict = True eval_dataloader = g_quant_config.eval_dataloader for i, data in enumerate(eval_dataloader()): with paddle.static.scope_guard(float_scope): @@ -236,12 +236,20 @@ def eval_quant_model(): def quantize(cfg): """model quantize job""" algo = cfg["algo"] if 'algo' in cfg else g_quant_config.algo[0][0] - hist_percent = cfg["hist_percent"] if "hist_percent" in cfg else g_quant_config.hist_percent[0][0] - bias_correct = cfg["bias_correct"] if "bias_correct" in cfg else g_quant_config.bias_correct[0][0] - batch_size = cfg["batch_size"] if "batch_size" in cfg else g_quant_config.batch_size[0][0] - batch_num = cfg["batch_num"] if "batch_num" in cfg else g_quant_config.batch_num[0][0] - weight_quantize_type = cfg["weight_quantize_type"] if "weight_quantize_type" in cfg else g_quant_config.weight_quantize_type[0] - print(hist_percent, bias_correct, batch_size, batch_num, weight_quantize_type) + hist_percent = cfg[ + "hist_percent"] if "hist_percent" in cfg else g_quant_config.hist_percent[ + 0][0] + bias_correct = cfg[ + "bias_correct"] if "bias_correct" in cfg else g_quant_config.bias_correct[ + 0][0] + batch_size = cfg[ + "batch_size"] if "batch_size" in cfg else g_quant_config.batch_size[0][ + 0] + batch_num = cfg[ + "batch_num"] if "batch_num" in cfg else g_quant_config.batch_num[0][0] + weight_quantize_type = cfg[ + "weight_quantize_type"] if "weight_quantize_type" in cfg else g_quant_config.weight_quantize_type[ + 0] quant_post( \ executor=g_quant_config.executor, \ @@ -279,34 +287,35 @@ def quantize(cfg): return emd_loss -def quant_post_hpo(executor, - place, - model_dir, - quantize_model_path, - train_sample_generator=None, - eval_sample_generator=None, - train_dataloader=None, - eval_dataloader=None, - eval_function=None, - model_filename=None, - params_filename=None, - save_model_filename='__model__', - save_params_filename='__params__', - scope=None, - quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], - is_full_quantize=False, - weight_bits=8, - activation_bits=8, - weight_quantize_type=['channel_wise_abs_max'], - algo=["KL", "hist", "avg", "mse"], - bias_correct=[True, False], - hist_percent=[0.98, 0.999], ### uniform sample in list. - batch_size=[10, 30], ### uniform sample in list. - batch_num=[10, 30], ### uniform sample in list. - optimize_model=False, - is_use_cache_file=False, - cache_dir="./temp_post_training", - runcount_limit=30): +def quant_post_hpo( + executor, + place, + model_dir, + quantize_model_path, + train_sample_generator=None, + eval_sample_generator=None, + train_dataloader=None, + eval_dataloader=None, + eval_function=None, + model_filename=None, + params_filename=None, + save_model_filename='__model__', + save_params_filename='__params__', + scope=None, + quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], + is_full_quantize=False, + weight_bits=8, + activation_bits=8, + weight_quantize_type=['channel_wise_abs_max'], + algo=["KL", "hist", "avg", "mse"], + bias_correct=[True, False], + hist_percent=[0.98, 0.999], ### uniform sample in list. + batch_size=[10, 30], ### uniform sample in list. + batch_num=[10, 30], ### uniform sample in list. + optimize_model=False, + is_use_cache_file=False, + cache_dir="./temp_post_training", + runcount_limit=30): """ The function utilizes static post training quantization method to quantize the fp32 model. It uses calibrate data to calculate the @@ -360,25 +369,27 @@ def quant_post_hpo(executor, global g_quant_config g_quant_config = QuantConfig( - executor, place, model_dir, quantize_model_path, algo, hist_percent, + executor, place, model_dir, quantize_model_path, algo, hist_percent, bias_correct, batch_size, batch_num, train_sample_generator, eval_sample_generator, train_dataloader, eval_dataloader, eval_function, - model_filename, params_filename, - save_model_filename, save_params_filename, scope, quantizable_op_type, - is_full_quantize, weight_bits, activation_bits, weight_quantize_type, - optimize_model, is_use_cache_file, cache_dir) + model_filename, params_filename, save_model_filename, + save_params_filename, scope, quantizable_op_type, is_full_quantize, + weight_bits, activation_bits, weight_quantize_type, optimize_model, + is_use_cache_file, cache_dir) cs = ConfigurationSpace() hyper_params = [] if 'hist' in algo: hist_percent = UniformFloatHyperparameter( - "hist_percent", hist_percent[0], hist_percent[1], default_value=hist_percent[0]) + "hist_percent", + hist_percent[0], + hist_percent[1], + default_value=hist_percent[0]) hyper_params.append(hist_percent) if len(algo) > 1: - algo = CategoricalHyperparameter( - "algo", algo, default_value=algo[0]) + algo = CategoricalHyperparameter("algo", algo, default_value=algo[0]) hyper_params.append(algo) else: algo = algo[0] @@ -397,7 +408,10 @@ def quant_post_hpo(executor, weight_quantize_type = weight_quantize_type[0] if len(batch_size) > 1: batch_size = UniformIntegerHyperparameter( - "batch_size", batch_size[0], batch_size[1], default_value=batch_size[0]) + "batch_size", + batch_size[0], + batch_size[1], + default_value=batch_size[0]) hyper_params.append(batch_size) else: batch_size = batch_size[0] @@ -407,7 +421,7 @@ def quant_post_hpo(executor, "batch_num", batch_num[0], batch_num[1], default_value=batch_num[0]) hyper_params.append(batch_num) else: - batch_num = batch_num[0] + batch_num = batch_num[0] if len(hyper_params) == 0: quant_post( \