zero_copy_tensor.cc 36.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/framework/convert_utils.h"
16
#include "paddle/fluid/framework/data_layout_transform.h"
17 18
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
19
#include "paddle/fluid/framework/string_array.h"
20
#include "paddle/fluid/inference/api/paddle_inference_api.h"
W
Wilber 已提交
21
#include "paddle/fluid/inference/api/paddle_tensor.h"
N
nhzlx 已提交
22
#include "paddle/fluid/memory/memcpy.h"
23
#include "paddle/fluid/platform/enforce.h"
24
#include "paddle/fluid/platform/float16.h"
25
#include "paddle/phi/core/allocator.h"
26
#ifdef PADDLE_WITH_ONNXRUNTIME
H
heliqi 已提交
27 28
#include "onnxruntime_c_api.h"    // NOLINT
#include "onnxruntime_cxx_api.h"  // NOLINT
29
#endif
30

31
namespace paddle_infer {
32

33 34
using float16 = paddle::platform::float16;

35
void Tensor::Reshape(const std::vector<int> &shape) {
36 37 38 39 40 41 42
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    shape_.assign(shape.begin(), shape.end());
    return;
  }
#endif

W
Wilber 已提交
43
  PADDLE_ENFORCE_EQ(
44 45
      name_.empty(),
      false,
46
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
47 48
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
49 50
  PADDLE_ENFORCE_EQ(input_or_output_,
                    true,
51
                    paddle::platform::errors::PermissionDenied(
W
Wilber 已提交
52
                        "Can't reshape the output tensor, it is readonly"));
53
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
54
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
55
  PADDLE_ENFORCE_NOT_NULL(
56 57 58
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
59
  auto *tensor = var->GetMutable<phi::DenseTensor>();
60
  tensor->Resize(phi::make_ddim(shape));
61 62
}

S
Steffy-zxf 已提交
63 64
void Tensor::ReshapeStrings(const size_t &shape) {
  PADDLE_ENFORCE_EQ(
65 66
      name_.empty(),
      false,
S
Steffy-zxf 已提交
67 68 69
      paddle::platform::errors::PreconditionNotMet(
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
70 71
  PADDLE_ENFORCE_EQ(input_or_output_,
                    true,
S
Steffy-zxf 已提交
72 73 74 75 76
                    paddle::platform::errors::PermissionDenied(
                        "Can't reshape the output tensor, it is readonly"));
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
  auto *var = scope->FindVar(name_);
  PADDLE_ENFORCE_NOT_NULL(
77 78 79
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
80 81
  paddle::framework::Strings *tensor =
      var->GetMutable<paddle::framework::Strings>();
S
Steffy-zxf 已提交
82 83 84 85 86 87 88 89
  tensor->resize(shape);
}

#define EAGER_GET_TENSOR(tensor_type)    \
  if (!tensor_) {                        \
    tensor_ = FindTensor<tensor_type>(); \
  }                                      \
  auto *tensor = static_cast<tensor_type *>(tensor_);
90

91
template <typename T>
92
T *Tensor::mutable_data(PlaceType place) {
93 94 95 96 97
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    return ORTGetMutableData<T>();
  }
#endif
98
  EAGER_GET_TENSOR(phi::DenseTensor);
99
  PADDLE_ENFORCE_GT(
100 101
      tensor->numel(),
      0,
102 103
      paddle::platform::errors::PreconditionNotMet(
          "You should call Tensor::Reshape(const std::vector<int> "
W
Wilber 已提交
104 105
          "&shape)"
          "function before retrieving mutable_data from input tensor."));
106
  switch (static_cast<int>(place)) {
107 108
    case static_cast<int>(PlaceType::kCPU): {
      return tensor->mutable_data<T>(paddle::platform::CPUPlace());
109
    }
110
    case static_cast<int>(PlaceType::kGPU): {
111 112 113
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
      paddle::platform::CUDAPlace gpu_place(device_);
      auto *dev_ctxs = reinterpret_cast<const std::map<
114 115 116
          phi::Place,
          std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
          device_contexs_);
117 118 119 120
      auto *dev_ctx =
          static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
      return dev_ctx->Alloc<T>(tensor, tensor->numel() * sizeof(T));
#else
121
      return tensor->mutable_data<T>(paddle::platform::CUDAPlace(device_));
122
#endif
123 124 125
    }
    case static_cast<int>(PlaceType::kXPU): {
      return tensor->mutable_data<T>(paddle::platform::XPUPlace(device_));
126
    }
127 128 129
    case static_cast<int>(PlaceType::kNPU): {
      return tensor->mutable_data<T>(paddle::platform::NPUPlace(device_));
    }
130
    default:
131
      PADDLE_THROW(paddle::platform::errors::Unavailable(
132 133
          "Only CPU / CUDA / XPU / NPU places is supported. The place `%d` is "
          "not supported.",
134
          static_cast<int>(place)));
135 136 137 138 139 140
      break;
  }
  return nullptr;
}

template <typename T>
141
T *Tensor::data(PlaceType *place, int *size) const {
142
  EAGER_GET_TENSOR(phi::DenseTensor);
143 144
  auto *res = tensor->data<T>();

145 146 147 148 149 150
  if (paddle::platform::is_cpu_place(tensor->place())) {
    *place = PlaceType::kCPU;
  } else if (paddle::platform::is_gpu_place(tensor->place())) {
    *place = PlaceType::kGPU;
  } else if (paddle::platform::is_xpu_place(tensor->place())) {
    *place = PlaceType::kXPU;
151 152
  } else if (paddle::platform::is_npu_place(tensor->place())) {
    *place = PlaceType::kNPU;
153
  } else {
154
    *place = PlaceType::kUNK;
155 156 157 158 159 160
  }

  *size = tensor->numel();
  return res;
}

161
DataType Tensor::type() const {
162 163 164 165 166
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    return dtype_;
  }
#endif
167
  EAGER_GET_TENSOR(phi::DenseTensor);
168
  auto type = paddle::framework::TransToProtoVarType(tensor->dtype());
169 170
  if (type == paddle::framework::proto::VarType::FP32) {
    return DataType::FLOAT32;
171 172
  } else if (type == paddle::framework::proto::VarType::FP16) {
    return DataType::FLOAT16;
173 174 175 176 177 178
  } else if (type == paddle::framework::proto::VarType::INT64) {
    return DataType::INT64;
  } else if (type == paddle::framework::proto::VarType::INT32) {
    return DataType::INT32;
  } else if (type == paddle::framework::proto::VarType::UINT8) {
    return DataType::UINT8;
179 180
  } else if (type == paddle::framework::proto::VarType::INT8) {
    return DataType::INT8;
181 182
  } else if (type == paddle::framework::proto::VarType::BOOL) {
    return DataType::BOOL;
183
  }
184
  return DataType::FLOAT32;
185 186
}

187 188
PlaceType Tensor::place() const { return place_; }

N
nhzlx 已提交
189
template <typename T>
190
void Tensor::CopyFromCpu(const T *data) {
191
  EAGER_GET_TENSOR(phi::DenseTensor);
192 193
  PADDLE_ENFORCE_GE(tensor->numel(),
                    0,
194 195
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
W
Wilber 已提交
196 197
                        "std::vector<int> &shape)"
                        "function before copying data from cpu."));
N
nhzlx 已提交
198 199
  size_t ele_size = tensor->numel() * sizeof(T);

200 201
  if (place_ == PlaceType::kCPU) {
    auto *t_data = tensor->mutable_data<T>(paddle::platform::CPUPlace());
N
nhzlx 已提交
202
    std::memcpy(static_cast<void *>(t_data), data, ele_size);
203
  } else if (place_ == PlaceType::kGPU) {
204
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
205

206
    paddle::platform::CUDAPlace gpu_place(device_);
207
    auto *dev_ctxs = reinterpret_cast<const std::map<
208 209
        phi::Place,
        std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
210 211 212 213
        device_contexs_);
    auto *dev_ctx =
        static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
    auto *t_data = dev_ctx->Alloc<T>(tensor, tensor->numel() * sizeof(T));
N
nhzlx 已提交
214

215 216 217 218 219
    paddle::memory::Copy(gpu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
220
                         dev_ctx->stream());
N
nhzlx 已提交
221
#else
222 223 224
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
225
#endif
226
  } else if (place_ == PlaceType::kXPU) {
227
#ifdef PADDLE_WITH_XPU
228
    paddle::platform::XPUPlace xpu_place(device_);
229
    auto *t_data = tensor->mutable_data<T>(xpu_place);
230 231 232 233 234
    paddle::memory::Copy(xpu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size);
235
#else
236 237 238
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
239 240 241 242 243 244 245 246 247
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::NPUPlace npu_place(device_);
    auto *t_data = tensor->mutable_data<T>(npu_place);
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
248 249 250 251 252
    paddle::memory::Copy(npu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
W
Wilber 已提交
253 254 255 256 257
                         dev_ctx->stream());
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
258 259
#endif
  } else {
260 261 262 263 264 265
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    auto device_type_id =
        static_cast<size_t>(place_) - static_cast<size_t>(PlaceType::kCUSTOM);
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::CustomPlace custom_place(
266 267 268
        phi::CustomRegisteredDeviceMap::Instance().GetGlobalDeviceType(
            device_type_id),
        device_);
269 270 271
    auto *t_data = tensor->mutable_data<T>(custom_place);
    auto *dev_ctx = static_cast<const paddle::platform::CustomDeviceContext *>(
        pool.Get(custom_place));
272 273 274 275 276
    paddle::memory::Copy(custom_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
277 278
                         dev_ctx->stream());
#else
279
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
280
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
281
#endif
N
nhzlx 已提交
282 283 284
  }
}

