未验证 提交 cbaebb04 编写于 作者: W Wilber 提交者: GitHub

convert to mixed model python api (#43881)

上级 74c9b57b
...@@ -404,6 +404,16 @@ void BindInferenceApi(py::module *m) { ...@@ -404,6 +404,16 @@ void BindInferenceApi(py::module *m) {
m->def("get_trt_compile_version", &paddle_infer::GetTrtCompileVersion); m->def("get_trt_compile_version", &paddle_infer::GetTrtCompileVersion);
m->def("get_trt_runtime_version", &paddle_infer::GetTrtRuntimeVersion); m->def("get_trt_runtime_version", &paddle_infer::GetTrtRuntimeVersion);
m->def("get_num_bytes_of_data_type", &paddle_infer::GetNumBytesOfDataType); m->def("get_num_bytes_of_data_type", &paddle_infer::GetNumBytesOfDataType);
m->def("convert_to_mixed_precision_bind",
&paddle_infer::ConvertToMixedPrecision,
py::arg("model_file"),
py::arg("params_file"),
py::arg("mixed_model_file"),
py::arg("mixed_params_file"),
py::arg("mixed_precision"),
py::arg("backend"),
py::arg("keep_io_types") = true,
py::arg("black_list") = std::unordered_set<std::string>());
} }
namespace { namespace {
...@@ -586,6 +596,14 @@ void BindAnalysisConfig(py::module *m) { ...@@ -586,6 +596,14 @@ void BindAnalysisConfig(py::module *m) {
.value("Float32", AnalysisConfig::Precision::kFloat32) .value("Float32", AnalysisConfig::Precision::kFloat32)
.value("Int8", AnalysisConfig::Precision::kInt8) .value("Int8", AnalysisConfig::Precision::kInt8)
.value("Half", AnalysisConfig::Precision::kHalf) .value("Half", AnalysisConfig::Precision::kHalf)
.value("Bfloat16", AnalysisConfig::Precision::kBf16)
.export_values();
py::enum_<AnalysisConfig::Backend>(analysis_config, "Backend")
.value("CPU", AnalysisConfig::Backend::kCPU)
.value("GPU", AnalysisConfig::Backend::kGPU)
.value("NPU", AnalysisConfig::Backend::kNPU)
.value("XPU", AnalysisConfig::Backend::kXPU)
.export_values(); .export_values();
analysis_config.def(py::init<>()) analysis_config.def(py::init<>())
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .wrapper import Config, DataType, PlaceType, PrecisionType, Tensor, Predictor from .wrapper import Config, DataType, PlaceType, PrecisionType, BackendType, Tensor, Predictor
from .wrapper import convert_to_mixed_precision
from ..core import create_predictor, get_version, get_num_bytes_of_data_type, PredictorPool, get_trt_compile_version, get_trt_runtime_version from ..core import create_predictor, get_version, get_num_bytes_of_data_type, PredictorPool, get_trt_compile_version, get_trt_runtime_version
...@@ -14,13 +14,17 @@ ...@@ -14,13 +14,17 @@
from ..core import AnalysisConfig, PaddleDType, PaddlePlace from ..core import AnalysisConfig, PaddleDType, PaddlePlace
from ..core import PaddleInferPredictor, PaddleInferTensor from ..core import PaddleInferPredictor, PaddleInferTensor
from ..core import convert_to_mixed_precision_bind
from .. import core from .. import core
import os
import numpy as np import numpy as np
from typing import Set
DataType = PaddleDType DataType = PaddleDType
PlaceType = PaddlePlace PlaceType = PaddlePlace
PrecisionType = AnalysisConfig.Precision PrecisionType = AnalysisConfig.Precision
BackendType = AnalysisConfig.Backend
Config = AnalysisConfig Config = AnalysisConfig
Tensor = PaddleInferTensor Tensor = PaddleInferTensor
Predictor = PaddleInferPredictor Predictor = PaddleInferPredictor
...@@ -50,5 +54,37 @@ def tensor_share_external_data(self, data): ...@@ -50,5 +54,37 @@ def tensor_share_external_data(self, data):
"In share_external_data, we only support LoDTensor data type.") "In share_external_data, we only support LoDTensor data type.")
def convert_to_mixed_precision(model_file: str,
params_file: str,
mixed_model_file: str,
mixed_params_file: str,
mixed_precision: PrecisionType,
backend: BackendType,
keep_io_types: bool = True,
black_list: Set = set()):
'''
Convert a fp32 model to mixed precision model.
Args:
model_file: fp32 model file, e.g. inference.pdmodel.
params_file: fp32 params file, e.g. inference.pdiparams.
mixed_model_file: The storage path of the converted mixed-precision model.
mixed_params_file: The storage path of the converted mixed-precision params.
mixed_precision: The precision, e.g. PrecisionType.Half.
backend: The backend, e.g. BackendType.GPU.
keep_io_types: Whether the model input and output dtype remains unchanged.
black_list: Operators that do not convert precision.
'''
mixed_model_dirname = os.path.dirname(mixed_model_file)
mixed_params_dirname = os.path.dirname(mixed_params_file)
if not os.path.exists(mixed_model_dirname):
os.makedirs(mixed_model_dirname)
if not os.path.exists(mixed_params_dirname):
os.makedirs(mixed_params_dirname)
convert_to_mixed_precision_bind(model_file, params_file, mixed_model_file,
mixed_params_file, mixed_precision, backend,
keep_io_types, black_list)
Tensor.copy_from_cpu = tensor_copy_from_cpu Tensor.copy_from_cpu = tensor_copy_from_cpu
Tensor.share_external_data = tensor_share_external_data Tensor.share_external_data = tensor_share_external_data
...@@ -17,3 +17,4 @@ endforeach() ...@@ -17,3 +17,4 @@ endforeach()
add_subdirectory(inference) add_subdirectory(inference)
set_tests_properties(test_fuse_resnet_unit PROPERTIES TIMEOUT 120) set_tests_properties(test_fuse_resnet_unit PROPERTIES TIMEOUT 120)
set_tests_properties(test_convert_to_mixed_precision PROPERTIES TIMEOUT 300)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.vision.models import resnet50
from paddle.jit import to_static
from paddle.static import InputSpec
from paddle.inference import PrecisionType, BackendType
from paddle.inference import convert_to_mixed_precision
@unittest.skipIf(not paddle.is_compiled_with_cuda()
or paddle.get_cudnn_version() < 8000,
'should compile with cuda.')
class TestConvertToMixedPrecision(unittest.TestCase):
def test_convert_to_fp16(self):
model = resnet50(True)
net = to_static(
model, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')])
paddle.jit.save(net, 'resnet50/inference')
convert_to_mixed_precision('resnet50/inference.pdmodel',
'resnet50/inference.pdiparams',
'mixed/inference.pdmodel',
'mixed/inference.pdiparams',
PrecisionType.Half, BackendType.GPU, True)
def test_convert_to_fp16_with_fp16_input(self):
model = resnet50(True)
net = to_static(
model, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')])
paddle.jit.save(net, 'resnet50/inference')
convert_to_mixed_precision('resnet50/inference.pdmodel',
'resnet50/inference.pdiparams',
'mixed1/inference.pdmodel',
'mixed1/inference.pdiparams',
PrecisionType.Half, BackendType.GPU, False)
def test_convert_to_fp16_with_blacklist(self):
model = resnet50(True)
net = to_static(
model, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')])
paddle.jit.save(net, 'resnet50/inference')
convert_to_mixed_precision('resnet50/inference.pdmodel',
'resnet50/inference.pdiparams',
'mixed2/inference.pdmodel',
'mixed2/inference.pdiparams',
PrecisionType.Half, BackendType.GPU, False,
set('conv2d'))
def test_convert_to_bf16(self):
model = resnet50(True)
net = to_static(
model, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')])
paddle.jit.save(net, 'resnet50/inference')
convert_to_mixed_precision('resnet50/inference.pdmodel',
'resnet50/inference.pdiparams',
'mixed3/inference.pdmodel',
'mixed3/inference.pdiparams',
PrecisionType.Bfloat16, BackendType.GPU,
True)
if __name__ == '__main__':
unittest.main()
...@@ -16,17 +16,20 @@ from ..fluid.inference import Config # noqa: F401 ...@@ -16,17 +16,20 @@ from ..fluid.inference import Config # noqa: F401
from ..fluid.inference import DataType # noqa: F401 from ..fluid.inference import DataType # noqa: F401
from ..fluid.inference import PlaceType # noqa: F401 from ..fluid.inference import PlaceType # noqa: F401
from ..fluid.inference import PrecisionType # noqa: F401 from ..fluid.inference import PrecisionType # noqa: F401
from ..fluid.inference import BackendType # noqa: F401
from ..fluid.inference import Tensor # noqa: F401 from ..fluid.inference import Tensor # noqa: F401
from ..fluid.inference import Predictor # noqa: F401 from ..fluid.inference import Predictor # noqa: F401
from ..fluid.inference import create_predictor # noqa: F401 from ..fluid.inference import create_predictor # noqa: F401
from ..fluid.inference import get_version # noqa: F401 from ..fluid.inference import get_version # noqa: F401
from ..fluid.inference import get_trt_compile_version # noqa: F401 from ..fluid.inference import get_trt_compile_version # noqa: F401
from ..fluid.inference import get_trt_runtime_version # noqa: F401 from ..fluid.inference import get_trt_runtime_version # noqa: F401
from ..fluid.inference import convert_to_mixed_precision # noqa: F401
from ..fluid.inference import get_num_bytes_of_data_type # noqa: F401 from ..fluid.inference import get_num_bytes_of_data_type # noqa: F401
from ..fluid.inference import PredictorPool # noqa: F401 from ..fluid.inference import PredictorPool # noqa: F401
__all__ = [ # noqa __all__ = [ # noqa
'Config', 'DataType', 'PlaceType', 'PrecisionType', 'Tensor', 'Predictor', 'Config', 'DataType', 'PlaceType', 'PrecisionType', 'BackendType', 'Tensor',
'create_predictor', 'get_version', 'get_trt_compile_version', 'Predictor', 'create_predictor', 'get_version', 'get_trt_compile_version',
'get_trt_runtime_version', 'get_num_bytes_of_data_type', 'PredictorPool' 'convert_to_mixed_precision', 'get_trt_runtime_version',
'get_num_bytes_of_data_type', 'PredictorPool'
] ]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册