tensor_py.h 42.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#pragma once
16

L
Luo Tao 已提交
17
#include <Python.h>
18 19 20 21
// Avoid a problem with copysign defined in pyconfig.h on Windows.
#ifdef copysign
#undef copysign
#endif
22

W
wopeizl 已提交
23 24
#include <algorithm>
#include <memory>
Q
qijun 已提交
25
#include <string>
C
chengduoZH 已提交
26
#include <tuple>
27
#include <type_traits>
28
#include <utility>
C
chengduoZH 已提交
29
#include <vector>
30

31
#include "paddle/fluid/framework/data_type.h"
Y
Yi Wang 已提交
32 33
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
34
#include "paddle/fluid/operators/eigen/eigen_function.h"
W
wopeizl 已提交
35
#include "paddle/fluid/operators/math/concat_and_split.h"
36
#include "paddle/fluid/platform/bfloat16.h"
37
#include "paddle/fluid/platform/device/device_wrapper.h"
38
#include "paddle/fluid/pybind/complex.h"
39
#include "paddle/phi/kernels/funcs/strided_memcpy.h"
40
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
41 42
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
43
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
44
#include "paddle/fluid/framework/convert_utils.h"
Z
zyfncg 已提交
45
#include "paddle/fluid/framework/eigen.h"
Y
Yi Wang 已提交
46
#include "paddle/fluid/platform/device_context.h"
47
#include "paddle/fluid/platform/float16.h"
48
#include "paddle/fluid/platform/profiler/event_tracing.h"
49
#include "paddle/phi/common/pstring.h"
J
Jack Zhou 已提交
50 51
#include "paddle/phi/core/string_tensor.h"
#include "paddle/phi/kernels/strings/unicode.h"
Q
qijun 已提交
52 53
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
54

W
wopeizl 已提交
55 56
namespace py = pybind11;

57 58 59 60 61 62 63
namespace pybind11 {
namespace detail {

// Note: use same enum number of float16 in numpy.
// import numpy as np
// print np.dtype(np.float16).num  # 23
constexpr int NPY_FLOAT16_ = 23;
64
constexpr int NPY_UINT16_ = 4;
65 66
constexpr int NPY_COMPLEX64 = 14;
constexpr int NPY_COMPLEX128 = 15;
67

68 69 70 71 72 73
template <typename T, typename S>
struct casting_complex_to_non_complex {
  static const bool value = pybind11::detail::is_complex<S>::value &&
                            !pybind11::detail::is_complex<T>::value;
};

W
wanghuancoder 已提交
74
// cast numpy type form S to T, this may allocate new memory
75 76 77 78 79
template <
    class T,
    class S,
    std::enable_if_t<!std::is_same<T, S>::value &&
                     !casting_complex_to_non_complex<T, S>::value> * = nullptr>
W
wanghuancoder 已提交
80 81 82 83 84 85 86 87 88 89 90 91
static py::array_t<T> CastNumpyType(py::array_t<S> array) {
  auto dim = array.ndim();
  std::vector<py::ssize_t> result_shape(dim);
  for (auto i = 0; i < dim; i++) {
    result_shape[i] = array.shape(i);
  }

  py::array_t<T> result(result_shape);

  return py::vectorize([](S s) { return static_cast<T>(s); })(array);
}

92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
template <
    class T,
    class S,
    std::enable_if_t<(!std::is_same<T, S>::value) &&
                     casting_complex_to_non_complex<T, S>::value> * = nullptr>
static py::array_t<T> CastNumpyType(py::array_t<S> array) {
  auto dim = array.ndim();
  std::vector<py::ssize_t> result_shape(dim);
  for (auto i = 0; i < dim; i++) {
    result_shape[i] = array.shape(i);
  }

  py::array_t<T> result(result_shape);

  return py::vectorize([](S s) { return static_cast<T>(s.real()); })(array);
}

template <class T,
          class S,
          std::enable_if_t<std::is_same<T, S>::value> * = nullptr>
static py::array_t<T> CastNumpyType(py::array_t<S> array) {
  return array;
}

W
wanghuancoder 已提交
116 117 118 119 120 121 122 123 124 125 126 127
template <class T>
static py::array_t<T> CastNumpyArray(const py::object &array) {
  if (py::isinstance<py::array_t<float>>(array)) {
    return CastNumpyType<T>(array.cast<py::array_t<float>>());
  } else if (py::isinstance<py::array_t<double>>(array)) {
    return CastNumpyType<T>(array.cast<py::array_t<double>>());
  } else if (py::isinstance<py::array_t<int32_t>>(array)) {
    return CastNumpyType<T>(array.cast<py::array_t<int32_t>>());
  } else if (py::isinstance<py::array_t<int64_t>>(array)) {
    return CastNumpyType<T>(array.cast<py::array_t<int64_t>>());
  } else if (py::isinstance<py::array_t<bool>>(array)) {
    return CastNumpyType<T>(array.cast<py::array_t<bool>>());
128 129 130 131
  } else if (py::isinstance<py::array_t<std::complex<float>>>(array)) {
    return CastNumpyType<T>(array.cast<py::array_t<std::complex<float>>>());
  } else if (py::isinstance<py::array_t<std::complex<double>>>(array)) {
    return CastNumpyType<T>(array.cast<py::array_t<std::complex<double>>>());
W
wanghuancoder 已提交
132 133 134
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "Value type error. The assign numpy value allows integer, float, "
135
        "double, complex64, complex128, and bool, "
W
wanghuancoder 已提交
136 137 138 139 140 141 142
        "but received %s.",
        Py_TYPE(array.ptr())->tp_name));
  }
  // can't reach here
  return py::array_t<T>();
}

143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
// Note: Since float16 is not a builtin type in C++, we register
// paddle::platform::float16 as numpy.float16.
// Ref: https://github.com/pybind/pybind11/issues/1776
template <>
struct npy_format_descriptor<paddle::platform::float16> {
  static py::dtype dtype() {
    handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16_);
    return reinterpret_borrow<py::dtype>(ptr);
  }
  static std::string format() {
    // Note: "e" represents float16.
    // Details at:
    // https://docs.python.org/3/library/struct.html#format-characters.
    return "e";
  }
158
  static constexpr auto name = _("float16");
159 160
};

