zero_copy_tensor.cc 37.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/framework/convert_utils.h"
16
#include "paddle/fluid/framework/data_layout_transform.h"
17 18 19
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
W
Wilber 已提交
20
#include "paddle/fluid/inference/api/paddle_tensor.h"
N
nhzlx 已提交
21
#include "paddle/fluid/memory/memcpy.h"
22
#include "paddle/fluid/platform/enforce.h"
23
#include "paddle/fluid/platform/float16.h"
24
#include "paddle/phi/core/allocator.h"
25
#ifdef PADDLE_WITH_ONNXRUNTIME
H
heliqi 已提交
26 27
#include "onnxruntime_c_api.h"    // NOLINT
#include "onnxruntime_cxx_api.h"  // NOLINT
28
#endif
29

30
namespace paddle_infer {
31

32 33
using float16 = paddle::platform::float16;

34
void Tensor::Reshape(const std::vector<int> &shape) {
35 36 37 38 39 40 41
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    shape_.assign(shape.begin(), shape.end());
    return;
  }
#endif

W
Wilber 已提交
42 43
  PADDLE_ENFORCE_EQ(
      name_.empty(), false,
44
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
45 46 47
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
  PADDLE_ENFORCE_EQ(input_or_output_, true,
48
                    paddle::platform::errors::PermissionDenied(
W
Wilber 已提交
49
                        "Can't reshape the output tensor, it is readonly"));
50
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
51
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
52
  PADDLE_ENFORCE_NOT_NULL(
53
      var, paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
54
               "No tensor called [%s] in the runtime scope", name_));
55
  auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
56
  tensor->Resize(phi::make_ddim(shape));
57 58
}

