zero_copy_tensor.cc 34.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/framework/convert_utils.h"
16
#include "paddle/fluid/framework/data_layout_transform.h"
17 18 19
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
W
Wilber 已提交
20
#include "paddle/fluid/inference/api/paddle_tensor.h"
N
nhzlx 已提交
21
#include "paddle/fluid/memory/memcpy.h"
22
#include "paddle/fluid/platform/enforce.h"
23
#include "paddle/fluid/platform/float16.h"
24
#include "paddle/phi/core/allocator.h"
25
#ifdef PADDLE_WITH_ONNXRUNTIME
H
heliqi 已提交
26 27
#include "onnxruntime_c_api.h"    // NOLINT
#include "onnxruntime_cxx_api.h"  // NOLINT
28
#endif
29

30
namespace paddle_infer {
31

32 33
using float16 = paddle::platform::float16;

34
void Tensor::Reshape(const std::vector<int> &shape) {
35 36 37 38 39 40 41
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    shape_.assign(shape.begin(), shape.end());
    return;
  }
#endif

W
Wilber 已提交
42
  PADDLE_ENFORCE_EQ(
43 44
      name_.empty(),
      false,
45
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
46 47
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
48 49
  PADDLE_ENFORCE_EQ(input_or_output_,
                    true,
50
                    paddle::platform::errors::PermissionDenied(
W
Wilber 已提交
51
                        "Can't reshape the output tensor, it is readonly"));
52
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
53
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
54
  PADDLE_ENFORCE_NOT_NULL(
55 56 57
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
58
  auto *tensor = var->GetMutable<phi::DenseTensor>();
59
  tensor->Resize(phi::make_ddim(shape));
60 61
}

S
Steffy-zxf 已提交
62 63
void Tensor::ReshapeStrings(const size_t &shape) {
  PADDLE_ENFORCE_EQ(
64 65
      name_.empty(),
      false,
S
Steffy-zxf 已提交
66 67 68
      paddle::platform::errors::PreconditionNotMet(
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
69 70
  PADDLE_ENFORCE_EQ(input_or_output_,
                    true,
S
Steffy-zxf 已提交
71 72 73 74 75
                    paddle::platform::errors::PermissionDenied(
                        "Can't reshape the output tensor, it is readonly"));
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
  auto *var = scope->FindVar(name_);
  PADDLE_ENFORCE_NOT_NULL(
76 77 78
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
S
Steffy-zxf 已提交
79 80 81 82 83 84 85 86 87
  paddle_infer::Strings *tensor = var->GetMutable<paddle_infer::Strings>();
  tensor->resize(shape);
}

#define EAGER_GET_TENSOR(tensor_type)    \
  if (!tensor_) {                        \
    tensor_ = FindTensor<tensor_type>(); \
  }                                      \
  auto *tensor = static_cast<tensor_type *>(tensor_);
88

89
template <typename T>
90
T *Tensor::mutable_data(PlaceType place) {
91 92 93 94 95
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    return ORTGetMutableData<T>();
  }
#endif
96
  EAGER_GET_TENSOR(phi::DenseTensor);
97
  PADDLE_ENFORCE_GT(
98 99
      tensor->numel(),
      0,
100 101
      paddle::platform::errors::PreconditionNotMet(
          "You should call Tensor::Reshape(const std::vector<int> "
W
Wilber 已提交
102 103
          "&shape)"
          "function before retrieving mutable_data from input tensor."));
104
  switch (static_cast<int>(place)) {
105 106
    case static_cast<int>(PlaceType::kCPU): {
      return tensor->mutable_data<T>(paddle::platform::CPUPlace());
107
    }
108
    case static_cast<int>(PlaceType::kGPU): {
109 110 111
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
      paddle::platform::CUDAPlace gpu_place(device_);
      auto *dev_ctxs = reinterpret_cast<const std::map<
112 113 114
          phi::Place,
          std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
          device_contexs_);
115 116 117 118
      auto *dev_ctx =
          static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
      return dev_ctx->Alloc<T>(tensor, tensor->numel() * sizeof(T));
#else
119
      return tensor->mutable_data<T>(paddle::platform::CUDAPlace(device_));
120
#endif
121 122 123
    }
    case static_cast<int>(PlaceType::kXPU): {
      return tensor->mutable_data<T>(paddle::platform::XPUPlace(device_));
124
    }
125 126 127
    case static_cast<int>(PlaceType::kNPU): {
      return tensor->mutable_data<T>(paddle::platform::NPUPlace(device_));
    }
128
    default:
129
      PADDLE_THROW(paddle::platform::errors::Unavailable(
130 131
          "Only CPU / CUDA / XPU / NPU places is supported. The place `%d` is "
          "not supported.",
132
          static_cast<int>(place)));
133 134 135 136 137 138
      break;
  }
  return nullptr;
}

template <typename T>
139
T *Tensor::data(PlaceType *place, int *size) const {
140
  EAGER_GET_TENSOR(phi::DenseTensor);
141 142
  auto *res = tensor->data<T>();

143 144 145 146 147 148
  if (paddle::platform::is_cpu_place(tensor->place())) {
    *place = PlaceType::kCPU;
  } else if (paddle::platform::is_gpu_place(tensor->place())) {
    *place = PlaceType::kGPU;
  } else if (paddle::platform::is_xpu_place(tensor->place())) {
    *place = PlaceType::kXPU;
149 150
  } else if (paddle::platform::is_npu_place(tensor->place())) {
    *place = PlaceType::kNPU;
151
  } else {
152
    *place = PlaceType::kUNK;
153 154 155 156 157 158
  }

  *size = tensor->numel();
  return res;
}

159
DataType Tensor::type() const {
160 161 162 163 164
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    return dtype_;
  }
#endif
165
  EAGER_GET_TENSOR(phi::DenseTensor);
166
  auto type = paddle::framework::TransToProtoVarType(tensor->dtype());
167 168
  if (type == paddle::framework::proto::VarType::FP32) {
    return DataType::FLOAT32;
169 170
  } else if (type == paddle::framework::proto::VarType::FP16) {
    return DataType::FLOAT16;
171 172 173 174 175 176
  } else if (type == paddle::framework::proto::VarType::INT64) {
    return DataType::INT64;
  } else if (type == paddle::framework::proto::VarType::INT32) {
    return DataType::INT32;
  } else if (type == paddle::framework::proto::VarType::UINT8) {
    return DataType::UINT8;
177 178
  } else if (type == paddle::framework::proto::VarType::INT8) {
    return DataType::INT8;
179
  }
180
  return DataType::FLOAT32;
181 182
}

183 184
PlaceType Tensor::place() const { return place_; }

N
nhzlx 已提交
185
template <typename T>
186
void Tensor::CopyFromCpu(const T *data) {
187
  EAGER_GET_TENSOR(phi::DenseTensor);
188 189
  PADDLE_ENFORCE_GE(tensor->numel(),
                    0,
190 191
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
W
Wilber 已提交
192 193
                        "std::vector<int> &shape)"
                        "function before copying data from cpu."));
N
nhzlx 已提交
194 195
  size_t ele_size = tensor->numel() * sizeof(T);

196 197
  if (place_ == PlaceType::kCPU) {
    auto *t_data = tensor->mutable_data<T>(paddle::platform::CPUPlace());
N
nhzlx 已提交
198
    std::memcpy(static_cast<void *>(t_data), data, ele_size);
199
  } else if (place_ == PlaceType::kGPU) {
200
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
201

202
    paddle::platform::CUDAPlace gpu_place(device_);
203
    auto *dev_ctxs = reinterpret_cast<const std::map<
204 205
        phi::Place,
        std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
206 207 208 209
        device_contexs_);
    auto *dev_ctx =
        static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
    auto *t_data = dev_ctx->Alloc<T>(tensor, tensor->numel() * sizeof(T));
N
nhzlx 已提交
210

211 212 213 214 215
    paddle::memory::Copy(gpu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
216
                         dev_ctx->stream());
N
nhzlx 已提交
217
#else
218 219 220
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
221
#endif
222
  } else if (place_ == PlaceType::kXPU) {
223
#ifdef PADDLE_WITH_XPU
224
    paddle::platform::XPUPlace xpu_place(device_);
225
    auto *t_data = tensor->mutable_data<T>(xpu_place);
226 227 228 229 230
    paddle::memory::Copy(xpu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size);
231
#else
232 233 234
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
235 236 237 238 239 240 241 242 243
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::NPUPlace npu_place(device_);
    auto *t_data = tensor->mutable_data<T>(npu_place);
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
244 245 246 247 248
    paddle::memory::Copy(npu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
W
Wilber 已提交
249 250 251 252 253
                         dev_ctx->stream());
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
254 255
#endif
  } else {
256 257 258 259 260 261 262 263 264 265
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    auto device_type_id =
        static_cast<size_t>(place_) - static_cast<size_t>(PlaceType::kCUSTOM);
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::CustomPlace custom_place(
        phi::GetGlobalDeviceType(device_type_id), device_);
    auto *t_data = tensor->mutable_data<T>(custom_place);
    auto *dev_ctx = static_cast<const paddle::platform::CustomDeviceContext *>(
        pool.Get(custom_place));
266 267 268 269 270
    paddle::memory::Copy(custom_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
271 272
                         dev_ctx->stream());
#else
273
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
274
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
275
#endif
N
nhzlx 已提交
276 277 278
  }
}

