zero_copy_tensor.cc 23.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/framework/convert_utils.h"
16
#include "paddle/fluid/framework/data_layout_transform.h"
17 18 19
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
W
Wilber 已提交
20
#include "paddle/fluid/inference/api/paddle_tensor.h"
N
nhzlx 已提交
21
#include "paddle/fluid/memory/memcpy.h"
22
#include "paddle/fluid/platform/enforce.h"
23
#include "paddle/fluid/platform/float16.h"
24
#include "paddle/phi/core/allocator.h"
25

26
namespace paddle_infer {
27

28 29
using float16 = paddle::platform::float16;

30
void Tensor::Reshape(const std::vector<int> &shape) {
W
Wilber 已提交
31 32
  PADDLE_ENFORCE_EQ(
      name_.empty(), false,
33
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
34 35 36
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
  PADDLE_ENFORCE_EQ(input_or_output_, true,
37
                    paddle::platform::errors::PermissionDenied(
W
Wilber 已提交
38
                        "Can't reshape the output tensor, it is readonly"));
39
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
40
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
41
  PADDLE_ENFORCE_NOT_NULL(
42
      var, paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
43
               "No tensor called [%s] in the runtime scope", name_));
44
  auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
45
  tensor->Resize(phi::make_ddim(shape));
46 47
}

S
Steffy-zxf 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
void Tensor::ReshapeStrings(const size_t &shape) {
  PADDLE_ENFORCE_EQ(
      name_.empty(), false,
      paddle::platform::errors::PreconditionNotMet(
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
  PADDLE_ENFORCE_EQ(input_or_output_, true,
                    paddle::platform::errors::PermissionDenied(
                        "Can't reshape the output tensor, it is readonly"));
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
  auto *var = scope->FindVar(name_);
  PADDLE_ENFORCE_NOT_NULL(
      var, paddle::platform::errors::PreconditionNotMet(
               "No tensor called [%s] in the runtime scope", name_));
  paddle_infer::Strings *tensor = var->GetMutable<paddle_infer::Strings>();
  tensor->resize(shape);
}

#define EAGER_GET_TENSOR(tensor_type)    \
  if (!tensor_) {                        \
    tensor_ = FindTensor<tensor_type>(); \
  }                                      \
  auto *tensor = static_cast<tensor_type *>(tensor_);
71

72
template <typename T>
73
T *Tensor::mutable_data(PlaceType place) {
S
Steffy-zxf 已提交
74
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
75 76
  PADDLE_ENFORCE_GT(
      tensor->numel(), 0,
77 78
      paddle::platform::errors::PreconditionNotMet(
          "You should call Tensor::Reshape(const std::vector<int> "
W
Wilber 已提交
79 80
          "&shape)"
          "function before retrieving mutable_data from input tensor."));
81
  switch (static_cast<int>(place)) {
82 83
    case static_cast<int>(PlaceType::kCPU): {
      return tensor->mutable_data<T>(paddle::platform::CPUPlace());
84
    }
85 86 87 88 89
    case static_cast<int>(PlaceType::kGPU): {
      return tensor->mutable_data<T>(paddle::platform::CUDAPlace(device_));
    }
    case static_cast<int>(PlaceType::kXPU): {
      return tensor->mutable_data<T>(paddle::platform::XPUPlace(device_));
90
    }
91 92 93
    case static_cast<int>(PlaceType::kNPU): {
      return tensor->mutable_data<T>(paddle::platform::NPUPlace(device_));
    }
94
    default:
95
      PADDLE_THROW(paddle::platform::errors::Unavailable(
96 97
          "Only CPU / CUDA / XPU / NPU places is supported. The place `%d` is "
          "not supported.",
98
          static_cast<int>(place)));
99 100 101 102 103 104
      break;
  }
  return nullptr;
}

template <typename T>
105
T *Tensor::data(PlaceType *place, int *size) const {
S
Steffy-zxf 已提交
106
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
107 108
  auto *res = tensor->data<T>();

109 110 111 112 113 114
  if (paddle::platform::is_cpu_place(tensor->place())) {
    *place = PlaceType::kCPU;
  } else if (paddle::platform::is_gpu_place(tensor->place())) {
    *place = PlaceType::kGPU;
  } else if (paddle::platform::is_xpu_place(tensor->place())) {
    *place = PlaceType::kXPU;
115 116
  } else if (paddle::platform::is_npu_place(tensor->place())) {
    *place = PlaceType::kNPU;
117
  } else {
118
    *place = PlaceType::kUNK;
119 120 121 122 123 124
  }

  *size = tensor->numel();
  return res;
}

125
DataType Tensor::type() const {
S
Steffy-zxf 已提交
126
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
127
  auto type = paddle::framework::TransToProtoVarType(tensor->dtype());
128 129
  if (type == paddle::framework::proto::VarType::FP32) {
    return DataType::FLOAT32;
130 131
  } else if (type == paddle::framework::proto::VarType::FP16) {
    return DataType::FLOAT16;
132 133 134 135 136 137
  } else if (type == paddle::framework::proto::VarType::INT64) {
    return DataType::INT64;
  } else if (type == paddle::framework::proto::VarType::INT32) {
    return DataType::INT32;
  } else if (type == paddle::framework::proto::VarType::UINT8) {
    return DataType::UINT8;
138 139
  } else if (type == paddle::framework::proto::VarType::INT8) {
    return DataType::INT8;
140
  }
141
  return DataType::FLOAT32;
142 143
}

144 145
PlaceType Tensor::place() const { return place_; }

N
nhzlx 已提交
146
template <typename T>
147
void Tensor::CopyFromCpu(const T *data) {
S
Steffy-zxf 已提交
148
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
W
Wilber 已提交
149
  PADDLE_ENFORCE_GE(tensor->numel(), 0,
150 151
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
W
Wilber 已提交
152 153
                        "std::vector<int> &shape)"
                        "function before copying data from cpu."));
N
nhzlx 已提交
154 155
  size_t ele_size = tensor->numel() * sizeof(T);

156 157
  if (place_ == PlaceType::kCPU) {
    auto *t_data = tensor->mutable_data<T>(paddle::platform::CPUPlace());
N
nhzlx 已提交
158
    std::memcpy(static_cast<void *>(t_data), data, ele_size);
159
  } else if (place_ == PlaceType::kGPU) {
160
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
161 162 163
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::CUDAPlace gpu_place(device_);
N
nhzlx 已提交
164
    auto *t_data = tensor->mutable_data<T>(gpu_place);
165 166
    auto *dev_ctx = static_cast<const paddle::platform::CUDADeviceContext *>(
        pool.Get(gpu_place));
N
nhzlx 已提交
167

168 169 170
    paddle::memory::Copy(gpu_place, static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(), data, ele_size,
                         dev_ctx->stream());
N
nhzlx 已提交
171
#else
172 173 174
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
175
#endif
176
  } else if (place_ == PlaceType::kXPU) {
177
#ifdef PADDLE_WITH_XPU
178
    paddle::platform::XPUPlace xpu_place(device_);
179
    auto *t_data = tensor->mutable_data<T>(xpu_place);
180 181
    paddle::memory::Copy(xpu_place, static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(), data, ele_size);
182
#else
183 184 185
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    paddle::platform::NPUPlace npu_place(device_);
    auto *t_data = tensor->mutable_data<T>(npu_place);
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
    paddle::memory::Copy(npu_place, static_cast<void *>(t_data),
                         paddle::platform::CPUPlace(), data, ele_size,
                         dev_ctx->stream());
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
202 203 204
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
205
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
N
nhzlx 已提交
206 207 208
  }
}

209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
template <typename T>
struct DataTypeInfo;

template <>
struct DataTypeInfo<float> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT32;
};

template <>
struct DataTypeInfo<float16> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::FLOAT16;
};

