diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index ae2298e10a38fe5eb14c8b43cc022383e4bc7ab8..d52e3ea10459d4e2488aba32bf06f88ee6eccfca 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -54,17 +54,19 @@ def _set_variable_data(scope, place, var_name, np_value): class PostTrainingQuantization(object): def __init__(self, - executor, - sample_generator, - model_dir, + executor=None, + scope=None, + model_dir=None, model_filename=None, params_filename=None, + sample_generator=None, batch_size=10, batch_nums=None, - scope=None, algo="KL", quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], is_full_quantize=False, + weight_bits=8, + activation_bits=8, is_use_cache_file=False, cache_dir="./temp_post_training"): ''' @@ -76,9 +78,8 @@ class PostTrainingQuantization(object): Args: executor(fluid.Executor): The executor to load, run and save the quantized model. - sample_generator(Python Generator): The sample generator provides - calibrate data for DataLoader, and it only returns a sample every - time. + scope(fluid.Scope, optional): The scope of the program, use it to load + and save variables. If scope=None, get scope by global_scope(). model_dir(str): The path of the fp32 model that will be quantized, and the model and params files are under the path. model_filename(str, optional): The name of file to load the inference @@ -88,12 +89,13 @@ class PostTrainingQuantization(object): When all parameters were saved in a single binary file, set it as the real filename. If parameters were saved in separate files, set it as 'None'. Default is 'None'. + sample_generator(Python Generator): The sample generator provides + calibrate data for DataLoader, and it only returns a sample every + time. batch_size(int, optional): The batch size of DataLoader. Default is 10. batch_nums(int, optional): If batch_nums is not None, the number of calibrate data is batch_size*batch_nums. If batch_nums is None, use all data provided by sample_generator as calibrate data. - scope(fluid.Scope, optional): The scope of the program, use it to load - and save variables. If scope=None, get scope by global_scope(). algo(str, optional): If algo=KL, use KL-divergenc method to get the more precise scale factor. If algo='direct', use abs_max methon to get the scale factor. Default is KL. @@ -104,6 +106,8 @@ class PostTrainingQuantization(object): apply quantization to all supported quantizable op type. If set is_full_quantized as False, only apply quantization to the op type according to the input quantizable_op_type. + weight_bits(int, optional): quantization bit number for weights. + activation_bits(int): quantization bit number for activation. is_use_cache_file(bool, optional): If set is_use_cache_file as False, all temp data will be saved in memory. If set is_use_cache_file as True, it will save temp data to disk. When the fp32 model is complex or @@ -150,14 +154,20 @@ class PostTrainingQuantization(object): ptq.quantize() ptq.save_quantized_model(save_model_path) ''' + + assert executor is not None, "The executor cannot be None." + assert model_dir is not None, "The model_dir cannot be None." + assert sample_generator is not None, \ + "The sample_generator cannot be None." + self._executor = executor - self._sample_generator = sample_generator + self._scope = global_scope() if scope == None else scope self._model_dir = model_dir self._model_filename = model_filename self._params_filename = params_filename + self._sample_generator = sample_generator self._batch_size = batch_size self._batch_nums = batch_nums - self._scope = global_scope() if scope == None else scope self._algo = algo self._is_use_cache_file = is_use_cache_file self._cache_dir = cache_dir @@ -604,7 +614,7 @@ class WeightQuantization(object): save_model_filename=None, save_params_filename=None, quantizable_op_type=["conv2d", "mul"], - quantize_weight_bits=8, + weight_bits=8, threshold_rate=0.0): ''' In order to reduce the size of model, this api quantizes the weight @@ -624,8 +634,8 @@ class WeightQuantization(object): that will be quantized, and the quantized ops should be contained in ["conv2d", "depthwise_conv2d", "mul"]. Default is ["conv2d","mul"]. - quantize_weight_bits(int, optional): The bits for the quantized - weight, and it should be 8 or 16. Default is 8. + weight_bits(int, optional): The bits for the quantized weight, + and it should be 8 or 16. Default is 8. threshold_rate(float, optional): This api uses abs_max methd to quantize the weight from float32 to int8/16, and the abs max value is important for quantization diff. When the abs_max @@ -637,10 +647,10 @@ class WeightQuantization(object): assert op_type in self._supported_quantizable_op_type, \ "input error:" + op_type + \ " is not supported for weight quantization." - assert quantize_weight_bits in [8, 16], \ - "input error: quantize_weight_bits should be 8 or 16." - quantize_range = (1 << (quantize_weight_bits - 1)) - 1 - save_weight_dtype = np.int8 if quantize_weight_bits == 8 else np.int16 + assert weight_bits in [8, 16], \ + "input error: weight_bits should be 8 or 16." + quantize_range = (1 << (weight_bits - 1)) - 1 + save_weight_dtype = np.int8 if weight_bits == 8 else np.int16 place = core.CPUPlace() exe = Executor(place) @@ -677,8 +687,7 @@ class WeightQuantization(object): _set_variable_data(scope, place, var_name, quantized_var_tensor_data) op._set_attr(var_name + "_quant_scale", [scale]) - op._set_attr('quantize_weight_bits', - quantize_weight_bits) + op._set_attr('quantize_weight_bits', weight_bits) io.save_inference_model( dirname=save_model_dir, diff --git a/python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py index c6380adf6b63cffbcbcc7d5e75a86926e6bcde8b..e872a26d4b319f0a1e4dd2338aabaf9ee16bbd32 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py @@ -42,8 +42,8 @@ class TestWeightQuantization(unittest.TestCase): zip_path) os.system(cmd) - def run_test(self, model_name, model_data_url, model_data_md5, - quantize_weight_bits, quantizable_op_type, threshold_rate): + def run_test(self, model_name, model_data_url, model_data_md5, weight_bits, + quantizable_op_type, threshold_rate): model_dir = self.download_model(model_name, model_data_url, model_data_md5) @@ -51,11 +51,11 @@ class TestWeightQuantization(unittest.TestCase): timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) save_model_dir = os.path.join( os.getcwd(), - model_name + "_wq_" + str(quantize_weight_bits) + "_" + timestamp) + model_name + "_wq_" + str(weight_bits) + "_" + timestamp) weight_quant = WeightQuantization(model_dir=model_dir + "/model") weight_quant.quantize_weight_to_int( save_model_dir=save_model_dir, - quantize_weight_bits=quantize_weight_bits, + weight_bits=weight_bits, quantizable_op_type=quantizable_op_type, threshold_rate=threshold_rate) print("finish weight quantization for " + model_name + "\n") @@ -73,18 +73,18 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization): model_data_md5 = "13892b0716d26443a8cdea15b3c6438b" def test_weight_quantization_mobilenetv1_8bit(self): - quantize_weight_bits = 8 + weight_bits = 8 quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] threshold_rate = 0.0 self.run_test(self.model_name, self.model_data_url, self.model_data_md5, - quantize_weight_bits, quantizable_op_type, threshold_rate) + weight_bits, quantizable_op_type, threshold_rate) def test_weight_quantization_mobilenetv1_16bit(self): - quantize_weight_bits = 16 + weight_bits = 16 quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] threshold_rate = 1e-9 self.run_test(self.model_name, self.model_data_url, self.model_data_md5, - quantize_weight_bits, quantizable_op_type, threshold_rate) + weight_bits, quantizable_op_type, threshold_rate) if __name__ == '__main__':