zero_copy_tensor.cc 26.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/framework/convert_utils.h"
16
#include "paddle/fluid/framework/data_layout_transform.h"
17 18 19
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
W
Wilber 已提交
20
#include "paddle/fluid/inference/api/paddle_tensor.h"
N
nhzlx 已提交
21
#include "paddle/fluid/memory/memcpy.h"
22
#include "paddle/fluid/platform/enforce.h"
23
#include "paddle/fluid/platform/float16.h"
24
#include "paddle/phi/core/allocator.h"
25
#ifdef PADDLE_WITH_ONNXRUNTIME
H
heliqi 已提交
26 27
#include "onnxruntime_c_api.h"    // NOLINT
#include "onnxruntime_cxx_api.h"  // NOLINT
28
#endif
29

30
namespace paddle_infer {
31

32 33
using float16 = paddle::platform::float16;

34
void Tensor::Reshape(const std::vector<int> &shape) {
35 36 37 38 39 40 41
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    shape_.assign(shape.begin(), shape.end());
    return;
  }
#endif

W
Wilber 已提交
42
  PADDLE_ENFORCE_EQ(
43 44
      name_.empty(),
      false,
45
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
46 47
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
48 49
  PADDLE_ENFORCE_EQ(input_or_output_,
                    true,
50
                    paddle::platform::errors::PermissionDenied(
W
Wilber 已提交
51
                        "Can't reshape the output tensor, it is readonly"));
52
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
53
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
54
  PADDLE_ENFORCE_NOT_NULL(
55 56 57
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
58
  auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
59
  tensor->Resize(phi::make_ddim(shape));
60 61
}

S
Steffy-zxf 已提交
62 63
void Tensor::ReshapeStrings(const size_t &shape) {
  PADDLE_ENFORCE_EQ(
64 65
      name_.empty(),
      false,
S
Steffy-zxf 已提交
66 67 68
      paddle::platform::errors::PreconditionNotMet(
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
69 70
  PADDLE_ENFORCE_EQ(input_or_output_,
                    true,
S
Steffy-zxf 已提交
71 72 73 74 75
                    paddle::platform::errors::PermissionDenied(
                        "Can't reshape the output tensor, it is readonly"));
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
  auto *var = scope->FindVar(name_);
  PADDLE_ENFORCE_NOT_NULL(
76 77 78
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
S
Steffy-zxf 已提交
79 80 81 82 83 84 85 86 87
  paddle_infer::Strings *tensor = var->GetMutable<paddle_infer::Strings>();
  tensor->resize(shape);
}

#define EAGER_GET_TENSOR(tensor_type)    \
  if (!tensor_) {                        \
    tensor_ = FindTensor<tensor_type>(); \
  }                                      \
  auto *tensor = static_cast<tensor_type *>(tensor_);
88

89
template <typename T>
90
T *Tensor::mutable_data(PlaceType place) {
91 92 93 94 95
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    return ORTGetMutableData<T>();
  }
#endif
S
Steffy-zxf 已提交
96
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
97
  PADDLE_ENFORCE_GT(
98 99
      tensor->numel(),
      0,
100 101
      paddle::platform::errors::PreconditionNotMet(
          "You should call Tensor::Reshape(const std::vector<int> "
W
Wilber 已提交
102 103
          "&shape)"
          "function before retrieving mutable_data from input tensor."));
104
  switch (static_cast<int>(place)) {
105 106
    case static_cast<int>(PlaceType::kCPU): {
      return tensor->mutable_data<T>(paddle::platform::CPUPlace());
107
    }
108 109 110 111 112
    case static_cast<int>(PlaceType::kGPU): {
      return tensor->mutable_data<T>(paddle::platform::CUDAPlace(device_));
    }
    case static_cast<int>(PlaceType::kXPU): {
      return tensor->mutable_data<T>(paddle::platform::XPUPlace(device_));
113
    }
114 115 116
    case static_cast<int>(PlaceType::kNPU): {
      return tensor->mutable_data<T>(paddle::platform::NPUPlace(device_));
    }
117
    default:
118
      PADDLE_THROW(paddle::platform::errors::Unavailable(
119 120
          "Only CPU / CUDA / XPU / NPU places is supported. The place `%d` is "
          "not supported.",
121
          static_cast<int>(place)));
122 123 124 125 126 127
      break;
  }
  return nullptr;
}

template <typename T>
128
T *Tensor::data(PlaceType *place, int *size) const {
S
Steffy-zxf 已提交
129
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
130 131
  auto *res = tensor->data<T>();

132 133 134 135 136 137
  if (paddle::platform::is_cpu_place(tensor->place())) {
    *place = PlaceType::kCPU;
  } else if (paddle::platform::is_gpu_place(tensor->place())) {
    *place = PlaceType::kGPU;
  } else if (paddle::platform::is_xpu_place(tensor->place())) {
    *place = PlaceType::kXPU;
138 139
  } else if (paddle::platform::is_npu_place(tensor->place())) {
    *place = PlaceType::kNPU;
140
  } else {
141
    *place = PlaceType::kUNK;
142 143 144 145 146 147
  }

  *size = tensor->numel();
  return res;
}

148
DataType Tensor::type() const {
149 150 151 152 153
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    return dtype_;
  }
#endif
S
Steffy-zxf 已提交
154
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
155
  auto type = paddle::framework::TransToProtoVarType(tensor->dtype());
156 157
  if (type == paddle::framework::proto::VarType::FP32) {
    return DataType::FLOAT32;
158 159
  } else if (type == paddle::framework::proto::VarType::FP16) {
    return DataType::FLOAT16;
160 161 162 163 164 165
  } else if (type == paddle::framework::proto::VarType::INT64) {
    return DataType::INT64;
  } else if (type == paddle::framework::proto::VarType::INT32) {
    return DataType::INT32;
  } else if (type == paddle::framework::proto::VarType::UINT8) {
    return DataType::UINT8;
166 167
  } else if (type == paddle::framework::proto::VarType::INT8) {
    return DataType::INT8;
168
  }
169
  return DataType::FLOAT32;
170 171
}

172 173
PlaceType Tensor::place() const { return place_; }

N
nhzlx 已提交
174
template <typename T>
175
void Tensor::CopyFromCpu(const T *data) {
S
Steffy-zxf 已提交
176
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
177 178
  PADDLE_ENFORCE_GE(tensor->numel(),
                    0,
179 180
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
W
Wilber 已提交
181 182
                        "std::vector<int> &shape)"
                        "function before copying data from cpu."));
N
nhzlx 已提交
183 184
  size_t ele_size = tensor->numel() * sizeof(T);

185 186
  if (place_ == PlaceType::kCPU) {
    auto *t_data = tensor->mutable_data<T>(paddle::platform::CPUPlace());
N
nhzlx 已提交
187
    std::memcpy(static_cast<void *>(t_data), data, ele_size);
188
  } else if (place_ == PlaceType::kGPU) {
189
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
190 191 192
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::CUDAPlace gpu_place(device_);
N
nhzlx 已提交
193
    auto *t_data = tensor->mutable_data<T>(gpu_place);
194 195
    auto *dev_ctx = static_cast<const paddle::platform::CUDADeviceContext *>(
        pool.Get(gpu_place));
N
nhzlx 已提交
196

197 198 199 200 201
    paddle::memory::Copy(gpu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
202
                         dev_ctx->stream());
N
nhzlx 已提交
203
#else
204 205 206
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
207
#endif
208
  } else if (place_ == PlaceType::kXPU) {
209
#ifdef PADDLE_WITH_XPU
210
    paddle::platform::XPUPlace xpu_place(device_);
211
    auto *t_data = tensor->mutable_data<T>(xpu_place);
212 213 214 215 216
    paddle::memory::Copy(xpu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size);
217
#else
218 219 220
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
221 222 223 224 225 226 227 228 229
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::NPUPlace npu_place(device_);
    auto *t_data = tensor->mutable_data<T>(npu_place);
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
230 231 232 233 234
    paddle::memory::Copy(npu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
W
Wilber 已提交
235 236 237 238 239
                         dev_ctx->stream());
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
240 241 242
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
243
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
N
nhzlx 已提交
244 245 246
  }
}

247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
template <typename T>
struct DataTypeInfo;

template <>
struct DataTypeInfo<float> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT32;
};

template <>
struct DataTypeInfo<float16> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT16;
};

