From 87e6149c68c960f944bd93583d70680f8faebc90 Mon Sep 17 00:00:00 2001 From: heliqi <1101791222@qq.com> Date: Wed, 4 May 2022 07:55:21 -0500 Subject: [PATCH] fix paddle-ort python bug (#42464) (#42470) * fix paddle-ort python bug * fix paddle-ort python bug --- .../inference/api/details/zero_copy_tensor.cc | 35 +++++++++++++++++-- paddle/fluid/inference/api/paddle_tensor.h | 1 + 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index c38088a2b8..e6cea1b46f 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -674,8 +674,39 @@ void Tensor::ORTCopyFromCpu(const T *data) { OrtMemTypeDefault); size_t size = std::accumulate(begin(shape_), end(shape_), 1UL, std::multiplies()); - auto ort_value = GetOrtVaule(memory_info, const_cast(data), size, - shape_.data(), shape_.size()); + size_t buffer_size = size * sizeof(T); + if (buffer_size > buffer_.size()) { + buffer_.resize(buffer_size); + } + std::memcpy(static_cast(buffer_.data()), data, buffer_size); + + auto onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + if (std::is_same::value) { + onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } else if (std::is_same::value) { + onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; + } else if (std::is_same::value) { + onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + } else if (std::is_same::value) { + onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; + } else if (std::is_same::value) { + onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + } else if (std::is_same::value) { + onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; + } else if (std::is_same::value) { + onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + } + + if (onnx_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Found undefined data type for onnxruntime, only supports " + "float16/float32/float64/int8/uint8/int32/int64.")); + } + + auto ort_value = + Ort::Value::CreateTensor(memory_info, buffer_.data(), buffer_size, + shape_.data(), shape_.size(), onnx_dtype); + binding->BindInput(name_.c_str(), ort_value); } diff --git a/paddle/fluid/inference/api/paddle_tensor.h b/paddle/fluid/inference/api/paddle_tensor.h index 2afe2d32e2..2ae5ac5e6d 100644 --- a/paddle/fluid/inference/api/paddle_tensor.h +++ b/paddle/fluid/inference/api/paddle_tensor.h @@ -183,6 +183,7 @@ class PD_INFER_DECL Tensor { #ifdef PADDLE_WITH_ONNXRUNTIME bool is_ort_tensor_{false}; std::vector shape_; + std::vector buffer_; std::weak_ptr binding_; int idx_{-1}; -- GitLab