285 286 287
template <typename T>
struct DataTypeInfo;

288 289 290 291 292
template <>
struct DataTypeInfo<bool> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::BOOL;
};

293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
template <>
struct DataTypeInfo<float> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT32;
};

template <>
struct DataTypeInfo<float16> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT16;
};

template <>
struct DataTypeInfo<int64_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT64;
};

template <>
struct DataTypeInfo<int8_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT8;
};

template <>
struct DataTypeInfo<uint8_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::UINT8;
};

template <>
struct DataTypeInfo<int32_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT32;
};

323
phi::DataLayout LayoutConvert(DataLayout layout) {
324
  PADDLE_ENFORCE_EQ(
325 326
      layout,
      DataLayout::kNCHW,
327
      paddle::platform::errors::InvalidArgument("Only NCHW is supported now."));
328
  return phi::DataLayout::NCHW;
329 330 331
}

template <typename T>
332 333 334 335
void Tensor::ShareExternalData(const T *data,
                               const std::vector<int> &shape,
                               PlaceType place,
                               DataLayout layout) {
336
  EAGER_GET_TENSOR(phi::DenseTensor)
337 338 339
  size_t size =
      std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
      sizeof(T);
340 341
  phi::DenseTensorMeta meta(
      DataTypeInfo<T>().TYPE, phi::make_ddim(shape), LayoutConvert(layout));
342 343
  if (place == PlaceType::kCPU) {
    phi::DenseTensor dtensor(
344 345
        std::make_shared<phi::Allocation>(
            const_cast<T *>(data), size, paddle::platform::CPUPlace()),
346 347 348 349
        meta);
    *tensor = std::move(dtensor);
  } else if (place == PlaceType::kGPU) {
    phi::DenseTensor dtensor(
350 351
        std::make_shared<phi::Allocation>(
            const_cast<T *>(data), size, paddle::platform::CUDAPlace(device_)),
352 353 354 355 356 357 358 359
        meta);
    *tensor = std::move(dtensor);
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "PlaceType must be PlaceType::kCPU or PlaceType::kGPU."));
  }
}