template <>
struct DataTypeInfo<int64_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT64;
};

template <>
struct DataTypeInfo<int8_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT8;
};

template <>
struct DataTypeInfo<uint8_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::UINT8;
};

template <>
struct DataTypeInfo<int32_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT32;
};

paddle::experimental::DataLayout LayoutConvert(DataLayout layout) {
  PADDLE_ENFORCE_EQ(
282 283
      layout,
      DataLayout::kNCHW,
284 285 286 287 288
      paddle::platform::errors::InvalidArgument("Only NCHW is supported now."));
  return paddle::experimental::DataLayout::NCHW;
}

template <typename T>
289 290 291 292
void Tensor::ShareExternalData(const T *data,
                               const std::vector<int> &shape,
                               PlaceType place,
                               DataLayout layout) {
293 294 295 296
  EAGER_GET_TENSOR(paddle::framework::LoDTensor)
  size_t size =
      std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
      sizeof(T);
297 298
  phi::DenseTensorMeta meta(
      DataTypeInfo<T>().TYPE, phi::make_ddim(shape), LayoutConvert(layout));
299 300
  if (place == PlaceType::kCPU) {
    phi::DenseTensor dtensor(
301 302
        std::make_shared<phi::Allocation>(
            const_cast<T *>(data), size, paddle::platform::CPUPlace()),
303 304 305 306
        meta);
    *tensor = std::move(dtensor);
  } else if (place == PlaceType::kGPU) {
    phi::DenseTensor dtensor(
307 308
        std::make_shared<phi::Allocation>(
            const_cast<T *>(data), size, paddle::platform::CUDAPlace(device_)),
309 310 311 312 313 314 315 316
        meta);
    *tensor = std::move(dtensor);
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "PlaceType must be PlaceType::kCPU or PlaceType::kGPU."));
  }
}

