zero_copy_tensor.cc 39.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/framework/convert_utils.h"
16
#include "paddle/fluid/framework/data_layout_transform.h"
17 18 19
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
W
Wilber 已提交
20
#include "paddle/fluid/inference/api/paddle_tensor.h"
N
nhzlx 已提交
21
#include "paddle/fluid/memory/memcpy.h"
22
#include "paddle/fluid/platform/enforce.h"
23
#include "paddle/fluid/platform/float16.h"
24
#include "paddle/phi/core/allocator.h"
25
#ifdef PADDLE_WITH_ONNXRUNTIME
H
heliqi 已提交
26 27
#include "onnxruntime_c_api.h"    // NOLINT
#include "onnxruntime_cxx_api.h"  // NOLINT
28
#endif
29

30
namespace paddle_infer {
31

32 33
using float16 = paddle::platform::float16;

34
void Tensor::Reshape(const std::vector<int> &shape) {
35 36 37 38 39 40 41
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    shape_.assign(shape.begin(), shape.end());
    return;
  }
#endif

W
Wilber 已提交
42
  PADDLE_ENFORCE_EQ(
43 44
      name_.empty(),
      false,
45
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
46 47
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
48 49
  PADDLE_ENFORCE_EQ(input_or_output_,
                    true,
50
                    paddle::platform::errors::PermissionDenied(
W
Wilber 已提交
51
                        "Can't reshape the output tensor, it is readonly"));
52
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
53
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
54
  PADDLE_ENFORCE_NOT_NULL(
55 56 57
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
58
  auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
59
  tensor->Resize(phi::make_ddim(shape));
60 61
}

S
Steffy-zxf 已提交
62 63
void Tensor::ReshapeStrings(const size_t &shape) {
  PADDLE_ENFORCE_EQ(
64 65
      name_.empty(),
      false,
S
Steffy-zxf 已提交
66 67 68
      paddle::platform::errors::PreconditionNotMet(
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
69 70
  PADDLE_ENFORCE_EQ(input_or_output_,
                    true,
S
Steffy-zxf 已提交
71 72 73 74 75
                    paddle::platform::errors::PermissionDenied(
                        "Can't reshape the output tensor, it is readonly"));
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
  auto *var = scope->FindVar(name_);
  PADDLE_ENFORCE_NOT_NULL(
76 77 78
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
S
Steffy-zxf 已提交
79 80 81 82 83 84 85 86 87
  paddle_infer::Strings *tensor = var->GetMutable<paddle_infer::Strings>();
  tensor->resize(shape);
}

#define EAGER_GET_TENSOR(tensor_type)    \
  if (!tensor_) {                        \
    tensor_ = FindTensor<tensor_type>(); \
  }                                      \
  auto *tensor = static_cast<tensor_type *>(tensor_);
88

89
template <typename T>
90
T *Tensor::mutable_data(PlaceType place) {
S
Steffy-zxf 已提交
91
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
92
  PADDLE_ENFORCE_GT(
93 94
      tensor->numel(),
      0,
95 96
      paddle::platform::errors::PreconditionNotMet(
          "You should call Tensor::Reshape(const std::vector<int> "
W
Wilber 已提交
97 98
          "&shape)"
          "function before retrieving mutable_data from input tensor."));
99
  switch (static_cast<int>(place)) {
100 101
    case static_cast<int>(PlaceType::kCPU): {
      return tensor->mutable_data<T>(paddle::platform::CPUPlace());
102
    }
103
    case static_cast<int>(PlaceType::kGPU): {
104 105 106
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
      paddle::platform::CUDAPlace gpu_place(device_);
      auto *dev_ctxs = reinterpret_cast<const std::map<
107 108 109
          phi::Place,
          std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
          device_contexs_);
110 111 112 113
      auto *dev_ctx =
          static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
      return dev_ctx->Alloc<T>(tensor, tensor->numel() * sizeof(T));
#else
114
      return tensor->mutable_data<T>(paddle::platform::CUDAPlace(device_));
115
#endif
116 117 118
    }
    case static_cast<int>(PlaceType::kXPU): {
      return tensor->mutable_data<T>(paddle::platform::XPUPlace(device_));
119
    }
120 121 122
    case static_cast<int>(PlaceType::kNPU): {
      return tensor->mutable_data<T>(paddle::platform::NPUPlace(device_));
    }
123
    default:
124
      PADDLE_THROW(paddle::platform::errors::Unavailable(
125 126
          "Only CPU / CUDA / XPU / NPU places is supported. The place `%d` is "
          "not supported.",
127
          static_cast<int>(place)));
128 129 130 131 132 133
      break;
  }
  return nullptr;
}

template <typename T>
134
T *Tensor::data(PlaceType *place, int *size) const {
S
Steffy-zxf 已提交
135
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
136 137
  auto *res = tensor->data<T>();

138 139 140 141 142 143
  if (paddle::platform::is_cpu_place(tensor->place())) {
    *place = PlaceType::kCPU;
  } else if (paddle::platform::is_gpu_place(tensor->place())) {
    *place = PlaceType::kGPU;
  } else if (paddle::platform::is_xpu_place(tensor->place())) {
    *place = PlaceType::kXPU;
144 145
  } else if (paddle::platform::is_npu_place(tensor->place())) {
    *place = PlaceType::kNPU;
146
  } else {
147
    *place = PlaceType::kUNK;
148 149 150 151 152 153
  }

  *size = tensor->numel();
  return res;
}

154
DataType Tensor::type() const {
155 156 157 158 159
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    return dtype_;
  }
#endif
S
Steffy-zxf 已提交
160
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
161
  auto type = paddle::framework::TransToProtoVarType(tensor->dtype());
162 163
  if (type == paddle::framework::proto::VarType::FP32) {
    return DataType::FLOAT32;
164 165
  } else if (type == paddle::framework::proto::VarType::FP16) {
    return DataType::FLOAT16;
166 167 168 169 170 171
  } else if (type == paddle::framework::proto::VarType::INT64) {
    return DataType::INT64;
  } else if (type == paddle::framework::proto::VarType::INT32) {
    return DataType::INT32;
  } else if (type == paddle::framework::proto::VarType::UINT8) {
    return DataType::UINT8;
172 173
  } else if (type == paddle::framework::proto::VarType::INT8) {
    return DataType::INT8;
174
  }
175
  return DataType::FLOAT32;
176 177
}

178 179
PlaceType Tensor::place() const { return place_; }

N
nhzlx 已提交
180
template <typename T>
181
void Tensor::CopyFromCpu(const T *data) {
182 183 184 185 186 187 188
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    ORTCopyFromCpu<T>(data);
    return;
  }
#endif

S
Steffy-zxf 已提交
189
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
190 191
  PADDLE_ENFORCE_GE(tensor->numel(),
                    0,
192 193
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
W
Wilber 已提交
194 195
                        "std::vector<int> &shape)"
                        "function before copying data from cpu."));
N
nhzlx 已提交
196 197
  size_t ele_size = tensor->numel() * sizeof(T);

198 199
  if (place_ == PlaceType::kCPU) {
    auto *t_data = tensor->mutable_data<T>(paddle::platform::CPUPlace());
N
nhzlx 已提交
200
    std::memcpy(static_cast<void *>(t_data), data, ele_size);
201
  } else if (place_ == PlaceType::kGPU) {
202
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
203

204
    paddle::platform::CUDAPlace gpu_place(device_);
205
    auto *dev_ctxs = reinterpret_cast<const std::map<
206 207
        phi::Place,
        std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
208 209 210 211
        device_contexs_);
    auto *dev_ctx =
        static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
    auto *t_data = dev_ctx->Alloc<T>(tensor, tensor->numel() * sizeof(T));
N
nhzlx 已提交
212

213 214 215 216 217
    paddle::memory::Copy(gpu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
218
                         dev_ctx->stream());
N
nhzlx 已提交
219
#else
220 221 222
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
223
#endif
224
  } else if (place_ == PlaceType::kXPU) {
225
#ifdef PADDLE_WITH_XPU
226
    paddle::platform::XPUPlace xpu_place(device_);
227
    auto *t_data = tensor->mutable_data<T>(xpu_place);
228 229 230 231 232
    paddle::memory::Copy(xpu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size);
233
#else
234 235 236
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
237 238 239 240 241 242 243 244 245
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::NPUPlace npu_place(device_);
    auto *t_data = tensor->mutable_data<T>(npu_place);
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
246 247 248 249 250
    paddle::memory::Copy(npu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
W
Wilber 已提交
251 252 253 254 255
                         dev_ctx->stream());
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
256 257
#endif
  } else {
258 259 260 261 262 263 264 265 266 267
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    auto device_type_id =
        static_cast<size_t>(place_) - static_cast<size_t>(PlaceType::kCUSTOM);
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::CustomPlace custom_place(
        phi::GetGlobalDeviceType(device_type_id), device_);
    auto *t_data = tensor->mutable_data<T>(custom_place);
    auto *dev_ctx = static_cast<const paddle::platform::CustomDeviceContext *>(
        pool.Get(custom_place));
268 269 270 271 272
    paddle::memory::Copy(custom_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
273 274
                         dev_ctx->stream());
#else
275
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
276
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
277
#endif
N
nhzlx 已提交
278 279 280
  }
}