161 162 163 164 165 166 167 168 169 170 171 172 173 174
// Note: Since bfloat16 is not a builtin type in C++ and in numpy,
// we register paddle::platform::bfloat16 as numpy.uint16.
template <>
struct npy_format_descriptor<paddle::platform::bfloat16> {
  static py::dtype dtype() {
    handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_UINT16_);
    return reinterpret_borrow<py::dtype>(ptr);
  }
  static std::string format() {
    // Note: "H" represents UINT16.
    // Details at:
    // https://docs.python.org/3/library/struct.html#format-characters.
    return "H";
  }
175
  static constexpr auto name = _("bfloat16");
176 177
};

178
// we register paddle::platform::complex<float> as numpy.complex64.
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
template <>
struct npy_format_descriptor<paddle::platform::complex<float>> {
  static py::dtype dtype() {
    handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_COMPLEX64);
    return reinterpret_borrow<py::dtype>(ptr);
  }

  static std::string format() {
    // Note: "F" represents complex64.
    // Details at:
    // https://stackoverflow.com/questions/13997087/what-are-the-available-datatypes-for-dtype-with-numpys-loadtxt-an-genfromtx
    // for k, v in np.sctypeDict.iteritems():
    //     print '{0:14s} : {1:40s}'.format(str(k), v)
    return "F";
  }
  static constexpr auto name = _("complext64");
};

template <>
struct npy_format_descriptor<paddle::platform::complex<double>> {
  static py::dtype dtype() {
    handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_COMPLEX128);
    return reinterpret_borrow<py::dtype>(ptr);
  }

  static std::string format() {
    // Note: "D" represents complex128.
    // Details at:
    // https://stackoverflow.com/questions/13997087/what-are-the-available-datatypes-for-dtype-with-numpys-loadtxt-an-genfromtx
    // for k, v in np.sctypeDict.iteritems():
    //     print '{0:14s} : {1:40s}'.format(str(k), v)
    return "D";
  }
  static constexpr auto name = _("complext128");
};

215 216 217
}  // namespace detail
}  // namespace pybind11

218
namespace paddle {
219
namespace pybind {
220

221 222
namespace details {

223 224 225 226
template <typename T>
class PYBIND11_HIDDEN NumpyAllocation : public memory::Allocation {
 public:
  explicit NumpyAllocation(const py::array &arr)
227 228
      : Allocation(const_cast<void *>(arr.data()),
                   sizeof(T) * (arr.size()),
229 230
                   paddle::platform::CPUPlace()),
        arr_(arr.ptr()) {
231 232 233 234
    PADDLE_ENFORCE_NOT_NULL(
        arr_,
        platform::errors::InvalidArgument("The underlying PyObject pointer of "
                                          "numpy array cannot be nullptr"));
235
    PADDLE_ENFORCE_NE(
236 237
        arr_,
        Py_None,
238 239 240 241 242 243 244 245 246 247 248 249 250
        platform::errors::PreconditionNotMet(
            "The underlying PyObject pointer of numpy array cannot be None"));
    Py_INCREF(arr_);
  }
  ~NumpyAllocation() override {
    py::gil_scoped_acquire gil;
    Py_DECREF(arr_);
  }

 private:
  PyObject *arr_;
};

251 252 253 254 255 256 257 258 259 260 261 262
template <typename T>
struct ValidDTypeToPyArrayChecker {
  static constexpr bool kValue = false;
};

#define DECLARE_VALID_DTYPE_TO_PY_ARRAY(type) \
  template <>                                 \
  struct ValidDTypeToPyArrayChecker<type> {   \
    static constexpr bool kValue = true;      \
  }

DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16);
263
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::bfloat16);
264 265
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex<float>);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex<double>);
266 267 268 269
DECLARE_VALID_DTYPE_TO_PY_ARRAY(float);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(double);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(bool);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(int8_t);
L
Leo Chen 已提交
270
DECLARE_VALID_DTYPE_TO_PY_ARRAY(int16_t);
271 272
DECLARE_VALID_DTYPE_TO_PY_ARRAY(int);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(int64_t);
L
Leo Chen 已提交
273
DECLARE_VALID_DTYPE_TO_PY_ARRAY(uint8_t);
274 275 276 277 278 279 280

inline std::string TensorDTypeToPyDTypeStr(
    framework::proto::VarType::Type type) {
#define TENSOR_DTYPE_TO_PY_DTYPE(T, proto_type)                             \
  if (type == proto_type) {                                                 \
    if (std::is_same<T, platform::float16>::value) {                        \
      return "e";                                                           \
281 282 283
    } else if (std::is_same<T, platform::bfloat16>::value) {                \
      /* NumPy character code of uint16 due to no support for bfloat16 */   \
      return "H";                                                           \
284 285 286 287
    } else if (std::is_same<T, platform::complex<float>>::value) {          \
      return "F";                                                           \
    } else if (std::is_same<T, platform::complex<double>>::value) {         \
      return "D";                                                           \
288 289
    } else {                                                                \
      constexpr auto kIsValidDType = ValidDTypeToPyArrayChecker<T>::kValue; \
290
      PADDLE_ENFORCE_EQ(                                                    \
291 292
          kIsValidDType,                                                    \
          true,                                                             \
293 294 295
          platform::errors::Unimplemented(                                  \
              "This type [%s] of tensor cannot be expose to Python",        \
              typeid(T).name()));                                           \
296 297 298 299 300 301
      return py::format_descriptor<T>::format();                            \
    }                                                                       \
  }

  _ForEachDataType_(TENSOR_DTYPE_TO_PY_DTYPE);
#undef TENSOR_DTYPE_TO_PY_DTYPE
302 303
  PADDLE_THROW(platform::errors::Unimplemented(
      "Unsupported tensor data type: %s", framework::DataTypeToString(type)));
304 305 306 307
}

}  // namespace details

