zero_copy_tensor.cc 28.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/framework/convert_utils.h"
16
#include "paddle/fluid/framework/data_layout_transform.h"
17 18 19
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
W
Wilber 已提交
20
#include "paddle/fluid/inference/api/paddle_tensor.h"
N
nhzlx 已提交
21
#include "paddle/fluid/memory/memcpy.h"
22
#include "paddle/fluid/platform/enforce.h"
23
#include "paddle/fluid/platform/float16.h"
24
#include "paddle/phi/core/allocator.h"
25 26 27
#ifdef PADDLE_WITH_ONNXRUNTIME
#include "paddle/fluid/inference/api/onnxruntime_predictor.h"
#endif
28

29
namespace paddle_infer {
30

31 32
using float16 = paddle::platform::float16;

33
void Tensor::Reshape(const std::vector<int> &shape) {
34 35 36 37 38 39 40
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    shape_.assign(shape.begin(), shape.end());
    return;
  }
#endif

W
Wilber 已提交
41 42
  PADDLE_ENFORCE_EQ(
      name_.empty(), false,
43
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
44 45 46
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
  PADDLE_ENFORCE_EQ(input_or_output_, true,
47
                    paddle::platform::errors::PermissionDenied(
W
Wilber 已提交
48
                        "Can't reshape the output tensor, it is readonly"));
49
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
50
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
51
  PADDLE_ENFORCE_NOT_NULL(
52
      var, paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
53
               "No tensor called [%s] in the runtime scope", name_));
54
  auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
55
  tensor->Resize(phi::make_ddim(shape));
56 57
}

S
Steffy-zxf 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
void Tensor::ReshapeStrings(const size_t &shape) {
  PADDLE_ENFORCE_EQ(
      name_.empty(), false,
      paddle::platform::errors::PreconditionNotMet(
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
  PADDLE_ENFORCE_EQ(input_or_output_, true,
                    paddle::platform::errors::PermissionDenied(
                        "Can't reshape the output tensor, it is readonly"));
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
  auto *var = scope->FindVar(name_);
  PADDLE_ENFORCE_NOT_NULL(
      var, paddle::platform::errors::PreconditionNotMet(
               "No tensor called [%s] in the runtime scope", name_));
  paddle_infer::Strings *tensor = var->GetMutable<paddle_infer::Strings>();
  tensor->resize(shape);
}

#define EAGER_GET_TENSOR(tensor_type)    \
  if (!tensor_) {                        \
    tensor_ = FindTensor<tensor_type>(); \
  }                                      \
  auto *tensor = static_cast<tensor_type *>(tensor_);
81

82
template <typename T>
83
T *Tensor::mutable_data(PlaceType place) {
S
Steffy-zxf 已提交
84
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
85 86
  PADDLE_ENFORCE_GT(
      tensor->numel(), 0,
87 88
      paddle::platform::errors::PreconditionNotMet(
          "You should call Tensor::Reshape(const std::vector<int> "
W
Wilber 已提交
89 90
          "&shape)"
          "function before retrieving mutable_data from input tensor."));
91
  switch (static_cast<int>(place)) {
92 93
    case static_cast<int>(PlaceType::kCPU): {
      return tensor->mutable_data<T>(paddle::platform::CPUPlace());
94
    }
95 96 97 98 99
    case static_cast<int>(PlaceType::kGPU): {
      return tensor->mutable_data<T>(paddle::platform::CUDAPlace(device_));
    }
    case static_cast<int>(PlaceType::kXPU): {
      return tensor->mutable_data<T>(paddle::platform::XPUPlace(device_));
100
    }
101 102 103
    case static_cast<int>(PlaceType::kNPU): {
      return tensor->mutable_data<T>(paddle::platform::NPUPlace(device_));
    }
104
    default:
105
      PADDLE_THROW(paddle::platform::errors::Unavailable(
106 107
          "Only CPU / CUDA / XPU / NPU places is supported. The place `%d` is "
          "not supported.",
108
          static_cast<int>(place)));
109 110 111 112 113 114
      break;
  }
  return nullptr;
}

template <typename T>
115
T *Tensor::data(PlaceType *place, int *size) const {
S
Steffy-zxf 已提交
116
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
117 118
  auto *res = tensor->data<T>();

119 120 121 122 123 124
  if (paddle::platform::is_cpu_place(tensor->place())) {
    *place = PlaceType::kCPU;
  } else if (paddle::platform::is_gpu_place(tensor->place())) {
    *place = PlaceType::kGPU;
  } else if (paddle::platform::is_xpu_place(tensor->place())) {
    *place = PlaceType::kXPU;
125 126
  } else if (paddle::platform::is_npu_place(tensor->place())) {
    *place = PlaceType::kNPU;
127
  } else {
128
    *place = PlaceType::kUNK;
129 130 131 132 133 134
  }

  *size = tensor->numel();
  return res;
}

135
DataType Tensor::type() const {
136 137 138 139 140
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    return dtype_;
  }
#endif
S
Steffy-zxf 已提交
141
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
142
  auto type = paddle::framework::TransToProtoVarType(tensor->dtype());
143 144
  if (type == paddle::framework::proto::VarType::FP32) {
    return DataType::FLOAT32;
145 146
  } else if (type == paddle::framework::proto::VarType::FP16) {
    return DataType::FLOAT16;
147 148 149 150 151 152
  } else if (type == paddle::framework::proto::VarType::INT64) {
    return DataType::INT64;
  } else if (type == paddle::framework::proto::VarType::INT32) {
    return DataType::INT32;
  } else if (type == paddle::framework::proto::VarType::UINT8) {
    return DataType::UINT8;
153 154
  } else if (type == paddle::framework::proto::VarType::INT8) {
    return DataType::INT8;
155
  }
156
  return DataType::FLOAT32;
157 158
}

159 160
PlaceType Tensor::place() const { return place_; }

N
nhzlx 已提交
161
template <typename T>
162
void Tensor::CopyFromCpu(const T *data) {
163 164 165 166 167 168 169
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    ORTCopyFromCpu<T>(data);
    return;
  }
#endif

S
Steffy-zxf 已提交
170
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
W
Wilber 已提交
171
  PADDLE_ENFORCE_GE(tensor->numel(), 0,
172 173
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
W
Wilber 已提交
174 175
                        "std::vector<int> &shape)"
                        "function before copying data from cpu."));