S
Steffy-zxf 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
void Tensor::ReshapeStrings(const size_t &shape) {
  PADDLE_ENFORCE_EQ(
      name_.empty(), false,
      paddle::platform::errors::PreconditionNotMet(
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
  PADDLE_ENFORCE_EQ(input_or_output_, true,
                    paddle::platform::errors::PermissionDenied(
                        "Can't reshape the output tensor, it is readonly"));
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
  auto *var = scope->FindVar(name_);
  PADDLE_ENFORCE_NOT_NULL(
      var, paddle::platform::errors::PreconditionNotMet(
               "No tensor called [%s] in the runtime scope", name_));
  paddle_infer::Strings *tensor = var->GetMutable<paddle_infer::Strings>();
  tensor->resize(shape);
}

#define EAGER_GET_TENSOR(tensor_type)    \
  if (!tensor_) {                        \
    tensor_ = FindTensor<tensor_type>(); \
  }                                      \
  auto *tensor = static_cast<tensor_type *>(tensor_);
82

83
template <typename T>
84
T *Tensor::mutable_data(PlaceType place) {
S
Steffy-zxf 已提交
85
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
86 87
  PADDLE_ENFORCE_GT(
      tensor->numel(), 0,
88 89
      paddle::platform::errors::PreconditionNotMet(
          "You should call Tensor::Reshape(const std::vector<int> "
W
Wilber 已提交
90 91
          "&shape)"
          "function before retrieving mutable_data from input tensor."));
92
  switch (static_cast<int>(place)) {
93 94
    case static_cast<int>(PlaceType::kCPU): {
      return tensor->mutable_data<T>(paddle::platform::CPUPlace());
95
    }
96
    case static_cast<int>(PlaceType::kGPU): {
97 98 99 100 101 102 103 104 105
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
      paddle::platform::CUDAPlace gpu_place(device_);
      auto *dev_ctxs = reinterpret_cast<const std::map<
          phi::Place, std::shared_future<std::unique_ptr<phi::DeviceContext>>>
                                            *>(device_contexs_);
      auto *dev_ctx =
          static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
      return dev_ctx->Alloc<T>(tensor, tensor->numel() * sizeof(T));
#else
106
      return tensor->mutable_data<T>(paddle::platform::CUDAPlace(device_));
107
#endif
108 109 110
    }
    case static_cast<int>(PlaceType::kXPU): {
      return tensor->mutable_data<T>(paddle::platform::XPUPlace(device_));
111
    }
112 113 114
    case static_cast<int>(PlaceType::kNPU): {
      return tensor->mutable_data<T>(paddle::platform::NPUPlace(device_));
    }
115
    default:
116
      PADDLE_THROW(paddle::platform::errors::Unavailable(
117 118
          "Only CPU / CUDA / XPU / NPU places is supported. The place `%d` is "
          "not supported.",
119
          static_cast<int>(place)));
120 121 122 123 124 125
      break;
  }
  return nullptr;
}

template <typename T>
126
T *Tensor::data(PlaceType *place, int *size) const {
S
Steffy-zxf 已提交
127
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
128 129
  auto *res = tensor->data<T>();

130 131 132 133 134 135
  if (paddle::platform::is_cpu_place(tensor->place())) {
    *place = PlaceType::kCPU;
  } else if (paddle::platform::is_gpu_place(tensor->place())) {
    *place = PlaceType::kGPU;
  } else if (paddle::platform::is_xpu_place(tensor->place())) {
    *place = PlaceType::kXPU;
136 137
  } else if (paddle::platform::is_npu_place(tensor->place())) {
    *place = PlaceType::kNPU;
138
  } else {
139
    *place = PlaceType::kUNK;
140 141 142 143 144 145
  }

  *size = tensor->numel();
  return res;
}

146
DataType Tensor::type() const {
147 148 149 150 151
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    return dtype_;
  }
#endif
S
Steffy-zxf 已提交
152
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
153
  auto type = paddle::framework::TransToProtoVarType(tensor->dtype());
154 155
  if (type == paddle::framework::proto::VarType::FP32) {
    return DataType::FLOAT32;
156 157
  } else if (type == paddle::framework::proto::VarType::FP16) {
    return DataType::FLOAT16;
158 159 160 161 162 163
  } else if (type == paddle::framework::proto::VarType::INT64) {
    return DataType::INT64;
  } else if (type == paddle::framework::proto::VarType::INT32) {
    return DataType::INT32;
  } else if (type == paddle::framework::proto::VarType::UINT8) {
    return DataType::UINT8;
164 165
  } else if (type == paddle::framework::proto::VarType::INT8) {
    return DataType::INT8;
166
  }
167
  return DataType::FLOAT32;
168 169
}

170 171
PlaceType Tensor::place() const { return place_; }

N
nhzlx 已提交
172
template <typename T>
173
void Tensor::CopyFromCpu(const T *data) {
174 175 176 177 178 179 180
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    ORTCopyFromCpu<T>(data);
    return;
  }
#endif

S
Steffy-zxf 已提交
181
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
W
Wilber 已提交
182
  PADDLE_ENFORCE_GE(tensor->numel(), 0,
183 184
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
W
Wilber 已提交
185 186
                        "std::vector<int> &shape)"
                        "function before copying data from cpu."));
N
nhzlx 已提交
187 188
  size_t ele_size = tensor->numel() * sizeof(T);

189 190
  if (place_ == PlaceType::kCPU) {
    auto *t_data = tensor->mutable_data<T>(paddle::platform::CPUPlace());
N
nhzlx 已提交
191
    std::memcpy(static_cast<void *>(t_data), data, ele_size);
192
  } else if (place_ == PlaceType::kGPU) {
193
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
194

195
    paddle::platform::CUDAPlace gpu_place(device_);
196 197 198 199 200 201
    auto *dev_ctxs = reinterpret_cast<const std::map<
        phi::Place, std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
        device_contexs_);
    auto *dev_ctx =
        static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
    auto *t_data = dev_ctx->Alloc<T>(tensor, tensor->numel() * sizeof(T));
N
nhzlx 已提交
202

203 204 205
    paddle::memory::Copy(gpu_place, static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(), data, ele_size,
                         dev_ctx->stream());