308
template <typename T>
309
T TensorGetElement(const phi::DenseTensor &self, size_t offset) {
310 311
  PADDLE_ENFORCE_LT(offset,
                    self.numel(),
312 313
                    platform::errors::InvalidArgument(
                        "The offset exceeds the size of tensor."));
314

Q
qingqing01 已提交
315
  T b = static_cast<T>(0);
316 317
  if (platform::is_cpu_place(self.place()) ||
      platform::is_cuda_pinned_place(self.place())) {
Q
qingqing01 已提交
318
    b = self.data<T>()[offset];
319 320 321
  } else if (platform::is_xpu_place(self.place())) {
#ifdef PADDLE_WITH_XPU
    const T *a = self.data<T>();
322
    auto p = self.place();
323 324
    paddle::memory::Copy(platform::CPUPlace(), &b, p, a + offset, sizeof(T));
#endif
325 326
  } else if (platform::is_gpu_place(self.place()) ||
             platform::is_cuda_pinned_place(self.place())) {
327
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
qingqing01 已提交
328
    const T *a = self.data<T>();
329
    auto p = self.place();
330 331
    paddle::memory::Copy(
        platform::CPUPlace(), &b, p, a + offset, sizeof(T), nullptr);
332 333 334 335 336
#endif
  } else if (platform::is_custom_place(self.place())) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
    const T *a = self.data<T>();
    auto p = self.place();
337 338
    paddle::memory::Copy(
        platform::CPUPlace(), &b, p, a + offset, sizeof(T), nullptr);
Q
qingqing01 已提交
339
#endif
340
  }
341 342
  VLOG(10) << "TensorGetElement, place: " << self.place()
           << ", offset: " << offset << ", element: " << b;
Q
qingqing01 已提交
343
  return b;
344 345 346
}

template <typename T>
347
void TensorSetElement(phi::DenseTensor *self, size_t offset, T elem) {
348 349
  PADDLE_ENFORCE_LT(offset,
                    self->numel(),
350 351
                    platform::errors::InvalidArgument(
                        "The offset exceeds the size of tensor."));
352 353
  VLOG(10) << "TensorSetElement, place: " << self->place()
           << ", offset: " << offset << ", element: " << elem;
Q
qingqing01 已提交
354
  if (platform::is_cpu_place(self->place())) {
Y
Yu Yang 已提交
355
    self->mutable_data<T>(self->place())[offset] = elem;
356 357
  } else if (platform::is_xpu_place(self->place())) {
#ifdef PADDLE_WITH_XPU
358
    auto p = self->place();
359 360 361
    T *a = self->mutable_data<T>(p);
    paddle::memory::Copy(p, a + offset, platform::CPUPlace(), &elem, sizeof(T));
#endif
362 363
  } else if (platform::is_gpu_place(self->place()) ||
             platform::is_cuda_pinned_place(self->place())) {
364
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
365
    auto p = self->place();
Q
qingqing01 已提交
366
    T *a = self->mutable_data<T>(p);
367 368
    paddle::memory::Copy(
        p, a + offset, platform::CPUPlace(), &elem, sizeof(T), nullptr);
369 370 371 372 373
#endif
  } else if (platform::is_custom_place(self->place())) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
    auto p = self->place();
    T *a = self->mutable_data<T>(p);
374 375
    paddle::memory::Copy(
        p, a + offset, platform::CPUPlace(), &elem, sizeof(T), nullptr);
Q
qingqing01 已提交
376
#endif
377
  }
378 379
}

380 381
template <typename T, typename P>
void SetTensorFromPyArrayT(
382
    phi::DenseTensor *self,
383
    const py::array_t<T, py::array::c_style | py::array::forcecast> &array,
384 385
    const P &place,
    bool zero_copy) {
386 387 388
  std::vector<int64_t> dims;
  dims.reserve(array.ndim());
  for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
389
    dims.push_back(static_cast<int64_t>(array.shape()[i]));
390
  }
391
  self->Resize(phi::make_ddim(dims));
392 393

  if (paddle::platform::is_cpu_place(place)) {
394 395 396
    if (zero_copy) {
      auto holder = std::make_shared<details::NumpyAllocation<T>>(array);
      auto type = framework::ToDataType(std::type_index(typeid(T)));
397
      self->ResetHolderWithType(holder, framework::TransToPhiDataType(type));
398 399 400 401
    } else {
      auto dst = self->mutable_data<T>(place);
      std::memcpy(dst, array.data(), array.nbytes());
    }
402 403
  } else if (paddle::platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
W
WangXi 已提交
404 405 406
    // NOTE(wangxi): When copying data to the accelerator card,
    // we need set_device(dev_id) first.
    platform::Place tmp_place = place;
407
    platform::XPUDeviceGuard guard(tmp_place.device);
408
    auto dst = self->mutable_data<T>(place);
409 410 411 412 413
    memory::Copy(tmp_place,
                 static_cast<void *>(dst),
                 platform::CPUPlace(),
                 static_cast<const void *>(array.data()),
                 array.nbytes());
414 415 416 417
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Cannot use XPUPlace in CPU/GPU version, "
        "Please recompile or reinstall Paddle with XPU support."));
J
jianghaicheng 已提交
418 419 420 421 422 423
#endif
  } else if (paddle::platform::is_ipu_place(place)) {
#ifdef PADDLE_WITH_IPU
    if (zero_copy) {
      auto holder = std::make_shared<details::NumpyAllocation<T>>(array);
      auto type = framework::ToDataType(std::type_index(typeid(T)));
424
      self->ResetHolderWithType(holder, framework::TransToPhiDataType(type));
J
jianghaicheng 已提交
425
    } else {
426 427 428 429 430 431 432 433
      // IPU does not store Tensor data, Tensor will be created on CPU
      if (!self->initialized()) {
        auto dst = self->mutable_data<T>(place);
        std::memcpy(dst, array.data(), array.nbytes());
      } else {
        auto dst = self->mutable_data<T>(self->place());
        std::memcpy(dst, array.data(), array.nbytes());
      }
J
jianghaicheng 已提交
434 435 436 437 438
    }
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Cannot use IPUPlace in CPU/GPU/XPU/NPU version, "
        "Please recompile or reinstall Paddle with IPU support."));
439 440 441 442
#endif
  } else if (paddle::platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    platform::Place tmp_place = place;
443
    phi::DeviceGuard guard(tmp_place);
444 445
    auto dst = self->mutable_data<T>(place);

446
    phi::DeviceManager::GetDeviceWithPlace(tmp_place)->MemoryCopyH2D(
447 448 449 450 451 452 453 454 455 456
        reinterpret_cast<void *>(dst),
        const_cast<void *>(reinterpret_cast<const void *>(array.data())),
        array.nbytes());
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &ctx = *pool.Get(place);
    ctx.Wait();
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Cannot use CustomDevice in CPU/GPU/XPU version. "
        "Please recompile or reinstall Paddle with CustomDevice support."));