281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
template <typename T>
struct DataTypeInfo;

template <>
struct DataTypeInfo<float> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT32;
};

template <>
struct DataTypeInfo<float16> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT16;
};

template <>
struct DataTypeInfo<int64_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT64;
};

template <>
struct DataTypeInfo<int8_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT8;
};

template <>
struct DataTypeInfo<uint8_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::UINT8;
};

template <>
struct DataTypeInfo<int32_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT32;
};

paddle::experimental::DataLayout LayoutConvert(DataLayout layout) {
  PADDLE_ENFORCE_EQ(
316 317
      layout,
      DataLayout::kNCHW,
318 319 320 321 322
      paddle::platform::errors::InvalidArgument("Only NCHW is supported now."));
  return paddle::experimental::DataLayout::NCHW;
}

template <typename T>
323 324 325 326
void Tensor::ShareExternalData(const T *data,
                               const std::vector<int> &shape,
                               PlaceType place,
                               DataLayout layout) {
327 328 329 330
  EAGER_GET_TENSOR(paddle::framework::LoDTensor)
  size_t size =
      std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
      sizeof(T);
331 332
  phi::DenseTensorMeta meta(
      DataTypeInfo<T>().TYPE, phi::make_ddim(shape), LayoutConvert(layout));
333 334
  if (place == PlaceType::kCPU) {
    phi::DenseTensor dtensor(
335 336
        std::make_shared<phi::Allocation>(
            const_cast<T *>(data), size, paddle::platform::CPUPlace()),
337 338 339 340
        meta);
    *tensor = std::move(dtensor);
  } else if (place == PlaceType::kGPU) {
    phi::DenseTensor dtensor(
341 342
        std::make_shared<phi::Allocation>(
            const_cast<T *>(data), size, paddle::platform::CUDAPlace(device_)),
343 344 345 346 347 348 349 350
        meta);
    *tensor = std::move(dtensor);
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "PlaceType must be PlaceType::kCPU or PlaceType::kGPU."));
  }
}

