zero_copy_tensor.cc 34.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/framework/convert_utils.h"
16
#include "paddle/fluid/framework/data_layout_transform.h"
17 18 19
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
W
Wilber 已提交
20
#include "paddle/fluid/inference/api/paddle_tensor.h"
N
nhzlx 已提交
21
#include "paddle/fluid/memory/memcpy.h"
22
#include "paddle/fluid/platform/enforce.h"
23
#include "paddle/fluid/platform/float16.h"
24
#include "paddle/phi/core/allocator.h"
25
#ifdef PADDLE_WITH_ONNXRUNTIME
H
heliqi 已提交
26 27
#include "onnxruntime_c_api.h"    // NOLINT
#include "onnxruntime_cxx_api.h"  // NOLINT
28
#endif
29

30
namespace paddle_infer {
31

32 33
using float16 = paddle::platform::float16;

34
void Tensor::Reshape(const std::vector<int> &shape) {
35 36 37 38 39 40 41
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    shape_.assign(shape.begin(), shape.end());
    return;
  }
#endif

W
Wilber 已提交
42
  PADDLE_ENFORCE_EQ(
43 44
      name_.empty(),
      false,
45
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
46 47
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
48 49
  PADDLE_ENFORCE_EQ(input_or_output_,
                    true,
50
                    paddle::platform::errors::PermissionDenied(
W
Wilber 已提交
51
                        "Can't reshape the output tensor, it is readonly"));
52
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
53
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
54
  PADDLE_ENFORCE_NOT_NULL(
55 56 57
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
58
  auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
59
  tensor->Resize(phi::make_ddim(shape));
60 61
}

S
Steffy-zxf 已提交
62 63
void Tensor::ReshapeStrings(const size_t &shape) {
  PADDLE_ENFORCE_EQ(
64 65
      name_.empty(),
      false,
S
Steffy-zxf 已提交
66 67 68
      paddle::platform::errors::PreconditionNotMet(
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
69 70
  PADDLE_ENFORCE_EQ(input_or_output_,
                    true,
S
Steffy-zxf 已提交
71 72 73 74 75
                    paddle::platform::errors::PermissionDenied(
                        "Can't reshape the output tensor, it is readonly"));
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
  auto *var = scope->FindVar(name_);
  PADDLE_ENFORCE_NOT_NULL(
76 77 78
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
S
Steffy-zxf 已提交
79 80 81 82 83 84 85 86 87
  paddle_infer::Strings *tensor = var->GetMutable<paddle_infer::Strings>();
  tensor->resize(shape);
}

#define EAGER_GET_TENSOR(tensor_type)    \
  if (!tensor_) {                        \
    tensor_ = FindTensor<tensor_type>(); \
  }                                      \
  auto *tensor = static_cast<tensor_type *>(tensor_);
88

89
template <typename T>
90
T *Tensor::mutable_data(PlaceType place) {
S
Steffy-zxf 已提交
91
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
92
  PADDLE_ENFORCE_GT(
93 94
      tensor->numel(),
      0,
95 96
      paddle::platform::errors::PreconditionNotMet(
          "You should call Tensor::Reshape(const std::vector<int> "
W
Wilber 已提交
97 98
          "&shape)"
          "function before retrieving mutable_data from input tensor."));
99
  switch (static_cast<int>(place)) {
100 101
    case static_cast<int>(PlaceType::kCPU): {
      return tensor->mutable_data<T>(paddle::platform::CPUPlace());
102
    }
103
    case static_cast<int>(PlaceType::kGPU): {
104 105 106
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
      paddle::platform::CUDAPlace gpu_place(device_);
      auto *dev_ctxs = reinterpret_cast<const std::map<
107 108 109
          phi::Place,
          std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
          device_contexs_);
110 111 112 113
      auto *dev_ctx =
          static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
      return dev_ctx->Alloc<T>(tensor, tensor->numel() * sizeof(T));
#else
114
      return tensor->mutable_data<T>(paddle::platform::CUDAPlace(device_));
115
#endif
116 117 118
    }
    case static_cast<int>(PlaceType::kXPU): {
      return tensor->mutable_data<T>(paddle::platform::XPUPlace(device_));
119
    }
120 121 122
    case static_cast<int>(PlaceType::kNPU): {
      return tensor->mutable_data<T>(paddle::platform::NPUPlace(device_));
    }
123
    default:
124
      PADDLE_THROW(paddle::platform::errors::Unavailable(
125 126
          "Only CPU / CUDA / XPU / NPU places is supported. The place `%d` is "
          "not supported.",
127
          static_cast<int>(place)));
128 129 130 131 132 133
      break;
  }
  return nullptr;
}

template <typename T>
134
T *Tensor::data(PlaceType *place, int *size) const {
S
Steffy-zxf 已提交
135
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
136 137
  auto *res = tensor->data<T>();

138 139 140 141 142 143
  if (paddle::platform::is_cpu_place(tensor->place())) {
    *place = PlaceType::kCPU;
  } else if (paddle::platform::is_gpu_place(tensor->place())) {
    *place = PlaceType::kGPU;
  } else if (paddle::platform::is_xpu_place(tensor->place())) {
    *place = PlaceType::kXPU;
144 145
  } else if (paddle::platform::is_npu_place(tensor->place())) {
    *place = PlaceType::kNPU;
146
  } else {
147
    *place = PlaceType::kUNK;
148 149 150 151 152 153
  }

  *size = tensor->numel();
  return res;
}

154
DataType Tensor::type() const {
155 156 157 158 159
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    return dtype_;
  }
#endif
S
Steffy-zxf 已提交
160
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
161
  auto type = paddle::framework::TransToProtoVarType(tensor->dtype());
162 163
  if (type == paddle::framework::proto::VarType::FP32) {
    return DataType::FLOAT32;
164 165
  } else if (type == paddle::framework::proto::VarType::FP16) {
    return DataType::FLOAT16;
166 167 168 169 170 171
  } else if (type == paddle::framework::proto::VarType::INT64) {
    return DataType::INT64;
  } else if (type == paddle::framework::proto::VarType::INT32) {
    return DataType::INT32;
  } else if (type == paddle::framework::proto::VarType::UINT8) {
    return DataType::UINT8;
172 173
  } else if (type == paddle::framework::proto::VarType::INT8) {
    return DataType::INT8;
174
  }
175
  return DataType::FLOAT32;
176 177
}

178 179
PlaceType Tensor::place() const { return place_; }

N
nhzlx 已提交
180
template <typename T>
181
void Tensor::CopyFromCpu(const T *data) {
S
Steffy-zxf 已提交
182
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
183 184
  PADDLE_ENFORCE_GE(tensor->numel(),
                    0,
185 186
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
W
Wilber 已提交
187 188
                        "std::vector<int> &shape)"
                        "function before copying data from cpu."));
N
nhzlx 已提交
189 190
  size_t ele_size = tensor->numel() * sizeof(T);

191 192
  if (place_ == PlaceType::kCPU) {
    auto *t_data = tensor->mutable_data<T>(paddle::platform::CPUPlace());
N
nhzlx 已提交
193
    std::memcpy(static_cast<void *>(t_data), data, ele_size);
194
  } else if (place_ == PlaceType::kGPU) {
195
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
196

197
    paddle::platform::CUDAPlace gpu_place(device_);
198
    auto *dev_ctxs = reinterpret_cast<const std::map<
199 200
        phi::Place,
        std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
201 202 203 204
        device_contexs_);
    auto *dev_ctx =
        static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
    auto *t_data = dev_ctx->Alloc<T>(tensor, tensor->numel() * sizeof(T));
N
nhzlx 已提交
205

206 207 208 209 210
    paddle::memory::Copy(gpu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
211
                         dev_ctx->stream());
N
nhzlx 已提交
212
#else
213 214 215
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
216
#endif
217
  } else if (place_ == PlaceType::kXPU) {
218
#ifdef PADDLE_WITH_XPU
219
    paddle::platform::XPUPlace xpu_place(device_);
220
    auto *t_data = tensor->mutable_data<T>(xpu_place);
221 222 223 224 225
    paddle::memory::Copy(xpu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size);
226
#else
227 228 229
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
230 231 232 233 234 235 236 237 238
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::NPUPlace npu_place(device_);
    auto *t_data = tensor->mutable_data<T>(npu_place);
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
239 240 241 242 243
    paddle::memory::Copy(npu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
W
Wilber 已提交
244 245 246 247 248
                         dev_ctx->stream());
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
249 250
#endif
  } else {
251 252 253 254 255 256 257 258 259 260
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    auto device_type_id =
        static_cast<size_t>(place_) - static_cast<size_t>(PlaceType::kCUSTOM);
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::CustomPlace custom_place(
        phi::GetGlobalDeviceType(device_type_id), device_);
    auto *t_data = tensor->mutable_data<T>(custom_place);
    auto *dev_ctx = static_cast<const paddle::platform::CustomDeviceContext *>(
        pool.Get(custom_place));
261 262 263 264 265
    paddle::memory::Copy(custom_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
266 267
                         dev_ctx->stream());
#else
268
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
269
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
270
#endif
N
nhzlx 已提交
271 272 273
  }
}

274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
template <typename T>
struct DataTypeInfo;

template <>
struct DataTypeInfo<float> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT32;
};