S
Steffy-zxf 已提交
360
void Tensor::CopyStringsFromCpu(const paddle_infer::Strings *data) {
361
  EAGER_GET_TENSOR(paddle::framework::Strings);
362 363
  PADDLE_ENFORCE_GE(tensor->size(),
                    0,
S
Steffy-zxf 已提交
364 365 366 367 368 369 370
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
                        "std::size_t &shape)function before copying"
                        "the string data from cpu."));
  *tensor = *data;
}

N
nhzlx 已提交
371
template <typename T>
372 373 374
void Tensor::CopyToCpuImpl(T *data,
                           void *exec_stream,
                           CallbackFunc cb,
375
                           void *cb_params) const {
376
  EAGER_GET_TENSOR(phi::DenseTensor);
N
nhzlx 已提交
377 378 379 380
  auto ele_num = tensor->numel();
  auto *t_data = tensor->data<T>();
  auto t_place = tensor->place();

381
  phi::DenseTensor out;
382 383
  auto mem_allocation =
      std::make_shared<paddle::memory::allocation::Allocation>(
384 385
          static_cast<void *>(data),
          ele_num * sizeof(T),
386
          paddle::platform::CPUPlace());
387 388
  out.ResetHolder(mem_allocation);

389
  if (paddle::platform::is_cpu_place(t_place)) {
390
#ifdef PADDLE_WITH_MKLDNN
391
    if (tensor->layout() == phi::DataLayout::ONEDNN)
392
      phi::funcs::TransDataLayoutFromOneDNN(
393
          tensor->layout(),
394
          phi::OneDNNContext::tls().get_cur_paddle_data_layout(),
395 396 397 398
          *tensor,
          &out,
          paddle::platform::CPUPlace(),
          true);
399 400 401
    else
      std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
N
nhzlx 已提交
402
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
J
jianghaicheng 已提交
403 404 405 406 407 408 409 410
#endif
  } else if (paddle::platform::is_ipu_place(t_place)) {
#ifdef PADDLE_WITH_IPU
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with IPU place because paddle is not compiled "
        "with IPU."));
