zero_copy_tensor.cc 36.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/framework/convert_utils.h"
16
#include "paddle/fluid/framework/data_layout_transform.h"
17 18
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
19
#include "paddle/fluid/framework/string_array.h"
20
#include "paddle/fluid/inference/api/paddle_inference_api.h"
W
Wilber 已提交
21
#include "paddle/fluid/inference/api/paddle_tensor.h"
N
nhzlx 已提交
22
#include "paddle/fluid/memory/memcpy.h"
23
#include "paddle/fluid/platform/enforce.h"
24
#include "paddle/fluid/platform/float16.h"
25
#include "paddle/phi/core/allocator.h"
26
#ifdef PADDLE_WITH_ONNXRUNTIME
H
heliqi 已提交
27 28
#include "onnxruntime_c_api.h"    // NOLINT
#include "onnxruntime_cxx_api.h"  // NOLINT
29
#endif
30

31
namespace paddle_infer {
32

33 34
using float16 = paddle::platform::float16;

35
void Tensor::Reshape(const std::vector<int> &shape) {
36 37 38 39 40 41 42
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    shape_.assign(shape.begin(), shape.end());
    return;
  }
#endif

W
Wilber 已提交
43
  PADDLE_ENFORCE_EQ(
44 45
      name_.empty(),
      false,
46
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
47 48
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
49 50
  PADDLE_ENFORCE_EQ(input_or_output_,
                    true,
51
                    paddle::platform::errors::PermissionDenied(
W
Wilber 已提交
52
                        "Can't reshape the output tensor, it is readonly"));
53
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
54
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
55
  PADDLE_ENFORCE_NOT_NULL(
56 57 58
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
59
  auto *tensor = var->GetMutable<phi::DenseTensor>();
60
  tensor->Resize(phi::make_ddim(shape));
61 62
}

S
Steffy-zxf 已提交
63 64
void Tensor::ReshapeStrings(const size_t &shape) {
  PADDLE_ENFORCE_EQ(
65 66
      name_.empty(),
      false,
S
Steffy-zxf 已提交
67 68 69
      paddle::platform::errors::PreconditionNotMet(
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
70 71
  PADDLE_ENFORCE_EQ(input_or_output_,
                    true,
S
Steffy-zxf 已提交
72 73 74 75 76
                    paddle::platform::errors::PermissionDenied(
                        "Can't reshape the output tensor, it is readonly"));
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
  auto *var = scope->FindVar(name_);
  PADDLE_ENFORCE_NOT_NULL(
77 78 79
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
80 81
  paddle::framework::Strings *tensor =
      var->GetMutable<paddle::framework::Strings>();
S
Steffy-zxf 已提交
82 83 84 85 86 87 88 89
  tensor->resize(shape);
}

#define EAGER_GET_TENSOR(tensor_type)    \
  if (!tensor_) {                        \
    tensor_ = FindTensor<tensor_type>(); \
  }                                      \
  auto *tensor = static_cast<tensor_type *>(tensor_);
90

91
template <typename T>
92
T *Tensor::mutable_data(PlaceType place) {
93 94 95 96 97
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    return ORTGetMutableData<T>();
  }
#endif
98
  EAGER_GET_TENSOR(phi::DenseTensor);
99
  PADDLE_ENFORCE_GT(
100 101
      tensor->numel(),
      0,
102 103
      paddle::platform::errors::PreconditionNotMet(
          "You should call Tensor::Reshape(const std::vector<int> "
W
Wilber 已提交
104 105
          "&shape)"
          "function before retrieving mutable_data from input tensor."));
106
  switch (static_cast<int>(place)) {
107 108
    case static_cast<int>(PlaceType::kCPU): {
      return tensor->mutable_data<T>(paddle::platform::CPUPlace());
109
    }
110
    case static_cast<int>(PlaceType::kGPU): {
111 112 113
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
      paddle::platform::CUDAPlace gpu_place(device_);
      auto *dev_ctxs = reinterpret_cast<const std::map<
114 115 116
          phi::Place,
          std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
          device_contexs_);
117 118 119 120
      auto *dev_ctx =
          static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
      return dev_ctx->Alloc<T>(tensor, tensor->numel() * sizeof(T));
#else
121
      return tensor->mutable_data<T>(paddle::platform::CUDAPlace(device_));
122
#endif
123 124 125
    }
    case static_cast<int>(PlaceType::kXPU): {
      return tensor->mutable_data<T>(paddle::platform::XPUPlace(device_));
126
    }
127 128 129
    case static_cast<int>(PlaceType::kNPU): {
      return tensor->mutable_data<T>(paddle::platform::NPUPlace(device_));
    }
130 131 132 133
    case static_cast<int>(PlaceType::kCUSTOM): {
      return tensor->mutable_data<T>(
          paddle::platform::CustomPlace(device_type_, device_));
    }
134
    default:
135
      PADDLE_THROW(paddle::platform::errors::Unavailable(
136 137
          "Only CPU / CUDA / XPU / NPU places is supported. The place `%d` is "
          "not supported.",
138
          static_cast<int>(place)));
139 140 141 142 143 144
      break;
  }
  return nullptr;
}

template <typename T>
145
T *Tensor::data(PlaceType *place, int *size) const {
146
  EAGER_GET_TENSOR(phi::DenseTensor);
147 148
  auto *res = tensor->data<T>();

149 150 151 152 153 154
  if (paddle::platform::is_cpu_place(tensor->place())) {
    *place = PlaceType::kCPU;
  } else if (paddle::platform::is_gpu_place(tensor->place())) {
    *place = PlaceType::kGPU;
  } else if (paddle::platform::is_xpu_place(tensor->place())) {
    *place = PlaceType::kXPU;
155 156
  } else if (paddle::platform::is_npu_place(tensor->place())) {
    *place = PlaceType::kNPU;
157 158
  } else if (paddle::platform::is_custom_place(tensor->place())) {
    *place = PlaceType::kCUSTOM;
159
  } else {
160
    *place = PlaceType::kUNK;
161 162 163 164 165 166
  }

  *size = tensor->numel();
  return res;
}

167
DataType Tensor::type() const {
168 169 170 171 172
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    return dtype_;
  }
#endif
173
  EAGER_GET_TENSOR(phi::DenseTensor);
174
  auto type = paddle::framework::TransToProtoVarType(tensor->dtype());
175 176
  if (type == paddle::framework::proto::VarType::FP32) {
    return DataType::FLOAT32;
177 178
  } else if (type == paddle::framework::proto::VarType::FP16) {
    return DataType::FLOAT16;
179 180 181 182 183 184
  } else if (type == paddle::framework::proto::VarType::INT64) {
    return DataType::INT64;
  } else if (type == paddle::framework::proto::VarType::INT32) {
    return DataType::INT32;
  } else if (type == paddle::framework::proto::VarType::UINT8) {
    return DataType::UINT8;
185 186
  } else if (type == paddle::framework::proto::VarType::INT8) {
    return DataType::INT8;
187 188
  } else if (type == paddle::framework::proto::VarType::BOOL) {
    return DataType::BOOL;
189
  }
190
  return DataType::FLOAT32;
191 192
}

193 194
PlaceType Tensor::place() const { return place_; }

N
nhzlx 已提交
195
template <typename T>
196
void Tensor::CopyFromCpu(const T *data) {
197
  EAGER_GET_TENSOR(phi::DenseTensor);
198 199
  PADDLE_ENFORCE_GE(tensor->numel(),
                    0,
200 201
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
W
Wilber 已提交
202 203
                        "std::vector<int> &shape)"
                        "function before copying data from cpu."));
N
nhzlx 已提交
204 205
  size_t ele_size = tensor->numel() * sizeof(T);

206 207
  if (place_ == PlaceType::kCPU) {
    auto *t_data = tensor->mutable_data<T>(paddle::platform::CPUPlace());
N
nhzlx 已提交
208
    std::memcpy(static_cast<void *>(t_data), data, ele_size);
209
  } else if (place_ == PlaceType::kGPU) {
210
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
211

212
    paddle::platform::CUDAPlace gpu_place(device_);
213
    auto *dev_ctxs = reinterpret_cast<const std::map<
214 215
        phi::Place,
        std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
216 217 218 219
        device_contexs_);
    auto *dev_ctx =
        static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
    auto *t_data = dev_ctx->Alloc<T>(tensor, tensor->numel() * sizeof(T));
N
nhzlx 已提交
220

221 222 223 224 225
    paddle::memory::Copy(gpu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
226
                         dev_ctx->stream());
N
nhzlx 已提交
227
#else
228 229 230
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
231
#endif
232
  } else if (place_ == PlaceType::kXPU) {
233
#ifdef PADDLE_WITH_XPU
234
    paddle::platform::XPUPlace xpu_place(device_);
235
    auto *t_data = tensor->mutable_data<T>(xpu_place);
236 237 238 239 240
    paddle::memory::Copy(xpu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size);
241
#else
242 243 244
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
245 246 247 248 249 250 251 252 253
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::NPUPlace npu_place(device_);
    auto *t_data = tensor->mutable_data<T>(npu_place);
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
254 255 256 257 258
    paddle::memory::Copy(npu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
W
Wilber 已提交
259 260 261 262 263
                         dev_ctx->stream());
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
264 265
#endif
  } else {
266 267 268 269 270 271
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    auto device_type_id =
        static_cast<size_t>(place_) - static_cast<size_t>(PlaceType::kCUSTOM);
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::CustomPlace custom_place(
272 273 274
        phi::CustomRegisteredDeviceMap::Instance().GetGlobalDeviceType(
            device_type_id),
        device_);
275 276 277
    auto *t_data = tensor->mutable_data<T>(custom_place);
    auto *dev_ctx = static_cast<const paddle::platform::CustomDeviceContext *>(
        pool.Get(custom_place));
278 279 280 281 282
    paddle::memory::Copy(custom_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
283 284
                         dev_ctx->stream());
#else
285
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
286
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
287
#endif
N
nhzlx 已提交
288 289 290
  }
}