457
#endif
458
  } else {
459
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
460
    if (paddle::platform::is_gpu_place(place)) {
W
WangXi 已提交
461 462
      // NOTE(wangxi): When copying data to the accelerator card,
      // we need set_device(dev_id) first.
463
      platform::CUDADeviceGuard guard(place.device);
464
      auto dst = self->mutable_data<T>(place);
465
#ifdef PADDLE_WITH_HIP
466 467
      paddle::platform::GpuMemcpySync(
          dst, array.data(), array.nbytes(), hipMemcpyHostToDevice);
468
#else
469 470
      paddle::platform::GpuMemcpySync(
          dst, array.data(), array.nbytes(), cudaMemcpyHostToDevice);
471
#endif
472

473 474 475
    } else if (paddle::platform::is_cuda_pinned_place(place)) {
      auto dst = self->mutable_data<T>(place);
      std::memcpy(dst, array.data(), array.nbytes());
476
    } else {
477 478 479
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Incompatible place type: Tensor.set() supports "
          "CPUPlace, CUDAPlace "
480
          "and CUDAPinnedPlace, but got %s!",
481
          place));
482 483
    }
#else
484
    PADDLE_THROW(platform::errors::PermissionDenied(
485
        "Cannot use CUDAPlace or CUDAPinnedPlace in CPU only version, "
486
        "Please recompile or reinstall Paddle with CUDA support."));
487 488 489 490 491
#endif
  }
}

template <typename P>
492
void SetTensorFromPyArray(phi::DenseTensor *self,
493 494 495
                          const py::object &obj,
                          const P &place,
                          bool zero_copy) {
496
  auto array = obj.cast<py::array>();
497
  if (py::isinstance<py::array_t<float>>(array)) {
498
    SetTensorFromPyArrayT<float, P>(self, array, place, zero_copy);
499
  } else if (py::isinstance<py::array_t<int>>(array)) {
500
    SetTensorFromPyArrayT<int, P>(self, array, place, zero_copy);
501
  } else if (py::isinstance<py::array_t<int64_t>>(array)) {
502
    SetTensorFromPyArrayT<int64_t, P>(self, array, place, zero_copy);
503
  } else if (py::isinstance<py::array_t<double>>(array)) {
504
    SetTensorFromPyArrayT<double, P>(self, array, place, zero_copy);
505
  } else if (py::isinstance<py::array_t<int8_t>>(array)) {
506
    SetTensorFromPyArrayT<int8_t, P>(self, array, place, zero_copy);
L
Leo Chen 已提交
507 508
  } else if (py::isinstance<py::array_t<int16_t>>(array)) {
    SetTensorFromPyArrayT<int16_t, P>(self, array, place, zero_copy);
509
  } else if (py::isinstance<py::array_t<uint8_t>>(array)) {
510
    SetTensorFromPyArrayT<uint8_t, P>(self, array, place, zero_copy);
511
  } else if (py::isinstance<py::array_t<paddle::platform::float16>>(array)) {
512 513
    SetTensorFromPyArrayT<paddle::platform::float16, P>(
        self, array, place, zero_copy);
514 515 516 517 518 519 520 521
  } else if (py::isinstance<py::array_t<paddle::platform::complex<float>>>(
                 array)) {
    SetTensorFromPyArrayT<paddle::platform::complex<float>, P>(
        self, array, place, zero_copy);
  } else if (py::isinstance<py::array_t<paddle::platform::complex<double>>>(
                 array)) {
    SetTensorFromPyArrayT<paddle::platform::complex<double>, P>(
        self, array, place, zero_copy);
522
  } else if (py::isinstance<py::array_t<uint16_t>>(array)) {
523 524
    // since there is still no support for bfloat16 in NumPy,
    // uint16 is used for casting bfloat16
525 526
    SetTensorFromPyArrayT<paddle::platform::bfloat16, P>(
        self, array, place, zero_copy);
527
  } else if (py::isinstance<py::array_t<bool>>(array)) {
528
    SetTensorFromPyArrayT<bool, P>(self, array, place, zero_copy);
529
  } else {
530 531
    // obj may be any type, obj.cast<py::array>() may be failed,
    // then the array.dtype will be string of unknown meaning,
532
    PADDLE_THROW(platform::errors::InvalidArgument(
533 534 535 536
        "Input object type error or incompatible array data type. "
        "tensor.set() supports array with bool, float16, float32, "
        "float64, int8, int16, int32, int64, uint8 or uint16, "
        "please check your input or input array data type."));
537 538 539
  }
}

J
Jack Zhou 已提交
540
template <typename P>
541 542
void SetStringTensorFromPyArray(phi::StringTensor *self,
                                const py::array &array,
J
Jack Zhou 已提交
543 544 545
                                const P &place) {
  bool is_string_pyarray =
      array.dtype().kind() == 'S' || array.dtype().kind() == 'U';
546 547
  PADDLE_ENFORCE_EQ(is_string_pyarray,
                    true,
J
Jack Zhou 已提交
548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
                    platform::errors::InvalidArgument(
                        "Expect the dtype of numpy array is string or "
                        "unicode, but recevie dtype %s",
                        array.dtype()));
  std::vector<int64_t> dims;
  dims.reserve(array.ndim());
  dims.reserve(array.ndim());
  for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
    dims.push_back(static_cast<int>(array.shape()[i]));
  }
  self->Resize(phi::make_ddim(dims));
  auto itemsize = array.itemsize();
  if (paddle::platform::is_cpu_place(place)) {
    auto dst = self->mutable_data(place);
    if (array.dtype().kind() == 'S') {
      for (int i = 0; i < self->numel(); ++i) {
        dst[i] =
            pstring(reinterpret_cast<const char *>(array.data()) + itemsize * i,
                    itemsize);
      }
    } else {
      // array.dtype().kind() == 'U'
      VLOG(6) << "numpy array itemsize: " << itemsize;
      for (int i = 0; i < self->numel(); ++i) {
        // Note(zhoushunjie): The itemsize of unicode numpy array is the
        // the size of each unicode string. Each unicode string is aligned
        // to max length of the array of unicode strings, so the size of
        // each unicode string is same. The size of each unicode character is
        // 4, so the size of unicode string is 4 times of the length of
        // unicode string.
        auto unicode_len = itemsize / 4;
        auto utf8_len = phi::strings::GetUTF8StrLen(
            reinterpret_cast<const uint32_t *>(array.data()) + unicode_len * i,
            unicode_len);
        pstring pstr(utf8_len - 1, 0);
        phi::strings::GetUTF8Str(
            reinterpret_cast<const uint32_t *>(array.data()) + unicode_len * i,
585 586
            pstr.mdata(),
            unicode_len);
J
Jack Zhou 已提交
587 588 589 590 591 592 593 594 595 596
        dst[i] = pstr;
      }
    }
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "StringTensor only support CPUPlace now, but receive %s",
        place.DebugString()));
  }
}