S
Steffy-zxf 已提交
317 318
void Tensor::CopyStringsFromCpu(const paddle_infer::Strings *data) {
  EAGER_GET_TENSOR(paddle_infer::Strings);
319 320
  PADDLE_ENFORCE_GE(tensor->size(),
                    0,
S
Steffy-zxf 已提交
321 322 323 324 325 326 327
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
                        "std::size_t &shape)function before copying"
                        "the string data from cpu."));
  *tensor = *data;
}

N
nhzlx 已提交
328
template <typename T>
329 330 331
void Tensor::CopyToCpuImpl(T *data,
                           void *exec_stream,
                           CallbackFunc cb,
332
                           void *cb_params) const {
S
Steffy-zxf 已提交
333
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
N
nhzlx 已提交
334 335 336 337
  auto ele_num = tensor->numel();
  auto *t_data = tensor->data<T>();
  auto t_place = tensor->place();

338
  paddle::framework::Tensor out;
339 340
  auto mem_allocation =
      std::make_shared<paddle::memory::allocation::Allocation>(
341 342
          static_cast<void *>(data),
          ele_num * sizeof(T),
343
          paddle::platform::CPUPlace());
344 345
  out.ResetHolder(mem_allocation);

346
  if (paddle::platform::is_cpu_place(t_place)) {
347 348 349
#ifdef PADDLE_WITH_MKLDNN
    if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN)
      paddle::framework::innerTransDataLayoutFromMKLDNN(
350 351 352 353 354 355 356
          tensor->layout(),
          paddle::platform::MKLDNNDeviceContext::tls()
              .get_cur_paddle_data_layout(),
          *tensor,
          &out,
          paddle::platform::CPUPlace(),
          true);
357 358 359
    else
      std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
N
nhzlx 已提交
360
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
J
jianghaicheng 已提交
361 362 363 364 365 366 367 368
#endif
  } else if (paddle::platform::is_ipu_place(t_place)) {
#ifdef PADDLE_WITH_IPU
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with IPU place because paddle is not compiled "
        "with IPU."));