279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
template <typename T>
struct DataTypeInfo;

template <>
struct DataTypeInfo<float> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT32;
};

template <>
struct DataTypeInfo<float16> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT16;
};

template <>
struct DataTypeInfo<int64_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT64;
};

template <>
struct DataTypeInfo<int8_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT8;
};

template <>
struct DataTypeInfo<uint8_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::UINT8;
};

template <>
struct DataTypeInfo<int32_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT32;
};

312
phi::DataLayout LayoutConvert(DataLayout layout) {
313
  PADDLE_ENFORCE_EQ(
314 315
      layout,
      DataLayout::kNCHW,
316
      paddle::platform::errors::InvalidArgument("Only NCHW is supported now."));
317
  return phi::DataLayout::NCHW;
318 319 320
}

template <typename T>
321 322 323 324
void Tensor::ShareExternalData(const T *data,
                               const std::vector<int> &shape,
                               PlaceType place,
                               DataLayout layout) {
325
  EAGER_GET_TENSOR(phi::DenseTensor)
326 327 328
  size_t size =
      std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
      sizeof(T);
329 330
  phi::DenseTensorMeta meta(
      DataTypeInfo<T>().TYPE, phi::make_ddim(shape), LayoutConvert(layout));
331 332
  if (place == PlaceType::kCPU) {
    phi::DenseTensor dtensor(
333 334
        std::make_shared<phi::Allocation>(
            const_cast<T *>(data), size, paddle::platform::CPUPlace()),
335 336 337 338
        meta);
    *tensor = std::move(dtensor);
  } else if (place == PlaceType::kGPU) {
    phi::DenseTensor dtensor(
339 340
        std::make_shared<phi::Allocation>(
            const_cast<T *>(data), size, paddle::platform::CUDAPlace(device_)),
341 342 343 344 345 346 347 348
        meta);
    *tensor = std::move(dtensor);
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "PlaceType must be PlaceType::kCPU or PlaceType::kGPU."));
  }
}