S
Steffy-zxf 已提交
351 352
void Tensor::CopyStringsFromCpu(const paddle_infer::Strings *data) {
  EAGER_GET_TENSOR(paddle_infer::Strings);
353 354
  PADDLE_ENFORCE_GE(tensor->size(),
                    0,
S
Steffy-zxf 已提交
355 356 357 358 359 360 361
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
                        "std::size_t &shape)function before copying"
                        "the string data from cpu."));
  *tensor = *data;
}

N
nhzlx 已提交
362
template <typename T>
363 364 365
void Tensor::CopyToCpuImpl(T *data,
                           void *exec_stream,
                           CallbackFunc cb,
366
                           void *cb_params) const {
S
Steffy-zxf 已提交
367
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
N
nhzlx 已提交
368 369 370 371
  auto ele_num = tensor->numel();
  auto *t_data = tensor->data<T>();
  auto t_place = tensor->place();

372
  paddle::framework::Tensor out;
373 374
  auto mem_allocation =
      std::make_shared<paddle::memory::allocation::Allocation>(
375 376
          static_cast<void *>(data),
          ele_num * sizeof(T),
377
          paddle::platform::CPUPlace());
378 379
  out.ResetHolder(mem_allocation);

380
  if (paddle::platform::is_cpu_place(t_place)) {
381 382 383
#ifdef PADDLE_WITH_MKLDNN
    if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN)
      paddle::framework::innerTransDataLayoutFromMKLDNN(
384 385 386
          tensor->layout(),
          paddle::platform::MKLDNNDeviceContext::tls()
              .get_cur_paddle_data_layout(),
387 388 389 390
          *tensor,
          &out,
          paddle::platform::CPUPlace(),
          true);
391 392 393
    else
      std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
N
nhzlx 已提交
394
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
J
jianghaicheng 已提交
395 396 397 398 399 400 401 402
#endif
  } else if (paddle::platform::is_ipu_place(t_place)) {
#ifdef PADDLE_WITH_IPU
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with IPU place because paddle is not compiled "
        "with IPU."));