369
#endif
370
  } else if (place_ == PlaceType::kGPU) {
371
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
372 373
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
374
    auto gpu_place = t_place;
375 376 377
    auto *dev_ctx = static_cast<const paddle::platform::CUDADeviceContext *>(
        pool.Get(gpu_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
378 379 380 381 382
                         static_cast<void *>(data),
                         gpu_place,
                         t_data,
                         ele_num * sizeof(T),
                         dev_ctx->stream());
383 384 385
#ifdef PADDLE_WITH_HIP
    hipStreamSynchronize(dev_ctx->stream());
#else
386 387 388 389 390 391 392 393 394 395
    // async, return stream
    if (nullptr != exec_stream) {
      *(static_cast<cudaStream_t *>(exec_stream)) = dev_ctx->stream();
      // async with callback
    } else if (cb) {
      cudaLaunchHostFunc(dev_ctx->stream(), cb, cb_params);
      // sync
    } else {
      cudaStreamSynchronize(dev_ctx->stream());
    }
396
#endif
N
nhzlx 已提交
397
#else
398 399 400
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
401
#endif
402
  } else if (place_ == PlaceType::kXPU) {
403
#ifdef PADDLE_WITH_XPU
404
    auto xpu_place = t_place;
405
    paddle::memory::Copy(paddle::platform::CPUPlace(),
406 407 408
                         static_cast<void *>(data),
                         xpu_place,
                         t_data,
409
                         ele_num * sizeof(T));
410
#else
411 412 413
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
414 415 416 417 418
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
419
    auto npu_place = t_place;
W
Wilber 已提交
420 421 422
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
423 424 425 426 427
                         static_cast<void *>(data),
                         npu_place,
                         t_data,
                         ele_num * sizeof(T),
                         dev_ctx->stream());
428
    paddle::platform::NPUStreamSync(dev_ctx->stream());
W
Wilber 已提交
429 430 431 432
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
433 434 435
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
436
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
N
nhzlx 已提交
437 438
  }
}
439 440 441

template <typename T>
void Tensor::CopyToCpu(T *data) const {
442 443 444 445 446 447 448
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    ORTCopyToCpu<T>(data);
    return;
  }
#endif