S
Steffy-zxf 已提交
349 350
void Tensor::CopyStringsFromCpu(const paddle_infer::Strings *data) {
  EAGER_GET_TENSOR(paddle_infer::Strings);
351 352
  PADDLE_ENFORCE_GE(tensor->size(),
                    0,
S
Steffy-zxf 已提交
353 354 355 356 357 358 359
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
                        "std::size_t &shape)function before copying"
                        "the string data from cpu."));
  *tensor = *data;
}

N
nhzlx 已提交
360
template <typename T>
361 362 363
void Tensor::CopyToCpuImpl(T *data,
                           void *exec_stream,
                           CallbackFunc cb,
364
                           void *cb_params) const {
365
  EAGER_GET_TENSOR(phi::DenseTensor);
N
nhzlx 已提交
366 367 368 369
  auto ele_num = tensor->numel();
  auto *t_data = tensor->data<T>();
  auto t_place = tensor->place();

370
  phi::DenseTensor out;
371 372
  auto mem_allocation =
      std::make_shared<paddle::memory::allocation::Allocation>(
373 374
          static_cast<void *>(data),
          ele_num * sizeof(T),
375
          paddle::platform::CPUPlace());
376 377
  out.ResetHolder(mem_allocation);

378
  if (paddle::platform::is_cpu_place(t_place)) {
379
#ifdef PADDLE_WITH_MKLDNN
380
    if (tensor->layout() == phi::DataLayout::ONEDNN)
381
      phi::funcs::TransDataLayoutFromOneDNN(
382
          tensor->layout(),
383
          phi::OneDNNContext::tls().get_cur_paddle_data_layout(),
384 385 386 387
          *tensor,
          &out,
          paddle::platform::CPUPlace(),
          true);
388 389 390
    else
      std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
N
nhzlx 已提交
391
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
J
jianghaicheng 已提交
392 393 394 395 396 397 398 399
#endif
  } else if (paddle::platform::is_ipu_place(t_place)) {
#ifdef PADDLE_WITH_IPU
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with IPU place because paddle is not compiled "
        "with IPU."));