template <>
struct DataTypeInfo<float16> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT16;
};

template <>
struct DataTypeInfo<int64_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT64;
};

template <>
struct DataTypeInfo<int8_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT8;
};

template <>
struct DataTypeInfo<uint8_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::UINT8;
};

template <>
struct DataTypeInfo<int32_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT32;
};

paddle::experimental::DataLayout LayoutConvert(DataLayout layout) {
  PADDLE_ENFORCE_EQ(
309 310
      layout,
      DataLayout::kNCHW,
311 312 313 314 315
      paddle::platform::errors::InvalidArgument("Only NCHW is supported now."));
  return paddle::experimental::DataLayout::NCHW;
}

template <typename T>
316 317 318 319
void Tensor::ShareExternalData(const T *data,
                               const std::vector<int> &shape,
                               PlaceType place,
                               DataLayout layout) {
320 321 322 323
  EAGER_GET_TENSOR(paddle::framework::LoDTensor)
  size_t size =
      std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
      sizeof(T);
324 325
  phi::DenseTensorMeta meta(
      DataTypeInfo<T>().TYPE, phi::make_ddim(shape), LayoutConvert(layout));
326 327
  if (place == PlaceType::kCPU) {
    phi::DenseTensor dtensor(
328 329
        std::make_shared<phi::Allocation>(
            const_cast<T *>(data), size, paddle::platform::CPUPlace()),
330 331 332 333
        meta);
    *tensor = std::move(dtensor);
  } else if (place == PlaceType::kGPU) {
    phi::DenseTensor dtensor(
334 335
        std::make_shared<phi::Allocation>(
            const_cast<T *>(data), size, paddle::platform::CUDAPlace(device_)),
336 337 338 339 340 341 342 343
        meta);
    *tensor = std::move(dtensor);
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "PlaceType must be PlaceType::kCPU or PlaceType::kGPU."));
  }
}

