diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 09b171fe901c35376f081b5fdc5fbe1bb8687fab..8a41d79433a8dade7bd931b3c68c8c2c40f0250a 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -90,4 +90,5 @@ REGISTER_OP_CPU_KERNEL( ops::SaveOpKernel, ops::SaveOpKernel, ops::SaveOpKernel, + ops::SaveOpKernel, ops::SaveOpKernel); diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index d0d69ae91a16b670494f1ec5310a029e52d8894b..ae2298e10a38fe5eb14c8b43cc022383e4bc7ab8 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -16,10 +16,10 @@ import os import re import logging import numpy as np -from ....executor import global_scope from .... import io from .... import core from .... import framework +from ....executor import global_scope, Executor from ....framework import IrGraph from ....log_helper import get_logger from .quantization_pass import QuantizationTransformPass @@ -27,12 +27,31 @@ from .quantization_pass import QuantizationFreezePass from .quantization_pass import AddQuantDequantPass from .quantization_pass import _op_real_in_out_name -__all__ = ['PostTrainingQuantization'] +__all__ = ['PostTrainingQuantization', 'WeightQuantization'] _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') +def _load_variable_data(scope, var_name): + ''' + Load variable value from scope + ''' + return np.array(scope.find_var(var_name).get_tensor()) + + +def _set_variable_data(scope, place, var_name, np_value): + ''' + Set the value of var node by name, if the node exits, + ''' + assert isinstance(np_value, np.ndarray), \ + 'The type of value should be numpy array.' + var_node = scope.find_var(var_name) + if var_node != None: + tensor = var_node.get_tensor() + tensor.set(np_value, place) + + class PostTrainingQuantization(object): def __init__(self, executor, @@ -297,12 +316,12 @@ class PostTrainingQuantization(object): ''' for var_name in self._quantized_weight_var_name: if var_name not in self._sampling_data: - var_tensor = self._load_var_value(var_name) + var_tensor = _load_variable_data(self._scope, var_name) self._sampling_data[var_name] = var_tensor if self._is_use_cache_file: for var_name in self._quantized_act_var_name: - var_tensor = self._load_var_value(var_name) + var_tensor = _load_variable_data(self._scope, var_name) var_tensor = var_tensor.ravel() save_path = os.path.join(self._cache_dir, var_name + "_" + str(iter) + ".npy") @@ -311,7 +330,7 @@ class PostTrainingQuantization(object): for var_name in self._quantized_act_var_name: if var_name not in self._sampling_data: self._sampling_data[var_name] = [] - var_tensor = self._load_var_value(var_name) + var_tensor = _load_variable_data(self._scope, var_name) var_tensor = var_tensor.ravel() self._sampling_data[var_name].append(var_tensor) @@ -397,11 +416,17 @@ class PostTrainingQuantization(object): # save scale factor to scale var node for key, val in self._quantized_var_scale_factor.items(): - self._set_var_node_value( - key + ".scale", np.array( + _set_variable_data( + self._scope, + self._place, + key + ".scale", + np.array( [val], dtype=np.float32)) - self._set_var_node_value( - key + ".quant_dequant.scale", np.array( + _set_variable_data( + self._scope, + self._place, + key + ".quant_dequant.scale", + np.array( [val], dtype=np.float32)) # apply QuantizationFreezePass, and obtain the final quant model @@ -430,23 +455,6 @@ class PostTrainingQuantization(object): self._quantized_var_scale_factor[ output_var_name]) - def _load_var_value(self, var_name): - ''' - Load variable value from scope - ''' - return np.array(self._scope.find_var(var_name).get_tensor()) - - def _set_var_node_value(self, var_node_name, np_value): - ''' - Set the value of var node by name, if the node exits, - ''' - assert isinstance(np_value, np.ndarray), \ - 'The type of value should be numpy array.' - var_node = self._scope.find_var(var_node_name) - if var_node != None: - tensor = var_node.get_tensor() - tensor.set(np_value, self._place) - def _is_input_all_not_persistable(self, op, persistable_var_names): ''' Analyze the real inputs of the op are all not persistable. @@ -566,3 +574,132 @@ class PostTrainingQuantization(object): tmp_sum1 += p_idx * (math.log(Q_sum * p_idx)) tmp_sum2 += p_idx * (math.log(P_sum * q_idx)) return (tmp_sum1 - tmp_sum2) / P_sum + + +class WeightQuantization(object): + _supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] + + def __init__(self, model_dir, model_filename=None, params_filename=None): + ''' + This class quantizes the weight of some ops to reduce the size of model + or improve the perforemace. + + Args: + model_dir(str): The path of the fp32 model that will be quantized, + and the model and params files are under the path. + model_filename(str, optional): The name of file to load the inference + program. If it is None, the default filename '__model__' will + be used. Default is 'None'. + params_filename(str, optional): The name of file to load all parameters. + When all parameters were saved in a single binary file, set it + as the real filename. If parameters were saved in separate files, + set it as 'None'. Default is 'None'. + ''' + self._model_dir = model_dir + self._model_filename = model_filename + self._params_filename = params_filename + + def quantize_weight_to_int(self, + save_model_dir, + save_model_filename=None, + save_params_filename=None, + quantizable_op_type=["conv2d", "mul"], + quantize_weight_bits=8, + threshold_rate=0.0): + ''' + In order to reduce the size of model, this api quantizes the weight + of some ops from float32 to int8/16. In the inference stage, the + quantized weight will be dequantized to float32 again. + + Args: + save_model_dir(str): The path to save the quantized model. + save_model_filename(str, optional): The name of file to + save the inference program. If it is None, the default + filename '__model__' will be used. Default is 'None'. + save_params_filename(str, optional): The name of file to + save all parameters. If it is None, parameters were + saved in separate files. If it is not None, all + parameters were saved in a single binary file. + quantizable_op_type(list[str], optional): The list of ops + that will be quantized, and the quantized ops should be + contained in ["conv2d", "depthwise_conv2d", "mul"]. + Default is ["conv2d","mul"]. + quantize_weight_bits(int, optional): The bits for the quantized + weight, and it should be 8 or 16. Default is 8. + threshold_rate(float, optional): This api uses abs_max methd to + quantize the weight from float32 to int8/16, and the abs max + value is important for quantization diff. When the abs_max + value is far away from the center of the numerical distribution, + we can set threshold_rate between 1e-6 and 1e-8, so the abs max + value will be optimized. Default is 0.0. + ''' + for op_type in quantizable_op_type: + assert op_type in self._supported_quantizable_op_type, \ + "input error:" + op_type + \ + " is not supported for weight quantization." + assert quantize_weight_bits in [8, 16], \ + "input error: quantize_weight_bits should be 8 or 16." + quantize_range = (1 << (quantize_weight_bits - 1)) - 1 + save_weight_dtype = np.int8 if quantize_weight_bits == 8 else np.int16 + + place = core.CPUPlace() + exe = Executor(place) + scope = global_scope() + [program, feed_list, fetch_list] = \ + io.load_inference_model(dirname=self._model_dir, + executor=exe, + model_filename=self._model_filename, + params_filename=self._params_filename) + + persistable_var_names = [] + for var in program.list_vars(): + if var.persistable: + persistable_var_names.append(var.name) + for op in program.global_block().ops: + if op.type in quantizable_op_type: + for var_name in op.input_arg_names: + if var_name in persistable_var_names: + var_tensor_data = _load_variable_data(scope, var_name) + if abs(threshold_rate) < 1e-10: + threshold_value = np.max(np.abs(var_tensor_data)) + else: + threshold_value = self._calculate_threshold(\ + var_tensor_data, threshold_rate) + var_tensor_data[var_tensor_data > + threshold_value] = threshold_value + var_tensor_data[var_tensor_data < + -threshold_value] = -threshold_value + scale = threshold_value / quantize_range + quantized_var_tensor_data = \ + np.around(var_tensor_data / scale) + quantized_var_tensor_data = \ + quantized_var_tensor_data.astype(save_weight_dtype) + _set_variable_data(scope, place, var_name, + quantized_var_tensor_data) + op._set_attr(var_name + "_quant_scale", [scale]) + op._set_attr('quantize_weight_bits', + quantize_weight_bits) + + io.save_inference_model( + dirname=save_model_dir, + feeded_var_names=feed_list, + target_vars=fetch_list, + executor=exe, + main_program=program, + model_filename=save_model_filename, + params_filename=save_params_filename) + + def _calculate_threshold(self, input, threshold_rate, histogram_bins=5000): + input_abs = np.abs(input) + hist, hist_edeges = np.histogram( + input_abs, bins=histogram_bins, range=(0, np.max(input_abs))) + hist = hist / float(sum(hist)) + hist_sum = 0 + hist_index = 0 + for i in range(len(hist)): + hist_sum += hist[i] + if hist_sum >= 1.0 - threshold_rate: + hist_index = i + 1 + break + bin_width = hist_edeges[1] - hist_edeges[0] + return hist_index * bin_width diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 65fd4984d95e9fdf98f86a23b06d52ef477ecd8b..8b80cac9018b71b3bffb5184e308f641dca08f18 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -63,6 +63,7 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_light_nas) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50) + list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1) endif() # int8 image classification python api test diff --git a/python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py new file mode 100644 index 0000000000000000000000000000000000000000..c6380adf6b63cffbcbcc7d5e75a86926e6bcde8b --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py @@ -0,0 +1,91 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import time +from paddle.dataset.common import download, DATA_HOME +from paddle.fluid.contrib.slim.quantization import WeightQuantization + + +class TestWeightQuantization(unittest.TestCase): + def setUp(self): + self.weight_quantization_dir = 'weight_quantization' + self.cache_folder = os.path.join(DATA_HOME, + self.weight_quantization_dir) + + def download_model(self, model_name, data_url, data_md5): + download(data_url, self.weight_quantization_dir, data_md5) + file_name = data_url.split('/')[-1] + file_path = os.path.join(self.cache_folder, file_name) + print(model_name + ' is downloaded at ' + file_path) + + unziped_path = os.path.join(self.cache_folder, model_name) + self.cache_unzipping(unziped_path, file_path) + print(model_name + ' is unziped at ' + unziped_path) + return unziped_path + + def cache_unzipping(self, target_folder, zip_path): + if not os.path.exists(target_folder): + cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, + zip_path) + os.system(cmd) + + def run_test(self, model_name, model_data_url, model_data_md5, + quantize_weight_bits, quantizable_op_type, threshold_rate): + + model_dir = self.download_model(model_name, model_data_url, + model_data_md5) + + timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) + save_model_dir = os.path.join( + os.getcwd(), + model_name + "_wq_" + str(quantize_weight_bits) + "_" + timestamp) + weight_quant = WeightQuantization(model_dir=model_dir + "/model") + weight_quant.quantize_weight_to_int( + save_model_dir=save_model_dir, + quantize_weight_bits=quantize_weight_bits, + quantizable_op_type=quantizable_op_type, + threshold_rate=threshold_rate) + print("finish weight quantization for " + model_name + "\n") + + try: + os.system("rm -rf {}".format(save_model_dir)) + except Exception as e: + print("Failed to delete {} due to {}".format(save_model_dir, str( + e))) + + +class TestWeightQuantizationMobilenetv1(TestWeightQuantization): + model_name = "mobilenetv1" + model_data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz" + model_data_md5 = "13892b0716d26443a8cdea15b3c6438b" + + def test_weight_quantization_mobilenetv1_8bit(self): + quantize_weight_bits = 8 + quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] + threshold_rate = 0.0 + self.run_test(self.model_name, self.model_data_url, self.model_data_md5, + quantize_weight_bits, quantizable_op_type, threshold_rate) + + def test_weight_quantization_mobilenetv1_16bit(self): + quantize_weight_bits = 16 + quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] + threshold_rate = 1e-9 + self.run_test(self.model_name, self.model_data_url, self.model_data_md5, + quantize_weight_bits, quantizable_op_type, threshold_rate) + + +if __name__ == '__main__': + unittest.main()