400
#endif
401
  } else if (place_ == PlaceType::kGPU) {
402
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
403
    auto gpu_place = t_place;
404
    auto *dev_ctxs = reinterpret_cast<const std::map<
405 406
        phi::Place,
        std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
407 408 409
        device_contexs_);
    auto *dev_ctx =
        static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
410
    paddle::memory::Copy(paddle::platform::CPUPlace(),
411 412 413 414 415
                         static_cast<void *>(data),
                         gpu_place,
                         t_data,
                         ele_num * sizeof(T),
                         dev_ctx->stream());
416 417 418
#ifdef PADDLE_WITH_HIP
    hipStreamSynchronize(dev_ctx->stream());
#else
419 420 421 422 423 424 425 426 427 428
    // async, return stream
    if (nullptr != exec_stream) {
      *(static_cast<cudaStream_t *>(exec_stream)) = dev_ctx->stream();
      // async with callback
    } else if (cb) {
      cudaLaunchHostFunc(dev_ctx->stream(), cb, cb_params);
      // sync
    } else {
      cudaStreamSynchronize(dev_ctx->stream());
    }
429
#endif
N
nhzlx 已提交
430
#else
431 432 433
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
434
#endif
435
  } else if (place_ == PlaceType::kXPU) {
436
#ifdef PADDLE_WITH_XPU
437
    auto xpu_place = t_place;
438
    paddle::memory::Copy(paddle::platform::CPUPlace(),
439 440 441
                         static_cast<void *>(data),
                         xpu_place,
                         t_data,
442
                         ele_num * sizeof(T));
443
#else
444 445 446
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
447 448 449 450 451
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
452
    auto npu_place = t_place;
W
Wilber 已提交
453 454 455
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
456 457 458 459 460
                         static_cast<void *>(data),
                         npu_place,
                         t_data,
                         ele_num * sizeof(T),
                         dev_ctx->stream());
461
    paddle::platform::NPUStreamSync(dev_ctx->stream());