S
Siming Dai 已提交
597
template <typename T>
S
Siming Dai 已提交
598
void SetUVATensorFromPyArrayImpl(
599
    phi::DenseTensor *self_tensor,
S
Siming Dai 已提交
600 601
    const py::array_t<T, py::array::c_style | py::array::forcecast> &array,
    int device_id) {
S
Siming Dai 已提交
602
#if defined(PADDLE_WITH_CUDA)
603
  VLOG(4) << "Running in SetUVATensorFromPyArrayImpl.";
S
Siming Dai 已提交
604 605 606 607
  std::vector<int64_t> dims;
  dims.reserve(array.ndim());
  int64_t numel = 1;
  for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
608 609
    dims.emplace_back(static_cast<int64_t>(array.shape()[i]));
    numel *= static_cast<int64_t>(array.shape()[i]);
S
Siming Dai 已提交
610
  }
611
  self_tensor->Resize(phi::make_ddim(dims));
S
Siming Dai 已提交
612 613 614 615

  auto data_type = framework::ToDataType(std::type_index(typeid(T)));
  const auto &need_allocate_size = numel * framework::SizeOfType(data_type);
  T *data_ptr;
616 617
  cudaHostAlloc(reinterpret_cast<void **>(&data_ptr),
                need_allocate_size,
S
Siming Dai 已提交
618 619 620 621 622
                cudaHostAllocWriteCombined | cudaHostAllocMapped);
  std::memcpy(data_ptr, array.data(), array.nbytes());

  void *cuda_device_pointer = nullptr;
  cudaHostGetDevicePointer(reinterpret_cast<void **>(&cuda_device_pointer),
623 624
                           reinterpret_cast<void *>(data_ptr),
                           0);
S
Siming Dai 已提交
625 626
  std::shared_ptr<memory::allocation::Allocation> holder =
      std::make_shared<memory::allocation::Allocation>(
627 628
          cuda_device_pointer,
          need_allocate_size,
S
Siming Dai 已提交
629
          platform::CUDAPlace(device_id));
630
  self_tensor->ResetHolderWithType(holder,
631
                                   framework::TransToPhiDataType(data_type));
S
Siming Dai 已提交
632 633 634
#endif
}

635 636 637
template <typename T>
void SetUVATensorFromPyArray(
    const std::shared_ptr<paddle::imperative::VarBase> &self,
S
Siming Dai 已提交
638
    const py::array_t<T, py::array::c_style | py::array::forcecast> &array,
639
    int device_id) {
640 641
#if defined(PADDLE_WITH_CUDA)
  VLOG(4) << "Running in SetUVATensorFromPyArray for VarBase.";
642
  auto *self_tensor = self->MutableVar()->GetMutable<phi::DenseTensor>();
643 644 645 646 647
  SetUVATensorFromPyArrayImpl<T>(self_tensor, array, device_id);
#endif
}

template <typename T>
648 649 650
void SetUVATensorFromPyArray(const std::shared_ptr<paddle::Tensor> &self,
                             const py::array_t<T> &array,
                             int device_id) {
651 652 653 654 655 656 657 658 659 660
#if defined(PADDLE_WITH_CUDA)
  VLOG(4) << "Running in SetUVATensorFromPyArray for Phi::Tensor.";
  phi::DenseTensorMeta meta =
      phi::DenseTensorMeta(phi::DataType::FLOAT32, phi::make_ddim({1, 1}));
  std::shared_ptr<phi::DenseTensor> tmp_t = std::make_shared<phi::DenseTensor>(
      std::make_unique<paddle::experimental::DefaultAllocator>(
          paddle::platform::CPUPlace())
          .get(),
      meta);
  self.get()->set_impl(tmp_t);
661
  auto *self_tensor = static_cast<phi::DenseTensor *>(self.get()->impl().get());
662 663 664 665 666

  SetUVATensorFromPyArrayImpl<T>(self_tensor, array, device_id);
#endif
}

W
wopeizl 已提交
667
template <typename T, size_t D>
668 669
void _sliceCompute(const phi::DenseTensor *in,
                   phi::DenseTensor *out,
L
Leo Chen 已提交
670
                   const phi::CPUContext &ctx,
W
wopeizl 已提交
671 672 673
                   const std::vector<int> &axes,
                   const std::vector<int> &starts) {
  auto &eigen_place = *ctx.eigen_device();
674
  auto out_dims = phi::vectorize<int>(out->dims());
W
wopeizl 已提交
675 676
  auto in_dims = in->dims();

677 678
  auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>();
  auto extents = Eigen::DSizes<Eigen::DenseIndex, D>();
W
wopeizl 已提交
679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697
  for (size_t i = 0; i < D; ++i) {
    offsets[i] = 0;
    extents[i] = out_dims[i];
  }
  int start;
  for (size_t i = 0; i < axes.size(); ++i) {
    start = starts[i];
    if (start < 0) {
      start = (start + in_dims[axes[i]]);
    }
    start = std::max(start, 0);
    offsets[axes[i]] = start;
  }
  auto in_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *in);
  auto out_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *out);
698 699
  operators::EigenSlice<std::decay_t<decltype(eigen_place)>, T, D>::Eval(
      eigen_place, out_t, in_t, offsets, extents);
W
wopeizl 已提交
700 701 702
}

