From cbaebb049e715833e538daed2c07fee9a6104627 Mon Sep 17 00:00:00 2001 From: Wilber Date: Wed, 29 Jun 2022 18:49:26 +0800 Subject: [PATCH] convert to mixed model python api (#43881) --- paddle/fluid/pybind/inference_api.cc | 18 +++++ python/paddle/fluid/inference/__init__.py | 3 +- python/paddle/fluid/inference/wrapper.py | 36 +++++++++ .../fluid/tests/unittests/ir/CMakeLists.txt | 1 + .../ir/test_convert_to_mixed_precision.py | 80 +++++++++++++++++++ python/paddle/inference/__init__.py | 9 ++- 6 files changed, 143 insertions(+), 4 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/ir/test_convert_to_mixed_precision.py diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 2461a9c952..3d25958603 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -404,6 +404,16 @@ void BindInferenceApi(py::module *m) { m->def("get_trt_compile_version", &paddle_infer::GetTrtCompileVersion); m->def("get_trt_runtime_version", &paddle_infer::GetTrtRuntimeVersion); m->def("get_num_bytes_of_data_type", &paddle_infer::GetNumBytesOfDataType); + m->def("convert_to_mixed_precision_bind", + &paddle_infer::ConvertToMixedPrecision, + py::arg("model_file"), + py::arg("params_file"), + py::arg("mixed_model_file"), + py::arg("mixed_params_file"), + py::arg("mixed_precision"), + py::arg("backend"), + py::arg("keep_io_types") = true, + py::arg("black_list") = std::unordered_set()); } namespace { @@ -586,6 +596,14 @@ void BindAnalysisConfig(py::module *m) { .value("Float32", AnalysisConfig::Precision::kFloat32) .value("Int8", AnalysisConfig::Precision::kInt8) .value("Half", AnalysisConfig::Precision::kHalf) + .value("Bfloat16", AnalysisConfig::Precision::kBf16) + .export_values(); + + py::enum_(analysis_config, "Backend") + .value("CPU", AnalysisConfig::Backend::kCPU) + .value("GPU", AnalysisConfig::Backend::kGPU) + .value("NPU", AnalysisConfig::Backend::kNPU) + .value("XPU", AnalysisConfig::Backend::kXPU) .export_values(); analysis_config.def(py::init<>()) diff --git a/python/paddle/fluid/inference/__init__.py b/python/paddle/fluid/inference/__init__.py index 946b4f0c8d..6f7bc32a51 100644 --- a/python/paddle/fluid/inference/__init__.py +++ b/python/paddle/fluid/inference/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .wrapper import Config, DataType, PlaceType, PrecisionType, Tensor, Predictor +from .wrapper import Config, DataType, PlaceType, PrecisionType, BackendType, Tensor, Predictor +from .wrapper import convert_to_mixed_precision from ..core import create_predictor, get_version, get_num_bytes_of_data_type, PredictorPool, get_trt_compile_version, get_trt_runtime_version diff --git a/python/paddle/fluid/inference/wrapper.py b/python/paddle/fluid/inference/wrapper.py index c81ad03df7..ec778c6339 100644 --- a/python/paddle/fluid/inference/wrapper.py +++ b/python/paddle/fluid/inference/wrapper.py @@ -14,13 +14,17 @@ from ..core import AnalysisConfig, PaddleDType, PaddlePlace from ..core import PaddleInferPredictor, PaddleInferTensor +from ..core import convert_to_mixed_precision_bind from .. import core +import os import numpy as np +from typing import Set DataType = PaddleDType PlaceType = PaddlePlace PrecisionType = AnalysisConfig.Precision +BackendType = AnalysisConfig.Backend Config = AnalysisConfig Tensor = PaddleInferTensor Predictor = PaddleInferPredictor @@ -50,5 +54,37 @@ def tensor_share_external_data(self, data): "In share_external_data, we only support LoDTensor data type.") +def convert_to_mixed_precision(model_file: str, + params_file: str, + mixed_model_file: str, + mixed_params_file: str, + mixed_precision: PrecisionType, + backend: BackendType, + keep_io_types: bool = True, + black_list: Set = set()): + ''' + Convert a fp32 model to mixed precision model. + + Args: + model_file: fp32 model file, e.g. inference.pdmodel. + params_file: fp32 params file, e.g. inference.pdiparams. + mixed_model_file: The storage path of the converted mixed-precision model. + mixed_params_file: The storage path of the converted mixed-precision params. + mixed_precision: The precision, e.g. PrecisionType.Half. + backend: The backend, e.g. BackendType.GPU. + keep_io_types: Whether the model input and output dtype remains unchanged. + black_list: Operators that do not convert precision. + ''' + mixed_model_dirname = os.path.dirname(mixed_model_file) + mixed_params_dirname = os.path.dirname(mixed_params_file) + if not os.path.exists(mixed_model_dirname): + os.makedirs(mixed_model_dirname) + if not os.path.exists(mixed_params_dirname): + os.makedirs(mixed_params_dirname) + convert_to_mixed_precision_bind(model_file, params_file, mixed_model_file, + mixed_params_file, mixed_precision, backend, + keep_io_types, black_list) + + Tensor.copy_from_cpu = tensor_copy_from_cpu Tensor.share_external_data = tensor_share_external_data diff --git a/python/paddle/fluid/tests/unittests/ir/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/CMakeLists.txt index d34ee9380e..58a3c8f987 100644 --- a/python/paddle/fluid/tests/unittests/ir/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/CMakeLists.txt @@ -17,3 +17,4 @@ endforeach() add_subdirectory(inference) set_tests_properties(test_fuse_resnet_unit PROPERTIES TIMEOUT 120) +set_tests_properties(test_convert_to_mixed_precision PROPERTIES TIMEOUT 300) diff --git a/python/paddle/fluid/tests/unittests/ir/test_convert_to_mixed_precision.py b/python/paddle/fluid/tests/unittests/ir/test_convert_to_mixed_precision.py new file mode 100644 index 0000000000..deb4990cf5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/test_convert_to_mixed_precision.py @@ -0,0 +1,80 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle + +from paddle.vision.models import resnet50 +from paddle.jit import to_static +from paddle.static import InputSpec + +from paddle.inference import PrecisionType, BackendType +from paddle.inference import convert_to_mixed_precision + + +@unittest.skipIf(not paddle.is_compiled_with_cuda() + or paddle.get_cudnn_version() < 8000, + 'should compile with cuda.') +class TestConvertToMixedPrecision(unittest.TestCase): + + def test_convert_to_fp16(self): + model = resnet50(True) + net = to_static( + model, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')]) + paddle.jit.save(net, 'resnet50/inference') + convert_to_mixed_precision('resnet50/inference.pdmodel', + 'resnet50/inference.pdiparams', + 'mixed/inference.pdmodel', + 'mixed/inference.pdiparams', + PrecisionType.Half, BackendType.GPU, True) + + def test_convert_to_fp16_with_fp16_input(self): + model = resnet50(True) + net = to_static( + model, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')]) + paddle.jit.save(net, 'resnet50/inference') + convert_to_mixed_precision('resnet50/inference.pdmodel', + 'resnet50/inference.pdiparams', + 'mixed1/inference.pdmodel', + 'mixed1/inference.pdiparams', + PrecisionType.Half, BackendType.GPU, False) + + def test_convert_to_fp16_with_blacklist(self): + model = resnet50(True) + net = to_static( + model, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')]) + paddle.jit.save(net, 'resnet50/inference') + convert_to_mixed_precision('resnet50/inference.pdmodel', + 'resnet50/inference.pdiparams', + 'mixed2/inference.pdmodel', + 'mixed2/inference.pdiparams', + PrecisionType.Half, BackendType.GPU, False, + set('conv2d')) + + def test_convert_to_bf16(self): + model = resnet50(True) + net = to_static( + model, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')]) + paddle.jit.save(net, 'resnet50/inference') + convert_to_mixed_precision('resnet50/inference.pdmodel', + 'resnet50/inference.pdiparams', + 'mixed3/inference.pdmodel', + 'mixed3/inference.pdiparams', + PrecisionType.Bfloat16, BackendType.GPU, + True) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/inference/__init__.py b/python/paddle/inference/__init__.py index 670c2cc8e4..b33fb1fdfe 100644 --- a/python/paddle/inference/__init__.py +++ b/python/paddle/inference/__init__.py @@ -16,17 +16,20 @@ from ..fluid.inference import Config # noqa: F401 from ..fluid.inference import DataType # noqa: F401 from ..fluid.inference import PlaceType # noqa: F401 from ..fluid.inference import PrecisionType # noqa: F401 +from ..fluid.inference import BackendType # noqa: F401 from ..fluid.inference import Tensor # noqa: F401 from ..fluid.inference import Predictor # noqa: F401 from ..fluid.inference import create_predictor # noqa: F401 from ..fluid.inference import get_version # noqa: F401 from ..fluid.inference import get_trt_compile_version # noqa: F401 from ..fluid.inference import get_trt_runtime_version # noqa: F401 +from ..fluid.inference import convert_to_mixed_precision # noqa: F401 from ..fluid.inference import get_num_bytes_of_data_type # noqa: F401 from ..fluid.inference import PredictorPool # noqa: F401 __all__ = [ # noqa - 'Config', 'DataType', 'PlaceType', 'PrecisionType', 'Tensor', 'Predictor', - 'create_predictor', 'get_version', 'get_trt_compile_version', - 'get_trt_runtime_version', 'get_num_bytes_of_data_type', 'PredictorPool' + 'Config', 'DataType', 'PlaceType', 'PrecisionType', 'BackendType', 'Tensor', + 'Predictor', 'create_predictor', 'get_version', 'get_trt_compile_version', + 'convert_to_mixed_precision', 'get_trt_runtime_version', + 'get_num_bytes_of_data_type', 'PredictorPool' ] -- GitLab