template <>
struct DataTypeInfo<int64_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT64;
};

template <>
struct DataTypeInfo<int8_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT8;
};

template <>
struct DataTypeInfo<uint8_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::UINT8;
};

template <>
struct DataTypeInfo<int32_t> {
  paddle::experimental::DataType TYPE = paddle::experimental::DataType::INT32;
};

paddle::experimental::DataLayout LayoutConvert(DataLayout layout) {
  PADDLE_ENFORCE_EQ(
      layout, DataLayout::kNCHW,
      paddle::platform::errors::InvalidArgument("Only NCHW is supported now."));
  return paddle::experimental::DataLayout::NCHW;
}

template <typename T>
void Tensor::ShareExternalData(const T *data, const std::vector<int> &shape,
                               PlaceType place, DataLayout layout) {
  EAGER_GET_TENSOR(paddle::framework::LoDTensor)
  size_t size =
      std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
      sizeof(T);
  phi::DenseTensorMeta meta(DataTypeInfo<T>().TYPE, phi::make_ddim(shape),
                            LayoutConvert(layout));
  if (place == PlaceType::kCPU) {
    phi::DenseTensor dtensor(
        std::make_shared<phi::Allocation>(const_cast<T *>(data), size,
                                          paddle::platform::CPUPlace()),
        meta);
    *tensor = std::move(dtensor);
  } else if (place == PlaceType::kGPU) {
    phi::DenseTensor dtensor(
        std::make_shared<phi::Allocation>(const_cast<T *>(data), size,
                                          paddle::platform::CUDAPlace(device_)),
        meta);
    *tensor = std::move(dtensor);
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "PlaceType must be PlaceType::kCPU or PlaceType::kGPU."));
  }
}