449 450 451 452 453 454 455 456 457 458 459 460 461
  CopyToCpuImpl<T>(data, nullptr, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, void *exec_stream) const {
  CopyToCpuImpl<T>(data, exec_stream, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, CallbackFunc cb, void *cb_params) const {
  CopyToCpuImpl<T>(data, nullptr, cb, cb_params);
}

462 463 464 465 466
template PD_INFER_DECL void Tensor::CopyFromCpu<float>(const float *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int64_t>(const int64_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int32_t>(const int32_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<uint8_t>(const uint8_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int8_t>(const int8_t *data);
467
template PD_INFER_DECL void Tensor::CopyFromCpu<float16>(const float16 *data);
468

469
template PD_INFER_DECL void Tensor::ShareExternalData<float>(
470 471 472
    const float *data,
    const std::vector<int> &shape,
    PlaceType place,
473 474
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int64_t>(
475 476 477
    const int64_t *data,
    const std::vector<int> &shape,
    PlaceType place,
478 479
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int32_t>(
480 481 482
    const int32_t *data,
    const std::vector<int> &shape,
    PlaceType place,
483 484
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<uint8_t>(
485 486 487
    const uint8_t *data,
    const std::vector<int> &shape,
    PlaceType place,
488 489
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int8_t>(
490 491 492
    const int8_t *data,
    const std::vector<int> &shape,
    PlaceType place,
493 494
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<float16>(
495 496 497
    const float16 *data,
    const std::vector<int> &shape,
    PlaceType place,
498 499
    DataLayout layout);

500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546
template PD_INFER_DECL void Tensor::CopyToCpu<float>(float *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int64_t>(int64_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int32_t>(int32_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<uint8_t>(uint8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int8_t>(int8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<float16>(float16 *data) const;

template PD_INFER_DECL void Tensor::CopyToCpuImpl<float>(float *data,
                                                         void *exec_stream,
                                                         CallbackFunc cb,
                                                         void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int64_t>(
    int64_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int32_t>(
    int32_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<uint8_t>(
    uint8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int8_t>(
    int8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<float16>(
    float16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, void *exec_stream) const;

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, CallbackFunc cb, void *cb_params) const;
547

548 549 550 551 552 553 554 555 556 557
template PD_INFER_DECL float *Tensor::data<float>(PlaceType *place,
                                                  int *size) const;
template PD_INFER_DECL int64_t *Tensor::data<int64_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int32_t *Tensor::data<int32_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL uint8_t *Tensor::data<uint8_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int8_t *Tensor::data<int8_t>(PlaceType *place,
                                                    int *size) const;
558 559
template PD_INFER_DECL float16 *Tensor::data<float16>(PlaceType *place,
                                                      int *size) const;
560

561 562 563 564 565
template PD_INFER_DECL float *Tensor::mutable_data<float>(PlaceType place);
template PD_INFER_DECL int64_t *Tensor::mutable_data<int64_t>(PlaceType place);
template PD_INFER_DECL int32_t *Tensor::mutable_data<int32_t>(PlaceType place);
template PD_INFER_DECL uint8_t *Tensor::mutable_data<uint8_t>(PlaceType place);
template PD_INFER_DECL int8_t *Tensor::mutable_data<int8_t>(PlaceType place);
566
template PD_INFER_DECL float16 *Tensor::mutable_data<float16>(PlaceType place);
567

568
Tensor::Tensor(void *scope) : scope_{scope} {}
569

S
Steffy-zxf 已提交
570
template <typename T>
571
void *Tensor::FindTensor() const {
W
Wilber 已提交
572
  PADDLE_ENFORCE_EQ(
573 574
      name_.empty(),
      false,
575
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
576 577
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
578
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
579
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
580
  PADDLE_ENFORCE_NOT_NULL(
581 582 583
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
S
Steffy-zxf 已提交
584
  auto *tensor = var->GetMutable<T>();
585 586 587
  return tensor;
}

588
std::vector<int> Tensor::shape() const {
589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    std::vector<int> shape;
    // input handle
    if (idx_ < 0) {
      shape.assign(shape_.begin(), shape_.end());
    } else {  // output handle
      auto binding = binding_.lock();
      PADDLE_ENFORCE_NOT_NULL(binding,
                              paddle::platform::errors::PreconditionNotMet(
                                  "output tensor [%s] no binding ptr", name_));
      std::vector<Ort::Value> outputs = binding->GetOutputValues();
      Ort::Value &value = outputs[idx_];
      auto info = value.GetTensorTypeAndShapeInfo();
      auto ort_shape = info.GetShape();
      shape.assign(ort_shape.begin(), ort_shape.end());
    }
    return shape;
  }
#endif
S
Steffy-zxf 已提交
609
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
W
Wilber 已提交
610
  PADDLE_ENFORCE_NOT_NULL(
611 612 613
      tensor_,
      paddle::platform::errors::PreconditionNotMet(
          "Not found tensor called %s in the scope", name_));
W
wenbin 已提交
614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629
// mkldnn may does layout transform internally, so need to reorder before
// return
#ifdef PADDLE_WITH_MKLDNN
  if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN) {
    paddle::framework::DataLayout out_layout =
        paddle::platform::MKLDNNDeviceContext::tls()
            .get_cur_paddle_data_layout();
    // Set default as NCHW in case not specified
    out_layout = out_layout == paddle::framework::DataLayout::kAnyLayout
                     ? paddle::framework::DataLayout::kNCHW
                     : out_layout;
    // In these data layouts, channel dimension is either on 2nd position: nChw
    // or
    // at last nhwC, so for dim==2 these layouts are the same and nothing should
    // be done. Similarly for dim==1 when you have just one possible
    // combination.
630
    if (tensor->dims().size() < 3) return phi::vectorize<int>(tensor->dims());
J
Jacek Czaja 已提交
631 632
    if (out_layout == paddle::framework::DataLayout::kNHWC ||
        out_layout == paddle::framework::DataLayout::kNDHWC) {
633
      auto dims = phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
634 635 636
      std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
      return dims;
    } else {
637
      return phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
638 639 640
    }
  }
#endif
641
  return phi::vectorize<int>(tensor->dims());
642 643
}

644
void Tensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
S
Steffy-zxf 已提交
645
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
646
  paddle::framework::LoD lod;
647 648 649 650 651 652
  for (auto &level : x) {
    lod.emplace_back(level);
  }
  tensor->set_lod(lod);
}

653
std::vector<std::vector<size_t>> Tensor::lod() const {
S
Steffy-zxf 已提交
654
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
655 656 657 658 659 660 661
  std::vector<std::vector<size_t>> res;
  for (auto &level : tensor->lod()) {
    res.emplace_back(level);
  }
  return res;
}

662 663 664 665 666 667 668 669 670
void Tensor::SetName(const std::string &name) { name_ = name; }

const std::string &Tensor::name() const { return name_; }

void Tensor::SetPlace(PlaceType place, int device) {
  place_ = place;
  device_ = device;
}

671 672 673 674 675 676 677 678
#ifdef PADDLE_WITH_ONNXRUNTIME
void Tensor::SetOrtMark(bool is_ort_tensor) { is_ort_tensor_ = is_ort_tensor; }

void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
  binding_ = binding;
}

template <typename T>
679
T *Tensor::ORTGetMutableData() {
680 681 682
  auto binding = binding_.lock();
  PADDLE_ENFORCE_NOT_NULL(binding,
                          paddle::platform::errors::PreconditionNotMet(
683 684 685 686
                              "output tensor [%s] no binding ptr", name_));
  std::vector<Ort::Value> outputs = binding->GetOutputValues();
  Ort::Value &value = outputs[idx_];
  return value.GetTensorMutableData<T>();
687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702
}

template <typename T>
void Tensor::ORTCopyToCpu(T *data) const {
  auto binding = binding_.lock();
  PADDLE_ENFORCE_NOT_NULL(binding,
                          paddle::platform::errors::PreconditionNotMet(
                              "output tensor [%s] no binding ptr", name_));
  std::vector<Ort::Value> outputs = binding->GetOutputValues();
  Ort::Value &value = outputs[idx_];
  auto info = value.GetTensorTypeAndShapeInfo();
  size_t size = info.GetElementCount() * sizeof(T);

  if (place_ == PlaceType::kCPU) {
    std::memcpy(static_cast<void *>(data), value.GetTensorData<void *>(), size);
  } else {
H
heliqi 已提交
703 704 705
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "CopyToCpu error.The current ONNXRuntime backend doesn't support "
        "GPU."));
706 707 708 709 710 711 712 713 714 715
  }
}

template void Tensor::ORTCopyToCpu<float>(float *data) const;
template void Tensor::ORTCopyToCpu<int32_t>(int32_t *data) const;
template void Tensor::ORTCopyToCpu<uint8_t>(uint8_t *data) const;
template void Tensor::ORTCopyToCpu<int8_t>(int8_t *data) const;
template void Tensor::ORTCopyToCpu<float16>(float16 *data) const;
#endif

716
}  // namespace paddle_infer