291 292 293
template <typename T>
struct DataTypeInfo;

294 295
template <>
struct DataTypeInfo<bool> {
296
  phi::DataType TYPE = phi::DataType::BOOL;
297 298
};

299 300
template <>
struct DataTypeInfo<float> {
301
  phi::DataType TYPE = phi::DataType::FLOAT32;
302 303 304 305
};

template <>
struct DataTypeInfo<float16> {
306
  phi::DataType TYPE = phi::DataType::FLOAT16;
307 308 309 310
};

template <>
struct DataTypeInfo<int64_t> {
311
  phi::DataType TYPE = phi::DataType::INT64;
312 313 314 315
};

template <>
struct DataTypeInfo<int8_t> {
316
  phi::DataType TYPE = phi::DataType::INT8;
317 318 319 320
};

template <>
struct DataTypeInfo<uint8_t> {
321
  phi::DataType TYPE = phi::DataType::UINT8;
322 323 324 325
};

template <>
struct DataTypeInfo<int32_t> {
326
  phi::DataType TYPE = phi::DataType::INT32;
327 328
};

329
phi::DataLayout LayoutConvert(DataLayout layout) {
330
  PADDLE_ENFORCE_EQ(
331 332
      layout,
      DataLayout::kNCHW,
333
      paddle::platform::errors::InvalidArgument("Only NCHW is supported now."));
334
  return phi::DataLayout::NCHW;
335 336 337
}