403
#endif
404
  } else if (place_ == PlaceType::kGPU) {
405
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
406
    auto gpu_place = t_place;
407
    auto *dev_ctxs = reinterpret_cast<const std::map<
408 409
        phi::Place,
        std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
410 411 412
        device_contexs_);
    auto *dev_ctx =
        static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
413
    paddle::memory::Copy(paddle::platform::CPUPlace(),
414 415 416 417 418
                         static_cast<void *>(data),
                         gpu_place,
                         t_data,
                         ele_num * sizeof(T),
                         dev_ctx->stream());
419 420 421
#ifdef PADDLE_WITH_HIP
    hipStreamSynchronize(dev_ctx->stream());
#else
422 423 424 425 426 427 428 429 430 431
    // async, return stream
    if (nullptr != exec_stream) {
      *(static_cast<cudaStream_t *>(exec_stream)) = dev_ctx->stream();
      // async with callback
    } else if (cb) {
      cudaLaunchHostFunc(dev_ctx->stream(), cb, cb_params);
      // sync
    } else {
      cudaStreamSynchronize(dev_ctx->stream());
    }
432
#endif
N
nhzlx 已提交
433
#else
434 435 436
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
437
#endif
438
  } else if (place_ == PlaceType::kXPU) {
439
#ifdef PADDLE_WITH_XPU
440
    auto xpu_place = t_place;
441
    paddle::memory::Copy(paddle::platform::CPUPlace(),
442 443 444
                         static_cast<void *>(data),
                         xpu_place,
                         t_data,
445
                         ele_num * sizeof(T));
446
#else
447 448 449
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
450 451 452 453 454
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
455
    auto npu_place = t_place;
W
Wilber 已提交
456 457 458
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
459 460 461 462 463
                         static_cast<void *>(data),
                         npu_place,
                         t_data,
                         ele_num * sizeof(T),
                         dev_ctx->stream());
464
    paddle::platform::NPUStreamSync(dev_ctx->stream());
W
Wilber 已提交
465 466 467 468
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
469 470
#endif
  } else {
471 472 473 474 475 476 477
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    auto custom_place = t_place;
    auto *dev_ctx = static_cast<const paddle::platform::CustomDeviceContext *>(
        pool.Get(custom_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
478 479 480 481 482
                         static_cast<void *>(data),
                         custom_place,
                         t_data,
                         ele_num * sizeof(T),
                         dev_ctx->stream());
483 484
// TODO(wangran16): sync_stream
#else
485
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
486
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
487
#endif
N
nhzlx 已提交
488 489
  }
}
490 491 492

template <typename T>
void Tensor::CopyToCpu(T *data) const {
493 494 495 496 497 498 499
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    ORTCopyToCpu<T>(data);
    return;
  }
#endif

