未验证 提交 3948c243 编写于 作者: H heliqi 提交者: GitHub

ort backend support output mutable data (#44724)

上级 d92b2f2d
......@@ -88,6 +88,11 @@ void Tensor::ReshapeStrings(const size_t &shape) {
template <typename T>
T *Tensor::mutable_data(PlaceType place) {
#ifdef PADDLE_WITH_ONNXRUNTIME
if (is_ort_tensor_) {
return ORTGetMutableData<T>();
}
#endif
EAGER_GET_TENSOR(paddle::framework::LoDTensor);
PADDLE_ENFORCE_GT(
tensor->numel(),
......@@ -720,6 +725,17 @@ void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
binding_ = binding;
}
template <typename T>
T *Tensor::ORTGetMutableData() {
auto binding = binding_.lock();
PADDLE_ENFORCE_NOT_NULL(binding,
paddle::platform::errors::PreconditionNotMet(
"output tensor [%s] no binding ptr", name_));
std::vector<Ort::Value> outputs = binding->GetOutputValues();
Ort::Value &value = outputs[idx_];
return value.GetTensorMutableData<T>();
}
template <typename T>
void Tensor::ORTCopyToCpu(T *data) const {
auto binding = binding_.lock();
......
......@@ -198,6 +198,9 @@ class PD_INFER_DECL Tensor {
void SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding);
template <typename T>
T* ORTGetMutableData();
template <typename T>
void ORTCopyFromCpu(const T* data);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册