From be6a83301e04389902137fee6aee41134e83f4f3 Mon Sep 17 00:00:00 2001 From: Wilber Date: Tue, 19 Oct 2021 15:49:13 +0800 Subject: [PATCH] Inference add type check in copy_from_cpu (#36429) * update * fix ut error * update ut --- .../fluid/inference/api/analysis_predictor.cc | 18 ++++++ .../api/analysis_predictor_tester.cc | 9 +++ .../inference/api/paddle_inference_api.h | 2 + paddle/fluid/inference/tensorrt/engine.cc | 13 ++++ paddle/fluid/inference/tensorrt/helper.h | 16 +++++ paddle/fluid/pybind/inference_api.cc | 11 ++-- python/paddle/fluid/inference/__init__.py | 2 +- python/paddle/fluid/inference/wrapper.py | 15 +++++ .../tests/unittests/test_inference_api.py | 59 +++++++++++++++++++ python/paddle/inference/__init__.py | 4 ++ 10 files changed, 144 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index dfa27037205..491ed71c4bc 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -36,6 +36,7 @@ #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h" #include "paddle/fluid/inference/api/helper.h" +#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_pass.h" #include "paddle/fluid/inference/utils/io_utils.h" #include "paddle/fluid/inference/utils/singleton.h" @@ -56,6 +57,7 @@ #if PADDLE_WITH_TENSORRT #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h" #endif @@ -1471,6 +1473,22 @@ int GetNumBytesOfDataType(DataType dtype) { std::string GetVersion() { return paddle::get_version(); } +std::tuple GetTrtCompileVersion() { +#ifdef PADDLE_WITH_TENSORRT + return paddle::inference::tensorrt::GetTrtCompileVersion(); +#else + return std::tuple{0, 0, 0}; +#endif +} + +std::tuple GetTrtRuntimeVersion() { +#ifdef PADDLE_WITH_TENSORRT + return paddle::inference::tensorrt::GetTrtRuntimeVersion(); +#else + return std::tuple{0, 0, 0}; +#endif +} + std::string UpdateDllFlag(const char *name, const char *value) { return paddle::UpdateDllFlag(name, value); } diff --git a/paddle/fluid/inference/api/analysis_predictor_tester.cc b/paddle/fluid/inference/api/analysis_predictor_tester.cc index 86fbde00075..a15a1cd84b1 100644 --- a/paddle/fluid/inference/api/analysis_predictor_tester.cc +++ b/paddle/fluid/inference/api/analysis_predictor_tester.cc @@ -359,6 +359,15 @@ TEST(AnalysisPredictor, set_xpu_device_id) { namespace paddle_infer { TEST(Predictor, Run) { + auto trt_compile_ver = GetTrtCompileVersion(); + auto trt_runtime_ver = GetTrtRuntimeVersion(); + LOG(INFO) << "trt compile version: " << std::get<0>(trt_compile_ver) << "." + << std::get<1>(trt_compile_ver) << "." + << std::get<2>(trt_compile_ver); + LOG(INFO) << "trt runtime version: " << std::get<0>(trt_runtime_ver) << "." + << std::get<1>(trt_runtime_ver) << "." + << std::get<2>(trt_runtime_ver); + Config config; config.SetModel(FLAGS_dirname); diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index a516abb1432..35b90bfa54f 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -169,6 +169,8 @@ PD_INFER_DECL std::shared_ptr CreatePredictor( PD_INFER_DECL int GetNumBytesOfDataType(DataType dtype); PD_INFER_DECL std::string GetVersion(); +PD_INFER_DECL std::tuple GetTrtCompileVersion(); +PD_INFER_DECL std::tuple GetTrtRuntimeVersion(); PD_INFER_DECL std::string UpdateDllFlag(const char* name, const char* value); namespace services { diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 24644645eee..26182a79321 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -190,6 +190,19 @@ void TensorRTEngine::FreezeNetwork() { #if IS_TRT_VERSION_GE(6000) LOG(INFO) << "Run Paddle-TRT Dynamic Shape mode."; for (auto &input : min_input_shape_) { +#if IS_TRT_VERSION_LT(7000) + // trt6 will check all_of input > 0 + if (!(std::all_of(input.second.begin(), input.second.end(), + [](int x) { return x > 0; }) && + std::all_of(max_input_shape_[input.first].begin(), + max_input_shape_[input.first].end(), + [](int x) { return x > 0; }) && + std::all_of(optim_input_shape_[input.first].begin(), + optim_input_shape_[input.first].end(), + [](int x) { return x > 0; }))) { + continue; + } +#endif VLOG(4) << "TRT dynamic_shape set " << input.first << " min: " << Vec2Str(input.second) << ", max: " << Vec2Str(max_input_shape_[input.first]) diff --git a/paddle/fluid/inference/tensorrt/helper.h b/paddle/fluid/inference/tensorrt/helper.h index 16595b8a032..b8051d86104 100644 --- a/paddle/fluid/inference/tensorrt/helper.h +++ b/paddle/fluid/inference/tensorrt/helper.h @@ -73,8 +73,24 @@ static nvinfer1::IPluginRegistry* GetPluginRegistry() { static int GetInferLibVersion() { return static_cast(dy::getInferLibVersion()); } +#else +static int GetInferLibVersion() { return 0; } #endif +static std::tuple GetTrtRuntimeVersion() { + int ver = GetInferLibVersion(); + int major = ver / 1000; + ver -= major * 1000; + int minor = ver / 100; + int patch = ver - minor * 100; + return std::tuple{major, minor, patch}; +} + +static std::tuple GetTrtCompileVersion() { + return std::tuple{NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, + NV_TENSORRT_PATCH}; +} + // A logger for create TensorRT infer builder. class NaiveLogger : public nvinfer1::ILogger { public: diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 8ce7bea2d8e..e02f25ff636 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -330,6 +330,8 @@ void BindInferenceApi(py::module *m) { m->def("paddle_dtype_size", &paddle::PaddleDtypeSize); m->def("paddle_tensor_to_bytes", &SerializePDTensorToBytes); m->def("get_version", &paddle_infer::GetVersion); + m->def("get_trt_compile_version", &paddle_infer::GetTrtCompileVersion); + m->def("get_trt_runtime_version", &paddle_infer::GetTrtRuntimeVersion); m->def("get_num_bytes_of_data_type", &paddle_infer::GetNumBytesOfDataType); } @@ -739,10 +741,11 @@ void BindZeroCopyTensor(py::module *m) { void BindPaddleInferTensor(py::module *m) { py::class_(*m, "PaddleInferTensor") .def("reshape", &paddle_infer::Tensor::Reshape) - .def("copy_from_cpu", &PaddleInferTensorCreate) - .def("copy_from_cpu", &PaddleInferTensorCreate) - .def("copy_from_cpu", &PaddleInferTensorCreate) - .def("copy_from_cpu", &PaddleInferTensorCreate) + .def("copy_from_cpu_bind", &PaddleInferTensorCreate) + .def("copy_from_cpu_bind", &PaddleInferTensorCreate) + .def("copy_from_cpu_bind", &PaddleInferTensorCreate) + .def("copy_from_cpu_bind", + &PaddleInferTensorCreate) .def("copy_to_cpu", &PaddleInferTensorToNumpy) .def("shape", &paddle_infer::Tensor::shape) .def("set_lod", &paddle_infer::Tensor::SetLoD) diff --git a/python/paddle/fluid/inference/__init__.py b/python/paddle/fluid/inference/__init__.py index 3013c1f2aff..946b4f0c8d7 100644 --- a/python/paddle/fluid/inference/__init__.py +++ b/python/paddle/fluid/inference/__init__.py @@ -14,4 +14,4 @@ from .wrapper import Config, DataType, PlaceType, PrecisionType, Tensor, Predictor -from ..core import create_predictor, get_version, get_num_bytes_of_data_type, PredictorPool +from ..core import create_predictor, get_version, get_num_bytes_of_data_type, PredictorPool, get_trt_compile_version, get_trt_runtime_version diff --git a/python/paddle/fluid/inference/wrapper.py b/python/paddle/fluid/inference/wrapper.py index 96885edcc5e..2c1b2c77504 100644 --- a/python/paddle/fluid/inference/wrapper.py +++ b/python/paddle/fluid/inference/wrapper.py @@ -15,9 +15,24 @@ from ..core import AnalysisConfig, PaddleDType, PaddlePlace from ..core import PaddleInferPredictor, PaddleInferTensor +import numpy as np + DataType = PaddleDType PlaceType = PaddlePlace PrecisionType = AnalysisConfig.Precision Config = AnalysisConfig Tensor = PaddleInferTensor Predictor = PaddleInferPredictor + + +def tensor_copy_from_cpu(self, data): + ''' + Support input type check based on tensor.copy_from_cpu. + ''' + if not isinstance(data, np.ndarray): + raise TypeError( + "In copy_from_cpu, we only support numpy ndarray data type.") + self.copy_from_cpu_bind(data) + + +Tensor.copy_from_cpu = tensor_copy_from_cpu diff --git a/python/paddle/fluid/tests/unittests/test_inference_api.py b/python/paddle/fluid/tests/unittests/test_inference_api.py index 98ec0b3db04..7ed908eb33b 100644 --- a/python/paddle/fluid/tests/unittests/test_inference_api.py +++ b/python/paddle/fluid/tests/unittests/test_inference_api.py @@ -14,10 +14,14 @@ import os, shutil import unittest +import paddle +paddle.enable_static() import numpy as np import paddle.fluid as fluid from paddle.fluid.core import PaddleTensor from paddle.fluid.core import PaddleDType +from paddle.inference import Config, Predictor, create_predictor +from paddle.inference import get_trt_compile_version, get_trt_runtime_version class TestInferenceApi(unittest.TestCase): @@ -54,5 +58,60 @@ class TestInferenceApi(unittest.TestCase): tensor_float.ravel().tolist()) +def get_sample_model(): + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + data = fluid.data(name="data", shape=[-1, 6, 64, 64], dtype="float32") + conv_out = fluid.layers.conv2d( + input=data, + num_filters=3, + filter_size=3, + groups=1, + padding=0, + bias_attr=False, + act=None) + exe.run(startup_program) + serialized_program = paddle.static.serialize_program( + data, conv_out, program=main_program) + serialized_params = paddle.static.serialize_persistables( + data, conv_out, executor=exe, program=main_program) + return serialized_program, serialized_params + + +class TestInferenceBaseAPI(unittest.TestCase): + def get_config(self, model, params): + config = Config() + config.set_model_buffer(model, len(model), params, len(params)) + config.enable_use_gpu(100, 0) + return config + + def test_apis(self): + print('trt compile version:', get_trt_compile_version()) + print('trt runtime version:', get_trt_runtime_version()) + program, params = get_sample_model() + config = self.get_config(program, params) + predictor = create_predictor(config) + in_names = predictor.get_input_names() + in_handle = predictor.get_input_handle(in_names[0]) + in_data = np.ones((1, 6, 32, 32)).astype(np.float32) + in_handle.copy_from_cpu(in_data) + predictor.run() + + def test_wrong_input(self): + with self.assertRaises(TypeError): + program, params = get_sample_model() + config = self.get_config(program, params) + predictor = create_predictor(config) + in_names = predictor.get_input_names() + in_handle = predictor.get_input_handle(in_names[0]) + in_data = np.ones((1, 6, 64, 64)).astype(np.float32) + in_handle.copy_from_cpu(list(in_data)) + predictor.run() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/inference/__init__.py b/python/paddle/inference/__init__.py index 4e172039716..ec5295b6dfe 100644 --- a/python/paddle/inference/__init__.py +++ b/python/paddle/inference/__init__.py @@ -20,6 +20,8 @@ from ..fluid.inference import Tensor # noqa: F401 from ..fluid.inference import Predictor # noqa: F401 from ..fluid.inference import create_predictor # noqa: F401 from ..fluid.inference import get_version # noqa: F401 +from ..fluid.inference import get_trt_compile_version # noqa: F401 +from ..fluid.inference import get_trt_runtime_version # noqa: F401 from ..fluid.inference import get_num_bytes_of_data_type # noqa: F401 from ..fluid.inference import PredictorPool # noqa: F401 @@ -32,6 +34,8 @@ __all__ = [ # noqa 'Predictor', 'create_predictor', 'get_version', + 'get_trt_compile_version', + 'get_trt_runtime_version', 'get_num_bytes_of_data_type', 'PredictorPool' ] -- GitLab