diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index 81c34ae29c05a79cd6407fb025a82e1eed600417..022ba1483b955dfd8809102a47c63fbd8082c5f3 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -88,6 +88,11 @@ void Tensor::ReshapeStrings(const size_t &shape) { template T *Tensor::mutable_data(PlaceType place) { +#ifdef PADDLE_WITH_ONNXRUNTIME + if (is_ort_tensor_) { + return ORTGetMutableData(); + } +#endif EAGER_GET_TENSOR(paddle::framework::LoDTensor); PADDLE_ENFORCE_GT( tensor->numel(), @@ -720,6 +725,17 @@ void Tensor::SetOrtBinding(const std::shared_ptr binding) { binding_ = binding; } +template +T *Tensor::ORTGetMutableData() { + auto binding = binding_.lock(); + PADDLE_ENFORCE_NOT_NULL(binding, + paddle::platform::errors::PreconditionNotMet( + "output tensor [%s] no binding ptr", name_)); + std::vector outputs = binding->GetOutputValues(); + Ort::Value &value = outputs[idx_]; + return value.GetTensorMutableData(); +} + template void Tensor::ORTCopyToCpu(T *data) const { auto binding = binding_.lock(); diff --git a/paddle/fluid/inference/api/paddle_tensor.h b/paddle/fluid/inference/api/paddle_tensor.h index d96148abd3b560d6782f820ee6579df18cb91e65..b10f051d6e44e499e6c278c2bec92e8c6f0e4046 100644 --- a/paddle/fluid/inference/api/paddle_tensor.h +++ b/paddle/fluid/inference/api/paddle_tensor.h @@ -198,6 +198,9 @@ class PD_INFER_DECL Tensor { void SetOrtBinding(const std::shared_ptr binding); + template + T* ORTGetMutableData(); + template void ORTCopyFromCpu(const T* data);