S
Steffy-zxf 已提交
276 277 278 279 280 281 282 283 284 285
void Tensor::CopyStringsFromCpu(const paddle_infer::Strings *data) {
  EAGER_GET_TENSOR(paddle_infer::Strings);
  PADDLE_ENFORCE_GE(tensor->size(), 0,
                    paddle::platform::errors::PreconditionNotMet(
                        "You should call Tensor::Reshape(const "
                        "std::size_t &shape)function before copying"
                        "the string data from cpu."));
  *tensor = *data;
}

N
nhzlx 已提交
286
template <typename T>
287 288
void Tensor::CopyToCpuImpl(T *data, void *exec_stream, CallbackFunc cb,
                           void *cb_params) const {
S
Steffy-zxf 已提交
289
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
N
nhzlx 已提交
290 291 292 293
  auto ele_num = tensor->numel();
  auto *t_data = tensor->data<T>();
  auto t_place = tensor->place();

294
  paddle::framework::Tensor out;
295 296 297 298
  auto mem_allocation =
      std::make_shared<paddle::memory::allocation::Allocation>(
          static_cast<void *>(data), ele_num * sizeof(T),
          paddle::platform::CPUPlace());
299 300
  out.ResetHolder(mem_allocation);

301
  if (paddle::platform::is_cpu_place(t_place)) {
302 303 304 305 306 307 308 309 310
#ifdef PADDLE_WITH_MKLDNN
    if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN)
      paddle::framework::innerTransDataLayoutFromMKLDNN(
          tensor->layout(), paddle::platform::MKLDNNDeviceContext::tls()
                                .get_cur_paddle_data_layout(),
          *tensor, &out, paddle::platform::CPUPlace(), true);
    else
      std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
N
nhzlx 已提交
311
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
J
jianghaicheng 已提交
312 313 314 315 316 317 318 319
#endif
  } else if (paddle::platform::is_ipu_place(t_place)) {
#ifdef PADDLE_WITH_IPU
    std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with IPU place because paddle is not compiled "
        "with IPU."));
320
#endif
321
  } else if (place_ == PlaceType::kGPU) {
322
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
323 324
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
325
    auto gpu_place = t_place;
326 327 328 329 330
    auto *dev_ctx = static_cast<const paddle::platform::CUDADeviceContext *>(
        pool.Get(gpu_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
                         static_cast<void *>(data), gpu_place, t_data,
                         ele_num * sizeof(T), dev_ctx->stream());
331 332 333
#ifdef PADDLE_WITH_HIP
    hipStreamSynchronize(dev_ctx->stream());
#else
334 335 336 337 338 339 340 341 342 343
    // async, return stream
    if (nullptr != exec_stream) {
      *(static_cast<cudaStream_t *>(exec_stream)) = dev_ctx->stream();
      // async with callback
    } else if (cb) {
      cudaLaunchHostFunc(dev_ctx->stream(), cb, cb_params);
      // sync
    } else {
      cudaStreamSynchronize(dev_ctx->stream());
    }
344
#endif
N
nhzlx 已提交
345
#else
346 347 348
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with CUDA place because paddle is not compiled "
        "with CUDA."));