template <typename T>
703 704
void _concatCompute(const std::vector<phi::DenseTensor> &ins,
                    phi::DenseTensor *out,
L
Leo Chen 已提交
705
                    const phi::CPUContext &ctx,
706
                    int64_t axis) {
W
wopeizl 已提交
707 708 709
  if (axis == 0 && ins.size() < 10) {
    size_t output_offset = 0;
    for (auto &in : ins) {
710 711
      auto in_stride = phi::stride_numel(in.dims());
      auto out_stride = phi::stride_numel(out->dims());
712 713 714 715 716 717 718 719
      phi::funcs::StridedNumelCopyWithAxis<T, phi::CPUContext>(
          ctx,
          axis,
          out->data<T>() + output_offset,
          out_stride,
          in.data<T>(),
          in_stride,
          in_stride[axis]);
W
wopeizl 已提交
720 721 722
      output_offset += in_stride[axis];
    }
  } else {
L
Leo Chen 已提交
723
    paddle::operators::math::ConcatFunctor<phi::CPUContext, T> concat_functor;
W
wopeizl 已提交
724 725 726 727
    concat_functor(ctx, ins, static_cast<int>(axis), out);
  }
}

728
inline void _getSliceinfo(const phi::DenseTensor &self,
729 730 731 732 733 734
                          py::object obj,
                          const int64_t dim,
                          int64_t *pstart,
                          int64_t *pstop,
                          int64_t *pstep,
                          int64_t *pslicelength) {
W
wopeizl 已提交
735 736 737 738 739
  auto &start = *pstart;
  auto &stop = *pstop;
  auto &step = *pstep;
  auto &slicelength = *pslicelength;
  const framework::DDim &srcDDim = self.dims();
Z
zyfncg 已提交
740 741 742 743
  PADDLE_ENFORCE(
      0 <= dim && dim < srcDDim.size(),
      platform::errors::OutOfRange("The dim %d of slice is out of bounds, it "
                                   "shound be in the range of [0, %d).",
744 745
                                   dim,
                                   srcDDim.size()));
Z
zyfncg 已提交
746

W
wopeizl 已提交
747 748 749 750
  if (py::isinstance<py::slice>(obj)) {
    size_t lstart, lstop, lstep, lslicelength;
    py::slice s = static_cast<py::slice>(obj);
    if (!s.compute(srcDDim[dim], &lstart, &lstop, &lstep, &lslicelength)) {
Z
zyfncg 已提交
751 752 753 754
      PADDLE_THROW(platform::errors::OutOfRange(
          "Slice on dim: %d is error, please check the validity of tensor "
          "dims or slice item.",
          dim));
W
wopeizl 已提交
755 756 757 758 759 760 761
    }
    start = static_cast<int64_t>(lstart);
    stop = static_cast<int64_t>(lstop);
    step = static_cast<int64_t>(lstep);
    slicelength = static_cast<int64_t>(lslicelength);
  } else if (py::isinstance<py::int_>(obj)) {
    start = static_cast<int64_t>(static_cast<py::int_>(obj));
Z
zyfncg 已提交
762 763 764 765
    PADDLE_ENFORCE(
        std::abs(start) < srcDDim[dim],
        platform::errors::OutOfRange("The start %d of slice is out of bounds, "
                                     "it shound be in the range of (%d, %d).",
766 767 768
                                     start,
                                     -srcDDim[dim],
                                     srcDDim[dim]));
W
wopeizl 已提交
769 770 771 772 773
    start = (start >= 0) ? start : srcDDim[dim] - start;
    stop = start + 1;
    step = 1;
    slicelength = 1;
  } else {
Z
zyfncg 已提交
774 775 776
    PADDLE_THROW(
        platform::errors::OutOfRange("Index object error, the index object for "
                                     "slice only supports slice(::) and int."));
W
wopeizl 已提交
777 778 779
  }
}

780 781 782
inline phi::DenseTensor *_getTensor(const phi::DenseTensor &self,
                                    const framework::DDim &ddim) {
  phi::DenseTensor *output = new phi::DenseTensor();
W
wopeizl 已提交
783 784 785
  output->Resize(ddim);
  auto place = self.place();
  if (platform::is_cpu_place(place)) {
786
    output->mutable_data(place, self.dtype());
787 788
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
789
    output->mutable_data(place, self.dtype());
790
#endif
W
wopeizl 已提交
791
  } else {
792
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
W
wopeizl 已提交
793
    if (platform::is_cuda_pinned_place(place)) {
794
      output->mutable_data(place, self.dtype());
W
wopeizl 已提交
795
    } else if ((platform::is_gpu_place(place))) {
796
      output->mutable_data(place, self.dtype());
W
wopeizl 已提交
797 798 799 800 801 802 803
    }
#endif
  }
  return output;
}

template <typename T>
804 805
void _sliceDapper(const phi::DenseTensor *in,
                  phi::DenseTensor *out,
L
Leo Chen 已提交
806
                  const phi::CPUContext &ctx,
807 808
                  const std::vector<int> &axes,
                  const std::vector<int> &starts,
W
wopeizl 已提交
809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838
                  int size) {
  switch (size) {
    case 1:
      _sliceCompute<T, 1>(in, out, ctx, axes, starts);
      break;
    case 2:
      _sliceCompute<T, 2>(in, out, ctx, axes, starts);
      break;
    case 3:
      _sliceCompute<T, 3>(in, out, ctx, axes, starts);
      break;
    case 4:
      _sliceCompute<T, 4>(in, out, ctx, axes, starts);
      break;
    case 5:
      _sliceCompute<T, 5>(in, out, ctx, axes, starts);
      break;
    case 6:
      _sliceCompute<T, 6>(in, out, ctx, axes, starts);
      break;
    case 7:
      _sliceCompute<T, 7>(in, out, ctx, axes, starts);
      break;
    case 8:
      _sliceCompute<T, 8>(in, out, ctx, axes, starts);
      break;
    case 9:
      _sliceCompute<T, 9>(in, out, ctx, axes, starts);
      break;
    default:
839 840
      PADDLE_THROW(platform::errors::InvalidArgument(
          "The dim size should be 1 to 9, current is %d", size));
W
wopeizl 已提交
841 842 843 844 845
      break;
  }
}

