未验证 提交 3948c243 编写于 作者: H heliqi 提交者: GitHub

ort backend support output mutable data (#44724)

上级 d92b2f2d
...@@ -88,6 +88,11 @@ void Tensor::ReshapeStrings(const size_t &shape) { ...@@ -88,6 +88,11 @@ void Tensor::ReshapeStrings(const size_t &shape) {
template <typename T> template <typename T>
T *Tensor::mutable_data(PlaceType place) { T *Tensor::mutable_data(PlaceType place) {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
return ORTGetMutableData<T>();
}
#endif
EAGER_GET_TENSOR(paddle::framework::LoDTensor); EAGER_GET_TENSOR(paddle::framework::LoDTensor);
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
tensor->numel(), tensor->numel(),
...@@ -720,6 +725,17 @@ void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) { ...@@ -720,6 +725,17 @@ void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
binding_ = binding; binding_ = binding;
} }
template <typename T>
T *Tensor::ORTGetMutableData() {
auto binding = binding_.lock();
PADDLE_ENFORCE_NOT_NULL(binding,
paddle::platform::errors::PreconditionNotMet(
"output tensor [%s] no binding ptr", name_));
std::vector<Ort::Value> outputs = binding->GetOutputValues();
Ort::Value &value = outputs[idx_];
return value.GetTensorMutableData<T>();
}
template <typename T> template <typename T>
void Tensor::ORTCopyToCpu(T *data) const { void Tensor::ORTCopyToCpu(T *data) const {
auto binding = binding_.lock(); auto binding = binding_.lock();
......
...@@ -198,6 +198,9 @@ class PD_INFER_DECL Tensor { ...@@ -198,6 +198,9 @@ class PD_INFER_DECL Tensor {
void SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding); void SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding);
template <typename T>
T* ORTGetMutableData();
template <typename T> template <typename T>
void ORTCopyFromCpu(const T* data); void ORTCopyFromCpu(const T* data);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册