template <typename T>
338 339 340 341
void Tensor::ShareExternalData(const T *data,
                               const std::vector<int> &shape,
                               PlaceType place,
                               DataLayout layout) {
342
  EAGER_GET_TENSOR(phi::DenseTensor)
343 344 345
  size_t size =
      std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
      sizeof(T);
346 347
  phi::DenseTensorMeta meta(
      DataTypeInfo<T>().TYPE, phi::make_ddim(shape), LayoutConvert(layout));
348 349
  if (place == PlaceType::kCPU) {
    phi::DenseTensor dtensor(
350 351
        std::make_shared<phi::Allocation>(
            const_cast<T *>(data), size, paddle::platform::CPUPlace()),
352 353 354 355
        meta);
    *tensor = std::move(dtensor);
  } else if (place == PlaceType::kGPU) {
    phi::DenseTensor dtensor(
356 357
        std::make_shared<phi::Allocation>(
            const_cast<T *>(data), size, paddle::platform::CUDAPlace(device_)),
358 359 360 361 362 363 364 365
        meta);
    *tensor = std::move(dtensor);
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "PlaceType must be PlaceType::kCPU or PlaceType::kGPU."));
  }
}

S
Steffy-zxf 已提交
366
void Tensor::CopyStringsFromCpu(const paddle_infer::Strings *data) {
367
  EAGER_GET_TENSOR(paddle::framework::Strings);
368 369
  PADDLE_ENFORCE_GE(tensor->size(),
                    0,
S
Steffy-zxf 已提交
370 371 372 373 374 375 376
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
                        "std::size_t &shape)function before copying"
                        "the string data from cpu."));
  *tensor = *data;
}