N
nhzlx 已提交
176 177
  size_t ele_size = tensor->numel() * sizeof(T);

178 179
  if (place_ == PlaceType::kCPU) {
    auto *t_data = tensor->mutable_data<T>(paddle::platform::CPUPlace());
N
nhzlx 已提交
180
    std::memcpy(static_cast<void *>(t_data), data, ele_size);
181
  } else if (place_ == PlaceType::kGPU) {
182
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
183 184 185
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::CUDAPlace gpu_place(device_);
N
nhzlx 已提交
186
    auto *t_data = tensor->mutable_data<T>(gpu_place);
187 188
    auto *dev_ctx = static_cast<const paddle::platform::CUDADeviceContext *>(
        pool.Get(gpu_place));
N
nhzlx 已提交
189

190 191 192
    paddle::memory::Copy(gpu_place, static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(), data, ele_size,
                         dev_ctx->stream());
N
nhzlx 已提交
193
#else
194 195 196
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
197
#endif
198
  } else if (place_ == PlaceType::kXPU) {
199
#ifdef PADDLE_WITH_XPU
200
    paddle::platform::XPUPlace xpu_place(device_);
201
    auto *t_data = tensor->mutable_data<T>(xpu_place);
202 203
    paddle::memory::Copy(xpu_place, static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(), data, ele_size);
204
#else
205 206 207
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::NPUPlace npu_place(device_);
    auto *t_data = tensor->mutable_data<T>(npu_place);
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
    paddle::memory::Copy(npu_place, static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(), data, ele_size,
                         dev_ctx->stream());
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
224 225 226
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
227
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
N
nhzlx 已提交
228 229 230
  }
}

231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
template <typename T>
struct DataTypeInfo;

template <>
struct DataTypeInfo<float> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT32;
};

