未验证 提交 e7eb0e25 编写于 作者: H heliqi 提交者: GitHub

fix paddle-ort python bug (#42464)

* fix paddle-ort python bug

* fix paddle-ort python bug
上级 be77aeea
...@@ -674,8 +674,39 @@ void Tensor::ORTCopyFromCpu(const T *data) { ...@@ -674,8 +674,39 @@ void Tensor::ORTCopyFromCpu(const T *data) {
OrtMemTypeDefault); OrtMemTypeDefault);
size_t size = std::accumulate(begin(shape_), end(shape_), 1UL, size_t size = std::accumulate(begin(shape_), end(shape_), 1UL,
std::multiplies<size_t>()); std::multiplies<size_t>());
auto ort_value = GetOrtVaule(memory_info, const_cast<T *>(data), size, size_t buffer_size = size * sizeof(T);
shape_.data(), shape_.size()); if (buffer_size > buffer_.size()) {
buffer_.resize(buffer_size);
}
std::memcpy(static_cast<void *>(buffer_.data()), data, buffer_size);
auto onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
if (std::is_same<T, float>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
} else if (std::is_same<T, double>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
} else if (std::is_same<T, int64_t>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
} else if (std::is_same<T, int32_t>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
} else if (std::is_same<T, uint8_t>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
} else if (std::is_same<T, int8_t>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
} else if (std::is_same<T, float16>::value) {
onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
}
if (onnx_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Found undefined data type for onnxruntime, only supports "
"float16/float32/float64/int8/uint8/int32/int64."));
}
auto ort_value =
Ort::Value::CreateTensor(memory_info, buffer_.data(), buffer_size,
shape_.data(), shape_.size(), onnx_dtype);
binding->BindInput(name_.c_str(), ort_value); binding->BindInput(name_.c_str(), ort_value);
} }
......
...@@ -187,6 +187,7 @@ class PD_INFER_DECL Tensor { ...@@ -187,6 +187,7 @@ class PD_INFER_DECL Tensor {
#ifdef PADDLE_WITH_ONNXRUNTIME #ifdef PADDLE_WITH_ONNXRUNTIME
bool is_ort_tensor_{false}; bool is_ort_tensor_{false};
std::vector<int64_t> shape_; std::vector<int64_t> shape_;
std::vector<int8_t> buffer_;
std::weak_ptr<Ort::IoBinding> binding_; std::weak_ptr<Ort::IoBinding> binding_;
int idx_{-1}; int idx_{-1};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册