411
#endif
412
  } else if (place_ == PlaceType::kGPU) {
413
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
414
    auto gpu_place = t_place;
415
    auto *dev_ctxs = reinterpret_cast<const std::map<
416 417
        phi::Place,
        std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
418 419 420
        device_contexs_);
    auto *dev_ctx =
        static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
421
    paddle::memory::Copy(paddle::platform::CPUPlace(),
422 423 424 425 426
                         static_cast<void *>(data),
                         gpu_place,
                         t_data,
                         ele_num * sizeof(T),
                         dev_ctx->stream());
427 428 429
#ifdef PADDLE_WITH_HIP
    hipStreamSynchronize(dev_ctx->stream());
#else
430 431 432 433 434 435 436 437 438 439
    // async, return stream
    if (nullptr != exec_stream) {
      *(static_cast<cudaStream_t *>(exec_stream)) = dev_ctx->stream();
      // async with callback
    } else if (cb) {
      cudaLaunchHostFunc(dev_ctx->stream(), cb, cb_params);
      // sync
    } else {
      cudaStreamSynchronize(dev_ctx->stream());
    }
440
#endif
N
nhzlx 已提交
441
#else
442 443 444
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
445
#endif
446
  } else if (place_ == PlaceType::kXPU) {
447
#ifdef PADDLE_WITH_XPU
448
    auto xpu_place = t_place;
449
    paddle::memory::Copy(paddle::platform::CPUPlace(),
450 451 452
                         static_cast<void *>(data),
                         xpu_place,
                         t_data,
453
                         ele_num * sizeof(T));
454
#else
455 456 457
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
458 459 460 461 462
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
463
    auto npu_place = t_place;
W
Wilber 已提交
464 465 466
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
467 468 469 470 471
                         static_cast<void *>(data),
                         npu_place,
                         t_data,
                         ele_num * sizeof(T),
                         dev_ctx->stream());
472
    paddle::platform::NPUStreamSync(dev_ctx->stream());
W
Wilber 已提交
473 474 475 476
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
477 478
#endif
  } else {
479 480 481 482 483 484 485
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    auto custom_place = t_place;
    auto *dev_ctx = static_cast<const paddle::platform::CustomDeviceContext *>(
        pool.Get(custom_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
486 487 488 489 490
                         static_cast<void *>(data),
                         custom_place,
                         t_data,
                         ele_num * sizeof(T),
                         dev_ctx->stream());
491 492
// TODO(wangran16): sync_stream
#else
493
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
494
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
495
#endif
N
nhzlx 已提交
496 497
  }
}
498 499 500

template <typename T>
void Tensor::CopyToCpu(T *data) const {
501 502 503 504 505 506 507
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    ORTCopyToCpu<T>(data);
    return;
  }
#endif

