未验证 提交 be6a8330 编写于 作者: W Wilber 提交者: GitHub

Inference add type check in copy_from_cpu (#36429)

* update

* fix ut error

* update ut
上级 6cdc5a4b
......@@ -36,6 +36,7 @@
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
#include "paddle/fluid/inference/utils/io_utils.h"
#include "paddle/fluid/inference/utils/singleton.h"
......@@ -56,6 +57,7 @@
#if PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#endif
......@@ -1471,6 +1473,22 @@ int GetNumBytesOfDataType(DataType dtype) {
std::string GetVersion() { return paddle::get_version(); }
std::tuple<int, int, int> GetTrtCompileVersion() {
#ifdef PADDLE_WITH_TENSORRT
return paddle::inference::tensorrt::GetTrtCompileVersion();
#else
return std::tuple<int, int, int>{0, 0, 0};
#endif
}
std::tuple<int, int, int> GetTrtRuntimeVersion() {
#ifdef PADDLE_WITH_TENSORRT
return paddle::inference::tensorrt::GetTrtRuntimeVersion();
#else
return std::tuple<int, int, int>{0, 0, 0};
#endif
}
std::string UpdateDllFlag(const char *name, const char *value) {
return paddle::UpdateDllFlag(name, value);
}
......
......@@ -359,6 +359,15 @@ TEST(AnalysisPredictor, set_xpu_device_id) {
namespace paddle_infer {
TEST(Predictor, Run) {
auto trt_compile_ver = GetTrtCompileVersion();
auto trt_runtime_ver = GetTrtRuntimeVersion();
LOG(INFO) << "trt compile version: " << std::get<0>(trt_compile_ver) << "."
<< std::get<1>(trt_compile_ver) << "."
<< std::get<2>(trt_compile_ver);
LOG(INFO) << "trt runtime version: " << std::get<0>(trt_runtime_ver) << "."
<< std::get<1>(trt_runtime_ver) << "."
<< std::get<2>(trt_runtime_ver);
Config config;
config.SetModel(FLAGS_dirname);
......
......@@ -169,6 +169,8 @@ PD_INFER_DECL std::shared_ptr<Predictor> CreatePredictor(
PD_INFER_DECL int GetNumBytesOfDataType(DataType dtype);
PD_INFER_DECL std::string GetVersion();
PD_INFER_DECL std::tuple<int, int, int> GetTrtCompileVersion();
PD_INFER_DECL std::tuple<int, int, int> GetTrtRuntimeVersion();
PD_INFER_DECL std::string UpdateDllFlag(const char* name, const char* value);
namespace services {
......
......@@ -190,6 +190,19 @@ void TensorRTEngine::FreezeNetwork() {
#if IS_TRT_VERSION_GE(6000)
LOG(INFO) << "Run Paddle-TRT Dynamic Shape mode.";
for (auto &input : min_input_shape_) {
#if IS_TRT_VERSION_LT(7000)
// trt6 will check all_of input > 0
if (!(std::all_of(input.second.begin(), input.second.end(),
[](int x) { return x > 0; }) &&
std::all_of(max_input_shape_[input.first].begin(),
max_input_shape_[input.first].end(),
[](int x) { return x > 0; }) &&
std::all_of(optim_input_shape_[input.first].begin(),
optim_input_shape_[input.first].end(),
[](int x) { return x > 0; }))) {
continue;
}
#endif
VLOG(4) << "TRT dynamic_shape set " << input.first
<< " min: " << Vec2Str(input.second)
<< ", max: " << Vec2Str(max_input_shape_[input.first])
......
......@@ -73,8 +73,24 @@ static nvinfer1::IPluginRegistry* GetPluginRegistry() {
static int GetInferLibVersion() {
return static_cast<int>(dy::getInferLibVersion());
}
#else
static int GetInferLibVersion() { return 0; }
#endif
static std::tuple<int, int, int> GetTrtRuntimeVersion() {
int ver = GetInferLibVersion();
int major = ver / 1000;
ver -= major * 1000;
int minor = ver / 100;
int patch = ver - minor * 100;
return std::tuple<int, int, int>{major, minor, patch};
}
static std::tuple<int, int, int> GetTrtCompileVersion() {
return std::tuple<int, int, int>{NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR,
NV_TENSORRT_PATCH};
}
// A logger for create TensorRT infer builder.
class NaiveLogger : public nvinfer1::ILogger {
public:
......
......@@ -330,6 +330,8 @@ void BindInferenceApi(py::module *m) {
m->def("paddle_dtype_size", &paddle::PaddleDtypeSize);
m->def("paddle_tensor_to_bytes", &SerializePDTensorToBytes);
m->def("get_version", &paddle_infer::GetVersion);
m->def("get_trt_compile_version", &paddle_infer::GetTrtCompileVersion);
m->def("get_trt_runtime_version", &paddle_infer::GetTrtRuntimeVersion);
m->def("get_num_bytes_of_data_type", &paddle_infer::GetNumBytesOfDataType);
}
......@@ -739,10 +741,11 @@ void BindZeroCopyTensor(py::module *m) {
void BindPaddleInferTensor(py::module *m) {
py::class_<paddle_infer::Tensor>(*m, "PaddleInferTensor")
.def("reshape", &paddle_infer::Tensor::Reshape)
.def("copy_from_cpu", &PaddleInferTensorCreate<int32_t>)
.def("copy_from_cpu", &PaddleInferTensorCreate<int64_t>)
.def("copy_from_cpu", &PaddleInferTensorCreate<float>)
.def("copy_from_cpu", &PaddleInferTensorCreate<paddle_infer::float16>)
.def("copy_from_cpu_bind", &PaddleInferTensorCreate<int32_t>)
.def("copy_from_cpu_bind", &PaddleInferTensorCreate<int64_t>)
.def("copy_from_cpu_bind", &PaddleInferTensorCreate<float>)
.def("copy_from_cpu_bind",
&PaddleInferTensorCreate<paddle_infer::float16>)
.def("copy_to_cpu", &PaddleInferTensorToNumpy)
.def("shape", &paddle_infer::Tensor::shape)
.def("set_lod", &paddle_infer::Tensor::SetLoD)
......
......@@ -14,4 +14,4 @@
from .wrapper import Config, DataType, PlaceType, PrecisionType, Tensor, Predictor
from ..core import create_predictor, get_version, get_num_bytes_of_data_type, PredictorPool
from ..core import create_predictor, get_version, get_num_bytes_of_data_type, PredictorPool, get_trt_compile_version, get_trt_runtime_version
......@@ -15,9 +15,24 @@
from ..core import AnalysisConfig, PaddleDType, PaddlePlace
from ..core import PaddleInferPredictor, PaddleInferTensor
import numpy as np
DataType = PaddleDType
PlaceType = PaddlePlace
PrecisionType = AnalysisConfig.Precision
Config = AnalysisConfig
Tensor = PaddleInferTensor
Predictor = PaddleInferPredictor
def tensor_copy_from_cpu(self, data):
'''
Support input type check based on tensor.copy_from_cpu.
'''
if not isinstance(data, np.ndarray):
raise TypeError(
"In copy_from_cpu, we only support numpy ndarray data type.")
self.copy_from_cpu_bind(data)
Tensor.copy_from_cpu = tensor_copy_from_cpu
......@@ -14,10 +14,14 @@
import os, shutil
import unittest
import paddle
paddle.enable_static()
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.core import PaddleTensor
from paddle.fluid.core import PaddleDType
from paddle.inference import Config, Predictor, create_predictor
from paddle.inference import get_trt_compile_version, get_trt_runtime_version
class TestInferenceApi(unittest.TestCase):
......@@ -54,5 +58,60 @@ class TestInferenceApi(unittest.TestCase):
tensor_float.ravel().tolist())
def get_sample_model():
place = fluid.CPUPlace()
exe = fluid.Executor(place)
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
data = fluid.data(name="data", shape=[-1, 6, 64, 64], dtype="float32")
conv_out = fluid.layers.conv2d(
input=data,
num_filters=3,
filter_size=3,
groups=1,
padding=0,
bias_attr=False,
act=None)
exe.run(startup_program)
serialized_program = paddle.static.serialize_program(
data, conv_out, program=main_program)
serialized_params = paddle.static.serialize_persistables(
data, conv_out, executor=exe, program=main_program)
return serialized_program, serialized_params
class TestInferenceBaseAPI(unittest.TestCase):
def get_config(self, model, params):
config = Config()
config.set_model_buffer(model, len(model), params, len(params))
config.enable_use_gpu(100, 0)
return config
def test_apis(self):
print('trt compile version:', get_trt_compile_version())
print('trt runtime version:', get_trt_runtime_version())
program, params = get_sample_model()
config = self.get_config(program, params)
predictor = create_predictor(config)
in_names = predictor.get_input_names()
in_handle = predictor.get_input_handle(in_names[0])
in_data = np.ones((1, 6, 32, 32)).astype(np.float32)
in_handle.copy_from_cpu(in_data)
predictor.run()
def test_wrong_input(self):
with self.assertRaises(TypeError):
program, params = get_sample_model()
config = self.get_config(program, params)
predictor = create_predictor(config)
in_names = predictor.get_input_names()
in_handle = predictor.get_input_handle(in_names[0])
in_data = np.ones((1, 6, 64, 64)).astype(np.float32)
in_handle.copy_from_cpu(list(in_data))
predictor.run()
if __name__ == '__main__':
unittest.main()
......@@ -20,6 +20,8 @@ from ..fluid.inference import Tensor # noqa: F401
from ..fluid.inference import Predictor # noqa: F401
from ..fluid.inference import create_predictor # noqa: F401
from ..fluid.inference import get_version # noqa: F401
from ..fluid.inference import get_trt_compile_version # noqa: F401
from ..fluid.inference import get_trt_runtime_version # noqa: F401
from ..fluid.inference import get_num_bytes_of_data_type # noqa: F401
from ..fluid.inference import PredictorPool # noqa: F401
......@@ -32,6 +34,8 @@ __all__ = [ # noqa
'Predictor',
'create_predictor',
'get_version',
'get_trt_compile_version',
'get_trt_runtime_version',
'get_num_bytes_of_data_type',
'PredictorPool'
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册