diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index 6335c80e0839bc3bb33ebfa78c99a7037c6799bf..b61f4acaee57543d51cd7aadb8163d164999c274 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -57,6 +57,14 @@ def _set_variable_data(scope, place, var_name, np_value): tensor.set(np_value, place) +def _all_persistable_var_names(program): + persistable_var_names = [] + for var in program.list_vars(): + if var.persistable: + persistable_var_names.append(var.name) + return persistable_var_names + + class PostTrainingQuantization(object): """ Utilizing post training quantization methon to quantize the FP32 model, @@ -365,11 +373,7 @@ class PostTrainingQuantization(object): else: self._quantized_act_var_name.add(var_name) - persistable_var_names = [] - for var in self._program.list_vars(): - if var.persistable: - persistable_var_names.append(var.name) - + persistable_var_names = _all_persistable_var_names(self._program) for op in self._program.global_block().ops: op_type = op.type # For quantized ops, sample inputs and outputs @@ -738,6 +742,7 @@ class PostTrainingQuantization(object): class WeightQuantization(object): _supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] + _supported_weight_quantize_type = ['channel_wise_abs_max', 'abs_max'] def __init__(self, model_dir, model_filename=None, params_filename=None): ''' @@ -765,6 +770,8 @@ class WeightQuantization(object): save_params_filename=None, quantizable_op_type=["conv2d", "mul"], weight_bits=8, + weight_quantize_type="channel_wise_abs_max", + generate_test_model=False, threshold_rate=0.0): ''' In order to reduce the size of model, this api quantizes the weight @@ -786,6 +793,13 @@ class WeightQuantization(object): Default is ["conv2d","mul"]. weight_bits(int, optional): The bits for the quantized weight, and it should be 8 or 16. Default is 8. + weight_quantize_type(str, optional): quantization type for weights, + support 'channel_wise_abs_max' and 'abs_max'. Set it as + 'channel_wise_abs_max', the accuracy performs better. + generate_test_model(bool, optional): If set generate_test_model + as True, it saves a fake quantized model, in which the weights + are quantized and dequantized. We can use PaddlePaddle to load + the fake quantized model and test the accuracy on GPU or CPU. threshold_rate(float, optional): This api uses abs_max methd to quantize the weight from float32 to int8/16, and the abs max value is important for quantization diff. When the abs_max @@ -795,13 +809,35 @@ class WeightQuantization(object): ''' for op_type in quantizable_op_type: assert op_type in self._supported_quantizable_op_type, \ - "input error:" + op_type + \ + "Input error:" + op_type + \ " is not supported for weight quantization." assert weight_bits in [8, 16], \ - "input error: weight_bits should be 8 or 16." - quantize_range = (1 << (weight_bits - 1)) - 1 - save_weight_dtype = np.int8 if weight_bits == 8 else np.int16 - + "Input error: weight_bits should be 8 or 16." + assert weight_quantize_type in self._supported_weight_quantize_type, \ + "Input error: weight_quantize_type should in {}".format( + self._supported_weight_quantize_type) + + quantized_model_dir = os.path.join(save_model_dir, "quantized_model") + self._quantize_weight_to_int(quantized_model_dir, save_model_filename, + save_params_filename, quantizable_op_type, + weight_bits, weight_quantize_type, False, + threshold_rate) + + if generate_test_model: + test_model_dir = os.path.join(save_model_dir, "test_model") + self._quantize_weight_to_int( + test_model_dir, save_model_filename, save_params_filename, + quantizable_op_type, weight_bits, weight_quantize_type, True, + threshold_rate) + + def _quantize_weight_to_int(self, save_model_dir, save_model_filename, + save_params_filename, quantizable_op_type, + weight_bits, weight_quantize_type, for_test, + threshold_rate): + """ + Generate quantized model or fake quantized model. + """ + # Load model place = core.CPUPlace() exe = Executor(place) scope = global_scope() @@ -811,33 +847,25 @@ class WeightQuantization(object): model_filename=self._model_filename, params_filename=self._params_filename) - persistable_var_names = [] - for var in program.list_vars(): - if var.persistable: - persistable_var_names.append(var.name) - for op in program.global_block().ops: - if op.type in quantizable_op_type: - for var_name in op.input_arg_names: - if var_name in persistable_var_names: - var_tensor_data = _load_variable_data(scope, var_name) - if abs(threshold_rate) < 1e-10: - threshold_value = np.max(np.abs(var_tensor_data)) - else: - threshold_value = self._calculate_threshold(\ - var_tensor_data, threshold_rate) - var_tensor_data[var_tensor_data > - threshold_value] = threshold_value - var_tensor_data[var_tensor_data < - -threshold_value] = -threshold_value - scale = threshold_value / quantize_range - quantized_var_tensor_data = \ - np.around(var_tensor_data / scale) - quantized_var_tensor_data = \ - quantized_var_tensor_data.astype(save_weight_dtype) - _set_variable_data(scope, place, var_name, - quantized_var_tensor_data) - op._set_attr(var_name + "_quant_scale", [scale]) - op._set_attr('quantize_weight_bits', weight_bits) + quantized_ops = [] + for index in range(program.num_blocks): + block = program.block(index) + for op in block.ops: + if op.type in quantizable_op_type: + quantized_ops.append(op) + + # Quantize weights + persistable_var_names = _all_persistable_var_names(program) + for op in quantized_ops: + for var_name in op.input_arg_names: + if var_name in persistable_var_names: + if weight_quantize_type == "abs_max": + self._weight_abs_max_quantization( + scope, place, weight_bits, threshold_rate, op, + var_name, for_test) + elif weight_quantize_type == "channel_wise_abs_max": + self._weight_channel_wise_abs_max_quantization( + scope, place, weight_bits, op, var_name, for_test) io.save_inference_model( dirname=save_model_dir, @@ -848,6 +876,137 @@ class WeightQuantization(object): model_filename=save_model_filename, params_filename=save_params_filename) + def _weight_abs_max_quantization(self, scope, place, weight_bits, + threshold_rate, op, var_name, for_test): + ''' + Use abs_max method to quantize weight. + ''' + quantize_range = (1 << (weight_bits - 1)) - 1 + save_weight_dtype = np.int8 if weight_bits == 8 else np.int16 + + # Get quantized scale and weight data + weight_data = _load_variable_data(scope, var_name) + if abs(threshold_rate) < 1e-10: + threshold_value = np.max(np.abs(weight_data)) + else: + threshold_value = self._calculate_threshold(\ + weight_data, threshold_rate) + weight_data[weight_data > threshold_value] = threshold_value + weight_data[weight_data < -threshold_value] = -threshold_value + scale = threshold_value / quantize_range + quantized_weight_data = \ + np.around(weight_data / scale).astype(save_weight_dtype) + + # Set weight data + if not for_test: + _set_variable_data(scope, place, var_name, quantized_weight_data) + else: + dequantized_weight_data = \ + (quantized_weight_data * scale).astype(np.float32) + _set_variable_data(scope, place, var_name, dequantized_weight_data) + + # Save info + op._set_attr('quantization_type', 'post_weight_abs_max') + op._set_attr('quantize_weight_bits', weight_bits) + op._set_attr(var_name + "_quant_scale", [scale]) # Save as list + + def _weight_channel_wise_abs_max_quantization( + self, scope, place, weight_bits, op, var_name, for_test): + ''' + Use channel_wise_abs_max method to quantize weight. + ''' + quantize_range = (1 << (weight_bits - 1)) - 1 + save_weight_dtype = np.int8 if weight_bits == 8 else np.int16 + + # Get quantized scale and weight data + weight_data = _load_variable_data(scope, var_name) + if op.type == "mul": + scales, quantized_weight_data = \ + self._mul_channel_wise_quantization(weight_data, + quantize_range, save_weight_dtype) + elif op.type in ["conv2d", "depthwise_conv2d"]: + scales, quantized_weight_data = \ + self._conv_channel_wise_quantization(weight_data, + quantize_range, save_weight_dtype) + else: + _logger.error(op.type + " is not supported by weight quantization") + + # Set weight data + if not for_test: + _set_variable_data(scope, place, var_name, quantized_weight_data) + else: + if op.type == "mul": + dequantized_weight_data = \ + self._mul_channel_wise_dequantization(quantized_weight_data, scales) + elif op.type in ["conv2d", "depthwise_conv2d"]: + dequantized_weight_data = \ + self._conv_channel_wise_dequantization(quantized_weight_data, scales) + else: + _logger.error(op.type + + " is not supported by weight quantization") + _set_variable_data(scope, place, var_name, dequantized_weight_data) + + # Save info + op._set_attr('quantization_type', 'post_weight_channel_wise_abs_max') + op._set_attr('quantize_weight_bits', weight_bits) + op._set_attr(var_name + "_quant_scale", scales) + + def _conv_channel_wise_quantization(self, weight_data, quantize_range, + save_weight_dtype): + ''' + Get channel wise scale for the weights of conv2d and depthwise_conv2d, + and quantize the weights. + ''' + scales = [] + quantized_weight_data = np.zeros_like( + weight_data, dtype=save_weight_dtype) + channel_num = weight_data.shape[0] + for i in range(channel_num): + scale = np.max(np.abs(weight_data[i])) / quantize_range + scales.append(scale) + quantized_weight_data[i] = \ + np.around(weight_data[i] / scale).astype(save_weight_dtype) + return scales, quantized_weight_data + + def _conv_channel_wise_dequantization(self, quantized_weight_data, scales): + ''' + For conv2d and depthwise_conv2d, dequantize the weights to fp32. + ''' + dequantized_weight_data = np.zeros_like( + quantized_weight_data, dtype=np.float32) + for i in range(len(scales)): + dequantized_weight_data[i] = \ + (quantized_weight_data[i] * scales[i]).astype(np.float32) + return dequantized_weight_data + + def _mul_channel_wise_quantization(self, weight_data, quantize_range, + save_weight_dtype): + ''' + Get channel wise scale for the weights of conv2d and depthwise_conv2d, + and quantize the weights. + ''' + scales = [] + quantized_weight_data = np.zeros_like( + weight_data, dtype=save_weight_dtype) + channel_num = weight_data.shape[-1] + for i in range(channel_num): + scale = np.max(np.abs(weight_data[:, i])) / quantize_range + scales.append(scale) + quantized_weight_data[:, i] = \ + np.around(weight_data[:, i] / scale).astype(save_weight_dtype) + return scales, quantized_weight_data + + def _mul_channel_wise_dequantization(self, quantized_weight_data, scales): + ''' + For mul, dequantize the weights to fp32. + ''' + dequantized_weight_data = np.zeros_like( + quantized_weight_data, dtype=np.float32) + for i in range(len(scales)): + dequantized_weight_data[:, i] = \ + (quantized_weight_data[:, i] * scales[i]).astype(np.float32) + return dequantized_weight_data + def _calculate_threshold(self, input, threshold_rate, histogram_bins=5000): input_abs = np.abs(input) hist, hist_edeges = np.histogram( diff --git a/python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py index e872a26d4b319f0a1e4dd2338aabaf9ee16bbd32..ff22b1b61e68f9c7d364b34a3b6b185a766f8c64 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py @@ -43,7 +43,8 @@ class TestWeightQuantization(unittest.TestCase): os.system(cmd) def run_test(self, model_name, model_data_url, model_data_md5, weight_bits, - quantizable_op_type, threshold_rate): + quantizable_op_type, weight_quantize_type, generate_test_model, + threshold_rate): model_dir = self.download_model(model_name, model_data_url, model_data_md5) @@ -57,6 +58,8 @@ class TestWeightQuantization(unittest.TestCase): save_model_dir=save_model_dir, weight_bits=weight_bits, quantizable_op_type=quantizable_op_type, + weight_quantize_type=weight_quantize_type, + generate_test_model=generate_test_model, threshold_rate=threshold_rate) print("finish weight quantization for " + model_name + "\n") @@ -72,19 +75,45 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization): model_data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz" model_data_md5 = "13892b0716d26443a8cdea15b3c6438b" - def test_weight_quantization_mobilenetv1_8bit(self): + def test_weight_quantization_mobilenetv1_8bit_abs_max(self): weight_bits = 8 quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] + weight_quantize_type = "abs_max" + generate_test_model = True threshold_rate = 0.0 self.run_test(self.model_name, self.model_data_url, self.model_data_md5, - weight_bits, quantizable_op_type, threshold_rate) + weight_bits, quantizable_op_type, weight_quantize_type, + generate_test_model, threshold_rate) - def test_weight_quantization_mobilenetv1_16bit(self): + def test_weight_quantization_mobilenetv1_8bit_channel_wise_abs_max(self): + weight_bits = 8 + quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] + weight_quantize_type = "channel_wise_abs_max" + generate_test_model = True + threshold_rate = 0.0 + self.run_test(self.model_name, self.model_data_url, self.model_data_md5, + weight_bits, quantizable_op_type, weight_quantize_type, + generate_test_model, threshold_rate) + + def test_weight_quantization_mobilenetv1_16bit_abs_max(self): + weight_bits = 16 + quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] + weight_quantize_type = "abs_max" + generate_test_model = False + threshold_rate = 1e-9 + self.run_test(self.model_name, self.model_data_url, self.model_data_md5, + weight_bits, quantizable_op_type, weight_quantize_type, + generate_test_model, threshold_rate) + + def test_weight_quantization_mobilenetv1_16bit_channel_wise_abs_max(self): weight_bits = 16 quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] + weight_quantize_type = "channel_wise_abs_max" + generate_test_model = False threshold_rate = 1e-9 self.run_test(self.model_name, self.model_data_url, self.model_data_md5, - weight_bits, quantizable_op_type, threshold_rate) + weight_bits, quantizable_op_type, weight_quantize_type, + generate_test_model, threshold_rate) if __name__ == '__main__':