template <>
struct DataTypeInfo<float16> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT16;
};

template <>
struct DataTypeInfo<int64_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT64;
};

template <>
struct DataTypeInfo<int8_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT8;
};

template <>
struct DataTypeInfo<uint8_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::UINT8;
};

template <>
struct DataTypeInfo<int32_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT32;
};

paddle::experimental::DataLayout LayoutConvert(DataLayout layout) {
  PADDLE_ENFORCE_EQ(
      layout, DataLayout::kNCHW,
      paddle::platform::errors::InvalidArgument("Only NCHW is supported now."));
  return paddle::experimental::DataLayout::NCHW;
}

template <typename T>
void Tensor::ShareExternalData(const T *data, const std::vector<int> &shape,
                               PlaceType place, DataLayout layout) {
  EAGER_GET_TENSOR(paddle::framework::LoDTensor)
  size_t size =
      std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
      sizeof(T);
  phi::DenseTensorMeta meta(DataTypeInfo<T>().TYPE, phi::make_ddim(shape),
                            LayoutConvert(layout));
  if (place == PlaceType::kCPU) {
    phi::DenseTensor dtensor(
        std::make_shared<phi::Allocation>(const_cast<T *>(data), size,
                                          paddle::platform::CPUPlace()),
        meta);
    *tensor = std::move(dtensor);
  } else if (place == PlaceType::kGPU) {
    phi::DenseTensor dtensor(
        std::make_shared<phi::Allocation>(const_cast<T *>(data), size,
                                          paddle::platform::CUDAPlace(device_)),
        meta);
    *tensor = std::move(dtensor);
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "PlaceType must be PlaceType::kCPU or PlaceType::kGPU."));
  }
}

S
Steffy-zxf 已提交
298 299 300 301 302 303 304 305 306 307
void Tensor::CopyStringsFromCpu(const paddle_infer::Strings *data) {
  EAGER_GET_TENSOR(paddle_infer::Strings);
  PADDLE_ENFORCE_GE(tensor->size(), 0,
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
                        "std::size_t &shape)function before copying"
                        "the string data from cpu."));
  *tensor = *data;
}

N
nhzlx 已提交
308
template <typename T>
309 310
void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb,
                           void *cb_params) const {
S
Steffy-zxf 已提交
311
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
N
nhzlx 已提交
312 313 314 315
  auto ele_num = tensor->numel();
  auto *t_data = tensor->data<T>();
  auto t_place = tensor->place();

316
  paddle::framework::Tensor out;
317 318 319 320
  auto mem_allocation =
      std::make_shared<paddle::memory::allocation::Allocation>(
          static_cast<void *>(data), ele_num * sizeof(T),
          paddle::platform::CPUPlace());
321 322
  out.ResetHolder(mem_allocation);

323
  if (paddle::platform::is_cpu_place(t_place)) {
324 325 326 327 328 329 330 331 332
#ifdef PADDLE_WITH_MKLDNN
    if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN)
      paddle::framework::innerTransDataLayoutFromMKLDNN(
          tensor->layout(), paddle::platform::MKLDNNDeviceContext::tls()
                                .get_cur_paddle_data_layout(),
          *tensor, &out, paddle::platform::CPUPlace(), true);
    else
      std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
N
nhzlx 已提交
333
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
J
jianghaicheng 已提交
334 335 336 337 338 339 340 341
#endif
  } else if (paddle::platform::is_ipu_place(t_place)) {
#ifdef PADDLE_WITH_IPU
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with IPU place because paddle is not compiled "
        "with IPU."));