template <typename T>
846 847
inline phi::DenseTensor *_sliceWrapper(const phi::DenseTensor &self,
                                       const phi::CPUContext &ctx,
848
                                       py::object obj UNUSED,
849 850 851
                                       int dim,
                                       int64_t start,
                                       int64_t slicelength) {
W
wopeizl 已提交
852 853 854 855
  framework::DDim dstDDim = self.dims();
  dstDDim[dim] = static_cast<int64_t>(slicelength);
  std::vector<int> axes({dim});
  std::vector<int> starts({static_cast<int>(start)});
856
  phi::DenseTensor *output = _getTensor(self, dstDDim);
W
wopeizl 已提交
857 858 859 860 861
  _sliceDapper<T>(&self, output, ctx, axes, starts, dstDDim.size());
  return output;
}

template <typename T>
862 863 864
inline phi::DenseTensor *_sliceAndConcat(const phi::DenseTensor &self,
                                         py::object obj,
                                         int dim) {
L
Leo Chen 已提交
865
  phi::CPUContext ctx;
W
wopeizl 已提交
866 867 868 869 870
  int64_t start, stop, step, slicelength;
  _getSliceinfo(self, obj, dim, &start, &stop, &step, &slicelength);
  if (step == 1 || slicelength == 1) {
    return _sliceWrapper<T>(self, ctx, obj, dim, start, slicelength);
  } else {
871
    std::vector<phi::DenseTensor> ins;
W
wopeizl 已提交
872 873 874 875 876 877 878
    for (auto i = 0; i < slicelength; ++i, start += step) {
      ins.emplace_back(*_sliceWrapper<T>(self, ctx, obj, dim, start, 1));
    }

    // do the concat operation
    framework::DDim dstDDim = self.dims();
    dstDDim[dim] = static_cast<int64_t>(slicelength);
879
    phi::DenseTensor *output1 = _getTensor(self, dstDDim);
W
wopeizl 已提交
880 881 882 883 884
    _concatCompute<T>(ins, output1, ctx, dim);
    return output1;
  }
}

885 886 887
inline phi::DenseTensor *_sliceTensor(const phi::DenseTensor &self,
                                      py::object obj,
                                      int dim) {
888
  auto src_type = framework::TransToProtoVarType(self.dtype());
W
wopeizl 已提交
889 890 891
  switch (src_type) {
    case framework::proto::VarType::FP16:
      return _sliceAndConcat<paddle::platform::float16>(self, obj, dim);
892 893
    case framework::proto::VarType::BF16:
      return _sliceAndConcat<paddle::platform::bfloat16>(self, obj, dim);
894
    case framework::proto::VarType::COMPLEX64:
895
      return _sliceAndConcat<paddle::platform::complex<float>>(self, obj, dim);
896
    case framework::proto::VarType::COMPLEX128:
897
      return _sliceAndConcat<paddle::platform::complex<double>>(self, obj, dim);
W
wopeizl 已提交
898 899 900 901
    case framework::proto::VarType::FP32:
      return _sliceAndConcat<float>(self, obj, dim);
    case framework::proto::VarType::FP64:
      return _sliceAndConcat<double>(self, obj, dim);
L
Leo Chen 已提交
902 903 904 905
    case framework::proto::VarType::INT8:
      return _sliceAndConcat<int8_t>(self, obj, dim);
    case framework::proto::VarType::INT16:
      return _sliceAndConcat<int16_t>(self, obj, dim);
W
wopeizl 已提交
906 907 908 909 910 911 912
    case framework::proto::VarType::INT32:
      return _sliceAndConcat<int>(self, obj, dim);
    case framework::proto::VarType::INT64:
      return _sliceAndConcat<int64_t>(self, obj, dim);
    case framework::proto::VarType::BOOL:
      return _sliceAndConcat<bool>(self, obj, dim);
    case framework::proto::VarType::UINT8:
L
Leo Chen 已提交
913
      return _sliceAndConcat<uint8_t>(self, obj, dim);
W
wopeizl 已提交
914
    default:
915 916 917
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Not support tensor type: %s",
          framework::DataTypeToString(src_type)));
W
wopeizl 已提交
918 919 920
  }
}

921 922
inline phi::DenseTensor *_pySliceTensor(const phi::DenseTensor &self,
                                        py::object obj) {
W
wopeizl 已提交
923 924
  if (py::isinstance<py::tuple>(obj)) {
    py::list l = static_cast<py::list>(obj);
925 926
    std::unique_ptr<phi::DenseTensor> target;
    phi::DenseTensor *src = const_cast<phi::DenseTensor *>(&self);
W
wopeizl 已提交
927 928 929 930 931 932 933 934 935 936 937 938 939 940
    for (auto i = 0; i < static_cast<int>(l.size()); ++i) {
      src = _sliceTensor(*src, l[i], i);
      if (i + 1 == static_cast<int>(l.size())) {
        return src;
      } else {
        target.reset(src);
      }
    }
    return nullptr;
  } else {
    return _sliceTensor(self, obj, 0);
  }
}

941 942
inline phi::DenseTensor *PySliceTensor(const phi::DenseTensor &self,
                                       py::object obj) {
W
wopeizl 已提交
943
  if (platform::is_gpu_place(self.place())) {
944 945
    std::unique_ptr<phi::DenseTensor> holder;
    phi::DenseTensor src;
W
wopeizl 已提交
946
    framework::TensorCopySync(self, platform::CPUPlace(), &src);
947
    phi::DenseTensor *output = _pySliceTensor(src, obj);
W
wopeizl 已提交
948
    holder.reset(output);
949
    phi::DenseTensor *dst = _getTensor(*output, output->dims());
W
wopeizl 已提交
950 951 952 953 954 955 956
    framework::TensorCopySync(*output, self.place(), dst);
    return dst;
  } else {
    return _pySliceTensor(self, obj);
  }
}