500 501 502 503 504 505 506 507 508 509 510 511 512
  CopyToCpuImpl<T>(data, nullptr, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, void *exec_stream) const {
  CopyToCpuImpl<T>(data, exec_stream, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, CallbackFunc cb, void *cb_params) const {
  CopyToCpuImpl<T>(data, nullptr, cb, cb_params);
}

513 514 515 516 517
template PD_INFER_DECL void Tensor::CopyFromCpu<float>(const float *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int64_t>(const int64_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int32_t>(const int32_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<uint8_t>(const uint8_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int8_t>(const int8_t *data);
518
template PD_INFER_DECL void Tensor::CopyFromCpu<float16>(const float16 *data);
519

520
template PD_INFER_DECL void Tensor::ShareExternalData<float>(
521 522 523
    const float *data,
    const std::vector<int> &shape,
    PlaceType place,
524 525
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int64_t>(
526 527 528
    const int64_t *data,
    const std::vector<int> &shape,
    PlaceType place,
529 530
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int32_t>(
531 532 533
    const int32_t *data,
    const std::vector<int> &shape,
    PlaceType place,
534 535
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<uint8_t>(
536 537 538
    const uint8_t *data,
    const std::vector<int> &shape,
    PlaceType place,
539 540
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int8_t>(
541 542 543
    const int8_t *data,
    const std::vector<int> &shape,
    PlaceType place,
544 545
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<float16>(
546 547 548
    const float16 *data,
    const std::vector<int> &shape,
    PlaceType place,
549 550
    DataLayout layout);

551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597
template PD_INFER_DECL void Tensor::CopyToCpu<float>(float *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int64_t>(int64_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int32_t>(int32_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<uint8_t>(uint8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int8_t>(int8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<float16>(float16 *data) const;

template PD_INFER_DECL void Tensor::CopyToCpuImpl<float>(float *data,
                                                         void *exec_stream,
                                                         CallbackFunc cb,
                                                         void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int64_t>(
    int64_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int32_t>(
    int32_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<uint8_t>(
    uint8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int8_t>(
    int8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<float16>(
    float16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, void *exec_stream) const;

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, CallbackFunc cb, void *cb_params) const;
598

599 600 601 602 603 604 605 606 607 608
template PD_INFER_DECL float *Tensor::data<float>(PlaceType *place,
                                                  int *size) const;
template PD_INFER_DECL int64_t *Tensor::data<int64_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int32_t *Tensor::data<int32_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL uint8_t *Tensor::data<uint8_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int8_t *Tensor::data<int8_t>(PlaceType *place,
                                                    int *size) const;
609 610
template PD_INFER_DECL float16 *Tensor::data<float16>(PlaceType *place,
                                                      int *size) const;
611

612 613 614 615 616
template PD_INFER_DECL float *Tensor::mutable_data<float>(PlaceType place);
template PD_INFER_DECL int64_t *Tensor::mutable_data<int64_t>(PlaceType place);
template PD_INFER_DECL int32_t *Tensor::mutable_data<int32_t>(PlaceType place);
template PD_INFER_DECL uint8_t *Tensor::mutable_data<uint8_t>(PlaceType place);
template PD_INFER_DECL int8_t *Tensor::mutable_data<int8_t>(PlaceType place);
617
template PD_INFER_DECL float16 *Tensor::mutable_data<float16>(PlaceType place);
618

619 620
Tensor::Tensor(void *scope, const void *device_contexts)
    : scope_{scope}, device_contexs_(device_contexts) {}
621

S
Steffy-zxf 已提交
622
template <typename T>
623
void *Tensor::FindTensor() const {
W
Wilber 已提交
624
  PADDLE_ENFORCE_EQ(
625 626
      name_.empty(),
      false,
627
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
628 629
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
630
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
631
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
632
  PADDLE_ENFORCE_NOT_NULL(
633 634 635
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
S
Steffy-zxf 已提交
636
  auto *tensor = var->GetMutable<T>();
637 638 639
  return tensor;
}

640
std::vector<int> Tensor::shape() const {
641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    std::vector<int> shape;
    // input handle
    if (idx_ < 0) {
      shape.assign(shape_.begin(), shape_.end());
    } else {  // output handle
      auto binding = binding_.lock();
      PADDLE_ENFORCE_NOT_NULL(binding,
                              paddle::platform::errors::PreconditionNotMet(
                                  "output tensor [%s] no binding ptr", name_));
      std::vector<Ort::Value> outputs = binding->GetOutputValues();
      Ort::Value &value = outputs[idx_];
      auto info = value.GetTensorTypeAndShapeInfo();
      auto ort_shape = info.GetShape();
      shape.assign(ort_shape.begin(), ort_shape.end());
    }
    return shape;
  }
#endif
S
Steffy-zxf 已提交
661
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
W
Wilber 已提交
662
  PADDLE_ENFORCE_NOT_NULL(
663 664 665
      tensor_,
      paddle::platform::errors::PreconditionNotMet(
          "Not found tensor called %s in the scope", name_));
W
wenbin 已提交
666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681
// mkldnn may does layout transform internally, so need to reorder before
// return
#ifdef PADDLE_WITH_MKLDNN
  if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN) {
    paddle::framework::DataLayout out_layout =
        paddle::platform::MKLDNNDeviceContext::tls()
            .get_cur_paddle_data_layout();
    // Set default as NCHW in case not specified
    out_layout = out_layout == paddle::framework::DataLayout::kAnyLayout
                     ? paddle::framework::DataLayout::kNCHW
                     : out_layout;
    // In these data layouts, channel dimension is either on 2nd position: nChw
    // or
    // at last nhwC, so for dim==2 these layouts are the same and nothing should
    // be done. Similarly for dim==1 when you have just one possible
    // combination.
682
    if (tensor->dims().size() < 3) return phi::vectorize<int>(tensor->dims());
683 684
    if (out_layout == paddle::framework::DataLayout::kNHWC ||
        out_layout == paddle::framework::DataLayout::kNDHWC) {
685
      auto dims = phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
686 687 688
      std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
      return dims;
    } else {
689
      return phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
690 691 692
    }
  }
#endif
693
  return phi::vectorize<int>(tensor->dims());
694 695
}

696
void Tensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
S
Steffy-zxf 已提交
697
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
698
  paddle::framework::LoD lod;
699 700 701 702 703 704
  for (auto &level : x) {
    lod.emplace_back(level);
  }
  tensor->set_lod(lod);
}

705
std::vector<std::vector<size_t>> Tensor::lod() const {
S
Steffy-zxf 已提交
706
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
707 708 709 710 711 712 713
  std::vector<std::vector<size_t>> res;
  for (auto &level : tensor->lod()) {
    res.emplace_back(level);
  }
  return res;
}

714 715 716 717 718 719 720 721 722
void Tensor::SetName(const std::string &name) { name_ = name; }

const std::string &Tensor::name() const { return name_; }

void Tensor::SetPlace(PlaceType place, int device) {
  place_ = place;
  device_ = device;
}

723 724 725 726 727 728 729
#ifdef PADDLE_WITH_ONNXRUNTIME
void Tensor::SetOrtMark(bool is_ort_tensor) { is_ort_tensor_ = is_ort_tensor; }

void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
  binding_ = binding;
}

730 731
void Tensor::SetOrtBuffer(const std::shared_ptr<std::vector<int8_t>> buffer) {
  buffer_ = buffer;
732 733
}

734 735 736 737 738 739 740
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info,
                       float *data,
                       size_t size,
                       const int64_t *shape,
                       size_t shape_len) {
  return Ort::Value::CreateTensor<float>(
      memory_info, data, size, shape, shape_len);
741 742
}

743 744 745 746 747 748 749
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info,
                       int64_t *data,
                       size_t size,
                       const int64_t *shape,
                       size_t shape_len) {
  return Ort::Value::CreateTensor<int64_t>(
      memory_info, data, size, shape, shape_len);
750 751
}

752 753 754 755 756 757 758
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info,
                       int32_t *data,
                       size_t size,
                       const int64_t *shape,
                       size_t shape_len) {
  return Ort::Value::CreateTensor<int32_t>(
      memory_info, data, size, shape, shape_len);
759 760
}

761 762 763 764 765 766 767
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info,
                       uint8_t *data,
                       size_t size,
                       const int64_t *shape,
                       size_t shape_len) {
  return Ort::Value::CreateTensor<uint8_t>(
      memory_info, data, size, shape, shape_len);
768 769
}

770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788
Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info,
                       int8_t *data,
                       size_t size,
                       const int64_t *shape,
                       size_t shape_len) {
  return Ort::Value::CreateTensor<int8_t>(
      memory_info, data, size, shape, shape_len);
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info,
                       float16 *data,
                       size_t size,
                       const int64_t *shape,
                       size_t shape_len) {
  return Ort::Value::CreateTensor(memory_info,
                                  static_cast<void *>(data),
                                  size * sizeof(float16),
                                  shape,
                                  shape_len,
789 790 791 792 793 794 795 796 797 798
                                  ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
}

template <typename T>
void Tensor::ORTCopyFromCpu(const T *data) {
  auto binding = binding_.lock();
  PADDLE_ENFORCE_NOT_NULL(binding,
                          paddle::platform::errors::PreconditionNotMet(
                              "input tensor [%s] no binding ptr", name_));
  const char *device_name = place_ == PlaceType::kCPU ? "Cpu" : "Cuda";
799 800 801 802 803
  Ort::MemoryInfo memory_info(
      device_name, OrtDeviceAllocator, device_, OrtMemTypeDefault);
  size_t size = std::accumulate(
      begin(shape_), end(shape_), 1UL, std::multiplies<size_t>());
  auto buffer = buffer_.lock();
H
heliqi 已提交
804
  size_t buffer_size = size * sizeof(T);
805 806
  if (buffer_size > buffer->size()) {
    buffer->resize(buffer_size);
H
heliqi 已提交
807
  }
808
  std::memcpy(static_cast<void *>(buffer->data()), data, buffer_size);
H
heliqi 已提交
809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824

  auto onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
  if (std::is_same<T, float>::value) {
    onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
  } else if (std::is_same<T, double>::value) {
    onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
  } else if (std::is_same<T, int64_t>::value) {
    onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
  } else if (std::is_same<T, int32_t>::value) {
    onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
  } else if (std::is_same<T, uint8_t>::value) {
    onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
  } else if (std::is_same<T, int8_t>::value) {
    onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
  } else if (std::is_same<T, float16>::value) {
    onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
825
  } else {
H
heliqi 已提交
826 827 828 829 830
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "Found undefined data type for onnxruntime, only supports "
        "float16/float32/float64/int8/uint8/int32/int64."));
  }

831 832 833 834 835 836
  auto ort_value = Ort::Value::CreateTensor(memory_info,
                                            buffer->data(),
                                            buffer_size,
                                            shape_.data(),
                                            shape_.size(),
                                            onnx_dtype);
837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853
  binding->BindInput(name_.c_str(), ort_value);
}

template <typename T>
void Tensor::ORTCopyToCpu(T *data) const {
  auto binding = binding_.lock();
  PADDLE_ENFORCE_NOT_NULL(binding,
                          paddle::platform::errors::PreconditionNotMet(
                              "output tensor [%s] no binding ptr", name_));
  std::vector<Ort::Value> outputs = binding->GetOutputValues();
  Ort::Value &value = outputs[idx_];
  auto info = value.GetTensorTypeAndShapeInfo();
  size_t size = info.GetElementCount() * sizeof(T);

  if (place_ == PlaceType::kCPU) {
    std::memcpy(static_cast<void *>(data), value.GetTensorData<void *>(), size);
  } else {
H
heliqi 已提交
854 855 856
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "CopyToCpu error.The current ONNXRuntime backend doesn't support "
        "GPU."));
857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873
  }
}

template void Tensor::ORTCopyFromCpu<float>(const float *data);
template void Tensor::ORTCopyFromCpu<int64_t>(const int64_t *data);
template void Tensor::ORTCopyFromCpu<int32_t>(const int32_t *data);
template void Tensor::ORTCopyFromCpu<uint8_t>(const uint8_t *data);
template void Tensor::ORTCopyFromCpu<int8_t>(const int8_t *data);
template void Tensor::ORTCopyFromCpu<float16>(const float16 *data);

template void Tensor::ORTCopyToCpu<float>(float *data) const;
template void Tensor::ORTCopyToCpu<int32_t>(int32_t *data) const;
template void Tensor::ORTCopyToCpu<uint8_t>(uint8_t *data) const;
template void Tensor::ORTCopyToCpu<int8_t>(int8_t *data) const;
template void Tensor::ORTCopyToCpu<float16>(float16 *data) const;
#endif

W
Wilber 已提交
874 875 876 877 878 879 880
namespace experimental {
template <typename T>
void InternalUtils::CopyFromCpuWithIoStream(paddle_infer::Tensor *t,
                                            const T *data,
                                            cudaStream_t stream) {
  if (t->tensor_ == nullptr) {
    PADDLE_ENFORCE_EQ(
881 882
        t->name_.empty(),
        false,
W
Wilber 已提交
883 884 885 886 887 888
        paddle::platform::errors::PreconditionNotMet(
            "Need to SetName first, so that the corresponding tensor can "
            "be retrieved."));
    auto *scope = static_cast<paddle::framework::Scope *>(t->scope_);
    auto *var = scope->FindVar(t->name_);
    PADDLE_ENFORCE_NOT_NULL(
889 890 891
        var,
        paddle::platform::errors::PreconditionNotMet(
            "No tensor called [%s] in the runtime scope", t->name_));
W
Wilber 已提交
892 893 894 895 896
    auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
    t->tensor_ = tensor;
  }

  auto *tensor = static_cast<paddle::framework::LoDTensor *>(t->tensor_);
897 898
  PADDLE_ENFORCE_GE(tensor->numel(),
                    0,
W
Wilber 已提交
899 900 901 902 903 904 905 906 907 908 909 910
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
                        "std::vector<int> &shape)"
                        "function before copying data from cpu."));
  size_t ele_size = tensor->numel() * sizeof(T);
  if (t->place_ == PlaceType::kCPU) {
    auto *t_data = tensor->mutable_data<T>(paddle::platform::CPUPlace());
    std::memcpy(static_cast<void *>(t_data), data, ele_size);
  } else if (t->place_ == PlaceType::kGPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    paddle::platform::CUDAPlace gpu_place(t->device_);
    auto *t_data = tensor->mutable_data<T>(gpu_place);
911 912 913 914 915 916
    paddle::memory::Copy(gpu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
                         stream);
W
Wilber 已提交
917 918 919 920 921 922 923 924 925 926 927 928
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "CopyFromCpuWithIoStream only supports CPU and GPU now."));
  }
}

template <typename T>
929 930
void InternalUtils::CopyToCpuWithIoStream(paddle_infer::Tensor *t,
                                          T *data,
W
Wilber 已提交
931 932 933
                                          cudaStream_t stream) {
  if (t->tensor_ == nullptr) {
    PADDLE_ENFORCE_EQ(
934 935
        t->name_.empty(),
        false,
W
Wilber 已提交
936 937 938 939 940 941
        paddle::platform::errors::PreconditionNotMet(
            "Need to SetName first, so that the corresponding tensor can "
            "be retrieved."));
    auto *scope = static_cast<paddle::framework::Scope *>(t->scope_);
    auto *var = scope->FindVar(t->name_);
    PADDLE_ENFORCE_NOT_NULL(
942 943 944
        var,
        paddle::platform::errors::PreconditionNotMet(
            "No tensor called [%s] in the runtime scope", t->name_));
W
Wilber 已提交
945 946 947 948 949 950 951 952 953 954 955 956
    auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
    t->tensor_ = tensor;
  }

  auto *tensor = static_cast<paddle::framework::LoDTensor *>(t->tensor_);
  auto ele_num = tensor->numel();
  auto *t_data = tensor->data<T>();
  auto t_place = tensor->place();

  paddle::framework::Tensor out;
  auto mem_allocation =
      std::make_shared<paddle::memory::allocation::Allocation>(
957 958
          static_cast<void *>(data),
          ele_num * sizeof(T),
W
Wilber 已提交
959 960 961 962 963 964 965
          paddle::platform::CPUPlace());
  out.ResetHolder(mem_allocation);

  if (paddle::platform::is_cpu_place(t_place)) {
#ifdef PADDLE_WITH_MKLDNN
    if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN)
      paddle::framework::innerTransDataLayoutFromMKLDNN(
966 967 968
          tensor->layout(),
          paddle::platform::MKLDNNDeviceContext::tls()
              .get_cur_paddle_data_layout(),
969 970 971 972
          *tensor,
          &out,
          paddle::platform::CPUPlace(),
          true);
W
Wilber 已提交
973 974 975 976 977 978 979 980
    else
      std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#endif
  } else if (t->place_ == PlaceType::kGPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    paddle::memory::Copy(paddle::platform::CPUPlace(),
981 982 983 984 985
                         static_cast<void *>(data),
                         t_place,
                         t_data,
                         ele_num * sizeof(T),
                         stream);
W
Wilber 已提交
986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "CopyToCpuWithIoStream only supports CPU and GPU now."));
  }
}

template void InternalUtils::CopyFromCpuWithIoStream<float>(
    paddle_infer::Tensor *t, const float *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int64_t>(
    paddle_infer::Tensor *t, const int64_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int32_t>(
    paddle_infer::Tensor *t, const int32_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<uint8_t>(
    paddle_infer::Tensor *t, const uint8_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int8_t>(
    paddle_infer::Tensor *t, const int8_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<float16>(
    paddle_infer::Tensor *t, const float16 *data, cudaStream_t stream);

template void InternalUtils::CopyToCpuWithIoStream<float>(
    paddle_infer::Tensor *t, float *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int64_t>(
    paddle_infer::Tensor *t, int64_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int32_t>(
    paddle_infer::Tensor *t, int32_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<uint8_t>(
    paddle_infer::Tensor *t, uint8_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int8_t>(
    paddle_infer::Tensor *t, int8_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<float16>(
    paddle_infer::Tensor *t, float16 *data, cudaStream_t stream);

}  // namespace experimental

1025
}  // namespace paddle_infer