N
nhzlx 已提交
349
#endif
350
  } else if (place_ == PlaceType::kXPU) {
351
#ifdef PADDLE_WITH_XPU
352
    auto xpu_place = t_place;
353 354 355
    paddle::memory::Copy(paddle::platform::CPUPlace(),
                         static_cast<void *>(data), xpu_place, t_data,
                         ele_num * sizeof(T));
356
#else
357 358 359
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with XPU place because paddle is not compiled "
        "with XPU."));
W
Wilber 已提交
360 361 362 363 364
#endif
  } else if (place_ == PlaceType::kNPU) {
#ifdef PADDLE_WITH_ASCEND_CL
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
365
    auto npu_place = t_place;
W
Wilber 已提交
366 367 368 369 370
    auto *dev_ctx = static_cast<const paddle::platform::NPUDeviceContext *>(
        pool.Get(npu_place));
    paddle::memory::Copy(paddle::platform::CPUPlace(),
                         static_cast<void *>(data), npu_place, t_data,
                         ele_num * sizeof(T), dev_ctx->stream());
371
    paddle::platform::NPUStreamSync(dev_ctx->stream());
W
Wilber 已提交
372 373 374 375
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Can not create tensor with NPU place because paddle is not compiled "
        "with NPU."));
376 377 378
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
379
        "The analysis predictor supports CPU, GPU, NPU and XPU now."));
N
nhzlx 已提交
380 381
  }
}
382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397