508 509 510 511 512 513 514 515 516 517 518 519 520
  CopyToCpuImpl<T>(data, nullptr, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, void *exec_stream) const {
  CopyToCpuImpl<T>(data, exec_stream, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, CallbackFunc cb, void *cb_params) const {
  CopyToCpuImpl<T>(data, nullptr, cb, cb_params);
}

521 522 523 524 525
template PD_INFER_DECL void Tensor::CopyFromCpu<float>(const float *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int64_t>(const int64_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int32_t>(const int32_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<uint8_t>(const uint8_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int8_t>(const int8_t *data);
526
template PD_INFER_DECL void Tensor::CopyFromCpu<float16>(const float16 *data);
527
template PD_INFER_DECL void Tensor::CopyFromCpu<bool>(const bool *data);
528

529
template PD_INFER_DECL void Tensor::ShareExternalData<float>(
530 531 532
    const float *data,
    const std::vector<int> &shape,
    PlaceType place,
533 534
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int64_t>(
535 536 537
    const int64_t *data,
    const std::vector<int> &shape,
    PlaceType place,
538 539
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int32_t>(
540 541 542
    const int32_t *data,
    const std::vector<int> &shape,
    PlaceType place,
543 544
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<uint8_t>(
545 546 547
    const uint8_t *data,
    const std::vector<int> &shape,
    PlaceType place,
548 549
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int8_t>(
550 551 552
    const int8_t *data,
    const std::vector<int> &shape,
    PlaceType place,
553 554
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<float16>(
555 556 557
    const float16 *data,
    const std::vector<int> &shape,
    PlaceType place,
558
    DataLayout layout);
559 560 561 562 563
template PD_INFER_DECL void Tensor::ShareExternalData<bool>(
    const bool *data,
    const std::vector<int> &shape,
    PlaceType place,
    DataLayout layout);
564

565 566 567 568 569 570
template PD_INFER_DECL void Tensor::CopyToCpu<float>(float *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int64_t>(int64_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int32_t>(int32_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<uint8_t>(uint8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int8_t>(int8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<float16>(float16 *data) const;
571
template PD_INFER_DECL void Tensor::CopyToCpu<bool>(bool *data) const;
572 573 574 575 576 577 578 579 580 581 582 583 584 585 586

template PD_INFER_DECL void Tensor::CopyToCpuImpl<float>(float *data,
                                                         void *exec_stream,
                                                         CallbackFunc cb,
                                                         void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int64_t>(
    int64_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int32_t>(
    int32_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<uint8_t>(
    uint8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int8_t>(
    int8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<float16>(
    float16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
587 588 589 590
template PD_INFER_DECL void Tensor::CopyToCpuImpl<bool>(bool *data,
                                                        void *exec_stream,
                                                        CallbackFunc cb,
                                                        void *cb_params) const;
591 592 593 594 595 596 597 598 599 600 601 602 603

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, void *exec_stream) const;
604 605
template PD_INFER_DECL void Tensor::CopyToCpuAsync<bool>(
    bool *data, void *exec_stream) const;
606 607 608 609 610 611 612 613 614 615 616 617 618

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, CallbackFunc cb, void *cb_params) const;
619 620 621
template PD_INFER_DECL void Tensor::CopyToCpuAsync<bool>(bool *data,
                                                         CallbackFunc cb,
                                                         void *cb_params) const;
622

623 624 625 626 627 628 629 630 631 632
template PD_INFER_DECL float *Tensor::data<float>(PlaceType *place,
                                                  int *size) const;
template PD_INFER_DECL int64_t *Tensor::data<int64_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int32_t *Tensor::data<int32_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL uint8_t *Tensor::data<uint8_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int8_t *Tensor::data<int8_t>(PlaceType *place,
                                                    int *size) const;
633 634
template PD_INFER_DECL float16 *Tensor::data<float16>(PlaceType *place,
                                                      int *size) const;
635 636
template PD_INFER_DECL bool *Tensor::data<bool>(PlaceType *place,
                                                int *size) const;
637

638 639 640 641 642
template PD_INFER_DECL float *Tensor::mutable_data<float>(PlaceType place);
template PD_INFER_DECL int64_t *Tensor::mutable_data<int64_t>(PlaceType place);
template PD_INFER_DECL int32_t *Tensor::mutable_data<int32_t>(PlaceType place);
template PD_INFER_DECL uint8_t *Tensor::mutable_data<uint8_t>(PlaceType place);
template PD_INFER_DECL int8_t *Tensor::mutable_data<int8_t>(PlaceType place);
643
template PD_INFER_DECL float16 *Tensor::mutable_data<float16>(PlaceType place);
644
template PD_INFER_DECL bool *Tensor::mutable_data<bool>(PlaceType place);
645

646 647
Tensor::Tensor(void *scope, const void *device_contexts)
    : scope_{scope}, device_contexs_(device_contexts) {}
648

S
Steffy-zxf 已提交
649
template <typename T>
650
void *Tensor::FindTensor() const {
W
Wilber 已提交
651
  PADDLE_ENFORCE_EQ(
652 653
      name_.empty(),
      false,
654
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
655 656
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
657
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
658
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
659
  PADDLE_ENFORCE_NOT_NULL(
660 661 662
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
S
Steffy-zxf 已提交
663
  auto *tensor = var->GetMutable<T>();
664 665 666
  return tensor;
}

667
std::vector<int> Tensor::shape() const {
668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    std::vector<int> shape;
    // input handle
    if (idx_ < 0) {
      shape.assign(shape_.begin(), shape_.end());
    } else {  // output handle
      auto binding = binding_.lock();
      PADDLE_ENFORCE_NOT_NULL(binding,
                              paddle::platform::errors::PreconditionNotMet(
                                  "output tensor [%s] no binding ptr", name_));
      std::vector<Ort::Value> outputs = binding->GetOutputValues();
      Ort::Value &value = outputs[idx_];
      auto info = value.GetTensorTypeAndShapeInfo();
      auto ort_shape = info.GetShape();
      shape.assign(ort_shape.begin(), ort_shape.end());
    }
    return shape;
  }
#endif
688
  EAGER_GET_TENSOR(phi::DenseTensor);
W
Wilber 已提交
689
  PADDLE_ENFORCE_NOT_NULL(
690 691 692
      tensor_,
      paddle::platform::errors::PreconditionNotMet(
          "Not found tensor called %s in the scope", name_));
693
// oneDNN may does layout transform internally, so need to reorder before
W
wenbin 已提交
694 695
// return
#ifdef PADDLE_WITH_MKLDNN
696
  if (tensor->layout() == phi::DataLayout::ONEDNN) {
697 698
    phi::DataLayout out_layout =
        phi::OneDNNContext::tls().get_cur_paddle_data_layout();
W
wenbin 已提交
699
    // Set default as NCHW in case not specified
700 701
    out_layout = out_layout == phi::DataLayout::kAnyLayout
                     ? phi::DataLayout::kNCHW
W
wenbin 已提交
702 703 704 705 706 707
                     : out_layout;
    // In these data layouts, channel dimension is either on 2nd position: nChw
    // or
    // at last nhwC, so for dim==2 these layouts are the same and nothing should
    // be done. Similarly for dim==1 when you have just one possible
    // combination.
708
    if (tensor->dims().size() < 3) return phi::vectorize<int>(tensor->dims());
709 710
    if (out_layout == phi::DataLayout::kNHWC ||
        out_layout == phi::DataLayout::kNDHWC) {
711
      auto dims = phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
712 713 714
      std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
      return dims;
    } else {
715
      return phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
716 717 718
    }
  }
#endif
719
  return phi::vectorize<int>(tensor->dims());
720 721
}

722
void Tensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
723
  EAGER_GET_TENSOR(phi::DenseTensor);
724
  paddle::framework::LoD lod;
725 726 727 728 729 730
  for (auto &level : x) {
    lod.emplace_back(level);
  }
  tensor->set_lod(lod);
}

731
std::vector<std::vector<size_t>> Tensor::lod() const {
732
  EAGER_GET_TENSOR(phi::DenseTensor);
733 734 735 736 737 738 739
  std::vector<std::vector<size_t>> res;
  for (auto &level : tensor->lod()) {
    res.emplace_back(level);
  }
  return res;
}

740 741 742 743 744 745 746 747 748
void Tensor::SetName(const std::string &name) { name_ = name; }

const std::string &Tensor::name() const { return name_; }

void Tensor::SetPlace(PlaceType place, int device) {
  place_ = place;
  device_ = device;
}

749 750 751 752 753 754 755
#ifdef PADDLE_WITH_ONNXRUNTIME
void Tensor::SetOrtMark(bool is_ort_tensor) { is_ort_tensor_ = is_ort_tensor; }

void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
  binding_ = binding;
}

756 757 758 759 760 761 762 763 764 765 766
template <typename T>
T *Tensor::ORTGetMutableData() {
  auto binding = binding_.lock();
  PADDLE_ENFORCE_NOT_NULL(binding,
                          paddle::platform::errors::PreconditionNotMet(
                              "output tensor [%s] no binding ptr", name_));
  std::vector<Ort::Value> outputs = binding->GetOutputValues();
  Ort::Value &value = outputs[idx_];
  return value.GetTensorMutableData<T>();
}

767 768 769 770 771 772 773 774 775 776 777 778 779 780
template <typename T>
void Tensor::ORTCopyToCpu(T *data) const {
  auto binding = binding_.lock();
  PADDLE_ENFORCE_NOT_NULL(binding,
                          paddle::platform::errors::PreconditionNotMet(
                              "output tensor [%s] no binding ptr", name_));
  std::vector<Ort::Value> outputs = binding->GetOutputValues();
  Ort::Value &value = outputs[idx_];
  auto info = value.GetTensorTypeAndShapeInfo();
  size_t size = info.GetElementCount() * sizeof(T);

  if (place_ == PlaceType::kCPU) {
    std::memcpy(static_cast<void *>(data), value.GetTensorData<void *>(), size);
  } else {
H
heliqi 已提交
781 782 783
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "CopyToCpu error.The current ONNXRuntime backend doesn't support "
        "GPU."));
784 785 786 787 788 789 790 791 792 793
  }
}

template void Tensor::ORTCopyToCpu<float>(float *data) const;
template void Tensor::ORTCopyToCpu<int32_t>(int32_t *data) const;
template void Tensor::ORTCopyToCpu<uint8_t>(uint8_t *data) const;
template void Tensor::ORTCopyToCpu<int8_t>(int8_t *data) const;
template void Tensor::ORTCopyToCpu<float16>(float16 *data) const;
#endif

W
Wilber 已提交
794 795 796 797 798 799 800
namespace experimental {
template <typename T>
void InternalUtils::CopyFromCpuWithIoStream(paddle_infer::Tensor *t,
                                            const T *data,
                                            cudaStream_t stream) {
  if (t->tensor_ == nullptr) {
    PADDLE_ENFORCE_EQ(
801 802
        t->name_.empty(),
        false,
W
Wilber 已提交
803 804 805 806 807 808
        paddle::platform::errors::PreconditionNotMet(
            "Need to SetName first, so that the corresponding tensor can "
            "be retrieved."));
    auto *scope = static_cast<paddle::framework::Scope *>(t->scope_);
    auto *var = scope->FindVar(t->name_);
    PADDLE_ENFORCE_NOT_NULL(
809 810 811
        var,
        paddle::platform::errors::PreconditionNotMet(
            "No tensor called [%s] in the runtime scope", t->name_));
812
    auto *tensor = var->GetMutable<phi::DenseTensor>();
W
Wilber 已提交
813 814 815
    t->tensor_ = tensor;
  }

816
  auto *tensor = static_cast<phi::DenseTensor *>(t->tensor_);
817 818
  PADDLE_ENFORCE_GE(tensor->numel(),
                    0,
W
Wilber 已提交
819 820 821 822 823 824 825 826 827 828 829 830
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
                        "std::vector<int> &shape)"
                        "function before copying data from cpu."));
  size_t ele_size = tensor->numel() * sizeof(T);
  if (t->place_ == PlaceType::kCPU) {
    auto *t_data = tensor->mutable_data<T>(paddle::platform::CPUPlace());
    std::memcpy(static_cast<void *>(t_data), data, ele_size);
  } else if (t->place_ == PlaceType::kGPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    paddle::platform::CUDAPlace gpu_place(t->device_);
    auto *t_data = tensor->mutable_data<T>(gpu_place);
831 832 833 834 835 836
    paddle::memory::Copy(gpu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
                         stream);
W
Wilber 已提交
837 838 839 840 841 842 843 844 845 846 847 848
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "CopyFromCpuWithIoStream only supports CPU and GPU now."));
  }
}

template <typename T>
849 850
void InternalUtils::CopyToCpuWithIoStream(paddle_infer::Tensor *t,
                                          T *data,
W
Wilber 已提交
851 852 853
                                          cudaStream_t stream) {
  if (t->tensor_ == nullptr) {
    PADDLE_ENFORCE_EQ(
854 855
        t->name_.empty(),
        false,
W
Wilber 已提交
856 857 858 859 860 861
        paddle::platform::errors::PreconditionNotMet(
            "Need to SetName first, so that the corresponding tensor can "
            "be retrieved."));
    auto *scope = static_cast<paddle::framework::Scope *>(t->scope_);
    auto *var = scope->FindVar(t->name_);
    PADDLE_ENFORCE_NOT_NULL(
862 863 864
        var,
        paddle::platform::errors::PreconditionNotMet(
            "No tensor called [%s] in the runtime scope", t->name_));
865
    auto *tensor = var->GetMutable<phi::DenseTensor>();
W
Wilber 已提交
866 867 868
    t->tensor_ = tensor;
  }

869
  auto *tensor = static_cast<phi::DenseTensor *>(t->tensor_);
W
Wilber 已提交
870 871 872 873
  auto ele_num = tensor->numel();
  auto *t_data = tensor->data<T>();
  auto t_place = tensor->place();

874
  phi::DenseTensor out;
W
Wilber 已提交
875 876
  auto mem_allocation =
      std::make_shared<paddle::memory::allocation::Allocation>(
877 878
          static_cast<void *>(data),
          ele_num * sizeof(T),
W
Wilber 已提交
879 880 881 882 883
          paddle::platform::CPUPlace());
  out.ResetHolder(mem_allocation);

  if (paddle::platform::is_cpu_place(t_place)) {
#ifdef PADDLE_WITH_MKLDNN
884
    if (tensor->layout() == phi::DataLayout::ONEDNN)
885
      phi::funcs::TransDataLayoutFromOneDNN(
886
          tensor->layout(),
887
          phi::OneDNNContext::tls().get_cur_paddle_data_layout(),
888 889 890 891
          *tensor,
          &out,
          paddle::platform::CPUPlace(),
          true);
W
Wilber 已提交
892 893 894 895 896 897 898 899
    else
      std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#endif
  } else if (t->place_ == PlaceType::kGPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    paddle::memory::Copy(paddle::platform::CPUPlace(),
900 901 902 903 904
                         static_cast<void *>(data),
                         t_place,
                         t_data,
                         ele_num * sizeof(T),
                         stream);
W
Wilber 已提交
905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "CopyToCpuWithIoStream only supports CPU and GPU now."));
  }
}

template void InternalUtils::CopyFromCpuWithIoStream<float>(
    paddle_infer::Tensor *t, const float *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int64_t>(
    paddle_infer::Tensor *t, const int64_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int32_t>(
    paddle_infer::Tensor *t, const int32_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<uint8_t>(
    paddle_infer::Tensor *t, const uint8_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int8_t>(
    paddle_infer::Tensor *t, const int8_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<float16>(
    paddle_infer::Tensor *t, const float16 *data, cudaStream_t stream);
928 929
template void InternalUtils::CopyFromCpuWithIoStream<bool>(
    paddle_infer::Tensor *t, const bool *data, cudaStream_t stream);
W
Wilber 已提交
930 931 932 933 934 935 936 937 938 939 940 941 942

template void InternalUtils::CopyToCpuWithIoStream<float>(
    paddle_infer::Tensor *t, float *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int64_t>(
    paddle_infer::Tensor *t, int64_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int32_t>(
    paddle_infer::Tensor *t, int32_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<uint8_t>(
    paddle_infer::Tensor *t, uint8_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int8_t>(
    paddle_infer::Tensor *t, int8_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<float16>(
    paddle_infer::Tensor *t, float16 *data, cudaStream_t stream);
943 944
template void InternalUtils::CopyToCpuWithIoStream<bool>(
    paddle_infer::Tensor *t, bool *data, cudaStream_t stream);
W
Wilber 已提交
945 946 947

}  // namespace experimental

948
}  // namespace paddle_infer