N
nhzlx 已提交
377
template <typename T>
378 379 380
void Tensor::CopyToCpuImpl(T *data,
                           void *exec_stream,
                           CallbackFunc cb,
381
                           void *cb_params) const {
382
  EAGER_GET_TENSOR(phi::DenseTensor);
N
nhzlx 已提交
383 384 385 386
  auto ele_num = tensor->numel();
  auto *t_data = tensor->data<T>();
  auto t_place = tensor->place();

387
  if (paddle::platform::is_cpu_place(t_place)) {
388
#ifdef PADDLE_WITH_MKLDNN
389 390 391 392 393 394 395 396
    if (tensor->layout() == phi::DataLayout::ONEDNN) {
      phi::DenseTensor out;
      auto mem_allocation =
          std::make_shared<paddle::memory::allocation::Allocation>(
              static_cast<void *>(data),
              ele_num * sizeof(T),
              paddle::platform::CPUPlace());
      out.ResetHolder(mem_allocation);
397
      phi::funcs::TransDataLayoutFromOneDNN(
398
          tensor->layout(),
399
          phi::OneDNNContext::tls().get_cur_paddle_data_layout(),
400 401 402 403
          *tensor,
          &out,
          paddle::platform::CPUPlace(),
          true);
404
    } else {
405
      std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
406
    }
407
#else
N
nhzlx 已提交
408
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
J
jianghaicheng 已提交
409 410 411 412 413 414 415 416
#endif
  } else if (paddle::platform::is_ipu_place(t_place)) {
#ifdef PADDLE_WITH_IPU
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with IPU place because paddle is not compiled "
        "with IPU."));
417
#endif
418
  } else if (place_ == PlaceType::kGPU) {
419
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
420
    auto gpu_place = t_place;
421
    auto *dev_ctxs = reinterpret_cast<const std::map<
422 423
        phi::Place,
        std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
424 425 426
        device_contexs_);
    auto *dev_ctx =
        static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
427
    paddle::memory::Copy(paddle::platform::CPUPlace(),
428 429 430 431 432
                         static_cast<void *>(data),
                         gpu_place,
                         t_data,
                         ele_num * sizeof(T),
                         dev_ctx->stream());
433 434 435
#ifdef PADDLE_WITH_HIP
    hipStreamSynchronize(dev_ctx->stream());
#else
436 437 438 439 440 441 442 443 444 445
    // async, return stream
    if (nullptr != exec_stream) {
      *(static_cast<cudaStream_t *>(exec_stream)) = dev_ctx->stream();
      // async with callback
    } else if (cb) {
      cudaLaunchHostFunc(dev_ctx->stream(), cb, cb_params);
      // sync
    } else {
      cudaStreamSynchronize(dev_ctx->stream());
    }
446
#endif
N
nhzlx 已提交
447
#else
448 449 450
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
451
#endif
452
  } else if (place_ == PlaceType::kXPU) {
453
#ifdef PADDLE_WITH_XPU
454
    auto xpu_place = t_place;
455
    paddle::memory::Copy(paddle::platform::CPUPlace(),
456 457 458
                         static_cast<void *>(data),
                         xpu_place,
                         t_data,
459
                         ele_num * sizeof(T));
460
#else
461 462 463
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
464 465 466 467 468
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
469
    auto npu_place = t_place;
W
Wilber 已提交
470 471 472
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
473 474 475 476 477
                         static_cast<void *>(data),
                         npu_place,
                         t_data,
                         ele_num * sizeof(T),
                         dev_ctx->stream());
478
    paddle::platform::NPUStreamSync(dev_ctx->stream());
W
Wilber 已提交
479 480 481 482
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
483 484
#endif
  } else {
485 486 487 488 489 490 491
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    auto custom_place = t_place;
    auto *dev_ctx = static_cast<const paddle::platform::CustomDeviceContext *>(
        pool.Get(custom_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
492 493 494 495 496
                         static_cast<void *>(data),
                         custom_place,
                         t_data,
                         ele_num * sizeof(T),
                         dev_ctx->stream());
497 498
// TODO(wangran16): sync_stream
#else
499
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
500
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
501
#endif
N
nhzlx 已提交
502 503
  }
}
504 505 506