N
nhzlx 已提交
206
#else
207 208 209
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
210
#endif
211
  } else if (place_ == PlaceType::kXPU) {
212
#ifdef PADDLE_WITH_XPU
213
    paddle::platform::XPUPlace xpu_place(device_);
214
    auto *t_data = tensor->mutable_data<T>(xpu_place);
215 216
    paddle::memory::Copy(xpu_place, static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(), data, ele_size);
217
#else
218 219 220
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::NPUPlace npu_place(device_);
    auto *t_data = tensor->mutable_data<T>(npu_place);
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
    paddle::memory::Copy(npu_place, static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(), data, ele_size,
                         dev_ctx->stream());
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
237 238
#endif
  } else {
239 240 241 242 243 244 245 246 247 248 249 250 251 252
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    auto device_type_id =
        static_cast<size_t>(place_) - static_cast<size_t>(PlaceType::kCUSTOM);
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::CustomPlace custom_place(
        phi::GetGlobalDeviceType(device_type_id), device_);
    auto *t_data = tensor->mutable_data<T>(custom_place);
    auto *dev_ctx = static_cast<const paddle::platform::CustomDeviceContext *>(
        pool.Get(custom_place));
    paddle::memory::Copy(custom_place, static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(), data, ele_size,
                         dev_ctx->stream());
#else
253
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
254
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
255
#endif
N
nhzlx 已提交
256 257 258
  }
}

259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
template <typename T>
struct DataTypeInfo;

template <>
struct DataTypeInfo<float> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT32;
};

template <>
struct DataTypeInfo<float16> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT16;
};

template <>
struct DataTypeInfo<int64_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT64;
};

template <>
struct DataTypeInfo<int8_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT8;
};

template <>
struct DataTypeInfo<uint8_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::UINT8;
};

template <>
struct DataTypeInfo<int32_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT32;
};

paddle::experimental::DataLayout LayoutConvert(DataLayout layout) {
  PADDLE_ENFORCE_EQ(
      layout, DataLayout::kNCHW,
      paddle::platform::errors::InvalidArgument("Only NCHW is supported now."));
  return paddle::experimental::DataLayout::NCHW;
}

template <typename T>
void Tensor::ShareExternalData(const T *data, const std::vector<int> &shape,
                               PlaceType place, DataLayout layout) {
  EAGER_GET_TENSOR(paddle::framework::LoDTensor)
  size_t size =
      std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
      sizeof(T);
  phi::DenseTensorMeta meta(DataTypeInfo<T>().TYPE, phi::make_ddim(shape),
                            LayoutConvert(layout));
  if (place == PlaceType::kCPU) {
    phi::DenseTensor dtensor(
        std::make_shared<phi::Allocation>(const_cast<T *>(data), size,
                                          paddle::platform::CPUPlace()),
        meta);
    *tensor = std::move(dtensor);
  } else if (place == PlaceType::kGPU) {
    phi::DenseTensor dtensor(
        std::make_shared<phi::Allocation>(const_cast<T *>(data), size,
                                          paddle::platform::CUDAPlace(device_)),
        meta);
    *tensor = std::move(dtensor);
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "PlaceType must be PlaceType::kCPU or PlaceType::kGPU."));
  }
}

S
Steffy-zxf 已提交
326 327 328 329 330 331 332 333 334 335
void Tensor::CopyStringsFromCpu(const paddle_infer::Strings *data) {
  EAGER_GET_TENSOR(paddle_infer::Strings);
  PADDLE_ENFORCE_GE(tensor->size(), 0,
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
                        "std::size_t &shape)function before copying"
                        "the string data from cpu."));
  *tensor = *data;
}

