From 26cc5c54f74baa1c5a6d220117520c15ce43a0e9 Mon Sep 17 00:00:00 2001 From: heliqi <1101791222@qq.com> Date: Fri, 22 Apr 2022 14:42:58 +0800 Subject: [PATCH] fix onnxruntime bug (#42095) (#42104) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复ORT在batch变动时,输出shape不对问题 --- paddle/fluid/inference/api/details/zero_copy_tensor.cc | 7 +++---- paddle/fluid/inference/api/onnxruntime_predictor.cc | 6 ++++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index 0f26a1076a6..c38088a2b80 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -693,10 +693,9 @@ void Tensor::ORTCopyToCpu(T *data) const { if (place_ == PlaceType::kCPU) { std::memcpy(static_cast(data), value.GetTensorData(), size); } else { - paddle::memory::Copy(paddle::platform::CPUPlace(), - static_cast(data), - paddle::platform::CUDAPlace(device_), - value.GetTensorData(), size, nullptr); + PADDLE_THROW(paddle::platform::errors::Unavailable( + "CopyToCpu error.The current ONNXRuntime backend doesn't support " + "GPU.")); } } diff --git a/paddle/fluid/inference/api/onnxruntime_predictor.cc b/paddle/fluid/inference/api/onnxruntime_predictor.cc index eb561667fe1..e42e395ce90 100644 --- a/paddle/fluid/inference/api/onnxruntime_predictor.cc +++ b/paddle/fluid/inference/api/onnxruntime_predictor.cc @@ -279,6 +279,12 @@ bool ONNXRuntimePredictor::Run(const std::vector &inputs, bool ONNXRuntimePredictor::ZeroCopyRun() { try { + const char *device_name = place_ == PlaceType::kCPU ? "Cpu" : "Cuda"; + for (auto output : output_desc_) { + Ort::MemoryInfo out_memory_info(device_name, OrtDeviceAllocator, + place_.GetDeviceId(), OrtMemTypeDefault); + binding_->BindOutput(output.name.c_str(), out_memory_info); + } session_.Run({}, *(binding_.get())); } catch (const std::exception &e) { LOG(ERROR) << e.what(); -- GitLab