template <typename T>
void Tensor::CopyToCpu(T *data) const {
507 508 509 510 511 512 513
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    ORTCopyToCpu<T>(data);
    return;
  }
#endif

514 515 516 517 518 519 520 521 522 523 524 525 526
  CopyToCpuImpl<T>(data, nullptr, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, void *exec_stream) const {
  CopyToCpuImpl<T>(data, exec_stream, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, CallbackFunc cb, void *cb_params) const {
  CopyToCpuImpl<T>(data, nullptr, cb, cb_params);
}

527 528 529 530 531
template PD_INFER_DECL void Tensor::CopyFromCpu<float>(const float *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int64_t>(const int64_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int32_t>(const int32_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<uint8_t>(const uint8_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int8_t>(const int8_t *data);
532
template PD_INFER_DECL void Tensor::CopyFromCpu<float16>(const float16 *data);
533
template PD_INFER_DECL void Tensor::CopyFromCpu<bool>(const bool *data);
534

535
template PD_INFER_DECL void Tensor::ShareExternalData<float>(
536 537 538
    const float *data,
    const std::vector<int> &shape,
    PlaceType place,
539 540
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int64_t>(
541 542 543
    const int64_t *data,
    const std::vector<int> &shape,
    PlaceType place,
544 545
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int32_t>(
546 547 548
    const int32_t *data,
    const std::vector<int> &shape,
    PlaceType place,
549 550
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<uint8_t>(
551 552 553
    const uint8_t *data,
    const std::vector<int> &shape,
    PlaceType place,
554 555
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int8_t>(
556 557 558
    const int8_t *data,
    const std::vector<int> &shape,
    PlaceType place,
559 560
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<float16>(
561 562 563
    const float16 *data,
    const std::vector<int> &shape,
    PlaceType place,
564
    DataLayout layout);
565 566 567 568 569
template PD_INFER_DECL void Tensor::ShareExternalData<bool>(
    const bool *data,
    const std::vector<int> &shape,
    PlaceType place,
    DataLayout layout);
570

571 572 573 574 575 576
template PD_INFER_DECL void Tensor::CopyToCpu<float>(float *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int64_t>(int64_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int32_t>(int32_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<uint8_t>(uint8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int8_t>(int8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<float16>(float16 *data) const;
577
template PD_INFER_DECL void Tensor::CopyToCpu<bool>(bool *data) const;
578 579 580 581 582 583 584 585 586 587 588 589 590 591 592

template PD_INFER_DECL void Tensor::CopyToCpuImpl<float>(float *data,
                                                         void *exec_stream,
                                                         CallbackFunc cb,
                                                         void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int64_t>(
    int64_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int32_t>(
    int32_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<uint8_t>(
    uint8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int8_t>(
    int8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<float16>(
    float16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
593 594 595 596
template PD_INFER_DECL void Tensor::CopyToCpuImpl<bool>(bool *data,
                                                        void *exec_stream,
                                                        CallbackFunc cb,
                                                        void *cb_params) const;
597 598 599 600 601 602 603 604 605 606 607 608 609

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, void *exec_stream) const;
610 611
template PD_INFER_DECL void Tensor::CopyToCpuAsync<bool>(
    bool *data, void *exec_stream) const;
612 613 614 615 616 617 618 619 620 621 622 623 624

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, CallbackFunc cb, void *cb_params) const;
625 626 627
template PD_INFER_DECL void Tensor::CopyToCpuAsync<bool>(bool *data,
                                                         CallbackFunc cb,
                                                         void *cb_params) const;
628

629 630 631 632 633 634 635 636 637 638
template PD_INFER_DECL float *Tensor::data<float>(PlaceType *place,
                                                  int *size) const;
template PD_INFER_DECL int64_t *Tensor::data<int64_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int32_t *Tensor::data<int32_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL uint8_t *Tensor::data<uint8_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int8_t *Tensor::data<int8_t>(PlaceType *place,
                                                    int *size) const;
639 640
template PD_INFER_DECL float16 *Tensor::data<float16>(PlaceType *place,
                                                      int *size) const;
641 642
template PD_INFER_DECL bool *Tensor::data<bool>(PlaceType *place,
                                                int *size) const;
643

644 645 646 647 648
template PD_INFER_DECL float *Tensor::mutable_data<float>(PlaceType place);
template PD_INFER_DECL int64_t *Tensor::mutable_data<int64_t>(PlaceType place);
template PD_INFER_DECL int32_t *Tensor::mutable_data<int32_t>(PlaceType place);
template PD_INFER_DECL uint8_t *Tensor::mutable_data<uint8_t>(PlaceType place);
template PD_INFER_DECL int8_t *Tensor::mutable_data<int8_t>(PlaceType place);
649
template PD_INFER_DECL float16 *Tensor::mutable_data<float16>(PlaceType place);
650
template PD_INFER_DECL bool *Tensor::mutable_data<bool>(PlaceType place);
651

652 653
Tensor::Tensor(void *scope, const void *device_contexts)
    : scope_{scope}, device_contexs_(device_contexts) {}
654

S
Steffy-zxf 已提交
655
template <typename T>
656
void *Tensor::FindTensor() const {
W
Wilber 已提交
657
  PADDLE_ENFORCE_EQ(
658 659
      name_.empty(),
      false,
660
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
661 662
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
663
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
664
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
665
  PADDLE_ENFORCE_NOT_NULL(
666 667 668
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
S
Steffy-zxf 已提交
669
  auto *tensor = var->GetMutable<T>();
670 671 672
  return tensor;
}

673
std::vector<int> Tensor::shape() const {
674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    std::vector<int> shape;
    // input handle
    if (idx_ < 0) {
      shape.assign(shape_.begin(), shape_.end());
    } else {  // output handle
      auto binding = binding_.lock();
      PADDLE_ENFORCE_NOT_NULL(binding,
                              paddle::platform::errors::PreconditionNotMet(
                                  "output tensor [%s] no binding ptr", name_));
      std::vector<Ort::Value> outputs = binding->GetOutputValues();
      Ort::Value &value = outputs[idx_];
      auto info = value.GetTensorTypeAndShapeInfo();
      auto ort_shape = info.GetShape();
      shape.assign(ort_shape.begin(), ort_shape.end());
    }
    return shape;
  }
#endif
694
  EAGER_GET_TENSOR(phi::DenseTensor);
W
Wilber 已提交
695
  PADDLE_ENFORCE_NOT_NULL(
696 697 698
      tensor_,
      paddle::platform::errors::PreconditionNotMet(
          "Not found tensor called %s in the scope", name_));
699
// oneDNN may does layout transform internally, so need to reorder before
W
wenbin 已提交
700 701
// return
#ifdef PADDLE_WITH_MKLDNN
702
  if (tensor->layout() == phi::DataLayout::ONEDNN) {
703 704
    phi::DataLayout out_layout =
        phi::OneDNNContext::tls().get_cur_paddle_data_layout();
W
wenbin 已提交
705
    // Set default as NCHW in case not specified
706 707
    out_layout = out_layout == phi::DataLayout::kAnyLayout
                     ? phi::DataLayout::kNCHW
W
wenbin 已提交
708 709 710 711 712 713
                     : out_layout;
    // In these data layouts, channel dimension is either on 2nd position: nChw
    // or
    // at last nhwC, so for dim==2 these layouts are the same and nothing should
    // be done. Similarly for dim==1 when you have just one possible
    // combination.
714
    if (tensor->dims().size() < 3) return phi::vectorize<int>(tensor->dims());
715 716
    if (out_layout == phi::DataLayout::kNHWC ||
        out_layout == phi::DataLayout::kNDHWC) {
717
      auto dims = phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
718 719 720
      std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
      return dims;
    } else {
721
      return phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
722 723 724
    }
  }
#endif
725
  return phi::vectorize<int>(tensor->dims());
726 727
}

728
void Tensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
729
  EAGER_GET_TENSOR(phi::DenseTensor);
730
  paddle::framework::LoD lod;
731 732 733 734 735 736
  for (auto &level : x) {
    lod.emplace_back(level);
  }
  tensor->set_lod(lod);
}

737
std::vector<std::vector<size_t>> Tensor::lod() const {
738
  EAGER_GET_TENSOR(phi::DenseTensor);
739 740 741 742 743 744 745
  std::vector<std::vector<size_t>> res;
  for (auto &level : tensor->lod()) {
    res.emplace_back(level);
  }
  return res;
}

746 747 748 749
void Tensor::SetName(const std::string &name) { name_ = name; }

const std::string &Tensor::name() const { return name_; }

750 751 752
void Tensor::SetPlace(PlaceType place,
                      int device,
                      const std::string device_type) {
753 754
  place_ = place;
  device_ = device;
755
  device_type_ = device_type;
756 757
}

758 759 760 761 762 763 764
#ifdef PADDLE_WITH_ONNXRUNTIME
void Tensor::SetOrtMark(bool is_ort_tensor) { is_ort_tensor_ = is_ort_tensor; }

void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
  binding_ = binding;
}

765 766 767 768 769 770 771 772 773 774 775
template <typename T>
T *Tensor::ORTGetMutableData() {
  auto binding = binding_.lock();
  PADDLE_ENFORCE_NOT_NULL(binding,
                          paddle::platform::errors::PreconditionNotMet(
                              "output tensor [%s] no binding ptr", name_));
  std::vector<Ort::Value> outputs = binding->GetOutputValues();
  Ort::Value &value = outputs[idx_];
  return value.GetTensorMutableData<T>();
}

776 777 778 779 780 781 782 783 784 785 786 787 788 789
template <typename T>
void Tensor::ORTCopyToCpu(T *data) const {
  auto binding = binding_.lock();
  PADDLE_ENFORCE_NOT_NULL(binding,
                          paddle::platform::errors::PreconditionNotMet(
                              "output tensor [%s] no binding ptr", name_));
  std::vector<Ort::Value> outputs = binding->GetOutputValues();
  Ort::Value &value = outputs[idx_];
  auto info = value.GetTensorTypeAndShapeInfo();
  size_t size = info.GetElementCount() * sizeof(T);

  if (place_ == PlaceType::kCPU) {
    std::memcpy(static_cast<void *>(data), value.GetTensorData<void *>(), size);
  } else {
H
heliqi 已提交
790 791 792
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "CopyToCpu error.The current ONNXRuntime backend doesn't support "
        "GPU."));
793 794 795 796 797 798 799 800 801 802
  }
}

template void Tensor::ORTCopyToCpu<float>(float *data) const;
template void Tensor::ORTCopyToCpu<int32_t>(int32_t *data) const;
template void Tensor::ORTCopyToCpu<uint8_t>(uint8_t *data) const;
template void Tensor::ORTCopyToCpu<int8_t>(int8_t *data) const;
template void Tensor::ORTCopyToCpu<float16>(float16 *data) const;
#endif

W
Wilber 已提交
803 804 805 806 807 808 809
namespace experimental {
template <typename T>
void InternalUtils::CopyFromCpuWithIoStream(paddle_infer::Tensor *t,
                                            const T *data,
                                            cudaStream_t stream) {
  if (t->tensor_ == nullptr) {
    PADDLE_ENFORCE_EQ(
810 811
        t->name_.empty(),
        false,
W
Wilber 已提交
812 813 814 815 816 817
        paddle::platform::errors::PreconditionNotMet(
            "Need to SetName first, so that the corresponding tensor can "
            "be retrieved."));
    auto *scope = static_cast<paddle::framework::Scope *>(t->scope_);
    auto *var = scope->FindVar(t->name_);
    PADDLE_ENFORCE_NOT_NULL(
818 819 820
        var,
        paddle::platform::errors::PreconditionNotMet(
            "No tensor called [%s] in the runtime scope", t->name_));
821
    auto *tensor = var->GetMutable<phi::DenseTensor>();
W
Wilber 已提交
822 823 824
    t->tensor_ = tensor;
  }

825
  auto *tensor = static_cast<phi::DenseTensor *>(t->tensor_);
826 827
  PADDLE_ENFORCE_GE(tensor->numel(),
                    0,
W
Wilber 已提交
828 829 830 831 832 833 834 835 836 837 838 839
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
                        "std::vector<int> &shape)"
                        "function before copying data from cpu."));
  size_t ele_size = tensor->numel() * sizeof(T);
  if (t->place_ == PlaceType::kCPU) {
    auto *t_data = tensor->mutable_data<T>(paddle::platform::CPUPlace());
    std::memcpy(static_cast<void *>(t_data), data, ele_size);
  } else if (t->place_ == PlaceType::kGPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    paddle::platform::CUDAPlace gpu_place(t->device_);
    auto *t_data = tensor->mutable_data<T>(gpu_place);
840 841 842 843 844 845
    paddle::memory::Copy(gpu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
                         stream);
W
Wilber 已提交
846 847 848 849 850 851 852 853 854 855 856 857
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "CopyFromCpuWithIoStream only supports CPU and GPU now."));
  }
}

template <typename T>
858 859
void InternalUtils::CopyToCpuWithIoStream(paddle_infer::Tensor *t,
                                          T *data,
W
Wilber 已提交
860 861 862
                                          cudaStream_t stream) {
  if (t->tensor_ == nullptr) {
    PADDLE_ENFORCE_EQ(
863 864
        t->name_.empty(),
        false,
W
Wilber 已提交
865 866 867 868 869 870
        paddle::platform::errors::PreconditionNotMet(
            "Need to SetName first, so that the corresponding tensor can "
            "be retrieved."));
    auto *scope = static_cast<paddle::framework::Scope *>(t->scope_);
    auto *var = scope->FindVar(t->name_);
    PADDLE_ENFORCE_NOT_NULL(
871 872 873
        var,
        paddle::platform::errors::PreconditionNotMet(
            "No tensor called [%s] in the runtime scope", t->name_));
874
    auto *tensor = var->GetMutable<phi::DenseTensor>();
W
Wilber 已提交
875 876 877
    t->tensor_ = tensor;
  }

878
  auto *tensor = static_cast<phi::DenseTensor *>(t->tensor_);
W
Wilber 已提交
879 880 881 882 883 884
  auto ele_num = tensor->numel();
  auto *t_data = tensor->data<T>();
  auto t_place = tensor->place();

  if (paddle::platform::is_cpu_place(t_place)) {
#ifdef PADDLE_WITH_MKLDNN
885 886 887 888 889 890 891 892
    if (tensor->layout() == phi::DataLayout::ONEDNN) {
      phi::DenseTensor out;
      auto mem_allocation =
          std::make_shared<paddle::memory::allocation::Allocation>(
              static_cast<void *>(data),
              ele_num * sizeof(T),
              paddle::platform::CPUPlace());
      out.ResetHolder(mem_allocation);
893
      phi::funcs::TransDataLayoutFromOneDNN(
894
          tensor->layout(),
895
          phi::OneDNNContext::tls().get_cur_paddle_data_layout(),
896 897 898 899
          *tensor,
          &out,
          paddle::platform::CPUPlace(),
          true);
900
    } else {
W
Wilber 已提交
901
      std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
902
    }
W
Wilber 已提交
903 904 905 906 907 908
#else
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#endif
  } else if (t->place_ == PlaceType::kGPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    paddle::memory::Copy(paddle::platform::CPUPlace(),
909 910 911 912 913
                         static_cast<void *>(data),
                         t_place,
                         t_data,
                         ele_num * sizeof(T),
                         stream);
W
Wilber 已提交
914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "CopyToCpuWithIoStream only supports CPU and GPU now."));
  }
}

template void InternalUtils::CopyFromCpuWithIoStream<float>(
    paddle_infer::Tensor *t, const float *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int64_t>(
    paddle_infer::Tensor *t, const int64_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int32_t>(
    paddle_infer::Tensor *t, const int32_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<uint8_t>(
    paddle_infer::Tensor *t, const uint8_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int8_t>(
    paddle_infer::Tensor *t, const int8_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<float16>(
    paddle_infer::Tensor *t, const float16 *data, cudaStream_t stream);
937 938
template void InternalUtils::CopyFromCpuWithIoStream<bool>(
    paddle_infer::Tensor *t, const bool *data, cudaStream_t stream);
W
Wilber 已提交
939 940 941 942 943 944 945 946 947 948 949 950 951

template void InternalUtils::CopyToCpuWithIoStream<float>(
    paddle_infer::Tensor *t, float *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int64_t>(
    paddle_infer::Tensor *t, int64_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int32_t>(
    paddle_infer::Tensor *t, int32_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<uint8_t>(
    paddle_infer::Tensor *t, uint8_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int8_t>(
    paddle_infer::Tensor *t, int8_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<float16>(
    paddle_infer::Tensor *t, float16 *data, cudaStream_t stream);
952 953
template void InternalUtils::CopyToCpuWithIoStream<bool>(
    paddle_infer::Tensor *t, bool *data, cudaStream_t stream);
W
Wilber 已提交
954 955 956

}  // namespace experimental

957
}  // namespace paddle_infer