未验证 提交 fcbb7440 编写于 作者: H heliqi 提交者: GitHub

CopyFromCpu and CopyToCpu of Onnxruntime back-end optimize (#40561)

* add onnxruntime predictor

* Add code comments

* support link paddle2onnx onnxruntime

* support onnxruntime with python

* support onnxruntime with python

* support onnxruntime with windows

* paddle2onnx compile with windows

* supoort windows compile

* supoort windows compile with onnxruntime

* supoort windows compile with paddle2onnx

* supoort mac compile

* compile with mac

* compile with mac

* add code comments

* fix remind word

* code optimization

* add test case

* add test case

* add inference demo_ci test case

* fix compile paddle2onnx with no python

* add inference demo_ci test case

* add inference demo_ci test case

* add inference infer_ut test case

* support c go api and test cases

* add converage test case

* add converage test case

* add capi test case

* add capi test case

* fix onnxruntime copyfromcpu and copytocpu

* fix goapi

* modify code
上级 da558f0e
...@@ -61,6 +61,7 @@ set(PADDLE2ONNX_OPTIONAL_ARGS ...@@ -61,6 +61,7 @@ set(PADDLE2ONNX_OPTIONAL_ARGS
-DONNX_CUSTOM_PROTOC_PATH=${PROTOC_BIN_PATH} -DONNX_CUSTOM_PROTOC_PATH=${PROTOC_BIN_PATH}
-DWITH_STATIC=OFF -DWITH_STATIC=OFF
-DCMAKE_INSTALL_PREFIX=${PADDLE2ONNX_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${PADDLE2ONNX_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${PADDLE2ONNX_INSTALL_DIR}/${LIBDIR}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS} ${EXTERNAL_OPTIONAL_ARGS}
......
...@@ -14,7 +14,11 @@ ...@@ -14,7 +14,11 @@
# #
cc_library(reset_tensor_array SRCS reset_tensor_array.cc DEPS lod_tensor scope) cc_library(reset_tensor_array SRCS reset_tensor_array.cc DEPS lod_tensor scope)
cc_library(zero_copy_tensor SRCS zero_copy_tensor.cc DEPS scope lod_tensor enforce) if (WITH_ONNXRUNTIME)
cc_library(zero_copy_tensor SRCS zero_copy_tensor.cc DEPS scope lod_tensor enforce onnxruntime)
else (WITH_ONNXRUNTIME)
cc_library(zero_copy_tensor SRCS zero_copy_tensor.cc DEPS scope lod_tensor enforce)
endif (WITH_ONNXRUNTIME)
cc_library(zero_copy_tensor_dummy SRCS zero_copy_tensor_dummy.cc) cc_library(zero_copy_tensor_dummy SRCS zero_copy_tensor_dummy.cc)
cc_test(zero_copy_tensor_test SRCS zero_copy_tensor_test.cc DEPS paddle_inference_api) cc_test(zero_copy_tensor_test SRCS zero_copy_tensor_test.cc DEPS paddle_inference_api)
...@@ -22,12 +22,22 @@ ...@@ -22,12 +22,22 @@
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/allocator.h" #include "paddle/phi/core/allocator.h"
#ifdef PADDLE_WITH_ONNXRUNTIME
#include "paddle/fluid/inference/api/onnxruntime_predictor.h"
#endif
namespace paddle_infer { namespace paddle_infer {
using float16 = paddle::platform::float16; using float16 = paddle::platform::float16;
void Tensor::Reshape(const std::vector<int> &shape) { void Tensor::Reshape(const std::vector<int> &shape) {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
shape_.assign(shape.begin(), shape.end());
return;
}
#endif
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
name_.empty(), false, name_.empty(), false,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
...@@ -123,6 +133,11 @@ T *Tensor::data(PlaceType *place, int *size) const { ...@@ -123,6 +133,11 @@ T *Tensor::data(PlaceType *place, int *size) const {
} }
DataType Tensor::type() const { DataType Tensor::type() const {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
return dtype_;
}
#endif
EAGER_GET_TENSOR(paddle::framework::LoDTensor); EAGER_GET_TENSOR(paddle::framework::LoDTensor);
auto type = paddle::framework::TransToProtoVarType(tensor->dtype()); auto type = paddle::framework::TransToProtoVarType(tensor->dtype());
if (type == paddle::framework::proto::VarType::FP32) { if (type == paddle::framework::proto::VarType::FP32) {
...@@ -145,6 +160,13 @@ PlaceType Tensor::place() const { return place_; } ...@@ -145,6 +160,13 @@ PlaceType Tensor::place() const { return place_; }
template <typename T> template <typename T>
void Tensor::CopyFromCpu(const T *data) { void Tensor::CopyFromCpu(const T *data) {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
ORTCopyFromCpu<T>(data);
return;
}
#endif
EAGER_GET_TENSOR(paddle::framework::LoDTensor); EAGER_GET_TENSOR(paddle::framework::LoDTensor);
PADDLE_ENFORCE_GE(tensor->numel(), 0, PADDLE_ENFORCE_GE(tensor->numel(), 0,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
...@@ -382,6 +404,13 @@ void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb, ...@@ -382,6 +404,13 @@ void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb,
template <typename T> template <typename T>
void Tensor::CopyToCpu(T *data) const { void Tensor::CopyToCpu(T *data) const {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
ORTCopyToCpu<T>(data);
return;
}
#endif
CopyToCpuImpl<T>(data, nullptr, nullptr, nullptr); CopyToCpuImpl<T>(data, nullptr, nullptr, nullptr);
} }
...@@ -489,12 +518,7 @@ template PD_INFER_DECL uint8_t *Tensor::mutable_data<uint8_t>(PlaceType place); ...@@ -489,12 +518,7 @@ template PD_INFER_DECL uint8_t *Tensor::mutable_data<uint8_t>(PlaceType place);
template PD_INFER_DECL int8_t *Tensor::mutable_data<int8_t>(PlaceType place); template PD_INFER_DECL int8_t *Tensor::mutable_data<int8_t>(PlaceType place);
template PD_INFER_DECL float16 *Tensor::mutable_data<float16>(PlaceType place); template PD_INFER_DECL float16 *Tensor::mutable_data<float16>(PlaceType place);
Tensor::Tensor(void *scope) : scope_{scope} { Tensor::Tensor(void *scope) : scope_{scope} {}
PADDLE_ENFORCE_NOT_NULL(scope_,
paddle::platform::errors::PreconditionNotMet(
"The `scope` can not be nullptr. It should be "
"set to the pointer of scope."));
}
template <typename T> template <typename T>
void *Tensor::FindTensor() const { void *Tensor::FindTensor() const {
...@@ -513,6 +537,26 @@ void *Tensor::FindTensor() const { ...@@ -513,6 +537,26 @@ void *Tensor::FindTensor() const {
} }
std::vector<int> Tensor::shape() const { std::vector<int> Tensor::shape() const {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
std::vector<int> shape;
// input handle
if (idx_ < 0) {
shape.assign(shape_.begin(), shape_.end());
} else { // output handle
auto binding = binding_.lock();
PADDLE_ENFORCE_NOT_NULL(binding,
paddle::platform::errors::PreconditionNotMet(
"output tensor [%s] no binding ptr", name_));
std::vector<Ort::Value> outputs = binding->GetOutputValues();
Ort::Value &value = outputs[idx_];
auto info = value.GetTensorTypeAndShapeInfo();
auto ort_shape = info.GetShape();
shape.assign(ort_shape.begin(), ort_shape.end());
}
return shape;
}
#endif
EAGER_GET_TENSOR(paddle::framework::LoDTensor); EAGER_GET_TENSOR(paddle::framework::LoDTensor);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
tensor_, paddle::platform::errors::PreconditionNotMet( tensor_, paddle::platform::errors::PreconditionNotMet(
...@@ -573,4 +617,99 @@ void Tensor::SetPlace(PlaceType place, int device) { ...@@ -573,4 +617,99 @@ void Tensor::SetPlace(PlaceType place, int device) {
device_ = device; device_ = device;
} }
#ifdef PADDLE_WITH_ONNXRUNTIME
void Tensor::SetOrtMark(bool is_ort_tensor) { is_ort_tensor_ = is_ort_tensor; }
void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
binding_ = binding;
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, float *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<float>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int64_t *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<int64_t>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int32_t *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<int32_t>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, uint8_t *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<uint8_t>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int8_t *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<int8_t>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, float16 *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor(memory_info, static_cast<void *>(data),
size * sizeof(float16), shape, shape_len,
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
}
template <typename T>
void Tensor::ORTCopyFromCpu(const T *data) {
auto binding = binding_.lock();
PADDLE_ENFORCE_NOT_NULL(binding,
paddle::platform::errors::PreconditionNotMet(
"input tensor [%s] no binding ptr", name_));
const char *device_name = place_ == PlaceType::kCPU ? "Cpu" : "Cuda";
Ort::MemoryInfo memory_info(device_name, OrtDeviceAllocator, device_,
OrtMemTypeDefault);
size_t size = std::accumulate(begin(shape_), end(shape_), 1UL,
std::multiplies<size_t>());
auto ort_value = GetOrtVaule(memory_info, const_cast<T *>(data), size,
shape_.data(), shape_.size());
binding->BindInput(name_.c_str(), ort_value);
}
template <typename T>
void Tensor::ORTCopyToCpu(T *data) const {
auto binding = binding_.lock();
PADDLE_ENFORCE_NOT_NULL(binding,
paddle::platform::errors::PreconditionNotMet(
"output tensor [%s] no binding ptr", name_));
std::vector<Ort::Value> outputs = binding->GetOutputValues();
Ort::Value &value = outputs[idx_];
auto info = value.GetTensorTypeAndShapeInfo();
size_t size = info.GetElementCount() * sizeof(T);
if (place_ == PlaceType::kCPU) {
std::memcpy(static_cast<void *>(data), value.GetTensorData<void *>(), size);
} else {
paddle::memory::Copy(paddle::platform::CPUPlace(),
static_cast<void *>(data),
paddle::platform::CUDAPlace(device_),
value.GetTensorData<void>(), size, nullptr);
}
}
template void Tensor::ORTCopyFromCpu<float>(const float *data);
template void Tensor::ORTCopyFromCpu<int64_t>(const int64_t *data);
template void Tensor::ORTCopyFromCpu<int32_t>(const int32_t *data);
template void Tensor::ORTCopyFromCpu<uint8_t>(const uint8_t *data);
template void Tensor::ORTCopyFromCpu<int8_t>(const int8_t *data);
template void Tensor::ORTCopyFromCpu<float16>(const float16 *data);
template void Tensor::ORTCopyToCpu<float>(float *data) const;
template void Tensor::ORTCopyToCpu<int32_t>(int32_t *data) const;
template void Tensor::ORTCopyToCpu<uint8_t>(uint8_t *data) const;
template void Tensor::ORTCopyToCpu<int8_t>(int8_t *data) const;
template void Tensor::ORTCopyToCpu<float16>(float16 *data) const;
#endif
} // namespace paddle_infer } // namespace paddle_infer
...@@ -25,11 +25,7 @@ ...@@ -25,11 +25,7 @@
#include <vector> #include <vector>
#include "paddle/fluid//platform/device/gpu/gpu_types.h" #include "paddle/fluid//platform/device/gpu/gpu_types.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/framework/version.h" #include "paddle/fluid/framework/version.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h" #include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
...@@ -45,24 +41,23 @@ ...@@ -45,24 +41,23 @@
namespace paddle { namespace paddle {
framework::proto::VarType::Type ConvertONNXType( paddle_infer::DataType ConvertONNXType(ONNXTensorElementDataType type) {
ONNXTensorElementDataType type) {
switch (type) { switch (type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return framework::proto::VarType::FP32; return paddle_infer::DataType::FLOAT32;
// case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
// return DataType::FP16; return paddle_infer::DataType::FLOAT16;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
return framework::proto::VarType::INT8; return paddle_infer::DataType::INT8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
return framework::proto::VarType::INT32; return paddle_infer::DataType::INT32;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
return framework::proto::VarType::INT64; return paddle_infer::DataType::INT64;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return framework::proto::VarType::UINT8; return paddle_infer::DataType::UINT8;
default: default:
LOG(ERROR) << "unsupported ONNX Tensor Type: " << static_cast<int>(type); LOG(ERROR) << "unsupported ONNX Tensor Type: " << static_cast<int>(type);
return framework::proto::VarType::FP32; return paddle_infer::DataType::FLOAT32;
} }
} }
...@@ -87,13 +82,12 @@ bool ONNXRuntimePredictor::Init() { ...@@ -87,13 +82,12 @@ bool ONNXRuntimePredictor::Init() {
VLOG(3) << "ONNXRuntime Predictor::init()"; VLOG(3) << "ONNXRuntime Predictor::init()";
// Now ONNXRuntime only suuport CPU // Now ONNXRuntime only suuport CPU
const char *device_name = config_.use_gpu() ? "Cuda" : "Cpu";
if (config_.use_gpu()) { if (config_.use_gpu()) {
place_ = paddle::platform::CUDAPlace(config_.gpu_device_id()); place_ = paddle::platform::CUDAPlace(config_.gpu_device_id());
} else { } else {
place_ = paddle::platform::CPUPlace(); place_ = paddle::platform::CPUPlace();
} }
scope_.reset(new paddle::framework::Scope());
sub_scope_ = &scope_->NewScope();
std::string onnx_proto; std::string onnx_proto;
paddle2onnx::Export(config_.prog_file(), config_.params_file(), &onnx_proto, paddle2onnx::Export(config_.prog_file(), config_.params_file(), &onnx_proto,
...@@ -125,13 +119,12 @@ bool ONNXRuntimePredictor::Init() { ...@@ -125,13 +119,12 @@ bool ONNXRuntimePredictor::Init() {
"generated."; "generated.";
} }
session_ = {env_, onnx_proto.data(), onnx_proto.size(), session_options}; session_ = {env_, onnx_proto.data(), onnx_proto.size(), session_options};
binding_ = std::make_shared<Ort::IoBinding>(session_);
auto memory_info = Ort::MemoryInfo memory_info(device_name, OrtDeviceAllocator,
Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); place_.GetDeviceId(), OrtMemTypeDefault);
Ort::Allocator allocator(session_, memory_info); Ort::Allocator allocator(session_, memory_info);
framework::proto::VarType::Type proto_type =
framework::proto::VarType::LOD_TENSOR;
size_t n_inputs = session_.GetInputCount(); size_t n_inputs = session_.GetInputCount();
for (size_t i = 0; i < n_inputs; ++i) { for (size_t i = 0; i < n_inputs; ++i) {
auto input_name = session_.GetInputName(i, allocator); auto input_name = session_.GetInputName(i, allocator);
...@@ -141,8 +134,6 @@ bool ONNXRuntimePredictor::Init() { ...@@ -141,8 +134,6 @@ bool ONNXRuntimePredictor::Init() {
ONNXTensorElementDataType data_type = ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType(); type_info.GetTensorTypeAndShapeInfo().GetElementType();
input_desc_.emplace_back(ONNXDesc{input_name, shape, data_type}); input_desc_.emplace_back(ONNXDesc{input_name, shape, data_type});
auto *ptr = scope_->Var(input_name);
framework::InitializeVariable(ptr, proto_type);
allocator.Free(input_name); allocator.Free(input_name);
} }
...@@ -155,11 +146,13 @@ bool ONNXRuntimePredictor::Init() { ...@@ -155,11 +146,13 @@ bool ONNXRuntimePredictor::Init() {
ONNXTensorElementDataType data_type = ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType(); type_info.GetTensorTypeAndShapeInfo().GetElementType();
output_desc_.emplace_back(ONNXDesc{output_name, shape, data_type}); output_desc_.emplace_back(ONNXDesc{output_name, shape, data_type});
auto *ptr = scope_->Var(output_name);
framework::InitializeVariable(ptr, proto_type); Ort::MemoryInfo out_memory_info(device_name, OrtDeviceAllocator,
place_.GetDeviceId(), OrtMemTypeDefault);
binding_->BindOutput(output_name, out_memory_info);
allocator.Free(output_name); allocator.Free(output_name);
} }
return true; return true;
} }
...@@ -216,15 +209,26 @@ std::vector<std::string> ONNXRuntimePredictor::GetOutputNames() { ...@@ -216,15 +209,26 @@ std::vector<std::string> ONNXRuntimePredictor::GetOutputNames() {
return output_names; return output_names;
} }
bool ONNXRuntimePredictor::FindONNXDesc(const std::string &name,
bool is_input) {
if (is_input) {
for (auto i : input_desc_)
if (i.name == name) return true;
} else {
for (auto i : output_desc_)
if (i.name == name) return true;
}
return false;
}
std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor( std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor(
const std::string &name) { const std::string &name) {
PADDLE_ENFORCE_NOT_NULL(scope_->FindVar(name), PADDLE_ENFORCE_EQ(FindONNXDesc(name, true), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The in variable named %s is not found in the " "The in variable named %s is not found in the "
"scope of the ONNXPredictor.", "ONNXPredictor.",
name)); name));
std::unique_ptr<ZeroCopyTensor> res( std::unique_ptr<ZeroCopyTensor> res(new ZeroCopyTensor(nullptr));
new ZeroCopyTensor(static_cast<void *>(scope_.get())));
res->input_or_output_ = true; res->input_or_output_ = true;
res->SetName(name); res->SetName(name);
if (platform::is_cpu_place(place_)) { if (platform::is_cpu_place(place_)) {
...@@ -233,18 +237,19 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor( ...@@ -233,18 +237,19 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor(
auto gpu_place = place_; auto gpu_place = place_;
res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId()); res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId());
} }
res->SetOrtMark(true);
res->SetOrtBinding(binding_);
return res; return res;
} }
std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetOutputTensor( std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetOutputTensor(
const std::string &name) { const std::string &name) {
PADDLE_ENFORCE_NOT_NULL(scope_->FindVar(name), PADDLE_ENFORCE_EQ(FindONNXDesc(name, false), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The out variable named %s is not found in the " "The out variable named %s is not found in the "
"scope of the ONNXPredictor.", "ONNXPredictor.",
name)); name));
std::unique_ptr<ZeroCopyTensor> res( std::unique_ptr<ZeroCopyTensor> res(new ZeroCopyTensor(nullptr));
new ZeroCopyTensor(static_cast<void *>(scope_.get())));
res->input_or_output_ = false; res->input_or_output_ = false;
res->SetName(name); res->SetName(name);
if (platform::is_cpu_place(place_)) { if (platform::is_cpu_place(place_)) {
...@@ -253,46 +258,18 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetOutputTensor( ...@@ -253,46 +258,18 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetOutputTensor(
auto gpu_place = place_; auto gpu_place = place_;
res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId()); res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId());
} }
res->SetOrtMark(true);
res->SetOrtBinding(binding_);
int size = output_desc_.size();
for (int i = 0; i < size; ++i)
if (output_desc_[i].name == name) {
res->idx_ = i;
res->dtype_ = ConvertONNXType(output_desc_[i].dtype);
break;
}
return res; return res;
} }
Ort::Value ONNXRuntimePredictor::GetOrtValue(const ONNXDesc &desc,
const char *device_name) {
Ort::MemoryInfo memory_info(device_name, OrtDeviceAllocator,
place_.GetDeviceId(), OrtMemTypeDefault);
auto *var = scope_->FindVar(desc.name);
auto *tensor = var->GetMutable<framework::LoDTensor>();
size_t size =
tensor->numel() *
framework::SizeOfType(framework::TransToProtoVarType(tensor->dtype()));
std::vector<int64_t> shape = phi::vectorize<int64_t>(tensor->dims());
return Ort::Value::CreateTensor(memory_info,
static_cast<void *>(tensor->data()), size,
shape.data(), shape.size(), desc.dtype);
}
void ONNXRuntimePredictor::AsTensor(const Ort::Value &value,
const ONNXDesc &desc) {
auto info = value.GetTensorTypeAndShapeInfo();
auto *var = scope_->FindVar(desc.name);
auto *tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize(phi::make_ddim(info.GetShape()));
auto dtype = ConvertONNXType(info.GetElementType());
auto *ptr = tensor->mutable_data(place_, dtype);
if (platform::is_cpu_place(place_)) {
std::memcpy(ptr, const_cast<void *>(value.GetTensorData<void>()),
tensor->numel() * framework::SizeOfType(dtype));
} else {
auto src_place = place_;
auto dst_place = place_;
memory::Copy(dst_place, ptr, src_place,
const_cast<void *>(value.GetTensorData<void>()),
tensor->numel() * framework::SizeOfType(dtype));
}
}
bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs, bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data, std::vector<PaddleTensor> *output_data,
int batch_size) { int batch_size) {
...@@ -302,31 +279,7 @@ bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs, ...@@ -302,31 +279,7 @@ bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs,
bool ONNXRuntimePredictor::ZeroCopyRun() { bool ONNXRuntimePredictor::ZeroCopyRun() {
try { try {
Ort::IoBinding binding(session_); session_.Run({}, *(binding_.get()));
std::vector<Ort::Value> inputs;
std::vector<Ort::Value> outputs;
Ort::RunOptions options;
inputs.reserve(input_desc_.size());
const char *device_name = config_.use_gpu() ? "Cuda" : "Cpu";
for (auto desc : input_desc_) {
inputs.push_back(GetOrtValue(desc, device_name));
binding.BindInput(desc.name.c_str(), inputs.back());
}
// TODO(heliqi): Optimization —— move to Init()
for (auto desc : output_desc_) {
Ort::MemoryInfo memory_info(device_name, OrtDeviceAllocator,
place_.GetDeviceId(), OrtMemTypeDefault);
binding.BindOutput(desc.name.c_str(), memory_info);
}
session_.Run({}, binding);
outputs = binding.GetOutputValues();
for (size_t i = 0; i < output_desc_.size(); ++i) {
AsTensor(outputs[i], output_desc_[i]);
}
} catch (const std::exception &e) { } catch (const std::exception &e) {
LOG(ERROR) << e.what(); LOG(ERROR) << e.what();
return false; return false;
...@@ -345,9 +298,9 @@ uint64_t ONNXRuntimePredictor::TryShrinkMemory() { ...@@ -345,9 +298,9 @@ uint64_t ONNXRuntimePredictor::TryShrinkMemory() {
} }
ONNXRuntimePredictor::~ONNXRuntimePredictor() { ONNXRuntimePredictor::~ONNXRuntimePredictor() {
if (sub_scope_) { binding_->ClearBoundInputs();
scope_->DeleteScope(sub_scope_); binding_->ClearBoundOutputs();
}
memory::Release(place_); memory::Release(place_);
} }
......
...@@ -94,9 +94,8 @@ class ONNXRuntimePredictor : public PaddlePredictor { ...@@ -94,9 +94,8 @@ class ONNXRuntimePredictor : public PaddlePredictor {
/// \param[in] AnalysisConfig config /// \param[in] AnalysisConfig config
/// ///
explicit ONNXRuntimePredictor(const AnalysisConfig &config) explicit ONNXRuntimePredictor(const AnalysisConfig &config)
: config_(config) { : config_(config), env_(ORT_LOGGING_LEVEL_WARNING, "onnx") {
predictor_id_ = inference::GetUniqueId(); predictor_id_ = inference::GetUniqueId();
env_ = Ort::Env(ORT_LOGGING_LEVEL_INFO, "onnx");
} }
/// ///
/// \brief Destroy the ONNXRuntime Predictor object /// \brief Destroy the ONNXRuntime Predictor object
...@@ -177,30 +176,17 @@ class ONNXRuntimePredictor : public PaddlePredictor { ...@@ -177,30 +176,17 @@ class ONNXRuntimePredictor : public PaddlePredictor {
/// ///
std::unique_ptr<PaddlePredictor> Clone() override; std::unique_ptr<PaddlePredictor> Clone() override;
std::shared_ptr<framework::Scope> scope_;
private: private:
/// ///
/// \brief get the Ort Value(input Tensor). /// \brief Whether to find in/out by name.
///
/// \param[in] desc ONNXDesce(name、shape、dtype)
///
/// \param[in] device_name "cpu" or "gpu" of device
///
/// \return get a Ort::Value
///
Ort::Value GetOrtValue(const ONNXDesc &desc, const char *device_name);
///
/// \brief Ort::Value to Paddle::ZeroCopyTensor.
/// ///
/// \param[in] value Ort::Value(output Tensor) /// \param[in] name input or output name
/// ///
/// \param[in] desc a ONNXDesce(name、shape、dtype) /// \param[in] is_input input(true) or output(false)
/// ///
/// \return get a Ort::Value /// \return Whether to find by name
/// ///
void AsTensor(const Ort::Value &value, const ONNXDesc &desc); bool FindONNXDesc(const std::string &name, bool is_input);
private: private:
AnalysisConfig config_; AnalysisConfig config_;
...@@ -208,9 +194,9 @@ class ONNXRuntimePredictor : public PaddlePredictor { ...@@ -208,9 +194,9 @@ class ONNXRuntimePredictor : public PaddlePredictor {
// ONNXRuntime // ONNXRuntime
Ort::Env env_; Ort::Env env_;
Ort::Session session_{nullptr}; Ort::Session session_{nullptr};
std::shared_ptr<Ort::IoBinding> binding_;
platform::Place place_; platform::Place place_;
framework::Scope *sub_scope_{nullptr};
std::vector<ONNXDesc> input_desc_; std::vector<ONNXDesc> input_desc_;
std::vector<ONNXDesc> output_desc_; std::vector<ONNXDesc> output_desc_;
int predictor_id_; int predictor_id_;
......
...@@ -18,6 +18,11 @@ ...@@ -18,6 +18,11 @@
#include "paddle_infer_declare.h" // NOLINT #include "paddle_infer_declare.h" // NOLINT
#ifdef PADDLE_WITH_ONNXRUNTIME
#include "onnxruntime_c_api.h" // NOLINT
#include "onnxruntime_cxx_api.h" // NOLINT
#endif
namespace paddle_infer { namespace paddle_infer {
/// \brief Experimental. /// \brief Experimental.
...@@ -175,6 +180,23 @@ class PD_INFER_DECL Tensor { ...@@ -175,6 +180,23 @@ class PD_INFER_DECL Tensor {
PlaceType place_; PlaceType place_;
int device_; int device_;
#ifdef PADDLE_WITH_ONNXRUNTIME
bool is_ort_tensor_{false};
std::vector<int64_t> shape_;
std::weak_ptr<Ort::IoBinding> binding_;
int idx_{-1};
void SetOrtMark(bool is_ort_tensor);
void SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding);
template <typename T>
void ORTCopyFromCpu(const T* data);
template <typename T>
void ORTCopyToCpu(T* data) const;
#endif
friend class paddle_infer::contrib::TensorUtils; friend class paddle_infer::contrib::TensorUtils;
#if defined(PADDLE_WITH_TESTING) && defined(PADDLE_WITH_INFERENCE_API_TEST) #if defined(PADDLE_WITH_TESTING) && defined(PADDLE_WITH_INFERENCE_API_TEST)
friend class paddle_infer::InferApiTesterUtils; friend class paddle_infer::InferApiTesterUtils;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册