S
Steffy-zxf 已提交
344 345
void Tensor::CopyStringsFromCpu(const paddle_infer::Strings *data) {
  EAGER_GET_TENSOR(paddle_infer::Strings);
346 347
  PADDLE_ENFORCE_GE(tensor->size(),
                    0,
S
Steffy-zxf 已提交
348 349 350 351 352 353 354
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
                        "std::size_t &shape)function before copying"
                        "the string data from cpu."));
  *tensor = *data;
}

N
nhzlx 已提交
355
template <typename T>
356 357 358
void Tensor::CopyToCpuImpl(T *data,
                           void *exec_stream,
                           CallbackFunc cb,
359
                           void *cb_params) const {
S
Steffy-zxf 已提交
360
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
N
nhzlx 已提交
361 362 363 364
  auto ele_num = tensor->numel();
  auto *t_data = tensor->data<T>();
  auto t_place = tensor->place();

365
  paddle::framework::Tensor out;
366 367
  auto mem_allocation =
      std::make_shared<paddle::memory::allocation::Allocation>(
368 369
          static_cast<void *>(data),
          ele_num * sizeof(T),
370
          paddle::platform::CPUPlace());
371 372
  out.ResetHolder(mem_allocation);

373
  if (paddle::platform::is_cpu_place(t_place)) {
374 375 376
#ifdef PADDLE_WITH_MKLDNN
    if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN)
      paddle::framework::innerTransDataLayoutFromMKLDNN(
377 378 379
          tensor->layout(),
          paddle::platform::MKLDNNDeviceContext::tls()
              .get_cur_paddle_data_layout(),
380 381 382 383
          *tensor,
          &out,
          paddle::platform::CPUPlace(),
          true);
384 385 386
    else
      std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
N
nhzlx 已提交
387
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
J
jianghaicheng 已提交
388 389 390 391 392 393 394 395
#endif
  } else if (paddle::platform::is_ipu_place(t_place)) {
#ifdef PADDLE_WITH_IPU
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with IPU place because paddle is not compiled "
        "with IPU."));