342
#endif
343
  } else if (place_ == PlaceType::kGPU) {
344
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
345 346
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
347
    auto gpu_place = t_place;
348 349 350 351 352
    auto *dev_ctx = static_cast<const paddle::platform::CUDADeviceContext *>(
        pool.Get(gpu_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
                         static_cast<void *>(data), gpu_place, t_data,
                         ele_num * sizeof(T), dev_ctx->stream());
353 354 355
#ifdef PADDLE_WITH_HIP
    hipStreamSynchronize(dev_ctx->stream());
#else
356 357 358 359 360 361 362 363 364 365
    // async, return stream
    if (nullptr != exec_stream) {
      *(static_cast<cudaStream_t *>(exec_stream)) = dev_ctx->stream();
      // async with callback
    } else if (cb) {
      cudaLaunchHostFunc(dev_ctx->stream(), cb, cb_params);
      // sync
    } else {
      cudaStreamSynchronize(dev_ctx->stream());
    }
366
#endif
N
nhzlx 已提交
367
#else
368 369 370
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
371
#endif
372
  } else if (place_ == PlaceType::kXPU) {
373
#ifdef PADDLE_WITH_XPU
374
    auto xpu_place = t_place;
375 376 377
    paddle::memory::Copy(paddle::platform::CPUPlace(),
                         static_cast<void *>(data), xpu_place, t_data,
                         ele_num * sizeof(T));
378
#else
379 380 381
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
382 383 384 385 386
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
387
    auto npu_place = t_place;
W
Wilber 已提交
388 389 390 391 392
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
                         static_cast<void *>(data), npu_place, t_data,
                         ele_num * sizeof(T), dev_ctx->stream());
393
    paddle::platform::NPUStreamSync(dev_ctx->stream());
W
Wilber 已提交
394 395 396 397
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
398 399 400
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
401
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
N
nhzlx 已提交
402 403
  }
}
404 405 406

template <typename T>
void Tensor::CopyToCpu(T *data) const {
407 408 409 410 411 412 413
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    ORTCopyToCpu<T>(data);
    return;
  }
#endif