N
nhzlx 已提交
336
template <typename T>
337 338
void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb,
                           void *cb_params) const {
S
Steffy-zxf 已提交
339
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
N
nhzlx 已提交
340 341 342 343
  auto ele_num = tensor->numel();
  auto *t_data = tensor->data<T>();
  auto t_place = tensor->place();

344
  paddle::framework::Tensor out;
345 346 347 348
  auto mem_allocation =
      std::make_shared<paddle::memory::allocation::Allocation>(
          static_cast<void *>(data), ele_num * sizeof(T),
          paddle::platform::CPUPlace());
349 350
  out.ResetHolder(mem_allocation);

351
  if (paddle::platform::is_cpu_place(t_place)) {
352 353 354
#ifdef PADDLE_WITH_MKLDNN
    if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN)
      paddle::framework::innerTransDataLayoutFromMKLDNN(
355 356 357
          tensor->layout(),
          paddle::platform::MKLDNNDeviceContext::tls()
              .get_cur_paddle_data_layout(),
358 359 360 361
          *tensor, &out, paddle::platform::CPUPlace(), true);
    else
      std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
N
nhzlx 已提交
362
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
J
jianghaicheng 已提交
363 364 365 366 367 368 369 370
#endif
  } else if (paddle::platform::is_ipu_place(t_place)) {
#ifdef PADDLE_WITH_IPU
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with IPU place because paddle is not compiled "
        "with IPU."));
371
#endif
372
  } else if (place_ == PlaceType::kGPU) {
373
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
374
    auto gpu_place = t_place;
375 376 377 378 379
    auto *dev_ctxs = reinterpret_cast<const std::map<
        phi::Place, std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
        device_contexs_);
    auto *dev_ctx =
        static_cast<phi::GPUContext *>(dev_ctxs->at(gpu_place).get().get());
380 381 382
    paddle::memory::Copy(paddle::platform::CPUPlace(),
                         static_cast<void *>(data), gpu_place, t_data,
                         ele_num * sizeof(T), dev_ctx->stream());
383 384 385
#ifdef PADDLE_WITH_HIP
    hipStreamSynchronize(dev_ctx->stream());
#else
386 387 388 389 390 391 392 393 394 395
    // async, return stream
    if (nullptr != exec_stream) {
      *(static_cast<cudaStream_t *>(exec_stream)) = dev_ctx->stream();
      // async with callback
    } else if (cb) {
      cudaLaunchHostFunc(dev_ctx->stream(), cb, cb_params);
      // sync
    } else {
      cudaStreamSynchronize(dev_ctx->stream());
    }
396
#endif
N
nhzlx 已提交
397
#else
398 399 400
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
401
#endif
402
  } else if (place_ == PlaceType::kXPU) {
403
#ifdef PADDLE_WITH_XPU
404
    auto xpu_place = t_place;
405 406 407
    paddle::memory::Copy(paddle::platform::CPUPlace(),
                         static_cast<void *>(data), xpu_place, t_data,
                         ele_num * sizeof(T));
408
#else
409 410 411
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
412 413 414 415 416
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
417
    auto npu_place = t_place;
W
Wilber 已提交
418 419 420 421 422
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
                         static_cast<void *>(data), npu_place, t_data,
                         ele_num * sizeof(T), dev_ctx->stream());
423
    paddle::platform::NPUStreamSync(dev_ctx->stream());
W
Wilber 已提交
424 425 426 427
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
428 429
#endif
  } else {
430 431 432 433 434 435 436 437 438 439 440
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    auto custom_place = t_place;
    auto *dev_ctx = static_cast<const paddle::platform::CustomDeviceContext *>(
        pool.Get(custom_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
                         static_cast<void *>(data), custom_place, t_data,
                         ele_num * sizeof(T), dev_ctx->stream());
// TODO(wangran16): sync_stream
#else
441
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
442
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
443
#endif
N
nhzlx 已提交
444 445
  }
}
446 447 448

template <typename T>
void Tensor::CopyToCpu(T *data) const {
449 450 451 452 453 454 455
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    ORTCopyToCpu<T>(data);
    return;
  }
#endif