396
#endif
397
  } else if (place_ == PlaceType::kGPU) {
398
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
399
    auto gpu_place = t_place;
400
    auto *dev_ctxs = reinterpret_cast<const std::map<
401 402
        phi::Place,
        std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
403 404 405
        device_contexs_);
    auto *dev_ctx =
        static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
406
    paddle::memory::Copy(paddle::platform::CPUPlace(),
407 408 409 410 411
                         static_cast<void *>(data),
                         gpu_place,
                         t_data,
                         ele_num * sizeof(T),
                         dev_ctx->stream());
412 413 414
#ifdef PADDLE_WITH_HIP
    hipStreamSynchronize(dev_ctx->stream());
#else
415 416 417 418 419 420 421 422 423 424
    // async, return stream
    if (nullptr != exec_stream) {
      *(static_cast<cudaStream_t *>(exec_stream)) = dev_ctx->stream();
      // async with callback
    } else if (cb) {
      cudaLaunchHostFunc(dev_ctx->stream(), cb, cb_params);
      // sync
    } else {
      cudaStreamSynchronize(dev_ctx->stream());
    }
425
#endif
N
nhzlx 已提交
426
#else
427 428 429
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
430
#endif
431
  } else if (place_ == PlaceType::kXPU) {
432
#ifdef PADDLE_WITH_XPU
433
    auto xpu_place = t_place;
434
    paddle::memory::Copy(paddle::platform::CPUPlace(),
435 436 437
                         static_cast<void *>(data),
                         xpu_place,
                         t_data,
438
                         ele_num * sizeof(T));
439
#else
440 441 442
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
443 444 445 446 447
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
448
    auto npu_place = t_place;
W
Wilber 已提交
449 450 451
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
452 453 454 455 456
                         static_cast<void *>(data),
                         npu_place,
                         t_data,
                         ele_num * sizeof(T),
                         dev_ctx->stream());
457
    paddle::platform::NPUStreamSync(dev_ctx->stream());
W
Wilber 已提交
458 459 460 461
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
462 463
#endif
  } else {
464 465 466 467 468 469 470
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    auto custom_place = t_place;
    auto *dev_ctx = static_cast<const paddle::platform::CustomDeviceContext *>(
        pool.Get(custom_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
471 472 473 474 475
                         static_cast<void *>(data),
                         custom_place,
                         t_data,
                         ele_num * sizeof(T),
                         dev_ctx->stream());
476 477
// TODO(wangran16): sync_stream
#else
478
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
479
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
480
#endif
N
nhzlx 已提交
481 482
  }
}
483 484 485

template <typename T>
void Tensor::CopyToCpu(T *data) const {
486 487 488 489 490 491 492
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    ORTCopyToCpu<T>(data);
    return;
  }
#endif

