未验证 提交 fcbb7440 编写于 作者: H heliqi 提交者: GitHub

CopyFromCpu and CopyToCpu of Onnxruntime back-end optimize (#40561)

* add onnxruntime predictor

* Add code comments

* support link paddle2onnx onnxruntime

* support onnxruntime with python

* support onnxruntime with python

* support onnxruntime with windows

* paddle2onnx compile with windows

* supoort windows compile

* supoort windows compile with onnxruntime

* supoort windows compile with paddle2onnx

* supoort mac compile

* compile with mac

* compile with mac

* add code comments

* fix remind word

* code optimization

* add test case

* add test case

* add inference demo_ci test case

* fix compile paddle2onnx with no python

* add inference demo_ci test case

* add inference demo_ci test case

* add inference infer_ut test case

* support c go api and test cases

* add converage test case

* add converage test case

* add capi test case

* add capi test case

* fix onnxruntime copyfromcpu and copytocpu

* fix goapi

* modify code
上级 da558f0e
......@@ -61,6 +61,7 @@ set(PADDLE2ONNX_OPTIONAL_ARGS
-DONNX_CUSTOM_PROTOC_PATH=${PROTOC_BIN_PATH}
-DWITH_STATIC=OFF
-DCMAKE_INSTALL_PREFIX=${PADDLE2ONNX_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${PADDLE2ONNX_INSTALL_DIR}/${LIBDIR}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS}
......
......@@ -14,7 +14,11 @@
#
cc_library(reset_tensor_array SRCS reset_tensor_array.cc DEPS lod_tensor scope)
cc_library(zero_copy_tensor SRCS zero_copy_tensor.cc DEPS scope lod_tensor enforce)
if (WITH_ONNXRUNTIME)
cc_library(zero_copy_tensor SRCS zero_copy_tensor.cc DEPS scope lod_tensor enforce onnxruntime)
else (WITH_ONNXRUNTIME)
cc_library(zero_copy_tensor SRCS zero_copy_tensor.cc DEPS scope lod_tensor enforce)
endif (WITH_ONNXRUNTIME)
cc_library(zero_copy_tensor_dummy SRCS zero_copy_tensor_dummy.cc)
cc_test(zero_copy_tensor_test SRCS zero_copy_tensor_test.cc DEPS paddle_inference_api)
......@@ -22,12 +22,22 @@
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/allocator.h"
#ifdef PADDLE_WITH_ONNXRUNTIME
#include "paddle/fluid/inference/api/onnxruntime_predictor.h"
#endif
namespace paddle_infer {
using float16 = paddle::platform::float16;
void Tensor::Reshape(const std::vector<int> &shape) {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
shape_.assign(shape.begin(), shape.end());
return;
}
#endif
PADDLE_ENFORCE_EQ(
name_.empty(), false,
paddle::platform::errors::PreconditionNotMet(
......@@ -123,6 +133,11 @@ T *Tensor::data(PlaceType *place, int *size) const {
}
DataType Tensor::type() const {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
return dtype_;
}
#endif
EAGER_GET_TENSOR(paddle::framework::LoDTensor);
auto type = paddle::framework::TransToProtoVarType(tensor->dtype());
if (type == paddle::framework::proto::VarType::FP32) {
......@@ -145,6 +160,13 @@ PlaceType Tensor::place() const { return place_; }
template <typename T>
void Tensor::CopyFromCpu(const T *data) {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
ORTCopyFromCpu<T>(data);
return;
}
#endif
EAGER_GET_TENSOR(paddle::framework::LoDTensor);
PADDLE_ENFORCE_GE(tensor->numel(), 0,
paddle::platform::errors::PreconditionNotMet(
......@@ -382,6 +404,13 @@ void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb,
template <typename T>
void Tensor::CopyToCpu(T *data) const {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
ORTCopyToCpu<T>(data);
return;
}
#endif
CopyToCpuImpl<T>(data, nullptr, nullptr, nullptr);
}
......@@ -489,12 +518,7 @@ template PD_INFER_DECL uint8_t *Tensor::mutable_data<uint8_t>(PlaceType place);
template PD_INFER_DECL int8_t *Tensor::mutable_data<int8_t>(PlaceType place);
template PD_INFER_DECL float16 *Tensor::mutable_data<float16>(PlaceType place);
Tensor::Tensor(void *scope) : scope_{scope} {
PADDLE_ENFORCE_NOT_NULL(scope_,
paddle::platform::errors::PreconditionNotMet(
"The `scope` can not be nullptr. It should be "
"set to the pointer of scope."));
}
Tensor::Tensor(void *scope) : scope_{scope} {}
template <typename T>
void *Tensor::FindTensor() const {
......@@ -513,6 +537,26 @@ void *Tensor::FindTensor() const {
}
std::vector<int> Tensor::shape() const {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
std::vector<int> shape;
// input handle
if (idx_ < 0) {
shape.assign(shape_.begin(), shape_.end());
} else { // output handle
auto binding = binding_.lock();
PADDLE_ENFORCE_NOT_NULL(binding,
paddle::platform::errors::PreconditionNotMet(
"output tensor [%s] no binding ptr", name_));
std::vector<Ort::Value> outputs = binding->GetOutputValues();
Ort::Value &value = outputs[idx_];
auto info = value.GetTensorTypeAndShapeInfo();
auto ort_shape = info.GetShape();
shape.assign(ort_shape.begin(), ort_shape.end());
}
return shape;
}
#endif
EAGER_GET_TENSOR(paddle::framework::LoDTensor);
PADDLE_ENFORCE_NOT_NULL(
tensor_, paddle::platform::errors::PreconditionNotMet(
......@@ -573,4 +617,99 @@ void Tensor::SetPlace(PlaceType place, int device) {
device_ = device;
}
#ifdef PADDLE_WITH_ONNXRUNTIME
void Tensor::SetOrtMark(bool is_ort_tensor) { is_ort_tensor_ = is_ort_tensor; }
void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
binding_ = binding;
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, float *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<float>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int64_t *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<int64_t>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int32_t *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<int32_t>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, uint8_t *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<uint8_t>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int8_t *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<int8_t>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, float16 *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor(memory_info, static_cast<void *>(data),
size * sizeof(float16), shape, shape_len,
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
}
template <typename T>
void Tensor::ORTCopyFromCpu(const T *data) {
auto binding = binding_.lock();
PADDLE_ENFORCE_NOT_NULL(binding,
paddle::platform::errors::PreconditionNotMet(
"input tensor [%s] no binding ptr", name_));
const char *device_name = place_ == PlaceType::kCPU ? "Cpu" : "Cuda";
Ort::MemoryInfo memory_info(device_name, OrtDeviceAllocator, device_,
OrtMemTypeDefault);
size_t size = std::accumulate(begin(shape_), end(shape_), 1UL,
std::multiplies<size_t>());
auto ort_value = GetOrtVaule(memory_info, const_cast<T *>(data), size,
shape_.data(), shape_.size());
binding->BindInput(name_.c_str(), ort_value);
}
template <typename T>
void Tensor::ORTCopyToCpu(T *data) const {
auto binding = binding_.lock();
PADDLE_ENFORCE_NOT_NULL(binding,
paddle::platform::errors::PreconditionNotMet(
"output tensor [%s] no binding ptr", name_));
std::vector<Ort::Value> outputs = binding->GetOutputValues();
Ort::Value &value = outputs[idx_];
auto info = value.GetTensorTypeAndShapeInfo();
size_t size = info.GetElementCount() * sizeof(T);
if (place_ == PlaceType::kCPU) {
std::memcpy(static_cast<void *>(data), value.GetTensorData<void *>(), size);
} else {
paddle::memory::Copy(paddle::platform::CPUPlace(),
static_cast<void *>(data),
paddle::platform::CUDAPlace(device_),
value.GetTensorData<void>(), size, nullptr);
}
}
template void Tensor::ORTCopyFromCpu<float>(const float *data);
template void Tensor::ORTCopyFromCpu<int64_t>(const int64_t *data);
template void Tensor::ORTCopyFromCpu<int32_t>(const int32_t *data);
template void Tensor::ORTCopyFromCpu<uint8_t>(const uint8_t *data);
template void Tensor::ORTCopyFromCpu<int8_t>(const int8_t *data);
template void Tensor::ORTCopyFromCpu<float16>(const float16 *data);
template void Tensor::ORTCopyToCpu<float>(float *data) const;
template void Tensor::ORTCopyToCpu<int32_t>(int32_t *data) const;
template void Tensor::ORTCopyToCpu<uint8_t>(uint8_t *data) const;
template void Tensor::ORTCopyToCpu<int8_t>(int8_t *data) const;
template void Tensor::ORTCopyToCpu<float16>(float16 *data) const;
#endif
} // namespace paddle_infer
......@@ -25,11 +25,7 @@
#include <vector>
#include "paddle/fluid//platform/device/gpu/gpu_types.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
......@@ -45,24 +41,23 @@
namespace paddle {
framework::proto::VarType::Type ConvertONNXType(
ONNXTensorElementDataType type) {
paddle_infer::DataType ConvertONNXType(ONNXTensorElementDataType type) {
switch (type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return framework::proto::VarType::FP32;
// case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
// return DataType::FP16;
return paddle_infer::DataType::FLOAT32;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
return paddle_infer::DataType::FLOAT16;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
return framework::proto::VarType::INT8;
return paddle_infer::DataType::INT8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
return framework::proto::VarType::INT32;
return paddle_infer::DataType::INT32;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
return framework::proto::VarType::INT64;
return paddle_infer::DataType::INT64;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return framework::proto::VarType::UINT8;
return paddle_infer::DataType::UINT8;
default:
LOG(ERROR) << "unsupported ONNX Tensor Type: " << static_cast<int>(type);
return framework::proto::VarType::FP32;
return paddle_infer::DataType::FLOAT32;
}
}
......@@ -87,13 +82,12 @@ bool ONNXRuntimePredictor::Init() {
VLOG(3) << "ONNXRuntime Predictor::init()";
// Now ONNXRuntime only suuport CPU
const char *device_name = config_.use_gpu() ? "Cuda" : "Cpu";
if (config_.use_gpu()) {
place_ = paddle::platform::CUDAPlace(config_.gpu_device_id());
} else {
place_ = paddle::platform::CPUPlace();
}
scope_.reset(new paddle::framework::Scope());
sub_scope_ = &scope_->NewScope();
std::string onnx_proto;
paddle2onnx::Export(config_.prog_file(), config_.params_file(), &onnx_proto,
......@@ -125,13 +119,12 @@ bool ONNXRuntimePredictor::Init() {
"generated.";
}
session_ = {env_, onnx_proto.data(), onnx_proto.size(), session_options};
binding_ = std::make_shared<Ort::IoBinding>(session_);
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
Ort::MemoryInfo memory_info(device_name, OrtDeviceAllocator,
place_.GetDeviceId(), OrtMemTypeDefault);
Ort::Allocator allocator(session_, memory_info);
framework::proto::VarType::Type proto_type =
framework::proto::VarType::LOD_TENSOR;
size_t n_inputs = session_.GetInputCount();
for (size_t i = 0; i < n_inputs; ++i) {
auto input_name = session_.GetInputName(i, allocator);
......@@ -141,8 +134,6 @@ bool ONNXRuntimePredictor::Init() {
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
input_desc_.emplace_back(ONNXDesc{input_name, shape, data_type});
auto *ptr = scope_->Var(input_name);
framework::InitializeVariable(ptr, proto_type);
allocator.Free(input_name);
}
......@@ -155,11 +146,13 @@ bool ONNXRuntimePredictor::Init() {
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
output_desc_.emplace_back(ONNXDesc{output_name, shape, data_type});
auto *ptr = scope_->Var(output_name);
framework::InitializeVariable(ptr, proto_type);
Ort::MemoryInfo out_memory_info(device_name, OrtDeviceAllocator,
place_.GetDeviceId(), OrtMemTypeDefault);
binding_->BindOutput(output_name, out_memory_info);
allocator.Free(output_name);
}
return true;
}
......@@ -216,15 +209,26 @@ std::vector<std::string> ONNXRuntimePredictor::GetOutputNames() {
return output_names;
}
bool ONNXRuntimePredictor::FindONNXDesc(const std::string &name,
bool is_input) {
if (is_input) {
for (auto i : input_desc_)
if (i.name == name) return true;
} else {
for (auto i : output_desc_)
if (i.name == name) return true;
}
return false;
}
std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor(
const std::string &name) {
PADDLE_ENFORCE_NOT_NULL(scope_->FindVar(name),
platform::errors::PreconditionNotMet(
"The in variable named %s is not found in the "
"scope of the ONNXPredictor.",
name));
std::unique_ptr<ZeroCopyTensor> res(
new ZeroCopyTensor(static_cast<void *>(scope_.get())));
PADDLE_ENFORCE_EQ(FindONNXDesc(name, true), true,
platform::errors::PreconditionNotMet(
"The in variable named %s is not found in the "
"ONNXPredictor.",
name));
std::unique_ptr<ZeroCopyTensor> res(new ZeroCopyTensor(nullptr));
res->input_or_output_ = true;
res->SetName(name);
if (platform::is_cpu_place(place_)) {
......@@ -233,18 +237,19 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor(
auto gpu_place = place_;
res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId());
}
res->SetOrtMark(true);
res->SetOrtBinding(binding_);
return res;
}
std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetOutputTensor(
const std::string &name) {
PADDLE_ENFORCE_NOT_NULL(scope_->FindVar(name),
platform::errors::PreconditionNotMet(
"The out variable named %s is not found in the "
"scope of the ONNXPredictor.",
name));
std::unique_ptr<ZeroCopyTensor> res(
new ZeroCopyTensor(static_cast<void *>(scope_.get())));
PADDLE_ENFORCE_EQ(FindONNXDesc(name, false), true,
platform::errors::PreconditionNotMet(
"The out variable named %s is not found in the "
"ONNXPredictor.",
name));
std::unique_ptr<ZeroCopyTensor> res(new ZeroCopyTensor(nullptr));
res->input_or_output_ = false;
res->SetName(name);
if (platform::is_cpu_place(place_)) {
......@@ -253,46 +258,18 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetOutputTensor(
auto gpu_place = place_;
res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId());
}
res->SetOrtMark(true);
res->SetOrtBinding(binding_);
int size = output_desc_.size();
for (int i = 0; i < size; ++i)
if (output_desc_[i].name == name) {
res->idx_ = i;
res->dtype_ = ConvertONNXType(output_desc_[i].dtype);
break;
}
return res;
}
Ort::Value ONNXRuntimePredictor::GetOrtValue(const ONNXDesc &desc,
const char *device_name) {
Ort::MemoryInfo memory_info(device_name, OrtDeviceAllocator,
place_.GetDeviceId(), OrtMemTypeDefault);
auto *var = scope_->FindVar(desc.name);
auto *tensor = var->GetMutable<framework::LoDTensor>();
size_t size =
tensor->numel() *
framework::SizeOfType(framework::TransToProtoVarType(tensor->dtype()));
std::vector<int64_t> shape = phi::vectorize<int64_t>(tensor->dims());
return Ort::Value::CreateTensor(memory_info,
static_cast<void *>(tensor->data()), size,
shape.data(), shape.size(), desc.dtype);
}
void ONNXRuntimePredictor::AsTensor(const Ort::Value &value,
const ONNXDesc &desc) {
auto info = value.GetTensorTypeAndShapeInfo();
auto *var = scope_->FindVar(desc.name);
auto *tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize(phi::make_ddim(info.GetShape()));
auto dtype = ConvertONNXType(info.GetElementType());
auto *ptr = tensor->mutable_data(place_, dtype);
if (platform::is_cpu_place(place_)) {
std::memcpy(ptr, const_cast<void *>(value.GetTensorData<void>()),
tensor->numel() * framework::SizeOfType(dtype));
} else {
auto src_place = place_;
auto dst_place = place_;
memory::Copy(dst_place, ptr, src_place,
const_cast<void *>(value.GetTensorData<void>()),
tensor->numel() * framework::SizeOfType(dtype));
}
}
bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data,
int batch_size) {
......@@ -302,31 +279,7 @@ bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs,
bool ONNXRuntimePredictor::ZeroCopyRun() {
try {
Ort::IoBinding binding(session_);
std::vector<Ort::Value> inputs;
std::vector<Ort::Value> outputs;
Ort::RunOptions options;
inputs.reserve(input_desc_.size());
const char *device_name = config_.use_gpu() ? "Cuda" : "Cpu";
for (auto desc : input_desc_) {
inputs.push_back(GetOrtValue(desc, device_name));
binding.BindInput(desc.name.c_str(), inputs.back());
}
// TODO(heliqi): Optimization —— move to Init()
for (auto desc : output_desc_) {
Ort::MemoryInfo memory_info(device_name, OrtDeviceAllocator,
place_.GetDeviceId(), OrtMemTypeDefault);
binding.BindOutput(desc.name.c_str(), memory_info);
}
session_.Run({}, binding);
outputs = binding.GetOutputValues();
for (size_t i = 0; i < output_desc_.size(); ++i) {
AsTensor(outputs[i], output_desc_[i]);
}
session_.Run({}, *(binding_.get()));
} catch (const std::exception &e) {
LOG(ERROR) << e.what();
return false;
......@@ -345,9 +298,9 @@ uint64_t ONNXRuntimePredictor::TryShrinkMemory() {
}
ONNXRuntimePredictor::~ONNXRuntimePredictor() {
if (sub_scope_) {
scope_->DeleteScope(sub_scope_);
}
binding_->ClearBoundInputs();
binding_->ClearBoundOutputs();
memory::Release(place_);
}
......
......@@ -94,9 +94,8 @@ class ONNXRuntimePredictor : public PaddlePredictor {
/// \param[in] AnalysisConfig config
///
explicit ONNXRuntimePredictor(const AnalysisConfig &config)
: config_(config) {
: config_(config), env_(ORT_LOGGING_LEVEL_WARNING, "onnx") {
predictor_id_ = inference::GetUniqueId();
env_ = Ort::Env(ORT_LOGGING_LEVEL_INFO, "onnx");
}
///
/// \brief Destroy the ONNXRuntime Predictor object
......@@ -177,30 +176,17 @@ class ONNXRuntimePredictor : public PaddlePredictor {
///
std::unique_ptr<PaddlePredictor> Clone() override;
std::shared_ptr<framework::Scope> scope_;
private:
///
/// \brief get the Ort Value(input Tensor).
///
/// \param[in] desc ONNXDesce(name、shape、dtype)
///
/// \param[in] device_name "cpu" or "gpu" of device
///
/// \return get a Ort::Value
///
Ort::Value GetOrtValue(const ONNXDesc &desc, const char *device_name);
///
/// \brief Ort::Value to Paddle::ZeroCopyTensor.
/// \brief Whether to find in/out by name.
///
/// \param[in] value Ort::Value(output Tensor)
/// \param[in] name input or output name
///
/// \param[in] desc a ONNXDesce(name、shape、dtype)
/// \param[in] is_input input(true) or output(false)
///
/// \return get a Ort::Value
/// \return Whether to find by name
///
void AsTensor(const Ort::Value &value, const ONNXDesc &desc);
bool FindONNXDesc(const std::string &name, bool is_input);
private:
AnalysisConfig config_;
......@@ -208,9 +194,9 @@ class ONNXRuntimePredictor : public PaddlePredictor {
// ONNXRuntime
Ort::Env env_;
Ort::Session session_{nullptr};
std::shared_ptr<Ort::IoBinding> binding_;
platform::Place place_;
framework::Scope *sub_scope_{nullptr};
std::vector<ONNXDesc> input_desc_;
std::vector<ONNXDesc> output_desc_;
int predictor_id_;
......
......@@ -18,6 +18,11 @@
#include "paddle_infer_declare.h" // NOLINT
#ifdef PADDLE_WITH_ONNXRUNTIME
#include "onnxruntime_c_api.h" // NOLINT
#include "onnxruntime_cxx_api.h" // NOLINT
#endif
namespace paddle_infer {
/// \brief Experimental.
......@@ -175,6 +180,23 @@ class PD_INFER_DECL Tensor {
PlaceType place_;
int device_;
#ifdef PADDLE_WITH_ONNXRUNTIME
bool is_ort_tensor_{false};
std::vector<int64_t> shape_;
std::weak_ptr<Ort::IoBinding> binding_;
int idx_{-1};
void SetOrtMark(bool is_ort_tensor);
void SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding);
template <typename T>
void ORTCopyFromCpu(const T* data);
template <typename T>
void ORTCopyToCpu(T* data) const;
#endif
friend class paddle_infer::contrib::TensorUtils;
#if defined(PADDLE_WITH_TESTING) && defined(PADDLE_WITH_INFERENCE_API_TEST)
friend class paddle_infer::InferApiTesterUtils;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册