W
Wilber 已提交
462 463 464 465
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
466 467
#endif
  } else {
468 469 470 471 472 473 474
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    auto custom_place = t_place;
    auto *dev_ctx = static_cast<const paddle::platform::CustomDeviceContext *>(
        pool.Get(custom_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
475 476 477 478 479
                         static_cast<void *>(data),
                         custom_place,
                         t_data,
                         ele_num * sizeof(T),
                         dev_ctx->stream());
480 481
// TODO(wangran16): sync_stream
#else
482
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
483
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
484
#endif
N
nhzlx 已提交
485 486
  }
}
487 488 489

template <typename T>
void Tensor::CopyToCpu(T *data) const {
490 491 492 493 494 495 496
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    ORTCopyToCpu<T>(data);
    return;
  }
#endif

497 498 499 500 501 502 503 504 505 506 507 508 509
  CopyToCpuImpl<T>(data, nullptr, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, void *exec_stream) const {
  CopyToCpuImpl<T>(data, exec_stream, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, CallbackFunc cb, void *cb_params) const {
  CopyToCpuImpl<T>(data, nullptr, cb, cb_params);
}

510 511 512 513 514
template PD_INFER_DECL void Tensor::CopyFromCpu<float>(const float *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int64_t>(const int64_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int32_t>(const int32_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<uint8_t>(const uint8_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int8_t>(const int8_t *data);
515
template PD_INFER_DECL void Tensor::CopyFromCpu<float16>(const float16 *data);
516

517
template PD_INFER_DECL void Tensor::ShareExternalData<float>(
518 519 520
    const float *data,
    const std::vector<int> &shape,
    PlaceType place,
521 522
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int64_t>(
523 524 525
    const int64_t *data,
    const std::vector<int> &shape,
    PlaceType place,
526 527
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int32_t>(
528 529 530
    const int32_t *data,
    const std::vector<int> &shape,
    PlaceType place,
531 532
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<uint8_t>(
533 534 535
    const uint8_t *data,
    const std::vector<int> &shape,
    PlaceType place,
536 537
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int8_t>(
538 539 540
    const int8_t *data,
    const std::vector<int> &shape,
    PlaceType place,
541 542
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<float16>(
543 544 545
    const float16 *data,
    const std::vector<int> &shape,
    PlaceType place,
546 547
    DataLayout layout);

548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594
template PD_INFER_DECL void Tensor::CopyToCpu<float>(float *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int64_t>(int64_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int32_t>(int32_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<uint8_t>(uint8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int8_t>(int8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<float16>(float16 *data) const;

template PD_INFER_DECL void Tensor::CopyToCpuImpl<float>(float *data,
                                                         void *exec_stream,
                                                         CallbackFunc cb,
                                                         void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int64_t>(
    int64_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int32_t>(
    int32_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<uint8_t>(
    uint8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int8_t>(
    int8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<float16>(
    float16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, void *exec_stream) const;

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, CallbackFunc cb, void *cb_params) const;
595

596 597 598 599 600 601 602 603 604 605
template PD_INFER_DECL float *Tensor::data<float>(PlaceType *place,
                                                  int *size) const;
template PD_INFER_DECL int64_t *Tensor::data<int64_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int32_t *Tensor::data<int32_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL uint8_t *Tensor::data<uint8_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int8_t *Tensor::data<int8_t>(PlaceType *place,
                                                    int *size) const;
606 607
template PD_INFER_DECL float16 *Tensor::data<float16>(PlaceType *place,
                                                      int *size) const;
608

609 610 611 612 613
template PD_INFER_DECL float *Tensor::mutable_data<float>(PlaceType place);
template PD_INFER_DECL int64_t *Tensor::mutable_data<int64_t>(PlaceType place);
template PD_INFER_DECL int32_t *Tensor::mutable_data<int32_t>(PlaceType place);
template PD_INFER_DECL uint8_t *Tensor::mutable_data<uint8_t>(PlaceType place);
template PD_INFER_DECL int8_t *Tensor::mutable_data<int8_t>(PlaceType place);
614
template PD_INFER_DECL float16 *Tensor::mutable_data<float16>(PlaceType place);
615

616 617
Tensor::Tensor(void *scope, const void *device_contexts)
    : scope_{scope}, device_contexs_(device_contexts) {}
618

S
Steffy-zxf 已提交
619
template <typename T>
620
void *Tensor::FindTensor() const {
W
Wilber 已提交
621
  PADDLE_ENFORCE_EQ(
622 623
      name_.empty(),
      false,
624
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
625 626
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
627
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
628
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
629
  PADDLE_ENFORCE_NOT_NULL(
630 631 632
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
S
Steffy-zxf 已提交
633
  auto *tensor = var->GetMutable<T>();
634 635 636
  return tensor;
}

637
std::vector<int> Tensor::shape() const {
638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    std::vector<int> shape;
    // input handle
    if (idx_ < 0) {
      shape.assign(shape_.begin(), shape_.end());
    } else {  // output handle
      auto binding = binding_.lock();
      PADDLE_ENFORCE_NOT_NULL(binding,
                              paddle::platform::errors::PreconditionNotMet(
                                  "output tensor [%s] no binding ptr", name_));
      std::vector<Ort::Value> outputs = binding->GetOutputValues();
      Ort::Value &value = outputs[idx_];
      auto info = value.GetTensorTypeAndShapeInfo();
      auto ort_shape = info.GetShape();
      shape.assign(ort_shape.begin(), ort_shape.end());
    }
    return shape;
  }
#endif
658
  EAGER_GET_TENSOR(phi::DenseTensor);
W
Wilber 已提交
659
  PADDLE_ENFORCE_NOT_NULL(
660 661 662
      tensor_,
      paddle::platform::errors::PreconditionNotMet(
          "Not found tensor called %s in the scope", name_));
663
// oneDNN may does layout transform internally, so need to reorder before
W
wenbin 已提交
664 665
// return
#ifdef PADDLE_WITH_MKLDNN
666
  if (tensor->layout() == phi::DataLayout::ONEDNN) {
667 668
    phi::DataLayout out_layout =
        phi::OneDNNContext::tls().get_cur_paddle_data_layout();
W
wenbin 已提交
669
    // Set default as NCHW in case not specified
670 671
    out_layout = out_layout == phi::DataLayout::kAnyLayout
                     ? phi::DataLayout::kNCHW
W
wenbin 已提交
672 673 674 675 676 677
                     : out_layout;
    // In these data layouts, channel dimension is either on 2nd position: nChw
    // or
    // at last nhwC, so for dim==2 these layouts are the same and nothing should
    // be done. Similarly for dim==1 when you have just one possible
    // combination.
678
    if (tensor->dims().size() < 3) return phi::vectorize<int>(tensor->dims());
679 680
    if (out_layout == phi::DataLayout::kNHWC ||
        out_layout == phi::DataLayout::kNDHWC) {
681
      auto dims = phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
682 683 684
      std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
      return dims;
    } else {
685
      return phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
686 687 688
    }
  }
#endif
689
  return phi::vectorize<int>(tensor->dims());
690 691
}

692
void Tensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
693
  EAGER_GET_TENSOR(phi::DenseTensor);
694
  paddle::framework::LoD lod;
695 696 697 698 699 700
  for (auto &level : x) {
    lod.emplace_back(level);
  }
  tensor->set_lod(lod);
}

701
std::vector<std::vector<size_t>> Tensor::lod() const {
702
  EAGER_GET_TENSOR(phi::DenseTensor);
703 704 705 706 707 708 709
  std::vector<std::vector<size_t>> res;
  for (auto &level : tensor->lod()) {
    res.emplace_back(level);
  }
  return res;
}

710 711 712 713 714 715 716 717 718
void Tensor::SetName(const std::string &name) { name_ = name; }

const std::string &Tensor::name() const { return name_; }

void Tensor::SetPlace(PlaceType place, int device) {
  place_ = place;
  device_ = device;
}

719 720 721 722 723 724 725
#ifdef PADDLE_WITH_ONNXRUNTIME
void Tensor::SetOrtMark(bool is_ort_tensor) { is_ort_tensor_ = is_ort_tensor; }

void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
  binding_ = binding;
}

726 727 728 729 730 731 732 733 734 735 736
template <typename T>
T *Tensor::ORTGetMutableData() {
  auto binding = binding_.lock();
  PADDLE_ENFORCE_NOT_NULL(binding,
                          paddle::platform::errors::PreconditionNotMet(
                              "output tensor [%s] no binding ptr", name_));
  std::vector<Ort::Value> outputs = binding->GetOutputValues();
  Ort::Value &value = outputs[idx_];
  return value.GetTensorMutableData<T>();
}

737 738 739 740 741 742 743 744 745 746 747 748 749 750
template <typename T>
void Tensor::ORTCopyToCpu(T *data) const {
  auto binding = binding_.lock();
  PADDLE_ENFORCE_NOT_NULL(binding,
                          paddle::platform::errors::PreconditionNotMet(
                              "output tensor [%s] no binding ptr", name_));
  std::vector<Ort::Value> outputs = binding->GetOutputValues();
  Ort::Value &value = outputs[idx_];
  auto info = value.GetTensorTypeAndShapeInfo();
  size_t size = info.GetElementCount() * sizeof(T);

  if (place_ == PlaceType::kCPU) {
    std::memcpy(static_cast<void *>(data), value.GetTensorData<void *>(), size);
  } else {
H
heliqi 已提交
751 752 753
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "CopyToCpu error.The current ONNXRuntime backend doesn't support "
        "GPU."));
754 755 756 757 758 759 760 761 762 763
  }
}

template void Tensor::ORTCopyToCpu<float>(float *data) const;
template void Tensor::ORTCopyToCpu<int32_t>(int32_t *data) const;
template void Tensor::ORTCopyToCpu<uint8_t>(uint8_t *data) const;
template void Tensor::ORTCopyToCpu<int8_t>(int8_t *data) const;
template void Tensor::ORTCopyToCpu<float16>(float16 *data) const;
#endif

W
Wilber 已提交
764 765 766 767 768 769 770
namespace experimental {
template <typename T>
void InternalUtils::CopyFromCpuWithIoStream(paddle_infer::Tensor *t,
                                            const T *data,
                                            cudaStream_t stream) {
  if (t->tensor_ == nullptr) {
    PADDLE_ENFORCE_EQ(
771 772
        t->name_.empty(),
        false,
W
Wilber 已提交
773 774 775 776 777 778
        paddle::platform::errors::PreconditionNotMet(
            "Need to SetName first, so that the corresponding tensor can "
            "be retrieved."));
    auto *scope = static_cast<paddle::framework::Scope *>(t->scope_);
    auto *var = scope->FindVar(t->name_);
    PADDLE_ENFORCE_NOT_NULL(
779 780 781
        var,
        paddle::platform::errors::PreconditionNotMet(
            "No tensor called [%s] in the runtime scope", t->name_));
782
    auto *tensor = var->GetMutable<phi::DenseTensor>();
W
Wilber 已提交
783 784 785
    t->tensor_ = tensor;
  }

786
  auto *tensor = static_cast<phi::DenseTensor *>(t->tensor_);
787 788
  PADDLE_ENFORCE_GE(tensor->numel(),
                    0,
W
Wilber 已提交
789 790 791 792 793 794 795 796 797 798 799 800
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
                        "std::vector<int> &shape)"
                        "function before copying data from cpu."));
  size_t ele_size = tensor->numel() * sizeof(T);
  if (t->place_ == PlaceType::kCPU) {
    auto *t_data = tensor->mutable_data<T>(paddle::platform::CPUPlace());
    std::memcpy(static_cast<void *>(t_data), data, ele_size);
  } else if (t->place_ == PlaceType::kGPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    paddle::platform::CUDAPlace gpu_place(t->device_);
    auto *t_data = tensor->mutable_data<T>(gpu_place);
801 802 803 804 805 806
    paddle::memory::Copy(gpu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
                         stream);
W
Wilber 已提交
807 808 809 810 811 812 813 814 815 816 817 818
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "CopyFromCpuWithIoStream only supports CPU and GPU now."));
  }
}

template <typename T>
819 820
void InternalUtils::CopyToCpuWithIoStream(paddle_infer::Tensor *t,
                                          T *data,
W
Wilber 已提交
821 822 823
                                          cudaStream_t stream) {
  if (t->tensor_ == nullptr) {
    PADDLE_ENFORCE_EQ(
824 825
        t->name_.empty(),
        false,
W
Wilber 已提交
826 827 828 829 830 831
        paddle::platform::errors::PreconditionNotMet(
            "Need to SetName first, so that the corresponding tensor can "
            "be retrieved."));
    auto *scope = static_cast<paddle::framework::Scope *>(t->scope_);
    auto *var = scope->FindVar(t->name_);
    PADDLE_ENFORCE_NOT_NULL(
832 833 834
        var,
        paddle::platform::errors::PreconditionNotMet(
            "No tensor called [%s] in the runtime scope", t->name_));
835
    auto *tensor = var->GetMutable<phi::DenseTensor>();
W
Wilber 已提交
836 837 838
    t->tensor_ = tensor;
  }

839
  auto *tensor = static_cast<phi::DenseTensor *>(t->tensor_);
W
Wilber 已提交
840 841 842 843
  auto ele_num = tensor->numel();
  auto *t_data = tensor->data<T>();
  auto t_place = tensor->place();

844
  phi::DenseTensor out;
W
Wilber 已提交
845 846
  auto mem_allocation =
      std::make_shared<paddle::memory::allocation::Allocation>(
847 848
          static_cast<void *>(data),
          ele_num * sizeof(T),
W
Wilber 已提交
849 850 851 852 853
          paddle::platform::CPUPlace());
  out.ResetHolder(mem_allocation);

  if (paddle::platform::is_cpu_place(t_place)) {
#ifdef PADDLE_WITH_MKLDNN
854
    if (tensor->layout() == phi::DataLayout::ONEDNN)
855
      phi::funcs::TransDataLayoutFromOneDNN(
856
          tensor->layout(),
857
          phi::OneDNNContext::tls().get_cur_paddle_data_layout(),
858 859 860 861
          *tensor,
          &out,
          paddle::platform::CPUPlace(),
          true);
W
Wilber 已提交
862 863 864 865 866 867 868 869
    else
      std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#endif
  } else if (t->place_ == PlaceType::kGPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    paddle::memory::Copy(paddle::platform::CPUPlace(),
870 871 872 873 874
                         static_cast<void *>(data),
                         t_place,
                         t_data,
                         ele_num * sizeof(T),
                         stream);
W
Wilber 已提交
875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "CopyToCpuWithIoStream only supports CPU and GPU now."));
  }
}

template void InternalUtils::CopyFromCpuWithIoStream<float>(
    paddle_infer::Tensor *t, const float *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int64_t>(
    paddle_infer::Tensor *t, const int64_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int32_t>(
    paddle_infer::Tensor *t, const int32_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<uint8_t>(
    paddle_infer::Tensor *t, const uint8_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int8_t>(
    paddle_infer::Tensor *t, const int8_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<float16>(
    paddle_infer::Tensor *t, const float16 *data, cudaStream_t stream);

template void InternalUtils::CopyToCpuWithIoStream<float>(
    paddle_infer::Tensor *t, float *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int64_t>(
    paddle_infer::Tensor *t, int64_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int32_t>(
    paddle_infer::Tensor *t, int32_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<uint8_t>(
    paddle_infer::Tensor *t, uint8_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int8_t>(
    paddle_infer::Tensor *t, int8_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<float16>(
    paddle_infer::Tensor *t, float16 *data, cudaStream_t stream);

}  // namespace experimental

914
}  // namespace paddle_infer