456 457 458 459 460 461 462 463 464 465 466 467 468
  CopyToCpuImpl<T>(data, nullptr, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, void *exec_stream) const {
  CopyToCpuImpl<T>(data, exec_stream, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, CallbackFunc cb, void *cb_params) const {
  CopyToCpuImpl<T>(data, nullptr, cb, cb_params);
}

469 470 471 472 473
template PD_INFER_DECL void Tensor::CopyFromCpu<float>(const float *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int64_t>(const int64_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int32_t>(const int32_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<uint8_t>(const uint8_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int8_t>(const int8_t *data);
474
template PD_INFER_DECL void Tensor::CopyFromCpu<float16>(const float16 *data);
475

476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494
template PD_INFER_DECL void Tensor::ShareExternalData<float>(
    const float *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int64_t>(
    const int64_t *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int32_t>(
    const int32_t *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<uint8_t>(
    const uint8_t *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int8_t>(
    const int8_t *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<float16>(
    const float16 *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);

495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541
template PD_INFER_DECL void Tensor::CopyToCpu<float>(float *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int64_t>(int64_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int32_t>(int32_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<uint8_t>(uint8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int8_t>(int8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<float16>(float16 *data) const;

template PD_INFER_DECL void Tensor::CopyToCpuImpl<float>(float *data,
                                                         void *exec_stream,
                                                         CallbackFunc cb,
                                                         void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int64_t>(
    int64_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int32_t>(
    int32_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<uint8_t>(
    uint8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int8_t>(
    int8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<float16>(
    float16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, void *exec_stream) const;

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, CallbackFunc cb, void *cb_params) const;
542

543 544 545 546 547 548 549 550 551 552
template PD_INFER_DECL float *Tensor::data<float>(PlaceType *place,
                                                  int *size) const;
template PD_INFER_DECL int64_t *Tensor::data<int64_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int32_t *Tensor::data<int32_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL uint8_t *Tensor::data<uint8_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int8_t *Tensor::data<int8_t>(PlaceType *place,
                                                    int *size) const;
553 554
template PD_INFER_DECL float16 *Tensor::data<float16>(PlaceType *place,
                                                      int *size) const;
555

556 557 558 559 560
template PD_INFER_DECL float *Tensor::mutable_data<float>(PlaceType place);
template PD_INFER_DECL int64_t *Tensor::mutable_data<int64_t>(PlaceType place);
template PD_INFER_DECL int32_t *Tensor::mutable_data<int32_t>(PlaceType place);
template PD_INFER_DECL uint8_t *Tensor::mutable_data<uint8_t>(PlaceType place);
template PD_INFER_DECL int8_t *Tensor::mutable_data<int8_t>(PlaceType place);
561
template PD_INFER_DECL float16 *Tensor::mutable_data<float16>(PlaceType place);
562

563 564
Tensor::Tensor(void *scope, const void *device_contexts)
    : scope_{scope}, device_contexs_(device_contexts) {}
565

S
Steffy-zxf 已提交
566
template <typename T>
567
void *Tensor::FindTensor() const {
W
Wilber 已提交
568 569
  PADDLE_ENFORCE_EQ(
      name_.empty(), false,
570
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
571 572
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
573
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
574
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
575
  PADDLE_ENFORCE_NOT_NULL(
576
      var, paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
577
               "No tensor called [%s] in the runtime scope", name_));
S
Steffy-zxf 已提交
578
  auto *tensor = var->GetMutable<T>();
579 580 581
  return tensor;
}

582
std::vector<int> Tensor::shape() const {
583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    std::vector<int> shape;
    // input handle
    if (idx_ < 0) {
      shape.assign(shape_.begin(), shape_.end());
    } else {  // output handle
      auto binding = binding_.lock();
      PADDLE_ENFORCE_NOT_NULL(binding,
                              paddle::platform::errors::PreconditionNotMet(
                                  "output tensor [%s] no binding ptr", name_));
      std::vector<Ort::Value> outputs = binding->GetOutputValues();
      Ort::Value &value = outputs[idx_];
      auto info = value.GetTensorTypeAndShapeInfo();
      auto ort_shape = info.GetShape();
      shape.assign(ort_shape.begin(), ort_shape.end());
    }
    return shape;
  }
#endif
S
Steffy-zxf 已提交
603
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
W
Wilber 已提交
604
  PADDLE_ENFORCE_NOT_NULL(
605
      tensor_, paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
606
                   "Not found tensor called %s in the scope", name_));
W
wenbin 已提交
607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622
// mkldnn may does layout transform internally, so need to reorder before
// return
#ifdef PADDLE_WITH_MKLDNN
  if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN) {
    paddle::framework::DataLayout out_layout =
        paddle::platform::MKLDNNDeviceContext::tls()
            .get_cur_paddle_data_layout();
    // Set default as NCHW in case not specified
    out_layout = out_layout == paddle::framework::DataLayout::kAnyLayout
                     ? paddle::framework::DataLayout::kNCHW
                     : out_layout;
    // In these data layouts, channel dimension is either on 2nd position: nChw
    // or
    // at last nhwC, so for dim==2 these layouts are the same and nothing should
    // be done. Similarly for dim==1 when you have just one possible
    // combination.
623
    if (tensor->dims().size() < 3) return phi::vectorize<int>(tensor->dims());
624 625
    if (out_layout == paddle::framework::DataLayout::kNHWC ||
        out_layout == paddle::framework::DataLayout::kNDHWC) {
626
      auto dims = phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
627 628 629
      std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
      return dims;
    } else {
630
      return phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
631 632 633
    }
  }
#endif
634
  return phi::vectorize<int>(tensor->dims());
635 636
}

637
void Tensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
S
Steffy-zxf 已提交
638
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
639
  paddle::framework::LoD lod;
640 641 642 643 644 645
  for (auto &level : x) {
    lod.emplace_back(level);
  }
  tensor->set_lod(lod);
}

646
std::vector<std::vector<size_t>> Tensor::lod() const {
S
Steffy-zxf 已提交
647
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
648 649 650 651 652 653 654
  std::vector<std::vector<size_t>> res;
  for (auto &level : tensor->lod()) {
    res.emplace_back(level);
  }
  return res;
}

655 656 657 658 659 660 661 662 663
void Tensor::SetName(const std::string &name) { name_ = name; }

const std::string &Tensor::name() const { return name_; }

void Tensor::SetPlace(PlaceType place, int device) {
  place_ = place;
  device_ = device;
}

664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718
#ifdef PADDLE_WITH_ONNXRUNTIME
void Tensor::SetOrtMark(bool is_ort_tensor) { is_ort_tensor_ = is_ort_tensor; }

void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
  binding_ = binding;
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, float *data,
                       size_t size, const int64_t *shape, size_t shape_len) {
  return Ort::Value::CreateTensor<float>(memory_info, data, size, shape,
                                         shape_len);
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int64_t *data,
                       size_t size, const int64_t *shape, size_t shape_len) {
  return Ort::Value::CreateTensor<int64_t>(memory_info, data, size, shape,
                                           shape_len);
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int32_t *data,
                       size_t size, const int64_t *shape, size_t shape_len) {
  return Ort::Value::CreateTensor<int32_t>(memory_info, data, size, shape,
                                           shape_len);
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, uint8_t *data,
                       size_t size, const int64_t *shape, size_t shape_len) {
  return Ort::Value::CreateTensor<uint8_t>(memory_info, data, size, shape,
                                           shape_len);
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int8_t *data,
                       size_t size, const int64_t *shape, size_t shape_len) {
  return Ort::Value::CreateTensor<int8_t>(memory_info, data, size, shape,
                                          shape_len);
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, float16 *data,
                       size_t size, const int64_t *shape, size_t shape_len) {
  return Ort::Value::CreateTensor(memory_info, static_cast<void *>(data),
                                  size * sizeof(float16), shape, shape_len,
                                  ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
}

template <typename T>
void Tensor::ORTCopyFromCpu(const T *data) {
  auto binding = binding_.lock();
  PADDLE_ENFORCE_NOT_NULL(binding,
                          paddle::platform::errors::PreconditionNotMet(
                              "input tensor [%s] no binding ptr", name_));
  const char *device_name = place_ == PlaceType::kCPU ? "Cpu" : "Cuda";
  Ort::MemoryInfo memory_info(device_name, OrtDeviceAllocator, device_,
                              OrtMemTypeDefault);
  size_t size = std::accumulate(begin(shape_), end(shape_), 1UL,
                                std::multiplies<size_t>());
H
heliqi 已提交
719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751
  size_t buffer_size = size * sizeof(T);
  if (buffer_size > buffer_.size()) {
    buffer_.resize(buffer_size);
  }
  std::memcpy(static_cast<void *>(buffer_.data()), data, buffer_size);

  auto onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
  if (std::is_same<T, float>::value) {
    onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
  } else if (std::is_same<T, double>::value) {
    onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
  } else if (std::is_same<T, int64_t>::value) {
    onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
  } else if (std::is_same<T, int32_t>::value) {
    onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
  } else if (std::is_same<T, uint8_t>::value) {
    onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
  } else if (std::is_same<T, int8_t>::value) {
    onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
  } else if (std::is_same<T, float16>::value) {
    onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
  }

  if (onnx_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "Found undefined data type for onnxruntime, only supports "
        "float16/float32/float64/int8/uint8/int32/int64."));
  }

  auto ort_value =
      Ort::Value::CreateTensor(memory_info, buffer_.data(), buffer_size,
                               shape_.data(), shape_.size(), onnx_dtype);

752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768
  binding->BindInput(name_.c_str(), ort_value);
}

template <typename T>
void Tensor::ORTCopyToCpu(T *data) const {
  auto binding = binding_.lock();
  PADDLE_ENFORCE_NOT_NULL(binding,
                          paddle::platform::errors::PreconditionNotMet(
                              "output tensor [%s] no binding ptr", name_));
  std::vector<Ort::Value> outputs = binding->GetOutputValues();
  Ort::Value &value = outputs[idx_];
  auto info = value.GetTensorTypeAndShapeInfo();
  size_t size = info.GetElementCount() * sizeof(T);

  if (place_ == PlaceType::kCPU) {
    std::memcpy(static_cast<void *>(data), value.GetTensorData<void *>(), size);
  } else {
H
heliqi 已提交
769 770 771
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "CopyToCpu error.The current ONNXRuntime backend doesn't support "
        "GPU."));
772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788
  }
}

template void Tensor::ORTCopyFromCpu<float>(const float *data);
template void Tensor::ORTCopyFromCpu<int64_t>(const int64_t *data);
template void Tensor::ORTCopyFromCpu<int32_t>(const int32_t *data);
template void Tensor::ORTCopyFromCpu<uint8_t>(const uint8_t *data);
template void Tensor::ORTCopyFromCpu<int8_t>(const int8_t *data);
template void Tensor::ORTCopyFromCpu<float16>(const float16 *data);

template void Tensor::ORTCopyToCpu<float>(float *data) const;
template void Tensor::ORTCopyToCpu<int32_t>(int32_t *data) const;
template void Tensor::ORTCopyToCpu<uint8_t>(uint8_t *data) const;
template void Tensor::ORTCopyToCpu<int8_t>(int8_t *data) const;
template void Tensor::ORTCopyToCpu<float16>(float16 *data) const;
#endif

W
Wilber 已提交
789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869
namespace experimental {
template <typename T>
void InternalUtils::CopyFromCpuWithIoStream(paddle_infer::Tensor *t,
                                            const T *data,
                                            cudaStream_t stream) {
  if (t->tensor_ == nullptr) {
    PADDLE_ENFORCE_EQ(
        t->name_.empty(), false,
        paddle::platform::errors::PreconditionNotMet(
            "Need to SetName first, so that the corresponding tensor can "
            "be retrieved."));
    auto *scope = static_cast<paddle::framework::Scope *>(t->scope_);
    auto *var = scope->FindVar(t->name_);
    PADDLE_ENFORCE_NOT_NULL(
        var, paddle::platform::errors::PreconditionNotMet(
                 "No tensor called [%s] in the runtime scope", t->name_));
    auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
    t->tensor_ = tensor;
  }

  auto *tensor = static_cast<paddle::framework::LoDTensor *>(t->tensor_);
  PADDLE_ENFORCE_GE(tensor->numel(), 0,
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
                        "std::vector<int> &shape)"
                        "function before copying data from cpu."));
  size_t ele_size = tensor->numel() * sizeof(T);
  if (t->place_ == PlaceType::kCPU) {
    auto *t_data = tensor->mutable_data<T>(paddle::platform::CPUPlace());
    std::memcpy(static_cast<void *>(t_data), data, ele_size);
  } else if (t->place_ == PlaceType::kGPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    paddle::platform::CUDAPlace gpu_place(t->device_);
    auto *t_data = tensor->mutable_data<T>(gpu_place);
    paddle::memory::Copy(gpu_place, static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(), data, ele_size, stream);
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "CopyFromCpuWithIoStream only supports CPU and GPU now."));
  }
}

template <typename T>
void InternalUtils::CopyToCpuWithIoStream(paddle_infer::Tensor *t, T *data,
                                          cudaStream_t stream) {
  if (t->tensor_ == nullptr) {
    PADDLE_ENFORCE_EQ(
        t->name_.empty(), false,
        paddle::platform::errors::PreconditionNotMet(
            "Need to SetName first, so that the corresponding tensor can "
            "be retrieved."));
    auto *scope = static_cast<paddle::framework::Scope *>(t->scope_);
    auto *var = scope->FindVar(t->name_);
    PADDLE_ENFORCE_NOT_NULL(
        var, paddle::platform::errors::PreconditionNotMet(
                 "No tensor called [%s] in the runtime scope", t->name_));
    auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
    t->tensor_ = tensor;
  }

  auto *tensor = static_cast<paddle::framework::LoDTensor *>(t->tensor_);
  auto ele_num = tensor->numel();
  auto *t_data = tensor->data<T>();
  auto t_place = tensor->place();

  paddle::framework::Tensor out;
  auto mem_allocation =
      std::make_shared<paddle::memory::allocation::Allocation>(
          static_cast<void *>(data), ele_num * sizeof(T),
          paddle::platform::CPUPlace());
  out.ResetHolder(mem_allocation);

  if (paddle::platform::is_cpu_place(t_place)) {
#ifdef PADDLE_WITH_MKLDNN
    if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN)
      paddle::framework::innerTransDataLayoutFromMKLDNN(
870 871 872
          tensor->layout(),
          paddle::platform::MKLDNNDeviceContext::tls()
              .get_cur_paddle_data_layout(),
W
Wilber 已提交
873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922
          *tensor, &out, paddle::platform::CPUPlace(), true);
    else
      std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#endif
  } else if (t->place_ == PlaceType::kGPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    paddle::memory::Copy(paddle::platform::CPUPlace(),
                         static_cast<void *>(data), t_place, t_data,
                         ele_num * sizeof(T), stream);
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "CopyToCpuWithIoStream only supports CPU and GPU now."));
  }
}

template void InternalUtils::CopyFromCpuWithIoStream<float>(
    paddle_infer::Tensor *t, const float *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int64_t>(
    paddle_infer::Tensor *t, const int64_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int32_t>(
    paddle_infer::Tensor *t, const int32_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<uint8_t>(
    paddle_infer::Tensor *t, const uint8_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int8_t>(
    paddle_infer::Tensor *t, const int8_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<float16>(
    paddle_infer::Tensor *t, const float16 *data, cudaStream_t stream);

template void InternalUtils::CopyToCpuWithIoStream<float>(
    paddle_infer::Tensor *t, float *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int64_t>(
    paddle_infer::Tensor *t, int64_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int32_t>(
    paddle_infer::Tensor *t, int32_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<uint8_t>(
    paddle_infer::Tensor *t, uint8_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int8_t>(
    paddle_infer::Tensor *t, int8_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<float16>(
    paddle_infer::Tensor *t, float16 *data, cudaStream_t stream);

}  // namespace experimental

923
}  // namespace paddle_infer