未验证 提交 35297bd8 编写于 作者: H heliqi 提交者: GitHub

[cherry-pick]Ort backend optimizer(#44136 #44703 #44724) (#44766)

* [Inference]ort backend optimizer (#44136)

* add ort clone interface

* paddle2onnx update to 1.0.0rc

* ort input_tensor use mutable data of scope

* clone ort_predictor reuse session (#44703)

* ort backend support output mutable data (#44724)

* 2.3 interface is different from the Develop interface

* 2.3 interface is different from the Develop interface

* 2.3 interface is different from the Develop interface
上级 e7547ca7
...@@ -24,7 +24,7 @@ endif() ...@@ -24,7 +24,7 @@ endif()
include(ExternalProject) include(ExternalProject)
set(PADDLE2ONNX_PROJECT "extern_paddle2onnx") set(PADDLE2ONNX_PROJECT "extern_paddle2onnx")
set(PADDLE2ONNX_VERSION "0.9.9") set(PADDLE2ONNX_VERSION "1.0.0rc")
set(PADDLE2ONNX_PREFIX_DIR ${THIRD_PARTY_PATH}/paddle2onnx) set(PADDLE2ONNX_PREFIX_DIR ${THIRD_PARTY_PATH}/paddle2onnx)
set(PADDLE2ONNX_SOURCE_DIR set(PADDLE2ONNX_SOURCE_DIR
${THIRD_PARTY_PATH}/paddle2onnx/src/${PADDLE2ONNX_PROJECT}) ${THIRD_PARTY_PATH}/paddle2onnx/src/${PADDLE2ONNX_PROJECT})
......
...@@ -40,36 +40,42 @@ void Tensor::Reshape(const std::vector<int> &shape) { ...@@ -40,36 +40,42 @@ void Tensor::Reshape(const std::vector<int> &shape) {
#endif #endif
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
name_.empty(), false, name_.empty(),
false,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"Need to SetName first, so that the corresponding tensor can " "Need to SetName first, so that the corresponding tensor can "
"be retrieved.")); "be retrieved."));
PADDLE_ENFORCE_EQ(input_or_output_, true, PADDLE_ENFORCE_EQ(input_or_output_,
true,
paddle::platform::errors::PermissionDenied( paddle::platform::errors::PermissionDenied(
"Can't reshape the output tensor, it is readonly")); "Can't reshape the output tensor, it is readonly"));
auto *scope = static_cast<paddle::framework::Scope *>(scope_); auto *scope = static_cast<paddle::framework::Scope *>(scope_);
auto *var = scope->FindVar(name_); auto *var = scope->FindVar(name_);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var, paddle::platform::errors::PreconditionNotMet( var,
"No tensor called [%s] in the runtime scope", name_)); paddle::platform::errors::PreconditionNotMet(
"No tensor called [%s] in the runtime scope", name_));
auto *tensor = var->GetMutable<paddle::framework::LoDTensor>(); auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
tensor->Resize(phi::make_ddim(shape)); tensor->Resize(phi::make_ddim(shape));
} }
void Tensor::ReshapeStrings(const size_t &shape) { void Tensor::ReshapeStrings(const size_t &shape) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
name_.empty(), false, name_.empty(),
false,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"Need to SetName first, so that the corresponding tensor can " "Need to SetName first, so that the corresponding tensor can "
"be retrieved.")); "be retrieved."));
PADDLE_ENFORCE_EQ(input_or_output_, true, PADDLE_ENFORCE_EQ(input_or_output_,
true,
paddle::platform::errors::PermissionDenied( paddle::platform::errors::PermissionDenied(
"Can't reshape the output tensor, it is readonly")); "Can't reshape the output tensor, it is readonly"));
auto *scope = static_cast<paddle::framework::Scope *>(scope_); auto *scope = static_cast<paddle::framework::Scope *>(scope_);
auto *var = scope->FindVar(name_); auto *var = scope->FindVar(name_);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var, paddle::platform::errors::PreconditionNotMet( var,
"No tensor called [%s] in the runtime scope", name_)); paddle::platform::errors::PreconditionNotMet(
"No tensor called [%s] in the runtime scope", name_));
paddle_infer::Strings *tensor = var->GetMutable<paddle_infer::Strings>(); paddle_infer::Strings *tensor = var->GetMutable<paddle_infer::Strings>();
tensor->resize(shape); tensor->resize(shape);
} }
...@@ -82,9 +88,15 @@ void Tensor::ReshapeStrings(const size_t &shape) { ...@@ -82,9 +88,15 @@ void Tensor::ReshapeStrings(const size_t &shape) {
template <typename T> template <typename T>
T *Tensor::mutable_data(PlaceType place) { T *Tensor::mutable_data(PlaceType place) {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
return ORTGetMutableData<T>();
}
#endif
EAGER_GET_TENSOR(paddle::framework::LoDTensor); EAGER_GET_TENSOR(paddle::framework::LoDTensor);
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
tensor->numel(), 0, tensor->numel(),
0,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"You should call Tensor::Reshape(const std::vector<int> " "You should call Tensor::Reshape(const std::vector<int> "
"&shape)" "&shape)"
...@@ -161,15 +173,9 @@ PlaceType Tensor::place() const { return place_; } ...@@ -161,15 +173,9 @@ PlaceType Tensor::place() const { return place_; }
template <typename T> template <typename T>
void Tensor::CopyFromCpu(const T *data) { void Tensor::CopyFromCpu(const T *data) {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
ORTCopyFromCpu<T>(data);
return;
}
#endif
EAGER_GET_TENSOR(paddle::framework::LoDTensor); EAGER_GET_TENSOR(paddle::framework::LoDTensor);
PADDLE_ENFORCE_GE(tensor->numel(), 0, PADDLE_ENFORCE_GE(tensor->numel(),
0,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"You should call Tensor::Reshape(const " "You should call Tensor::Reshape(const "
"std::vector<int> &shape)" "std::vector<int> &shape)"
...@@ -188,8 +194,11 @@ void Tensor::CopyFromCpu(const T *data) { ...@@ -188,8 +194,11 @@ void Tensor::CopyFromCpu(const T *data) {
auto *dev_ctx = static_cast<const paddle::platform::CUDADeviceContext *>( auto *dev_ctx = static_cast<const paddle::platform::CUDADeviceContext *>(
pool.Get(gpu_place)); pool.Get(gpu_place));
paddle::memory::Copy(gpu_place, static_cast<void *>(t_data), paddle::memory::Copy(gpu_place,
paddle::platform::CPUPlace(), data, ele_size, static_cast<void *>(t_data),
paddle::platform::CPUPlace(),
data,
ele_size,
dev_ctx->stream()); dev_ctx->stream());
#else #else
PADDLE_THROW(paddle::platform::errors::Unavailable( PADDLE_THROW(paddle::platform::errors::Unavailable(
...@@ -200,8 +209,11 @@ void Tensor::CopyFromCpu(const T *data) { ...@@ -200,8 +209,11 @@ void Tensor::CopyFromCpu(const T *data) {
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
paddle::platform::XPUPlace xpu_place(device_); paddle::platform::XPUPlace xpu_place(device_);
auto *t_data = tensor->mutable_data<T>(xpu_place); auto *t_data = tensor->mutable_data<T>(xpu_place);
paddle::memory::Copy(xpu_place, static_cast<void *>(t_data), paddle::memory::Copy(xpu_place,
paddle::platform::CPUPlace(), data, ele_size); static_cast<void *>(t_data),
paddle::platform::CPUPlace(),
data,
ele_size);
#else #else
PADDLE_THROW(paddle::platform::errors::Unavailable( PADDLE_THROW(paddle::platform::errors::Unavailable(
"Can not create tensor with XPU place because paddle is not compiled " "Can not create tensor with XPU place because paddle is not compiled "
...@@ -215,8 +227,11 @@ void Tensor::CopyFromCpu(const T *data) { ...@@ -215,8 +227,11 @@ void Tensor::CopyFromCpu(const T *data) {
auto *t_data = tensor->mutable_data<T>(npu_place); auto *t_data = tensor->mutable_data<T>(npu_place);
auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>( auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
pool.Get(npu_place)); pool.Get(npu_place));
paddle::memory::Copy(npu_place, static_cast<void *>(t_data), paddle::memory::Copy(npu_place,
paddle::platform::CPUPlace(), data, ele_size, static_cast<void *>(t_data),
paddle::platform::CPUPlace(),
data,
ele_size,
dev_ctx->stream()); dev_ctx->stream());
#else #else
PADDLE_THROW(paddle::platform::errors::Unavailable( PADDLE_THROW(paddle::platform::errors::Unavailable(
...@@ -264,30 +279,33 @@ struct DataTypeInfo<int32_t> { ...@@ -264,30 +279,33 @@ struct DataTypeInfo<int32_t> {
paddle::experimental::DataLayout LayoutConvert(DataLayout layout) { paddle::experimental::DataLayout LayoutConvert(DataLayout layout) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
layout, DataLayout::kNCHW, layout,
DataLayout::kNCHW,
paddle::platform::errors::InvalidArgument("Only NCHW is supported now.")); paddle::platform::errors::InvalidArgument("Only NCHW is supported now."));
return paddle::experimental::DataLayout::NCHW; return paddle::experimental::DataLayout::NCHW;
} }
template <typename T> template <typename T>
void Tensor::ShareExternalData(const T *data, const std::vector<int> &shape, void Tensor::ShareExternalData(const T *data,
PlaceType place, DataLayout layout) { const std::vector<int> &shape,
PlaceType place,
DataLayout layout) {
EAGER_GET_TENSOR(paddle::framework::LoDTensor) EAGER_GET_TENSOR(paddle::framework::LoDTensor)
size_t size = size_t size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) * std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
sizeof(T); sizeof(T);
phi::DenseTensorMeta meta(DataTypeInfo<T>().TYPE, phi::make_ddim(shape), phi::DenseTensorMeta meta(
LayoutConvert(layout)); DataTypeInfo<T>().TYPE, phi::make_ddim(shape), LayoutConvert(layout));
if (place == PlaceType::kCPU) { if (place == PlaceType::kCPU) {
phi::DenseTensor dtensor( phi::DenseTensor dtensor(
std::make_shared<phi::Allocation>(const_cast<T *>(data), size, std::make_shared<phi::Allocation>(
paddle::platform::CPUPlace()), const_cast<T *>(data), size, paddle::platform::CPUPlace()),
meta); meta);
*tensor = std::move(dtensor); *tensor = std::move(dtensor);
} else if (place == PlaceType::kGPU) { } else if (place == PlaceType::kGPU) {
phi::DenseTensor dtensor( phi::DenseTensor dtensor(
std::make_shared<phi::Allocation>(const_cast<T *>(data), size, std::make_shared<phi::Allocation>(
paddle::platform::CUDAPlace(device_)), const_cast<T *>(data), size, paddle::platform::CUDAPlace(device_)),
meta); meta);
*tensor = std::move(dtensor); *tensor = std::move(dtensor);
} else { } else {
...@@ -298,7 +316,8 @@ void Tensor::ShareExternalData(const T *data, const std::vector<int> &shape, ...@@ -298,7 +316,8 @@ void Tensor::ShareExternalData(const T *data, const std::vector<int> &shape,
void Tensor::CopyStringsFromCpu(const paddle_infer::Strings *data) { void Tensor::CopyStringsFromCpu(const paddle_infer::Strings *data) {
EAGER_GET_TENSOR(paddle_infer::Strings); EAGER_GET_TENSOR(paddle_infer::Strings);
PADDLE_ENFORCE_GE(tensor->size(), 0, PADDLE_ENFORCE_GE(tensor->size(),
0,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"You should call Tensor::Reshape(const " "You should call Tensor::Reshape(const "
"std::size_t &shape)function before copying" "std::size_t &shape)function before copying"
...@@ -307,7 +326,9 @@ void Tensor::CopyStringsFromCpu(const paddle_infer::Strings *data) { ...@@ -307,7 +326,9 @@ void Tensor::CopyStringsFromCpu(const paddle_infer::Strings *data) {
} }
template <typename T> template <typename T>
void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb, void Tensor::CopyToCpuImpl(T *data,
void *exec_stream,
CallbackFunc cb,
void *cb_params) const { void *cb_params) const {
EAGER_GET_TENSOR(paddle::framework::LoDTensor); EAGER_GET_TENSOR(paddle::framework::LoDTensor);
auto ele_num = tensor->numel(); auto ele_num = tensor->numel();
...@@ -317,7 +338,8 @@ void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb, ...@@ -317,7 +338,8 @@ void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb,
paddle::framework::Tensor out; paddle::framework::Tensor out;
auto mem_allocation = auto mem_allocation =
std::make_shared<paddle::memory::allocation::Allocation>( std::make_shared<paddle::memory::allocation::Allocation>(
static_cast<void *>(data), ele_num * sizeof(T), static_cast<void *>(data),
ele_num * sizeof(T),
paddle::platform::CPUPlace()); paddle::platform::CPUPlace());
out.ResetHolder(mem_allocation); out.ResetHolder(mem_allocation);
...@@ -325,9 +347,13 @@ void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb, ...@@ -325,9 +347,13 @@ void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb,
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN) if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN)
paddle::framework::innerTransDataLayoutFromMKLDNN( paddle::framework::innerTransDataLayoutFromMKLDNN(
tensor->layout(), paddle::platform::MKLDNNDeviceContext::tls() tensor->layout(),
.get_cur_paddle_data_layout(), paddle::platform::MKLDNNDeviceContext::tls()
*tensor, &out, paddle::platform::CPUPlace(), true); .get_cur_paddle_data_layout(),
*tensor,
&out,
paddle::platform::CPUPlace(),
true);
else else
std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T)); std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else #else
...@@ -349,8 +375,11 @@ void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb, ...@@ -349,8 +375,11 @@ void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb,
auto *dev_ctx = static_cast<const paddle::platform::CUDADeviceContext *>( auto *dev_ctx = static_cast<const paddle::platform::CUDADeviceContext *>(
pool.Get(gpu_place)); pool.Get(gpu_place));
paddle::memory::Copy(paddle::platform::CPUPlace(), paddle::memory::Copy(paddle::platform::CPUPlace(),
static_cast<void *>(data), gpu_place, t_data, static_cast<void *>(data),
ele_num * sizeof(T), dev_ctx->stream()); gpu_place,
t_data,
ele_num * sizeof(T),
dev_ctx->stream());
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
hipStreamSynchronize(dev_ctx->stream()); hipStreamSynchronize(dev_ctx->stream());
#else #else
...@@ -374,7 +403,9 @@ void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb, ...@@ -374,7 +403,9 @@ void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb,
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
auto xpu_place = t_place; auto xpu_place = t_place;
paddle::memory::Copy(paddle::platform::CPUPlace(), paddle::memory::Copy(paddle::platform::CPUPlace(),
static_cast<void *>(data), xpu_place, t_data, static_cast<void *>(data),
xpu_place,
t_data,
ele_num * sizeof(T)); ele_num * sizeof(T));
#else #else
PADDLE_THROW(paddle::platform::errors::Unavailable( PADDLE_THROW(paddle::platform::errors::Unavailable(
...@@ -389,8 +420,11 @@ void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb, ...@@ -389,8 +420,11 @@ void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb,
auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>( auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
pool.Get(npu_place)); pool.Get(npu_place));
paddle::memory::Copy(paddle::platform::CPUPlace(), paddle::memory::Copy(paddle::platform::CPUPlace(),
static_cast<void *>(data), npu_place, t_data, static_cast<void *>(data),
ele_num * sizeof(T), dev_ctx->stream()); npu_place,
t_data,
ele_num * sizeof(T),
dev_ctx->stream());
paddle::platform::NPUStreamSync(dev_ctx->stream()); paddle::platform::NPUStreamSync(dev_ctx->stream());
#else #else
PADDLE_THROW(paddle::platform::errors::Unavailable( PADDLE_THROW(paddle::platform::errors::Unavailable(
...@@ -433,22 +467,34 @@ template PD_INFER_DECL void Tensor::CopyFromCpu<int8_t>(const int8_t *data); ...@@ -433,22 +467,34 @@ template PD_INFER_DECL void Tensor::CopyFromCpu<int8_t>(const int8_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<float16>(const float16 *data); template PD_INFER_DECL void Tensor::CopyFromCpu<float16>(const float16 *data);
template PD_INFER_DECL void Tensor::ShareExternalData<float>( template PD_INFER_DECL void Tensor::ShareExternalData<float>(
const float *data, const std::vector<int> &shape, PlaceType place, const float *data,
const std::vector<int> &shape,
PlaceType place,
DataLayout layout); DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int64_t>( template PD_INFER_DECL void Tensor::ShareExternalData<int64_t>(
const int64_t *data, const std::vector<int> &shape, PlaceType place, const int64_t *data,
const std::vector<int> &shape,
PlaceType place,
DataLayout layout); DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int32_t>( template PD_INFER_DECL void Tensor::ShareExternalData<int32_t>(
const int32_t *data, const std::vector<int> &shape, PlaceType place, const int32_t *data,
const std::vector<int> &shape,
PlaceType place,
DataLayout layout); DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<uint8_t>( template PD_INFER_DECL void Tensor::ShareExternalData<uint8_t>(
const uint8_t *data, const std::vector<int> &shape, PlaceType place, const uint8_t *data,
const std::vector<int> &shape,
PlaceType place,
DataLayout layout); DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int8_t>( template PD_INFER_DECL void Tensor::ShareExternalData<int8_t>(
const int8_t *data, const std::vector<int> &shape, PlaceType place, const int8_t *data,
const std::vector<int> &shape,
PlaceType place,
DataLayout layout); DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<float16>( template PD_INFER_DECL void Tensor::ShareExternalData<float16>(
const float16 *data, const std::vector<int> &shape, PlaceType place, const float16 *data,
const std::vector<int> &shape,
PlaceType place,
DataLayout layout); DataLayout layout);
template PD_INFER_DECL void Tensor::CopyToCpu<float>(float *data) const; template PD_INFER_DECL void Tensor::CopyToCpu<float>(float *data) const;
...@@ -524,15 +570,17 @@ Tensor::Tensor(void *scope) : scope_{scope} {} ...@@ -524,15 +570,17 @@ Tensor::Tensor(void *scope) : scope_{scope} {}
template <typename T> template <typename T>
void *Tensor::FindTensor() const { void *Tensor::FindTensor() const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
name_.empty(), false, name_.empty(),
false,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"Need to SetName first, so that the corresponding tensor can " "Need to SetName first, so that the corresponding tensor can "
"be retrieved.")); "be retrieved."));
auto *scope = static_cast<paddle::framework::Scope *>(scope_); auto *scope = static_cast<paddle::framework::Scope *>(scope_);
auto *var = scope->FindVar(name_); auto *var = scope->FindVar(name_);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var, paddle::platform::errors::PreconditionNotMet( var,
"No tensor called [%s] in the runtime scope", name_)); paddle::platform::errors::PreconditionNotMet(
"No tensor called [%s] in the runtime scope", name_));
auto *tensor = var->GetMutable<T>(); auto *tensor = var->GetMutable<T>();
return tensor; return tensor;
} }
...@@ -560,8 +608,9 @@ std::vector<int> Tensor::shape() const { ...@@ -560,8 +608,9 @@ std::vector<int> Tensor::shape() const {
#endif #endif
EAGER_GET_TENSOR(paddle::framework::LoDTensor); EAGER_GET_TENSOR(paddle::framework::LoDTensor);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
tensor_, paddle::platform::errors::PreconditionNotMet( tensor_,
"Not found tensor called %s in the scope", name_)); paddle::platform::errors::PreconditionNotMet(
"Not found tensor called %s in the scope", name_));
// mkldnn may does layout transform internally, so need to reorder before // mkldnn may does layout transform internally, so need to reorder before
// return // return
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -626,91 +675,15 @@ void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) { ...@@ -626,91 +675,15 @@ void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
binding_ = binding; binding_ = binding;
} }
void Tensor::SetOrtBuffer(const std::shared_ptr<std::vector<int8_t>> buffer) {
buffer_ = buffer;
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, float *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<float>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int64_t *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<int64_t>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int32_t *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<int32_t>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, uint8_t *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<uint8_t>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int8_t *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor<int8_t>(memory_info, data, size, shape,
shape_len);
}
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, float16 *data,
size_t size, const int64_t *shape, size_t shape_len) {
return Ort::Value::CreateTensor(memory_info, static_cast<void *>(data),
size * sizeof(float16), shape, shape_len,
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
}
template <typename T> template <typename T>
void Tensor::ORTCopyFromCpu(const T *data) { T *Tensor::ORTGetMutableData() {
auto binding = binding_.lock(); auto binding = binding_.lock();
PADDLE_ENFORCE_NOT_NULL(binding, PADDLE_ENFORCE_NOT_NULL(binding,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"input tensor [%s] no binding ptr", name_)); "output tensor [%s] no binding ptr", name_));
const char *device_name = place_ == PlaceType::kCPU ? "Cpu" : "Cuda"; std::vector<Ort::Value> outputs = binding->GetOutputValues();
Ort::MemoryInfo memory_info(device_name, OrtDeviceAllocator, device_, Ort::Value &value = outputs[idx_];
OrtMemTypeDefault); return value.GetTensorMutableData<T>();
size_t size = std::accumulate(begin(shape_), end(shape_), 1UL,
std::multiplies<size_t>());
auto buffer = buffer_.lock();
size_t buffer_size = size * sizeof(T);
if (buffer_size > buffer->size()) {
buffer->resize(buffer_size);
}
std::memcpy(static_cast<void *>(buffer->data()), data, buffer_size);
auto onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
if (std::is_same<T, float>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
} else if (std::is_same<T, double>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
} else if (std::is_same<T, int64_t>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
} else if (std::is_same<T, int32_t>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
} else if (std::is_same<T, uint8_t>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
} else if (std::is_same<T, int8_t>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
} else if (std::is_same<T, float16>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Found undefined data type for onnxruntime, only supports "
"float16/float32/float64/int8/uint8/int32/int64."));
}
auto ort_value =
Ort::Value::CreateTensor(memory_info, buffer->data(), buffer_size,
shape_.data(), shape_.size(), onnx_dtype);
binding->BindInput(name_.c_str(), ort_value);
} }
template <typename T> template <typename T>
...@@ -733,13 +706,6 @@ void Tensor::ORTCopyToCpu(T *data) const { ...@@ -733,13 +706,6 @@ void Tensor::ORTCopyToCpu(T *data) const {
} }
} }
template void Tensor::ORTCopyFromCpu<float>(const float *data);
template void Tensor::ORTCopyFromCpu<int64_t>(const int64_t *data);
template void Tensor::ORTCopyFromCpu<int32_t>(const int32_t *data);
template void Tensor::ORTCopyFromCpu<uint8_t>(const uint8_t *data);
template void Tensor::ORTCopyFromCpu<int8_t>(const int8_t *data);
template void Tensor::ORTCopyFromCpu<float16>(const float16 *data);
template void Tensor::ORTCopyToCpu<float>(float *data) const; template void Tensor::ORTCopyToCpu<float>(float *data) const;
template void Tensor::ORTCopyToCpu<int32_t>(int32_t *data) const; template void Tensor::ORTCopyToCpu<int32_t>(int32_t *data) const;
template void Tensor::ORTCopyToCpu<uint8_t>(uint8_t *data) const; template void Tensor::ORTCopyToCpu<uint8_t>(uint8_t *data) const;
......
...@@ -24,11 +24,10 @@ ...@@ -24,11 +24,10 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid//platform/device/gpu/gpu_types.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/version.h" #include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_pass.h" #include "paddle/fluid/inference/api/paddle_inference_pass.h"
...@@ -71,22 +70,23 @@ bool CheckConvertToONNX(const AnalysisConfig &config) { ...@@ -71,22 +70,23 @@ bool CheckConvertToONNX(const AnalysisConfig &config) {
} else if (config.prog_file().empty() || config.params_file().empty()) { } else if (config.prog_file().empty() || config.params_file().empty()) {
LOG(ERROR) << string::Sprintf( LOG(ERROR) << string::Sprintf(
"not valid model path '%s' or program path '%s' or params path '%s'.", "not valid model path '%s' or program path '%s' or params path '%s'.",
config.model_dir(), config.prog_file(), config.params_file()); config.model_dir(),
config.prog_file(),
config.params_file());
return false; return false;
} }
if (config.model_from_memory()) { if (config.model_from_memory()) {
return paddle2onnx::IsExportable( return paddle2onnx::IsExportable(config.prog_file().data(),
config.prog_file().data(), config.prog_file().size(), config.prog_file().size(),
config.params_file().data(), config.params_file().size()); config.params_file().data(),
config.params_file().size());
} else { } else {
return paddle2onnx::IsExportable(config.prog_file().c_str(), return paddle2onnx::IsExportable(config.prog_file().c_str(),
config.params_file().c_str()); config.params_file().c_str());
} }
} }
bool ONNXRuntimePredictor::Init() { bool ONNXRuntimePredictor::InitBinding() {
VLOG(3) << "ONNXRuntime Predictor::init()";
// Now ONNXRuntime only support CPU // Now ONNXRuntime only support CPU
const char *device_name = config_.use_gpu() ? "Cuda" : "Cpu"; const char *device_name = config_.use_gpu() ? "Cuda" : "Cpu";
if (config_.use_gpu()) { if (config_.use_gpu()) {
...@@ -94,16 +94,69 @@ bool ONNXRuntimePredictor::Init() { ...@@ -94,16 +94,69 @@ bool ONNXRuntimePredictor::Init() {
} else { } else {
place_ = paddle::platform::CPUPlace(); place_ = paddle::platform::CPUPlace();
} }
scope_.reset(new paddle::framework::Scope());
binding_ = std::make_shared<Ort::IoBinding>(*session_);
Ort::MemoryInfo memory_info(
device_name, OrtDeviceAllocator, place_.GetDeviceId(), OrtMemTypeDefault);
Ort::Allocator allocator(*session_, memory_info);
size_t n_inputs = session_->GetInputCount();
framework::proto::VarType::Type proto_type =
framework::proto::VarType::LOD_TENSOR;
for (size_t i = 0; i < n_inputs; ++i) {
auto input_name = session_->GetInputName(i, allocator);
auto type_info = session_->GetInputTypeInfo(i);
std::vector<int64_t> shape =
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
input_desc_.emplace_back(ONNXDesc{input_name, shape, data_type});
auto *ptr = scope_->Var(input_name);
framework::InitializeVariable(ptr, proto_type);
allocator.Free(input_name);
}
size_t n_outputs = session_->GetOutputCount();
for (size_t i = 0; i < n_outputs; ++i) {
auto output_name = session_->GetOutputName(i, allocator);
auto type_info = session_->GetOutputTypeInfo(i);
std::vector<int64_t> shape =
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
output_desc_.emplace_back(ONNXDesc{output_name, shape, data_type});
Ort::MemoryInfo out_memory_info(device_name,
OrtDeviceAllocator,
place_.GetDeviceId(),
OrtMemTypeDefault);
binding_->BindOutput(output_name, out_memory_info);
allocator.Free(output_name);
}
return true;
}
bool ONNXRuntimePredictor::Init() {
VLOG(3) << "ONNXRuntime Predictor::init()";
char *onnx_proto = nullptr; char *onnx_proto = nullptr;
int out_size; int out_size;
if (config_.model_from_memory()) { if (config_.model_from_memory()) {
paddle2onnx::Export(config_.prog_file().data(), config_.prog_file().size(), paddle2onnx::Export(config_.prog_file().data(),
config_.prog_file().size(),
config_.params_file().data(), config_.params_file().data(),
config_.params_file().size(), &onnx_proto, &out_size); config_.params_file().size(),
&onnx_proto,
&out_size);
} else { } else {
paddle2onnx::Export(config_.prog_file().c_str(), paddle2onnx::Export(config_.prog_file().c_str(),
config_.params_file().c_str(), &onnx_proto, &out_size); config_.params_file().c_str(),
&onnx_proto,
&out_size);
} }
Ort::SessionOptions session_options; Ort::SessionOptions session_options;
...@@ -131,42 +184,11 @@ bool ONNXRuntimePredictor::Init() { ...@@ -131,42 +184,11 @@ bool ONNXRuntimePredictor::Init() {
"will be " "will be "
"generated."; "generated.";
} }
session_ = {env_, onnx_proto, static_cast<size_t>(out_size), session_options}; session_ = std::make_shared<Ort::Session>(
binding_ = std::make_shared<Ort::IoBinding>(session_); *env_, onnx_proto, static_cast<size_t>(out_size), session_options);
InitBinding();
Ort::MemoryInfo memory_info(device_name, OrtDeviceAllocator, delete[] onnx_proto;
place_.GetDeviceId(), OrtMemTypeDefault);
Ort::Allocator allocator(session_, memory_info);
size_t n_inputs = session_.GetInputCount();
for (size_t i = 0; i < n_inputs; ++i) {
auto input_name = session_.GetInputName(i, allocator);
auto type_info = session_.GetInputTypeInfo(i);
std::vector<int64_t> shape =
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
input_desc_.emplace_back(ONNXDesc{input_name, shape, data_type});
allocator.Free(input_name);
}
size_t n_outputs = session_.GetOutputCount();
for (size_t i = 0; i < n_outputs; ++i) {
auto output_name = session_.GetOutputName(i, allocator);
auto type_info = session_.GetOutputTypeInfo(i);
std::vector<int64_t> shape =
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
output_desc_.emplace_back(ONNXDesc{output_name, shape, data_type});
Ort::MemoryInfo out_memory_info(device_name, OrtDeviceAllocator,
place_.GetDeviceId(), OrtMemTypeDefault);
binding_->BindOutput(output_name, out_memory_info);
allocator.Free(output_name);
}
delete onnx_proto;
onnx_proto = nullptr; onnx_proto = nullptr;
return true; return true;
} }
...@@ -181,7 +203,8 @@ CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kONNXRuntime>( ...@@ -181,7 +203,8 @@ CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kONNXRuntime>(
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
config.is_valid(), true, config.is_valid(),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Note: Each config can only be used for one predictor.")); "Note: Each config can only be used for one predictor."));
...@@ -238,12 +261,13 @@ bool ONNXRuntimePredictor::FindONNXDesc(const std::string &name, ...@@ -238,12 +261,13 @@ bool ONNXRuntimePredictor::FindONNXDesc(const std::string &name,
std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor( std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor(
const std::string &name) { const std::string &name) {
PADDLE_ENFORCE_EQ(FindONNXDesc(name, true), true, PADDLE_ENFORCE_NOT_NULL(scope_->FindVar(name),
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The in variable named %s is not found in the " "The in variable named %s is not found in the "
"ONNXPredictor.", "ONNXPredictor.",
name)); name));
std::unique_ptr<ZeroCopyTensor> res(new ZeroCopyTensor(nullptr)); std::unique_ptr<ZeroCopyTensor> res(
new ZeroCopyTensor(static_cast<void *>(scope_.get())));
res->input_or_output_ = true; res->input_or_output_ = true;
res->SetName(name); res->SetName(name);
if (platform::is_cpu_place(place_)) { if (platform::is_cpu_place(place_)) {
...@@ -252,22 +276,13 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor( ...@@ -252,22 +276,13 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetInputTensor(
auto gpu_place = place_; auto gpu_place = place_;
res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId()); res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId());
} }
res->SetOrtMark(true);
res->SetOrtBinding(binding_);
auto iter = input_buffers_.find(name);
if (iter == input_buffers_.end()) {
std::vector<int8_t> i_vector;
input_buffers_[name] = std::make_shared<std::vector<int8_t>>(i_vector);
res->SetOrtBuffer(input_buffers_[name]);
} else {
res->SetOrtBuffer(iter->second);
}
return res; return res;
} }
std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetOutputTensor( std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetOutputTensor(
const std::string &name) { const std::string &name) {
PADDLE_ENFORCE_EQ(FindONNXDesc(name, false), true, PADDLE_ENFORCE_EQ(FindONNXDesc(name, false),
true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The out variable named %s is not found in the " "The out variable named %s is not found in the "
"ONNXPredictor.", "ONNXPredictor.",
...@@ -293,6 +308,24 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetOutputTensor( ...@@ -293,6 +308,24 @@ std::unique_ptr<ZeroCopyTensor> ONNXRuntimePredictor::GetOutputTensor(
return res; return res;
} }
Ort::Value ONNXRuntimePredictor::GetOrtValue(const ONNXDesc &desc,
const char *device_name) {
Ort::MemoryInfo memory_info(
device_name, OrtDeviceAllocator, place_.GetDeviceId(), OrtMemTypeDefault);
auto *var = scope_->FindVar(desc.name);
auto *tensor = var->GetMutable<framework::LoDTensor>();
size_t size =
tensor->numel() *
framework::SizeOfType(framework::TransToProtoVarType(tensor->dtype()));
std::vector<int64_t> shape = phi::vectorize<int64_t>(tensor->dims());
return Ort::Value::CreateTensor(memory_info,
static_cast<void *>(tensor->data()),
size,
shape.data(),
shape.size(),
desc.dtype);
}
bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs, bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data, std::vector<PaddleTensor> *output_data,
int batch_size) { int batch_size) {
...@@ -302,13 +335,21 @@ bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs, ...@@ -302,13 +335,21 @@ bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs,
bool ONNXRuntimePredictor::ZeroCopyRun() { bool ONNXRuntimePredictor::ZeroCopyRun() {
try { try {
const char *device_name = place_ == PlaceType::kCPU ? "Cpu" : "Cuda"; const char *device_name = platform::is_cpu_place(place_) ? "Cpu" : "Cuda";
std::vector<Ort::Value> inputs;
inputs.reserve(input_desc_.size());
for (auto desc : input_desc_) {
inputs.push_back(GetOrtValue(desc, device_name));
binding_->BindInput(desc.name.c_str(), inputs.back());
}
for (auto output : output_desc_) { for (auto output : output_desc_) {
Ort::MemoryInfo out_memory_info(device_name, OrtDeviceAllocator, Ort::MemoryInfo out_memory_info(device_name,
place_.GetDeviceId(), OrtMemTypeDefault); OrtDeviceAllocator,
place_.GetDeviceId(),
OrtMemTypeDefault);
binding_->BindOutput(output.name.c_str(), out_memory_info); binding_->BindOutput(output.name.c_str(), out_memory_info);
} }
session_.Run({}, *(binding_.get())); session_->Run({}, *(binding_.get()));
} catch (const std::exception &e) { } catch (const std::exception &e) {
LOG(ERROR) << e.what(); LOG(ERROR) << e.what();
return false; return false;
...@@ -318,8 +359,10 @@ bool ONNXRuntimePredictor::ZeroCopyRun() { ...@@ -318,8 +359,10 @@ bool ONNXRuntimePredictor::ZeroCopyRun() {
} }
std::unique_ptr<PaddlePredictor> ONNXRuntimePredictor::Clone() { std::unique_ptr<PaddlePredictor> ONNXRuntimePredictor::Clone() {
LOG(ERROR) << "Not support Clone(), Please create new Predictor"; std::lock_guard<std::mutex> lk(clone_mutex_);
return nullptr; auto *x = new ONNXRuntimePredictor(config_, env_, session_);
x->InitBinding();
return std::unique_ptr<PaddlePredictor>(x);
} }
uint64_t ONNXRuntimePredictor::TryShrinkMemory() { uint64_t ONNXRuntimePredictor::TryShrinkMemory() {
......
...@@ -18,8 +18,9 @@ ...@@ -18,8 +18,9 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/op_compatible_info.h" #include "onnxruntime_c_api.h" // NOLINT
#include "onnxruntime_cxx_api.h" // NOLINT
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/details/reset_tensor_array.h"
...@@ -27,9 +28,6 @@ ...@@ -27,9 +28,6 @@
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/platform/device/gpu/gpu_types.h" #include "paddle/fluid/platform/device/gpu/gpu_types.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "onnxruntime_c_api.h" // NOLINT
#include "onnxruntime_cxx_api.h" // NOLINT
#include "paddle2onnx/converter.h" #include "paddle2onnx/converter.h"
#ifdef PADDLE_WITH_TESTING #ifdef PADDLE_WITH_TESTING
...@@ -94,7 +92,22 @@ class ONNXRuntimePredictor : public PaddlePredictor { ...@@ -94,7 +92,22 @@ class ONNXRuntimePredictor : public PaddlePredictor {
/// \param[in] AnalysisConfig config /// \param[in] AnalysisConfig config
/// ///
explicit ONNXRuntimePredictor(const AnalysisConfig &config) explicit ONNXRuntimePredictor(const AnalysisConfig &config)
: config_(config), env_(ORT_LOGGING_LEVEL_WARNING, "onnx") { : env_(std::make_shared<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
"paddle-ort")),
session_(nullptr),
binding_(nullptr),
config_(config) {
predictor_id_ = inference::GetUniqueId();
}
///
/// \brief Clone a ONNXRuntime Predictor object
///
/// \param[in] AnalysisConfig config
///
explicit ONNXRuntimePredictor(const AnalysisConfig &config,
std::shared_ptr<Ort::Env> env,
std::shared_ptr<Ort::Session> session)
: env_(env), session_(session), binding_(nullptr), config_(config) {
predictor_id_ = inference::GetUniqueId(); predictor_id_ = inference::GetUniqueId();
} }
/// ///
...@@ -102,6 +115,13 @@ class ONNXRuntimePredictor : public PaddlePredictor { ...@@ -102,6 +115,13 @@ class ONNXRuntimePredictor : public PaddlePredictor {
/// ///
~ONNXRuntimePredictor(); ~ONNXRuntimePredictor();
///
/// \brief Initialize ORT Binding
///
/// \return Whether the init function executed successfully
///
bool InitBinding();
/// ///
/// \brief Initialize predictor /// \brief Initialize predictor
/// ///
...@@ -176,6 +196,8 @@ class ONNXRuntimePredictor : public PaddlePredictor { ...@@ -176,6 +196,8 @@ class ONNXRuntimePredictor : public PaddlePredictor {
/// ///
std::unique_ptr<PaddlePredictor> Clone() override; std::unique_ptr<PaddlePredictor> Clone() override;
std::shared_ptr<framework::Scope> scope_;
private: private:
/// ///
/// \brief Whether to find in/out by name. /// \brief Whether to find in/out by name.
...@@ -188,18 +210,27 @@ class ONNXRuntimePredictor : public PaddlePredictor { ...@@ -188,18 +210,27 @@ class ONNXRuntimePredictor : public PaddlePredictor {
/// ///
bool FindONNXDesc(const std::string &name, bool is_input); bool FindONNXDesc(const std::string &name, bool is_input);
private: /// \brief get the Ort Value(input Tensor).
AnalysisConfig config_; ///
/// \param[in] desc ONNXDesce(name、shape、dtype)
///
/// \param[in] device_name "cpu" or "gpu" of device
///
/// \return get a Ort::Value
///
Ort::Value GetOrtValue(const ONNXDesc &desc, const char *device_name);
private:
// ONNXRuntime // ONNXRuntime
Ort::Env env_; std::shared_ptr<Ort::Env> env_;
Ort::Session session_{nullptr}; std::shared_ptr<Ort::Session> session_{nullptr};
std::shared_ptr<Ort::IoBinding> binding_; std::shared_ptr<Ort::IoBinding> binding_;
AnalysisConfig config_;
std::mutex clone_mutex_;
platform::Place place_; platform::Place place_;
std::vector<ONNXDesc> input_desc_; std::vector<ONNXDesc> input_desc_;
std::vector<ONNXDesc> output_desc_; std::vector<ONNXDesc> output_desc_;
std::map<std::string, std::shared_ptr<std::vector<int8_t>>> input_buffers_;
int predictor_id_; int predictor_id_;
// Some more detailed tests, they are made the friends of the predictor, so that // Some more detailed tests, they are made the friends of the predictor, so that
......
...@@ -183,7 +183,6 @@ class PD_INFER_DECL Tensor { ...@@ -183,7 +183,6 @@ class PD_INFER_DECL Tensor {
#ifdef PADDLE_WITH_ONNXRUNTIME #ifdef PADDLE_WITH_ONNXRUNTIME
bool is_ort_tensor_{false}; bool is_ort_tensor_{false};
std::vector<int64_t> shape_; std::vector<int64_t> shape_;
std::weak_ptr<std::vector<int8_t>> buffer_;
std::weak_ptr<Ort::IoBinding> binding_; std::weak_ptr<Ort::IoBinding> binding_;
int idx_{-1}; int idx_{-1};
...@@ -191,7 +190,8 @@ class PD_INFER_DECL Tensor { ...@@ -191,7 +190,8 @@ class PD_INFER_DECL Tensor {
void SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding); void SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding);
void SetOrtBuffer(const std::shared_ptr<std::vector<int8_t>> buffer); template <typename T>
T* ORTGetMutableData();
template <typename T> template <typename T>
void ORTCopyFromCpu(const T* data); void ORTCopyFromCpu(const T* data);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册