493 494 495 496 497 498 499 500 501 502 503 504 505
  CopyToCpuImpl<T>(data, nullptr, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, void *exec_stream) const {
  CopyToCpuImpl<T>(data, exec_stream, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, CallbackFunc cb, void *cb_params) const {
  CopyToCpuImpl<T>(data, nullptr, cb, cb_params);
}

506 507 508 509 510
template PD_INFER_DECL void Tensor::CopyFromCpu<float>(const float *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int64_t>(const int64_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int32_t>(const int32_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<uint8_t>(const uint8_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int8_t>(const int8_t *data);
511
template PD_INFER_DECL void Tensor::CopyFromCpu<float16>(const float16 *data);
512

513
template PD_INFER_DECL void Tensor::ShareExternalData<float>(
514 515 516
    const float *data,
    const std::vector<int> &shape,
    PlaceType place,
517 518
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int64_t>(
519 520 521
    const int64_t *data,
    const std::vector<int> &shape,
    PlaceType place,
522 523
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int32_t>(
524 525 526
    const int32_t *data,
    const std::vector<int> &shape,
    PlaceType place,
527 528
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<uint8_t>(
529 530 531
    const uint8_t *data,
    const std::vector<int> &shape,
    PlaceType place,
532 533
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int8_t>(
534 535 536
    const int8_t *data,
    const std::vector<int> &shape,
    PlaceType place,
537 538
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<float16>(
539 540 541
    const float16 *data,
    const std::vector<int> &shape,
    PlaceType place,
542 543
    DataLayout layout);

544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590
template PD_INFER_DECL void Tensor::CopyToCpu<float>(float *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int64_t>(int64_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int32_t>(int32_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<uint8_t>(uint8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int8_t>(int8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<float16>(float16 *data) const;

template PD_INFER_DECL void Tensor::CopyToCpuImpl<float>(float *data,
                                                         void *exec_stream,
                                                         CallbackFunc cb,
                                                         void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int64_t>(
    int64_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int32_t>(
    int32_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<uint8_t>(
    uint8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int8_t>(
    int8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<float16>(
    float16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, void *exec_stream) const;

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, CallbackFunc cb, void *cb_params) const;
591

592 593 594 595 596 597 598 599 600 601
template PD_INFER_DECL float *Tensor::data<float>(PlaceType *place,
                                                  int *size) const;
template PD_INFER_DECL int64_t *Tensor::data<int64_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int32_t *Tensor::data<int32_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL uint8_t *Tensor::data<uint8_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int8_t *Tensor::data<int8_t>(PlaceType *place,
                                                    int *size) const;
602 603
template PD_INFER_DECL float16 *Tensor::data<float16>(PlaceType *place,
                                                      int *size) const;
604

605 606 607 608 609
template PD_INFER_DECL float *Tensor::mutable_data<float>(PlaceType place);
template PD_INFER_DECL int64_t *Tensor::mutable_data<int64_t>(PlaceType place);
template PD_INFER_DECL int32_t *Tensor::mutable_data<int32_t>(PlaceType place);
template PD_INFER_DECL uint8_t *Tensor::mutable_data<uint8_t>(PlaceType place);
template PD_INFER_DECL int8_t *Tensor::mutable_data<int8_t>(PlaceType place);
610
template PD_INFER_DECL float16 *Tensor::mutable_data<float16>(PlaceType place);
611

612 613
Tensor::Tensor(void *scope, const void *device_contexts)
    : scope_{scope}, device_contexs_(device_contexts) {}
614

S
Steffy-zxf 已提交
615
template <typename T>
616
void *Tensor::FindTensor() const {
W
Wilber 已提交
617
  PADDLE_ENFORCE_EQ(
618 619
      name_.empty(),
      false,
620
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
621 622
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
623
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
624
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
625
  PADDLE_ENFORCE_NOT_NULL(
626 627 628
      var,
      paddle::platform::errors::PreconditionNotMet(
          "No tensor called [%s] in the runtime scope", name_));
S
Steffy-zxf 已提交
629
  auto *tensor = var->GetMutable<T>();
630 631 632
  return tensor;
}

633
std::vector<int> Tensor::shape() const {
634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    std::vector<int> shape;
    // input handle
    if (idx_ < 0) {
      shape.assign(shape_.begin(), shape_.end());
    } else {  // output handle
      auto binding = binding_.lock();
      PADDLE_ENFORCE_NOT_NULL(binding,
                              paddle::platform::errors::PreconditionNotMet(
                                  "output tensor [%s] no binding ptr", name_));
      std::vector<Ort::Value> outputs = binding->GetOutputValues();
      Ort::Value &value = outputs[idx_];
      auto info = value.GetTensorTypeAndShapeInfo();
      auto ort_shape = info.GetShape();
      shape.assign(ort_shape.begin(), ort_shape.end());
    }
    return shape;
  }
#endif
S
Steffy-zxf 已提交
654
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
W
Wilber 已提交
655
  PADDLE_ENFORCE_NOT_NULL(
656 657 658
      tensor_,
      paddle::platform::errors::PreconditionNotMet(
          "Not found tensor called %s in the scope", name_));
W
wenbin 已提交
659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674
// mkldnn may does layout transform internally, so need to reorder before
// return
#ifdef PADDLE_WITH_MKLDNN
  if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN) {
    paddle::framework::DataLayout out_layout =
        paddle::platform::MKLDNNDeviceContext::tls()
            .get_cur_paddle_data_layout();
    // Set default as NCHW in case not specified
    out_layout = out_layout == paddle::framework::DataLayout::kAnyLayout
                     ? paddle::framework::DataLayout::kNCHW
                     : out_layout;
    // In these data layouts, channel dimension is either on 2nd position: nChw
    // or
    // at last nhwC, so for dim==2 these layouts are the same and nothing should
    // be done. Similarly for dim==1 when you have just one possible
    // combination.
675
    if (tensor->dims().size() < 3) return phi::vectorize<int>(tensor->dims());
676 677
    if (out_layout == paddle::framework::DataLayout::kNHWC ||
        out_layout == paddle::framework::DataLayout::kNDHWC) {
678
      auto dims = phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
679 680 681
      std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
      return dims;
    } else {
682
      return phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
683 684 685
    }
  }
#endif
686
  return phi::vectorize<int>(tensor->dims());
687 688
}

689
void Tensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
S
Steffy-zxf 已提交
690
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
691
  paddle::framework::LoD lod;
692 693 694 695 696 697
  for (auto &level : x) {
    lod.emplace_back(level);
  }
  tensor->set_lod(lod);
}

698
std::vector<std::vector<size_t>> Tensor::lod() const {
S
Steffy-zxf 已提交
699
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
700 701 702 703 704 705 706
  std::vector<std::vector<size_t>> res;
  for (auto &level : tensor->lod()) {
    res.emplace_back(level);
  }
  return res;
}

707 708 709 710 711 712 713 714 715
void Tensor::SetName(const std::string &name) { name_ = name; }

const std::string &Tensor::name() const { return name_; }

void Tensor::SetPlace(PlaceType place, int device) {
  place_ = place;
  device_ = device;
}

716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736
#ifdef PADDLE_WITH_ONNXRUNTIME
void Tensor::SetOrtMark(bool is_ort_tensor) { is_ort_tensor_ = is_ort_tensor; }

void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
  binding_ = binding;
}

template <typename T>
void Tensor::ORTCopyToCpu(T *data) const {
  auto binding = binding_.lock();
  PADDLE_ENFORCE_NOT_NULL(binding,
                          paddle::platform::errors::PreconditionNotMet(
                              "output tensor [%s] no binding ptr", name_));
  std::vector<Ort::Value> outputs = binding->GetOutputValues();
  Ort::Value &value = outputs[idx_];
  auto info = value.GetTensorTypeAndShapeInfo();
  size_t size = info.GetElementCount() * sizeof(T);

  if (place_ == PlaceType::kCPU) {
    std::memcpy(static_cast<void *>(data), value.GetTensorData<void *>(), size);
  } else {
H
heliqi 已提交
737 738 739
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "CopyToCpu error.The current ONNXRuntime backend doesn't support "
        "GPU."));
740 741 742 743 744 745 746 747 748 749
  }
}

template void Tensor::ORTCopyToCpu<float>(float *data) const;
template void Tensor::ORTCopyToCpu<int32_t>(int32_t *data) const;
template void Tensor::ORTCopyToCpu<uint8_t>(uint8_t *data) const;
template void Tensor::ORTCopyToCpu<int8_t>(int8_t *data) const;
template void Tensor::ORTCopyToCpu<float16>(float16 *data) const;
#endif

W
Wilber 已提交
750 751 752 753 754 755 756
namespace experimental {
template <typename T>
void InternalUtils::CopyFromCpuWithIoStream(paddle_infer::Tensor *t,
                                            const T *data,
                                            cudaStream_t stream) {
  if (t->tensor_ == nullptr) {
    PADDLE_ENFORCE_EQ(
757 758
        t->name_.empty(),
        false,
W
Wilber 已提交
759 760 761 762 763 764
        paddle::platform::errors::PreconditionNotMet(
            "Need to SetName first, so that the corresponding tensor can "
            "be retrieved."));
    auto *scope = static_cast<paddle::framework::Scope *>(t->scope_);
    auto *var = scope->FindVar(t->name_);
    PADDLE_ENFORCE_NOT_NULL(
765 766 767
        var,
        paddle::platform::errors::PreconditionNotMet(
            "No tensor called [%s] in the runtime scope", t->name_));
W
Wilber 已提交
768 769 770 771 772
    auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
    t->tensor_ = tensor;
  }

  auto *tensor = static_cast<paddle::framework::LoDTensor *>(t->tensor_);
773 774
  PADDLE_ENFORCE_GE(tensor->numel(),
                    0,
W
Wilber 已提交
775 776 777 778 779 780 781 782 783 784 785 786
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
                        "std::vector<int> &shape)"
                        "function before copying data from cpu."));
  size_t ele_size = tensor->numel() * sizeof(T);
  if (t->place_ == PlaceType::kCPU) {
    auto *t_data = tensor->mutable_data<T>(paddle::platform::CPUPlace());
    std::memcpy(static_cast<void *>(t_data), data, ele_size);
  } else if (t->place_ == PlaceType::kGPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    paddle::platform::CUDAPlace gpu_place(t->device_);
    auto *t_data = tensor->mutable_data<T>(gpu_place);
787 788 789 790 791 792
    paddle::memory::Copy(gpu_place,
                         static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(),
                         data,
                         ele_size,
                         stream);
W
Wilber 已提交
793 794 795 796 797 798 799 800 801 802 803 804
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "CopyFromCpuWithIoStream only supports CPU and GPU now."));
  }
}

template <typename T>
805 806
void InternalUtils::CopyToCpuWithIoStream(paddle_infer::Tensor *t,
                                          T *data,
W
Wilber 已提交
807 808 809
                                          cudaStream_t stream) {
  if (t->tensor_ == nullptr) {
    PADDLE_ENFORCE_EQ(
810 811
        t->name_.empty(),
        false,
W
Wilber 已提交
812 813 814 815 816 817
        paddle::platform::errors::PreconditionNotMet(
            "Need to SetName first, so that the corresponding tensor can "
            "be retrieved."));
    auto *scope = static_cast<paddle::framework::Scope *>(t->scope_);
    auto *var = scope->FindVar(t->name_);
    PADDLE_ENFORCE_NOT_NULL(
818 819 820
        var,
        paddle::platform::errors::PreconditionNotMet(
            "No tensor called [%s] in the runtime scope", t->name_));
W
Wilber 已提交
821 822 823 824 825 826 827 828 829 830 831 832
    auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
    t->tensor_ = tensor;
  }

  auto *tensor = static_cast<paddle::framework::LoDTensor *>(t->tensor_);
  auto ele_num = tensor->numel();
  auto *t_data = tensor->data<T>();
  auto t_place = tensor->place();

  paddle::framework::Tensor out;
  auto mem_allocation =
      std::make_shared<paddle::memory::allocation::Allocation>(
833 834
          static_cast<void *>(data),
          ele_num * sizeof(T),
W
Wilber 已提交
835 836 837 838 839 840 841
          paddle::platform::CPUPlace());
  out.ResetHolder(mem_allocation);

  if (paddle::platform::is_cpu_place(t_place)) {
#ifdef PADDLE_WITH_MKLDNN
    if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN)
      paddle::framework::innerTransDataLayoutFromMKLDNN(
842 843 844
          tensor->layout(),
          paddle::platform::MKLDNNDeviceContext::tls()
              .get_cur_paddle_data_layout(),
845 846 847 848
          *tensor,
          &out,
          paddle::platform::CPUPlace(),
          true);
W
Wilber 已提交
849 850 851 852 853 854 855 856
    else
      std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#endif
  } else if (t->place_ == PlaceType::kGPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    paddle::memory::Copy(paddle::platform::CPUPlace(),
857 858 859 860 861
                         static_cast<void *>(data),
                         t_place,
                         t_data,
                         ele_num * sizeof(T),
                         stream);
W
Wilber 已提交
862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "CopyToCpuWithIoStream only supports CPU and GPU now."));
  }
}

template void InternalUtils::CopyFromCpuWithIoStream<float>(
    paddle_infer::Tensor *t, const float *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int64_t>(
    paddle_infer::Tensor *t, const int64_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int32_t>(
    paddle_infer::Tensor *t, const int32_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<uint8_t>(
    paddle_infer::Tensor *t, const uint8_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int8_t>(
    paddle_infer::Tensor *t, const int8_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<float16>(
    paddle_infer::Tensor *t, const float16 *data, cudaStream_t stream);

template void InternalUtils::CopyToCpuWithIoStream<float>(
    paddle_infer::Tensor *t, float *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int64_t>(
    paddle_infer::Tensor *t, int64_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int32_t>(
    paddle_infer::Tensor *t, int32_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<uint8_t>(
    paddle_infer::Tensor *t, uint8_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int8_t>(
    paddle_infer::Tensor *t, int8_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<float16>(
    paddle_infer::Tensor *t, float16 *data, cudaStream_t stream);

}  // namespace experimental

901
}  // namespace paddle_infer