template <typename T>
void Tensor::CopyToCpu(T *data) const {
  CopyToCpuImpl<T>(data, nullptr, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, void *exec_stream) const {
  CopyToCpuImpl<T>(data, exec_stream, nullptr, nullptr);
}

template <typename T>
void Tensor::CopyToCpuAsync(T *data, CallbackFunc cb, void *cb_params) const {
  CopyToCpuImpl<T>(data, nullptr, cb, cb_params);
}

398 399 400 401 402
template PD_INFER_DECL void Tensor::CopyFromCpu<float>(const float *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int64_t>(const int64_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int32_t>(const int32_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<uint8_t>(const uint8_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<int8_t>(const int8_t *data);
403
template PD_INFER_DECL void Tensor::CopyFromCpu<float16>(const float16 *data);
404

405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
template PD_INFER_DECL void Tensor::ShareExternalData<float>(
    const float *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int64_t>(
    const int64_t *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int32_t>(
    const int32_t *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<uint8_t>(
    const uint8_t *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int8_t>(
    const int8_t *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<float16>(
    const float16 *data, const std::vector<int> &shape, PlaceType place,
    DataLayout layout);

424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470
template PD_INFER_DECL void Tensor::CopyToCpu<float>(float *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int64_t>(int64_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int32_t>(int32_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<uint8_t>(uint8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int8_t>(int8_t *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<float16>(float16 *data) const;

template PD_INFER_DECL void Tensor::CopyToCpuImpl<float>(float *data,
                                                         void *exec_stream,
                                                         CallbackFunc cb,
                                                         void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int64_t>(
    int64_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int32_t>(
    int32_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<uint8_t>(
    uint8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<int8_t>(
    int8_t *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuImpl<float16>(
    float16 *data, void *exec_stream, CallbackFunc cb, void *cb_params) const;

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, void *exec_stream) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, void *exec_stream) const;

template PD_INFER_DECL void Tensor::CopyToCpuAsync<float>(
    float *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int64_t>(
    int64_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int32_t>(
    int32_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<uint8_t>(
    uint8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<int8_t>(
    int8_t *data, CallbackFunc cb, void *cb_params) const;
template PD_INFER_DECL void Tensor::CopyToCpuAsync<float16>(
    float16 *data, CallbackFunc cb, void *cb_params) const;
471

472 473 474 475 476 477 478 479 480 481
template PD_INFER_DECL float *Tensor::data<float>(PlaceType *place,
                                                  int *size) const;
template PD_INFER_DECL int64_t *Tensor::data<int64_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int32_t *Tensor::data<int32_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL uint8_t *Tensor::data<uint8_t>(PlaceType *place,
                                                      int *size) const;
template PD_INFER_DECL int8_t *Tensor::data<int8_t>(PlaceType *place,
                                                    int *size) const;
482 483
template PD_INFER_DECL float16 *Tensor::data<float16>(PlaceType *place,
                                                      int *size) const;
484

485 486 487 488 489
template PD_INFER_DECL float *Tensor::mutable_data<float>(PlaceType place);
template PD_INFER_DECL int64_t *Tensor::mutable_data<int64_t>(PlaceType place);
template PD_INFER_DECL int32_t *Tensor::mutable_data<int32_t>(PlaceType place);
template PD_INFER_DECL uint8_t *Tensor::mutable_data<uint8_t>(PlaceType place);
template PD_INFER_DECL int8_t *Tensor::mutable_data<int8_t>(PlaceType place);
490
template PD_INFER_DECL float16 *Tensor::mutable_data<float16>(PlaceType place);
491

492 493 494 495 496 497
Tensor::Tensor(void *scope) : scope_{scope} {
  PADDLE_ENFORCE_NOT_NULL(scope_,
                          paddle::platform::errors::PreconditionNotMet(
                              "The `scope` can not be nullptr. It should be "
                              "set to the pointer of scope."));
}
498

S
Steffy-zxf 已提交
499
template <typename T>
500
void *Tensor::FindTensor() const {
W
Wilber 已提交
501 502
  PADDLE_ENFORCE_EQ(
      name_.empty(), false,
503
      paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
504 505
          "Need to SetName first, so that the corresponding tensor can "
          "be retrieved."));
506
  auto *scope = static_cast<paddle::framework::Scope *>(scope_);
507
  auto *var = scope->FindVar(name_);
W
Wilber 已提交
508
  PADDLE_ENFORCE_NOT_NULL(
509
      var, paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
510
               "No tensor called [%s] in the runtime scope", name_));
S
Steffy-zxf 已提交
511
  auto *tensor = var->GetMutable<T>();
512 513 514
  return tensor;
}

515
std::vector<int> Tensor::shape() const {
S
Steffy-zxf 已提交
516
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
W
Wilber 已提交
517
  PADDLE_ENFORCE_NOT_NULL(
518
      tensor_, paddle::platform::errors::PreconditionNotMet(
W
Wilber 已提交
519
                   "Not found tensor called %s in the scope", name_));
W
wenbin 已提交
520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535
// mkldnn may does layout transform internally, so need to reorder before
// return
#ifdef PADDLE_WITH_MKLDNN
  if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN) {
    paddle::framework::DataLayout out_layout =
        paddle::platform::MKLDNNDeviceContext::tls()
            .get_cur_paddle_data_layout();
    // Set default as NCHW in case not specified
    out_layout = out_layout == paddle::framework::DataLayout::kAnyLayout
                     ? paddle::framework::DataLayout::kNCHW
                     : out_layout;
    // In these data layouts, channel dimension is either on 2nd position: nChw
    // or
    // at last nhwC, so for dim==2 these layouts are the same and nothing should
    // be done. Similarly for dim==1 when you have just one possible
    // combination.
536
    if (tensor->dims().size() < 3) return phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
537
    if (out_layout == paddle::framework::DataLayout::kNHWC) {
538
      auto dims = phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
539 540 541
      std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
      return dims;
    } else {
542
      return phi::vectorize<int>(tensor->dims());
W
wenbin 已提交
543 544 545
    }
  }
#endif
546
  return phi::vectorize<int>(tensor->dims());
547 548
}

549
void Tensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
S
Steffy-zxf 已提交
550
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
551
  paddle::framework::LoD lod;
552 553 554 555 556 557
  for (auto &level : x) {
    lod.emplace_back(level);
  }
  tensor->set_lod(lod);
}

558
std::vector<std::vector<size_t>> Tensor::lod() const {
S
Steffy-zxf 已提交
559
  EAGER_GET_TENSOR(paddle::framework::LoDTensor);
560 561 562 563 564 565 566
  std::vector<std::vector<size_t>> res;
  for (auto &level : tensor->lod()) {
    res.emplace_back(level);
  }
  return res;
}

567 568 569 570 571 572 573 574 575 576
void Tensor::SetName(const std::string &name) { name_ = name; }

const std::string &Tensor::name() const { return name_; }

void Tensor::SetPlace(PlaceType place, int device) {
  place_ = place;
  device_ = device;
}

}  // namespace paddle_infer