diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index ad07fa44fe50ad9d4229c0c42e89d1a847b1a4fa..a74105a237d96294ebc87ea8f367b2bcbbb77778 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -674,6 +674,7 @@ class AutoCompression: hist_percent=config.hist_percent, batch_size=[1], batch_num=config.batch_num, + onnx_format=config.onnx_format, runcount_limit=config.max_quant_count) else: diff --git a/paddleslim/quant/post_quant_hpo.py b/paddleslim/quant/post_quant_hpo.py index 28f1cb2d388d474fa1b867f7d4a5c7475223f003..92617a4c5f7a266a0fe29946bbe52fa564c2e053 100755 --- a/paddleslim/quant/post_quant_hpo.py +++ b/paddleslim/quant/post_quant_hpo.py @@ -78,6 +78,7 @@ class QuantConfig(object): activation_bits=8, weight_quantize_type='channel_wise_abs_max', optimize_model=False, + onnx_format=False, is_use_cache_file=False, cache_dir="./temp_post_training"): """QuantConfig init""" @@ -106,6 +107,7 @@ class QuantConfig(object): self.activation_bits = activation_bits self.weight_quantize_type = weight_quantize_type self.optimize_model = optimize_model + self.onnx_format = onnx_format self.is_use_cache_file = is_use_cache_file self.cache_dir = cache_dir @@ -291,7 +293,8 @@ def quantize(cfg): hist_percent=hist_percent, \ bias_correction=bias_correct, \ batch_size=batch_size, \ - batch_nums=batch_num) + batch_nums=batch_num, + onnx_format=g_quant_config.onnx_format) global g_min_emd_loss try: @@ -356,6 +359,7 @@ def quant_post_hpo( batch_size=[10, 30], ### uniform sample in list. batch_num=[10, 30], ### uniform sample in list. optimize_model=False, + onnx_format=False, is_use_cache_file=False, cache_dir="./temp_post_training", runcount_limit=30): @@ -403,6 +407,7 @@ def quant_post_hpo( optimize_model(bool, optional): If set optimize_model as True, it applies some passes to optimize the model before quantization. So far, the place of executor must be cpu it supports fusing batch_norm into convs. + onnx_format(bool): Whether to export the quantized model with format of ONNX. Default is False. is_use_cache_file(bool): This param is deprecated. cache_dir(str): This param is deprecated. runcount_limit(int): max. number of model quantization. @@ -429,7 +434,7 @@ def quant_post_hpo( model_filename, params_filename, save_model_filename, save_params_filename, scope, quantizable_op_type, is_full_quantize, weight_bits, activation_bits, weight_quantize_type, optimize_model, - is_use_cache_file, cache_dir) + onnx_format, is_use_cache_file, cache_dir) cs = ConfigurationSpace() hyper_params = []