diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index f619f3d59cece50ea39666170fe57479334457b5..194274cdd5bb4d59188e171866f685b127cb1369 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -88,6 +88,8 @@ REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker, REGISTER_OP_CPU_KERNEL( save, ops::SaveOpKernel, ops::SaveOpKernel, + ops::SaveOpKernel, ops::SaveOpKernel, ops::SaveOpKernel, ops::SaveOpKernel, diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index b59534b5965adf96233d461f24854a96d643a27f..aba6005f0cfdf07b5f7374fbdeed2ed602620274 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -16,9 +16,11 @@ import os import re import logging import numpy as np +import shutil from .... import io from .... import core from .... import framework +from .... import unique_name from ....executor import global_scope, Executor from ....framework import IrGraph from ....log_helper import get_logger @@ -1006,6 +1008,82 @@ class WeightQuantization(object): quantizable_op_type, weight_bits, weight_quantize_type, True, threshold_rate) + def convert_weight_to_fp16(self, save_model_dir): + """ + Convert all presistable vars from fp32 to fp16. + Note that, this api only changes the data type of variables in + __params__ file, and the __model__ file remains unchanged. + + Args: + save_model_dir(str): The path to save the fp16 model. + """ + + # Load model + place = core.CPUPlace() + exe = Executor(place) + scope = global_scope() + [infer_program, feed_list, fetch_list] = \ + io.load_inference_model(dirname=self._model_dir, + executor=exe, + model_filename=self._model_filename, + params_filename=self._params_filename) + + # Clone and save fp16 weights + save_program = framework.Program() + save_block = save_program.global_block() + save_var_map = {} + + for var in infer_program.list_vars(): + if (var.type == core.VarDesc.VarType.RAW) or \ + (not var.persistable) or (var.name in ['feed', 'fetch']) \ + or (var.dtype != core.VarDesc.VarType.FP32): + continue + + #new_var = _clone_var_to_block_(var, save_block) + new_var = save_block._clone_variable(var) + if self._params_filename is not None: + save_var_map[new_var.name] = new_var + else: + save_file_path = os.path.join( + os.path.normpath(save_model_dir), new_var.name) + save_block.append_op( + type='save', + inputs={'X': [new_var]}, + outputs={}, + attrs={ + 'file_path': os.path.normpath(save_file_path), + 'save_as_fp16': True + }) + + if self._params_filename is not None: + save_var_list = [] + for name in sorted(save_var_map.keys()): + save_var_list.append(save_var_map[name]) + + saved_params_var = save_block.create_var( + type=core.VarDesc.VarType.RAW, + name=unique_name.generate("saved_params")) + saved_params_var.desc.set_persistable(True) + + save_path = os.path.join( + os.path.normpath(save_model_dir), self._params_filename) + save_block.append_op( + type='save_combine', + inputs={'X': save_var_list}, + outputs={'Y': saved_params_var}, + attrs={'file_path': save_path, + 'save_as_fp16': True}) + + save_program._sync_with_cpp() + exe.run(save_program) + + # Copy model + model_filename = "__model__" if self._model_filename is None \ + else self._model_filename + src_model = os.path.join(self._model_dir, model_filename) + dest_model = os.path.join(save_model_dir, model_filename) + shutil.copyfile(src_model, dest_model) + def _quantize_weight_to_int(self, save_model_dir, save_model_filename, save_params_filename, quantizable_op_type, weight_bits, weight_quantize_type, for_test, diff --git a/python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py index 1e8fa51d635e32d5d0169cf23ca0681051028ae9..744c97c514b3613d9642c3767ea4603adece68fd 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_weight_quantization_mobilenetv1.py @@ -15,6 +15,7 @@ import unittest import os import time +import numpy as np from paddle.dataset.common import download, DATA_HOME from paddle.fluid.contrib.slim.quantization import WeightQuantization import paddle @@ -22,6 +23,28 @@ import paddle paddle.enable_static() +def _load_variable_data(scope, var_name): + ''' + Load variable value from scope + ''' + var_node = scope.find_var(var_name) + assert var_node is not None, \ + "Cannot find " + var_name + " in scope." + return np.array(var_node.get_tensor()) + + +def _set_variable_data(scope, place, var_name, np_value): + ''' + Set the value of var node by name, if the node exits, + ''' + assert isinstance(np_value, np.ndarray), \ + 'The type of value should be numpy array.' + var_node = scope.find_var(var_name) + if var_node != None: + tensor = var_node.get_tensor() + tensor.set(np_value, place) + + class TestWeightQuantization(unittest.TestCase): def setUp(self): self.weight_quantization_dir = 'weight_quantization' @@ -45,18 +68,20 @@ class TestWeightQuantization(unittest.TestCase): zip_path) os.system(cmd) - def run_test(self, model_name, model_data_url, model_data_md5, weight_bits, - quantizable_op_type, weight_quantize_type, generate_test_model, - threshold_rate): + def quantize_to_int(self, model_name, model_data_url, model_data_md5, + weight_bits, quantizable_op_type, weight_quantize_type, + generate_test_model, threshold_rate): model_dir = self.download_model(model_name, model_data_url, model_data_md5) + load_model_dir = os.path.join(model_dir, model_name) timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) save_model_dir = os.path.join( os.getcwd(), model_name + "_wq_" + str(weight_bits) + "_" + timestamp) - weight_quant = WeightQuantization(model_dir=model_dir + "/model") + + weight_quant = WeightQuantization(model_dir=load_model_dir) weight_quant.quantize_weight_to_int( save_model_dir=save_model_dir, weight_bits=weight_bits, @@ -72,11 +97,79 @@ class TestWeightQuantization(unittest.TestCase): print("Failed to delete {} due to {}".format(save_model_dir, str( e))) + def convert_to_fp16(self, model_name, model_data_url, model_data_md5, + model_filename, params_filename): + model_dir = self.download_model(model_name, model_data_url, + model_data_md5) + load_model_dir = os.path.join(model_dir, model_name) + + timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) + save_model_dir = os.path.join(os.getcwd(), + model_name + "_wq_fp16_" + timestamp) + + weight_quant = WeightQuantization(load_model_dir, model_filename, + params_filename) + + weight_quant.convert_weight_to_fp16(save_model_dir) + + print("finish converting the data type of weights to fp16 for " + + model_name) + print("fp16 model saved in " + save_model_dir + "\n") + + input_data = np.ones([1, 3, 224, 224], dtype=np.float32) + res_fp32 = self.run_models(load_model_dir, model_filename, + params_filename, input_data, False) + res_fp16 = self.run_models(save_model_dir, model_filename, + params_filename, input_data, True) + + self.assertTrue( + np.allclose( + res_fp32, res_fp16, rtol=1e-5, atol=1e-08, equal_nan=True), + msg='Failed to test the accuracy of the fp32 and fp16 model.') + + try: + os.system("rm -rf {}".format(save_model_dir)) + except Exception as e: + print("Failed to delete {} due to {}".format(save_model_dir, str( + e))) + + def run_models(self, model_dir, model_filename, params_filename, input_data, + is_fp16_model): + print(model_dir) + + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + scope = paddle.static.Scope() + with paddle.static.scope_guard(scope): + [inference_program, feed_target_names, fetch_targets] = \ + paddle.fluid.io.load_inference_model(model_dir, exe, + model_filename=model_filename, + params_filename=params_filename) + + if is_fp16_model: + for var in inference_program.list_vars(): + if (var.type == paddle.fluid.core.VarDesc.VarType.RAW) or \ + (not var.persistable) or (var.name in ['feed', 'fetch']) \ + or (var.dtype != paddle.fluid.core.VarDesc.VarType.FP16): + continue + tensor = _load_variable_data(scope, var.name) + _set_variable_data(scope, place, var.name, + tensor.astype(np.float32)) + + results = exe.run(inference_program, + feed={feed_target_names[0]: input_data}, + fetch_list=fetch_targets) + return np.array(results[0]) + class TestWeightQuantizationMobilenetv1(TestWeightQuantization): - model_name = "mobilenetv1" - model_data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz" - model_data_md5 = "13892b0716d26443a8cdea15b3c6438b" + nocomb_model_name = "mobilenetv1_fp32_nocombined" + nocomb_model_data_url = "https://paddle-inference-dist.cdn.bcebos.com/Paddle-Inference-Demo/mobilenetv1_fp32_nocombined.tar.gz" + nocomb_model_data_md5 = "c9aae3b04d9d535c84590ae557be0a0b" + + comb_model_name = "mobilenetv1_fp32_combined" + comb_model_data_url = "https://paddle-inference-dist.cdn.bcebos.com/Paddle-Inference-Demo/mobilenetv1_fp32_combined.tar.gz" + comb_model_data_md5 = "087c67e2b2b0a8b689fcc570a56c005f" def test_weight_quantization_mobilenetv1_8bit_abs_max(self): weight_bits = 8 @@ -84,9 +177,10 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization): weight_quantize_type = "abs_max" generate_test_model = True threshold_rate = 0.0 - self.run_test(self.model_name, self.model_data_url, self.model_data_md5, - weight_bits, quantizable_op_type, weight_quantize_type, - generate_test_model, threshold_rate) + self.quantize_to_int(self.nocomb_model_name, self.nocomb_model_data_url, + self.nocomb_model_data_md5, weight_bits, + quantizable_op_type, weight_quantize_type, + generate_test_model, threshold_rate) def test_weight_quantization_mobilenetv1_8bit_channel_wise_abs_max(self): weight_bits = 8 @@ -94,19 +188,21 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization): weight_quantize_type = "channel_wise_abs_max" generate_test_model = True threshold_rate = 0.0 - self.run_test(self.model_name, self.model_data_url, self.model_data_md5, - weight_bits, quantizable_op_type, weight_quantize_type, - generate_test_model, threshold_rate) + self.quantize_to_int(self.nocomb_model_name, self.nocomb_model_data_url, + self.nocomb_model_data_md5, weight_bits, + quantizable_op_type, weight_quantize_type, + generate_test_model, threshold_rate) def test_weight_quantization_mobilenetv1_16bit_abs_max(self): weight_bits = 16 quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] weight_quantize_type = "abs_max" generate_test_model = False - threshold_rate = 1e-9 - self.run_test(self.model_name, self.model_data_url, self.model_data_md5, - weight_bits, quantizable_op_type, weight_quantize_type, - generate_test_model, threshold_rate) + threshold_rate = 0 + self.quantize_to_int(self.nocomb_model_name, self.nocomb_model_data_url, + self.nocomb_model_data_md5, weight_bits, + quantizable_op_type, weight_quantize_type, + generate_test_model, threshold_rate) def test_weight_quantization_mobilenetv1_16bit_channel_wise_abs_max(self): weight_bits = 16 @@ -114,9 +210,24 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization): weight_quantize_type = "channel_wise_abs_max" generate_test_model = False threshold_rate = 1e-9 - self.run_test(self.model_name, self.model_data_url, self.model_data_md5, - weight_bits, quantizable_op_type, weight_quantize_type, - generate_test_model, threshold_rate) + self.quantize_to_int(self.nocomb_model_name, self.nocomb_model_data_url, + self.nocomb_model_data_md5, weight_bits, + quantizable_op_type, weight_quantize_type, + generate_test_model, threshold_rate) + + def test_mobilenetv1_fp16_combined(self): + model_filename = '__model__' + params_filename = '__params__' + self.convert_to_fp16(self.comb_model_name, self.comb_model_data_url, + self.comb_model_data_md5, model_filename, + params_filename) + + def test_mobilenetv1_fp16_nocombined(self): + model_filename = None + params_filename = None + self.convert_to_fp16(self.nocomb_model_name, self.nocomb_model_data_url, + self.nocomb_model_data_md5, model_filename, + params_filename) if __name__ == '__main__':