未验证 提交 1bae1e74 编写于 作者: C cc 提交者: GitHub

Support converting the model from fp32 to fp16 (#32112)

* Support converting the model from fp32 to fp16
上级 e45c3fa5
...@@ -88,6 +88,8 @@ REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker, ...@@ -88,6 +88,8 @@ REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
save, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, float>, save, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, double>, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int>, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, uint8_t>, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
......
...@@ -16,9 +16,11 @@ import os ...@@ -16,9 +16,11 @@ import os
import re import re
import logging import logging
import numpy as np import numpy as np
import shutil
from .... import io from .... import io
from .... import core from .... import core
from .... import framework from .... import framework
from .... import unique_name
from ....executor import global_scope, Executor from ....executor import global_scope, Executor
from ....framework import IrGraph from ....framework import IrGraph
from ....log_helper import get_logger from ....log_helper import get_logger
...@@ -1006,6 +1008,82 @@ class WeightQuantization(object): ...@@ -1006,6 +1008,82 @@ class WeightQuantization(object):
quantizable_op_type, weight_bits, weight_quantize_type, True, quantizable_op_type, weight_bits, weight_quantize_type, True,
threshold_rate) threshold_rate)
def convert_weight_to_fp16(self, save_model_dir):
"""
Convert all presistable vars from fp32 to fp16.
Note that, this api only changes the data type of variables in
__params__ file, and the __model__ file remains unchanged.
Args:
save_model_dir(str): The path to save the fp16 model.
"""
# Load model
place = core.CPUPlace()
exe = Executor(place)
scope = global_scope()
[infer_program, feed_list, fetch_list] = \
io.load_inference_model(dirname=self._model_dir,
executor=exe,
model_filename=self._model_filename,
params_filename=self._params_filename)
# Clone and save fp16 weights
save_program = framework.Program()
save_block = save_program.global_block()
save_var_map = {}
for var in infer_program.list_vars():
if (var.type == core.VarDesc.VarType.RAW) or \
(not var.persistable) or (var.name in ['feed', 'fetch']) \
or (var.dtype != core.VarDesc.VarType.FP32):
continue
#new_var = _clone_var_to_block_(var, save_block)
new_var = save_block._clone_variable(var)
if self._params_filename is not None:
save_var_map[new_var.name] = new_var
else:
save_file_path = os.path.join(
os.path.normpath(save_model_dir), new_var.name)
save_block.append_op(
type='save',
inputs={'X': [new_var]},
outputs={},
attrs={
'file_path': os.path.normpath(save_file_path),
'save_as_fp16': True
})
if self._params_filename is not None:
save_var_list = []
for name in sorted(save_var_map.keys()):
save_var_list.append(save_var_map[name])
saved_params_var = save_block.create_var(
type=core.VarDesc.VarType.RAW,
name=unique_name.generate("saved_params"))
saved_params_var.desc.set_persistable(True)
save_path = os.path.join(
os.path.normpath(save_model_dir), self._params_filename)
save_block.append_op(
type='save_combine',
inputs={'X': save_var_list},
outputs={'Y': saved_params_var},
attrs={'file_path': save_path,
'save_as_fp16': True})
save_program._sync_with_cpp()
exe.run(save_program)
# Copy model
model_filename = "__model__" if self._model_filename is None \
else self._model_filename
src_model = os.path.join(self._model_dir, model_filename)
dest_model = os.path.join(save_model_dir, model_filename)
shutil.copyfile(src_model, dest_model)
def _quantize_weight_to_int(self, save_model_dir, save_model_filename, def _quantize_weight_to_int(self, save_model_dir, save_model_filename,
save_params_filename, quantizable_op_type, save_params_filename, quantizable_op_type,
weight_bits, weight_quantize_type, for_test, weight_bits, weight_quantize_type, for_test,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import os import os
import time import time
import numpy as np
from paddle.dataset.common import download, DATA_HOME from paddle.dataset.common import download, DATA_HOME
from paddle.fluid.contrib.slim.quantization import WeightQuantization from paddle.fluid.contrib.slim.quantization import WeightQuantization
import paddle import paddle
...@@ -22,6 +23,28 @@ import paddle ...@@ -22,6 +23,28 @@ import paddle
paddle.enable_static() paddle.enable_static()
def _load_variable_data(scope, var_name):
'''
Load variable value from scope
'''
var_node = scope.find_var(var_name)
assert var_node is not None, \
"Cannot find " + var_name + " in scope."
return np.array(var_node.get_tensor())
def _set_variable_data(scope, place, var_name, np_value):
'''
Set the value of var node by name, if the node exits,
'''
assert isinstance(np_value, np.ndarray), \
'The type of value should be numpy array.'
var_node = scope.find_var(var_name)
if var_node != None:
tensor = var_node.get_tensor()
tensor.set(np_value, place)
class TestWeightQuantization(unittest.TestCase): class TestWeightQuantization(unittest.TestCase):
def setUp(self): def setUp(self):
self.weight_quantization_dir = 'weight_quantization' self.weight_quantization_dir = 'weight_quantization'
...@@ -45,18 +68,20 @@ class TestWeightQuantization(unittest.TestCase): ...@@ -45,18 +68,20 @@ class TestWeightQuantization(unittest.TestCase):
zip_path) zip_path)
os.system(cmd) os.system(cmd)
def run_test(self, model_name, model_data_url, model_data_md5, weight_bits, def quantize_to_int(self, model_name, model_data_url, model_data_md5,
quantizable_op_type, weight_quantize_type, generate_test_model, weight_bits, quantizable_op_type, weight_quantize_type,
threshold_rate): generate_test_model, threshold_rate):
model_dir = self.download_model(model_name, model_data_url, model_dir = self.download_model(model_name, model_data_url,
model_data_md5) model_data_md5)
load_model_dir = os.path.join(model_dir, model_name)
timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
save_model_dir = os.path.join( save_model_dir = os.path.join(
os.getcwd(), os.getcwd(),
model_name + "_wq_" + str(weight_bits) + "_" + timestamp) model_name + "_wq_" + str(weight_bits) + "_" + timestamp)
weight_quant = WeightQuantization(model_dir=model_dir + "/model")
weight_quant = WeightQuantization(model_dir=load_model_dir)
weight_quant.quantize_weight_to_int( weight_quant.quantize_weight_to_int(
save_model_dir=save_model_dir, save_model_dir=save_model_dir,
weight_bits=weight_bits, weight_bits=weight_bits,
...@@ -72,11 +97,79 @@ class TestWeightQuantization(unittest.TestCase): ...@@ -72,11 +97,79 @@ class TestWeightQuantization(unittest.TestCase):
print("Failed to delete {} due to {}".format(save_model_dir, str( print("Failed to delete {} due to {}".format(save_model_dir, str(
e))) e)))
def convert_to_fp16(self, model_name, model_data_url, model_data_md5,
model_filename, params_filename):
model_dir = self.download_model(model_name, model_data_url,
model_data_md5)
load_model_dir = os.path.join(model_dir, model_name)
timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
save_model_dir = os.path.join(os.getcwd(),
model_name + "_wq_fp16_" + timestamp)
weight_quant = WeightQuantization(load_model_dir, model_filename,
params_filename)
weight_quant.convert_weight_to_fp16(save_model_dir)
print("finish converting the data type of weights to fp16 for " +
model_name)
print("fp16 model saved in " + save_model_dir + "\n")
input_data = np.ones([1, 3, 224, 224], dtype=np.float32)
res_fp32 = self.run_models(load_model_dir, model_filename,
params_filename, input_data, False)
res_fp16 = self.run_models(save_model_dir, model_filename,
params_filename, input_data, True)
self.assertTrue(
np.allclose(
res_fp32, res_fp16, rtol=1e-5, atol=1e-08, equal_nan=True),
msg='Failed to test the accuracy of the fp32 and fp16 model.')
try:
os.system("rm -rf {}".format(save_model_dir))
except Exception as e:
print("Failed to delete {} due to {}".format(save_model_dir, str(
e)))
def run_models(self, model_dir, model_filename, params_filename, input_data,
is_fp16_model):
print(model_dir)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
[inference_program, feed_target_names, fetch_targets] = \
paddle.fluid.io.load_inference_model(model_dir, exe,
model_filename=model_filename,
params_filename=params_filename)
if is_fp16_model:
for var in inference_program.list_vars():
if (var.type == paddle.fluid.core.VarDesc.VarType.RAW) or \
(not var.persistable) or (var.name in ['feed', 'fetch']) \
or (var.dtype != paddle.fluid.core.VarDesc.VarType.FP16):
continue
tensor = _load_variable_data(scope, var.name)
_set_variable_data(scope, place, var.name,
tensor.astype(np.float32))
results = exe.run(inference_program,
feed={feed_target_names[0]: input_data},
fetch_list=fetch_targets)
return np.array(results[0])
class TestWeightQuantizationMobilenetv1(TestWeightQuantization): class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
model_name = "mobilenetv1" nocomb_model_name = "mobilenetv1_fp32_nocombined"
model_data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz" nocomb_model_data_url = "https://paddle-inference-dist.cdn.bcebos.com/Paddle-Inference-Demo/mobilenetv1_fp32_nocombined.tar.gz"
model_data_md5 = "13892b0716d26443a8cdea15b3c6438b" nocomb_model_data_md5 = "c9aae3b04d9d535c84590ae557be0a0b"
comb_model_name = "mobilenetv1_fp32_combined"
comb_model_data_url = "https://paddle-inference-dist.cdn.bcebos.com/Paddle-Inference-Demo/mobilenetv1_fp32_combined.tar.gz"
comb_model_data_md5 = "087c67e2b2b0a8b689fcc570a56c005f"
def test_weight_quantization_mobilenetv1_8bit_abs_max(self): def test_weight_quantization_mobilenetv1_8bit_abs_max(self):
weight_bits = 8 weight_bits = 8
...@@ -84,9 +177,10 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization): ...@@ -84,9 +177,10 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
weight_quantize_type = "abs_max" weight_quantize_type = "abs_max"
generate_test_model = True generate_test_model = True
threshold_rate = 0.0 threshold_rate = 0.0
self.run_test(self.model_name, self.model_data_url, self.model_data_md5, self.quantize_to_int(self.nocomb_model_name, self.nocomb_model_data_url,
weight_bits, quantizable_op_type, weight_quantize_type, self.nocomb_model_data_md5, weight_bits,
generate_test_model, threshold_rate) quantizable_op_type, weight_quantize_type,
generate_test_model, threshold_rate)
def test_weight_quantization_mobilenetv1_8bit_channel_wise_abs_max(self): def test_weight_quantization_mobilenetv1_8bit_channel_wise_abs_max(self):
weight_bits = 8 weight_bits = 8
...@@ -94,19 +188,21 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization): ...@@ -94,19 +188,21 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
weight_quantize_type = "channel_wise_abs_max" weight_quantize_type = "channel_wise_abs_max"
generate_test_model = True generate_test_model = True
threshold_rate = 0.0 threshold_rate = 0.0
self.run_test(self.model_name, self.model_data_url, self.model_data_md5, self.quantize_to_int(self.nocomb_model_name, self.nocomb_model_data_url,
weight_bits, quantizable_op_type, weight_quantize_type, self.nocomb_model_data_md5, weight_bits,
generate_test_model, threshold_rate) quantizable_op_type, weight_quantize_type,
generate_test_model, threshold_rate)
def test_weight_quantization_mobilenetv1_16bit_abs_max(self): def test_weight_quantization_mobilenetv1_16bit_abs_max(self):
weight_bits = 16 weight_bits = 16
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
weight_quantize_type = "abs_max" weight_quantize_type = "abs_max"
generate_test_model = False generate_test_model = False
threshold_rate = 1e-9 threshold_rate = 0
self.run_test(self.model_name, self.model_data_url, self.model_data_md5, self.quantize_to_int(self.nocomb_model_name, self.nocomb_model_data_url,
weight_bits, quantizable_op_type, weight_quantize_type, self.nocomb_model_data_md5, weight_bits,
generate_test_model, threshold_rate) quantizable_op_type, weight_quantize_type,
generate_test_model, threshold_rate)
def test_weight_quantization_mobilenetv1_16bit_channel_wise_abs_max(self): def test_weight_quantization_mobilenetv1_16bit_channel_wise_abs_max(self):
weight_bits = 16 weight_bits = 16
...@@ -114,9 +210,24 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization): ...@@ -114,9 +210,24 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
weight_quantize_type = "channel_wise_abs_max" weight_quantize_type = "channel_wise_abs_max"
generate_test_model = False generate_test_model = False
threshold_rate = 1e-9 threshold_rate = 1e-9
self.run_test(self.model_name, self.model_data_url, self.model_data_md5, self.quantize_to_int(self.nocomb_model_name, self.nocomb_model_data_url,
weight_bits, quantizable_op_type, weight_quantize_type, self.nocomb_model_data_md5, weight_bits,
generate_test_model, threshold_rate) quantizable_op_type, weight_quantize_type,
generate_test_model, threshold_rate)
def test_mobilenetv1_fp16_combined(self):
model_filename = '__model__'
params_filename = '__params__'
self.convert_to_fp16(self.comb_model_name, self.comb_model_data_url,
self.comb_model_data_md5, model_filename,
params_filename)
def test_mobilenetv1_fp16_nocombined(self):
model_filename = None
params_filename = None
self.convert_to_fp16(self.nocomb_model_name, self.nocomb_model_data_url,
self.nocomb_model_data_md5, model_filename,
params_filename)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册