957
inline py::array TensorToPyArray(const phi::DenseTensor &tensor,
958
                                 bool need_deep_copy = false) {
Q
qingqing01 已提交
959 960 961
  if (!tensor.IsInitialized()) {
    return py::array();
  }
962
  bool is_gpu_tensor = platform::is_gpu_place(tensor.place());
963
  bool is_xpu_tensor = platform::is_xpu_place(tensor.place());
964
  bool is_custom_device_tensor = platform::is_custom_place(tensor.place());
965
  const auto &tensor_dims = tensor.dims();
966
  auto tensor_dtype = framework::TransToProtoVarType(tensor.dtype());
967 968 969 970 971 972 973
  size_t sizeof_dtype = framework::SizeOfType(tensor_dtype);

  std::vector<size_t> py_dims(tensor_dims.size());
  std::vector<size_t> py_strides(tensor_dims.size());

  size_t numel = 1;
  for (int i = tensor_dims.size() - 1; i >= 0; --i) {
974
    py_dims[i] = static_cast<size_t>(tensor_dims[i]);
975 976 977 978
    py_strides[i] = sizeof_dtype * numel;
    numel *= py_dims[i];
  }

979
  const void *tensor_buf_ptr = tensor.data();
980

981 982
  std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(
      framework::TransToProtoVarType(tensor.dtype()));
983

张春乔 已提交
984
  if (!is_gpu_tensor && !is_xpu_tensor && !is_custom_device_tensor) {
985
    if (!need_deep_copy) {
986
      auto base = py::cast(std::move(tensor));
987 988 989 990 991
      return py::array(py::dtype(py_dtype_str.c_str()),
                       py_dims,
                       py_strides,
                       const_cast<void *>(tensor_buf_ptr),
                       base);
992 993
    } else {
      py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides);
994
      PADDLE_ENFORCE_EQ(
995 996
          py_arr.writeable(),
          true,
997 998 999 1000
          platform::errors::InvalidArgument(
              "PyArray is not writable, in which case memory leak "
              "or double free would occur"));
      PADDLE_ENFORCE_EQ(
1001 1002
          py_arr.owndata(),
          true,
1003 1004 1005
          platform::errors::InvalidArgument(
              "PyArray does not own data, in which case  memory leak "
              "or double free would occur"));
1006 1007
      platform::CPUPlace place;
      size_t copy_bytes = sizeof_dtype * numel;
1008 1009
      paddle::memory::Copy(
          place, py_arr.mutable_data(), place, tensor_buf_ptr, copy_bytes);
1010 1011
      return py_arr;
    }
1012 1013 1014
  } else if (is_xpu_tensor) {
#ifdef PADDLE_WITH_XPU
    py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides);
1015 1016
    PADDLE_ENFORCE_EQ(py_arr.writeable(),
                      true,
1017 1018 1019 1020
                      platform::errors::InvalidArgument(
                          "PyArray is not writable, in which case memory leak "
                          "or double free would occur"));
    PADDLE_ENFORCE_EQ(
1021 1022
        py_arr.owndata(),
        true,
1023 1024 1025 1026 1027
        platform::errors::InvalidArgument(
            "PyArray does not own data, in which case  memory leak "
            "or double free would occur"));

    size_t copy_bytes = sizeof_dtype * numel;
1028
    auto p = tensor.place();
1029 1030 1031 1032 1033
    paddle::memory::Copy(platform::CPUPlace(),
                         py_arr.mutable_data(),
                         p,
                         tensor_buf_ptr,
                         copy_bytes);
1034 1035 1036 1037 1038 1039 1040
    return py_arr;
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Cannot use XPUPlace in CPU/GPU version, "
        "Please recompile or reinstall Paddle with XPU support."));
#endif
  } else if (is_gpu_tensor) {
1041
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1042
    py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides);
1043 1044
    PADDLE_ENFORCE_EQ(py_arr.writeable(),
                      true,
1045 1046 1047 1048
                      platform::errors::InvalidArgument(
                          "PyArray is not writable, in which case memory leak "
                          "or double free would occur"));
    PADDLE_ENFORCE_EQ(
1049 1050
        py_arr.owndata(),
        true,
1051 1052 1053 1054 1055
        platform::errors::InvalidArgument(
            "PyArray does not own data, in which case  memory leak "
            "or double free would occur"));

    size_t copy_bytes = sizeof_dtype * numel;
1056
    auto p = tensor.place();
1057 1058 1059 1060 1061 1062
    paddle::memory::Copy(platform::CPUPlace(),
                         py_arr.mutable_data(),
                         p,
                         tensor_buf_ptr,
                         copy_bytes,
                         nullptr);
1063
    return py_arr;
1064
#else
1065 1066 1067
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Cannot use CUDAPlace in CPU only version, "
        "Please recompile or reinstall Paddle with CUDA support."));
1068 1069 1070 1071
#endif
  } else if (is_custom_device_tensor) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides);
1072 1073
    PADDLE_ENFORCE_EQ(py_arr.writeable(),
                      true,
1074 1075 1076 1077
                      platform::errors::InvalidArgument(
                          "PyArray is not writable, in which case memory leak "
                          "or double free would occur"));
    PADDLE_ENFORCE_EQ(
1078 1079
        py_arr.owndata(),
        true,
1080 1081 1082 1083
        platform::errors::InvalidArgument(
            "PyArray does not own data, in which case  memory leak "
            "or double free would occur"));

1084 1085
    // TODO(qili93): temporary for ascned npu performance to be removed along
    // with npu_identity op
1086
    paddle::Tensor tensor_out(std::make_shared<phi::DenseTensor>());
1087
    if (tensor.storage_properties_initialized()) {
1088
      paddle::Tensor tensor_in(std::make_shared<phi::DenseTensor>(tensor));
1089 1090 1091 1092 1093 1094
      tensor_out = npu_identity_ad_func(tensor_in, -1);
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(tensor_out.impl());
      tensor_buf_ptr = dense_tensor->data();
    }

1095 1096 1097 1098
    size_t copy_bytes = sizeof_dtype * numel;
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
    auto &ctx = *pool.Get(tensor.place());
    paddle::memory::Copy(
1099 1100 1101 1102 1103
        platform::CPUPlace(),
        py_arr.mutable_data(),
        tensor.place(),
        tensor_buf_ptr,
        copy_bytes,
1104 1105 1106 1107 1108 1109 1110 1111
        reinterpret_cast<const platform::CustomDeviceContext &>(ctx).stream());
    ctx.Wait();
    return py_arr;
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Cannot use CustomPlace in CPU/GPU/XPU/NPU version, "
        "Please recompile or reinstall Paddle with CustomPlace "
        "support."));
1112
#endif
1113 1114 1115
  }
  PADDLE_THROW(platform::errors::Unimplemented("Place is not supported"));
  return py::array();
1116 1117
}

1118 1119
}  // namespace pybind
}  // namespace paddle