414 415 416 417 418 419 420 421 422 423 424 425 426
  CopyToCpuImpl<T>(data, nullptr, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, void *exec_stream) const {
  CopyToCpuImpl<T>(data, exec_stream, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, CallbackFunc cb, void *cb_params) const {
  CopyToCpuImpl<T>(data, nullptr, cb, cb_params);
}

427 428 429 430 431
template PD_INFER_DECL void Tensor::CopyFromCpu<float>(const float *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int64_t>(const int64_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int32_t>(const int32_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<uint8_t>(const uint8_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int8_t>(const int8_t *data);
432
template PD_INFER_DECL void Tensor::CopyFromCpu<float16>(const float16 *data);
433

434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
template PD_INFER_DECL void Tensor::ShareExternalData<float>(
    const float *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int64_t>(
    const int64_t *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int32_t>(
    const int32_t *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<uint8_t>(
    const uint8_t *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int8_t>(
    const int8_t *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<float16>(
    const float16 *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);

453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
template PD_INFER_DECL void Tensor::CopyToCpu<float>(float *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int64_t>(int64_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int32_t>(int32_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<uint8_t>(uint8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int8_t>(int8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<float16>(float16 *data) const;

template PD_INFER_DECL void Tensor::CopyToCpuImpl<float>(float *data,
                                                         void *exec_stream,
                                                         CallbackFunc cb,
                                                         void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int64_t>(
    int64_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int32_t>(
    int32_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<uint8_t>(
    uint8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int8_t>(
    int8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<float16>(
    float16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, void *exec_stream) const;

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, CallbackFunc cb, void *cb_params) const;
500

501 502 503 504 505 506 507 508 509 510
template PD_INFER_DECL float *Tensor::data<float>(PlaceType *place,
                                                  int *size) const;
template PD_INFER_DECL int64_t *Tensor::data<int64_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int32_t *Tensor::data<int32_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL uint8_t *Tensor::data<uint8_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int8_t *Tensor::data<int8_t>(PlaceType *place,
                                                    int *size) const;
511 512
template PD_INFER_DECL float16 *Tensor::data<float16>(PlaceType *place,
                                                      int *size) const;
513

514 515 516 517 518
template PD_INFER_DECL float *Tensor::mutable_data<float>(PlaceType place);
template PD_INFER_DECL int64_t *Tensor::mutable_data<int64_t>(PlaceType place);
template PD_INFER_DECL int32_t *Tensor::mutable_data<int32_t>(PlaceType place);
template PD_INFER_DECL uint8_t *Tensor::mutable_data<uint8_t>(PlaceType place);
template PD_INFER_DECL int8_t *Tensor::mutable_data<int8_t>(PlaceType place);
519
template PD_INFER_DECL float16 *Tensor::mutable_data<float16>(PlaceType place);
520

521
Tensor::Tensor(void *scope) : scope_{scope} {}
522

S
Steffy-zxf 已提交
523
template <typename T>
524
void *Tensor::FindTensor() const {
W
Wilber 已提交
525 526
  PADDLE_ENFORCE_EQ(
      name_.empty(), false,
527
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
528 529
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
530
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
531
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
532
  PADDLE_ENFORCE_NOT_NULL(
533
      var, paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
534
               "No tensor called [%s] in the runtime scope", name_));
S
Steffy-zxf 已提交
535
  auto *tensor = var->GetMutable<T>();
536 537 538
  return tensor;
}

539
std::vector<int> Tensor::shape() const {
540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559
#ifdef PADDLE_WITH_ONNXRUNTIME
  if (is_ort_tensor_) {
    std::vector<int> shape;
    // input handle
    if (idx_ < 0) {
      shape.assign(shape_.begin(), shape_.end());
    } else {  // output handle
      auto binding = binding_.lock();
      PADDLE_ENFORCE_NOT_NULL(binding,
                              paddle::platform::errors::PreconditionNotMet(
                                  "output tensor [%s] no binding ptr", name_));
      std::vector<Ort::Value> outputs = binding->GetOutputValues();
      Ort::Value &value = outputs[idx_];
      auto info = value.GetTensorTypeAndShapeInfo();
      auto ort_shape = info.GetShape();
      shape.assign(ort_shape.begin(), ort_shape.end());
    }
    return shape;
  }
#endif
S
Steffy-zxf 已提交
560
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
W
Wilber 已提交
561
  PADDLE_ENFORCE_NOT_NULL(
562
      tensor_, paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
563
                   "Not found tensor called %s in the scope", name_));
W
wenbin 已提交
564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579
// mkldnn may does layout transform internally, so need to reorder before
// return
#ifdef PADDLE_WITH_MKLDNN
  if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN) {
    paddle::framework::DataLayout out_layout =
        paddle::platform::MKLDNNDeviceContext::tls()
            .get_cur_paddle_data_layout();
    // Set default as NCHW in case not specified
    out_layout = out_layout == paddle::framework::DataLayout::kAnyLayout
                     ? paddle::framework::DataLayout::kNCHW
                     : out_layout;
    // In these data layouts, channel dimension is either on 2nd position: nChw
    // or
    // at last nhwC, so for dim==2 these layouts are the same and nothing should
    // be done. Similarly for dim==1 when you have just one possible
    // combination.
580
    if (tensor->dims().size() < 3) return phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
581
    if (out_layout == paddle::framework::DataLayout::kNHWC) {
582
      auto dims = phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
583 584 585
      std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
      return dims;
    } else {
586
      return phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
587 588 589
    }
  }
#endif
590
  return phi::vectorize<int>(tensor->dims());
591 592
}

593
void Tensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
S
Steffy-zxf 已提交
594
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
595
  paddle::framework::LoD lod;
596 597 598 599 600 601
  for (auto &level : x) {
    lod.emplace_back(level);
  }
  tensor->set_lod(lod);
}

602
std::vector<std::vector<size_t>> Tensor::lod() const {
S
Steffy-zxf 已提交
603
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
604 605 606 607 608 609 610
  std::vector<std::vector<size_t>> res;
  for (auto &level : tensor->lod()) {
    res.emplace_back(level);
  }
  return res;
}

611 612 613 614 615 616 617 618 619
void Tensor::SetName(const std::string &name) { name_ = name; }

const std::string &Tensor::name() const { return name_; }

void Tensor::SetPlace(PlaceType place, int device) {
  place_ = place;
  device_ = device;
}

620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714
#ifdef PADDLE_WITH_ONNXRUNTIME
void Tensor::SetOrtMark(bool is_ort_tensor) { is_ort_tensor_ = is_ort_tensor; }

void Tensor::SetOrtBinding(const std::shared_ptr<Ort::IoBinding> binding) {
  binding_ = binding;
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, float *data,
                       size_t size, const int64_t *shape, size_t shape_len) {
  return Ort::Value::CreateTensor<float>(memory_info, data, size, shape,
                                         shape_len);
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int64_t *data,
                       size_t size, const int64_t *shape, size_t shape_len) {
  return Ort::Value::CreateTensor<int64_t>(memory_info, data, size, shape,
                                           shape_len);
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int32_t *data,
                       size_t size, const int64_t *shape, size_t shape_len) {
  return Ort::Value::CreateTensor<int32_t>(memory_info, data, size, shape,
                                           shape_len);
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, uint8_t *data,
                       size_t size, const int64_t *shape, size_t shape_len) {
  return Ort::Value::CreateTensor<uint8_t>(memory_info, data, size, shape,
                                           shape_len);
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, int8_t *data,
                       size_t size, const int64_t *shape, size_t shape_len) {
  return Ort::Value::CreateTensor<int8_t>(memory_info, data, size, shape,
                                          shape_len);
}

Ort::Value GetOrtVaule(const Ort::MemoryInfo &memory_info, float16 *data,
                       size_t size, const int64_t *shape, size_t shape_len) {
  return Ort::Value::CreateTensor(memory_info, static_cast<void *>(data),
                                  size * sizeof(float16), shape, shape_len,
                                  ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
}

template <typename T>
void Tensor::ORTCopyFromCpu(const T *data) {
  auto binding = binding_.lock();
  PADDLE_ENFORCE_NOT_NULL(binding,
                          paddle::platform::errors::PreconditionNotMet(
                              "input tensor [%s] no binding ptr", name_));
  const char *device_name = place_ == PlaceType::kCPU ? "Cpu" : "Cuda";
  Ort::MemoryInfo memory_info(device_name, OrtDeviceAllocator, device_,
                              OrtMemTypeDefault);
  size_t size = std::accumulate(begin(shape_), end(shape_), 1UL,
                                std::multiplies<size_t>());
  auto ort_value = GetOrtVaule(memory_info, const_cast<T *>(data), size,
                               shape_.data(), shape_.size());
  binding->BindInput(name_.c_str(), ort_value);
}

template <typename T>
void Tensor::ORTCopyToCpu(T *data) const {
  auto binding = binding_.lock();
  PADDLE_ENFORCE_NOT_NULL(binding,
                          paddle::platform::errors::PreconditionNotMet(
                              "output tensor [%s] no binding ptr", name_));
  std::vector<Ort::Value> outputs = binding->GetOutputValues();
  Ort::Value &value = outputs[idx_];
  auto info = value.GetTensorTypeAndShapeInfo();
  size_t size = info.GetElementCount() * sizeof(T);

  if (place_ == PlaceType::kCPU) {
    std::memcpy(static_cast<void *>(data), value.GetTensorData<void *>(), size);
  } else {
    paddle::memory::Copy(paddle::platform::CPUPlace(),
                         static_cast<void *>(data),
                         paddle::platform::CUDAPlace(device_),
                         value.GetTensorData<void>(), size, nullptr);
  }
}

template void Tensor::ORTCopyFromCpu<float>(const float *data);
template void Tensor::ORTCopyFromCpu<int64_t>(const int64_t *data);
template void Tensor::ORTCopyFromCpu<int32_t>(const int32_t *data);
template void Tensor::ORTCopyFromCpu<uint8_t>(const uint8_t *data);
template void Tensor::ORTCopyFromCpu<int8_t>(const int8_t *data);
template void Tensor::ORTCopyFromCpu<float16>(const float16 *data);

template void Tensor::ORTCopyToCpu<float>(float *data) const;
template void Tensor::ORTCopyToCpu<int32_t>(int32_t *data) const;
template void Tensor::ORTCopyToCpu<uint8_t>(uint8_t *data) const;
template void Tensor::ORTCopyToCpu<int8_t>(int8_t *data) const;
template void Tensor::ORTCopyToCpu<float16>(float16 *data) const;
#endif

715
}  // namespace paddle_infer