From 3948c243efea49206db2bf2c9f77fe3362f84c52 Mon Sep 17 00:00:00 2001 From: heliqi <1101791222@qq.com> Date: Sun, 31 Jul 2022 21:38:26 -0500 Subject: [PATCH] ort backend support output mutable data (#44724) --- .../inference/api/details/zero_copy_tensor.cc | 16 ++++++++++++++++ paddle/fluid/inference/api/paddle_tensor.h | 3 +++ 2 files changed, 19 insertions(+) diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index 81c34ae29c0..022ba1483b9 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -88,6 +88,11 @@ void Tensor::ReshapeStrings(const size_t &shape) { template T *Tensor::mutable_data(PlaceType place) { +#ifdef PADDLE_WITH_ONNXRUNTIME + if (is_ort_tensor_) { + return ORTGetMutableData(); + } +#endif EAGER_GET_TENSOR(paddle::framework::LoDTensor); PADDLE_ENFORCE_GT( tensor->numel(), @@ -720,6 +725,17 @@ void Tensor::SetOrtBinding(const std::shared_ptr binding) { binding_ = binding; } +template +T *Tensor::ORTGetMutableData() { + auto binding = binding_.lock(); + PADDLE_ENFORCE_NOT_NULL(binding, + paddle::platform::errors::PreconditionNotMet( + "output tensor [%s] no binding ptr", name_)); + std::vector outputs = binding->GetOutputValues(); + Ort::Value &value = outputs[idx_]; + return value.GetTensorMutableData(); +} + template void Tensor::ORTCopyToCpu(T *data) const { auto binding = binding_.lock(); diff --git a/paddle/fluid/inference/api/paddle_tensor.h b/paddle/fluid/inference/api/paddle_tensor.h index d96148abd3b..b10f051d6e4 100644 --- a/paddle/fluid/inference/api/paddle_tensor.h +++ b/paddle/fluid/inference/api/paddle_tensor.h @@ -198,6 +198,9 @@ class PD_INFER_DECL Tensor { void SetOrtBinding(const std::shared_ptr binding); + template + T* ORTGetMutableData(); + template void ORTCopyFromCpu(const T* data); -- GitLab