/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include #include #include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/platform/bfloat16.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/cuda_device_guard.h" #endif #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/profiler.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" namespace py = pybind11; namespace pybind11 { namespace detail { // Note: use same enum number of float16 in numpy. // import numpy as np // print np.dtype(np.float16).num # 23 constexpr int NPY_FLOAT16_ = 23; constexpr int NPY_UINT16_ = 4; constexpr int NPY_COMPLEX64 = 14; constexpr int NPY_COMPLEX128 = 15; // Note: Since float16 is not a builtin type in C++, we register // paddle::platform::float16 as numpy.float16. // Ref: https://github.com/pybind/pybind11/issues/1776 template <> struct npy_format_descriptor { static py::dtype dtype() { handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16_); return reinterpret_borrow(ptr); } static std::string format() { // Note: "e" represents float16. // Details at: // https://docs.python.org/3/library/struct.html#format-characters. return "e"; } static constexpr auto name = _("float16"); }; // Note: Since bfloat16 is not a builtin type in C++ and in numpy, // we register paddle::platform::bfloat16 as numpy.uint16. template <> struct npy_format_descriptor { static py::dtype dtype() { handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_UINT16_); return reinterpret_borrow(ptr); } static std::string format() { // Note: "H" represents UINT16. // Details at: // https://docs.python.org/3/library/struct.html#format-characters. return "H"; } static constexpr auto name = _("bfloat16"); }; // we register paddle::platform::complex as numpy.complex64. template <> struct npy_format_descriptor> { static py::dtype dtype() { handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_COMPLEX64); return reinterpret_borrow(ptr); } static std::string format() { // Note: "F" represents complex64. // Details at: // https://stackoverflow.com/questions/13997087/what-are-the-available-datatypes-for-dtype-with-numpys-loadtxt-an-genfromtx // for k, v in np.sctypeDict.iteritems(): // print '{0:14s} : {1:40s}'.format(str(k), v) return "F"; } static constexpr auto name = _("complext64"); }; template <> struct npy_format_descriptor> { static py::dtype dtype() { handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_COMPLEX128); return reinterpret_borrow(ptr); } static std::string format() { // Note: "D" represents complex128. // Details at: // https://stackoverflow.com/questions/13997087/what-are-the-available-datatypes-for-dtype-with-numpys-loadtxt-an-genfromtx // for k, v in np.sctypeDict.iteritems(): // print '{0:14s} : {1:40s}'.format(str(k), v) return "D"; } static constexpr auto name = _("complext128"); }; } // namespace detail } // namespace pybind11 namespace paddle { namespace pybind { namespace details { template class PYBIND11_HIDDEN NumpyAllocation : public memory::Allocation { public: explicit NumpyAllocation(const py::array &arr) : Allocation(const_cast(arr.data()), sizeof(T) * (arr.size()), paddle::platform::CPUPlace()), arr_(arr.ptr()) { PADDLE_ENFORCE_NOT_NULL(arr_, platform::errors::InvalidArgument( "The underlying PyObject pointer of " "numpy array cannot be nullptr")); PADDLE_ENFORCE_NE( arr_, Py_None, platform::errors::PreconditionNotMet( "The underlying PyObject pointer of numpy array cannot be None")); Py_INCREF(arr_); } ~NumpyAllocation() override { py::gil_scoped_acquire gil; Py_DECREF(arr_); } private: PyObject *arr_; }; template struct ValidDTypeToPyArrayChecker { static constexpr bool kValue = false; }; #define DECLARE_VALID_DTYPE_TO_PY_ARRAY(type) \ template <> \ struct ValidDTypeToPyArrayChecker { \ static constexpr bool kValue = true; \ } DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16); DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::bfloat16); DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex); DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::complex); DECLARE_VALID_DTYPE_TO_PY_ARRAY(float); DECLARE_VALID_DTYPE_TO_PY_ARRAY(double); DECLARE_VALID_DTYPE_TO_PY_ARRAY(bool); DECLARE_VALID_DTYPE_TO_PY_ARRAY(int8_t); DECLARE_VALID_DTYPE_TO_PY_ARRAY(int16_t); DECLARE_VALID_DTYPE_TO_PY_ARRAY(int); DECLARE_VALID_DTYPE_TO_PY_ARRAY(int64_t); DECLARE_VALID_DTYPE_TO_PY_ARRAY(uint8_t); inline std::string TensorDTypeToPyDTypeStr( framework::proto::VarType::Type type) { #define TENSOR_DTYPE_TO_PY_DTYPE(T, proto_type) \ if (type == proto_type) { \ if (std::is_same::value) { \ return "e"; \ } else if (std::is_same::value) { \ /* NumPy character code of uint16 due to no support for bfloat16 */ \ return "H"; \ } else if (std::is_same>::value) { \ return "F"; \ } else if (std::is_same>::value) { \ return "D"; \ } else { \ constexpr auto kIsValidDType = ValidDTypeToPyArrayChecker::kValue; \ PADDLE_ENFORCE_EQ( \ kIsValidDType, true, \ platform::errors::Unimplemented( \ "This type [%s] of tensor cannot be expose to Python", \ typeid(T).name())); \ return py::format_descriptor::format(); \ } \ } _ForEachDataType_(TENSOR_DTYPE_TO_PY_DTYPE); #undef TENSOR_DTYPE_TO_PY_DTYPE PADDLE_THROW(platform::errors::Unimplemented( "Unsupported tensor data type: %s", framework::DataTypeToString(type))); } } // namespace details template T TensorGetElement(const framework::Tensor &self, size_t offset) { PADDLE_ENFORCE_LT(offset, self.numel(), platform::errors::InvalidArgument( "The offset exceeds the size of tensor.")); T b = static_cast(0); if (platform::is_cpu_place(self.place())) { b = self.data()[offset]; } else if (platform::is_xpu_place(self.place())) { #ifdef PADDLE_WITH_XPU const T *a = self.data(); auto p = BOOST_GET_CONST(platform::XPUPlace, self.place()); paddle::memory::Copy(platform::CPUPlace(), &b, p, a + offset, sizeof(T)); #endif } else if (platform::is_gpu_place(self.place())) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) const T *a = self.data(); auto p = BOOST_GET_CONST(platform::CUDAPlace, self.place()); paddle::memory::Copy(platform::CPUPlace(), &b, p, a + offset, sizeof(T), nullptr); #endif } else if (platform::is_npu_place(self.place())) { #if defined(PADDLE_WITH_ASCEND_CL) const T *a = self.data(); auto p = BOOST_GET_CONST(platform::NPUPlace, self.place()); paddle::memory::Copy(platform::CPUPlace(), &b, p, a + offset, sizeof(T), nullptr); #endif } VLOG(10) << "TensorGetElement, place: " << self.place() << ", offset: " << offset << ", element: " << b; return b; } template void TensorSetElement(framework::Tensor *self, size_t offset, T elem) { PADDLE_ENFORCE_LT(offset, self->numel(), platform::errors::InvalidArgument( "The offset exceeds the size of tensor.")); VLOG(10) << "TensorSetElement, place: " << self->place() << ", offset: " << offset << ", element: " << elem; if (platform::is_cpu_place(self->place())) { self->mutable_data(self->place())[offset] = elem; } else if (platform::is_xpu_place(self->place())) { #ifdef PADDLE_WITH_XPU auto p = BOOST_GET_CONST(platform::XPUPlace, self->place()); T *a = self->mutable_data(p); paddle::memory::Copy(p, a + offset, platform::CPUPlace(), &elem, sizeof(T)); #endif } else if (platform::is_gpu_place(self->place())) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto p = BOOST_GET_CONST(platform::CUDAPlace, self->place()); T *a = self->mutable_data(p); paddle::memory::Copy(p, a + offset, platform::CPUPlace(), &elem, sizeof(T), nullptr); #endif } else if (platform::is_npu_place(self->place())) { #if defined(PADDLE_WITH_ASCEND_CL) auto p = BOOST_GET_CONST(platform::NPUPlace, self->place()); T *a = self->mutable_data(p); paddle::memory::Copy(p, a + offset, platform::CPUPlace(), &elem, sizeof(T), nullptr); #endif } } template void SetTensorFromPyArrayT( framework::Tensor *self, const py::array_t &array, const P &place, bool zero_copy) { std::vector dims; dims.reserve(array.ndim()); for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) { dims.push_back(static_cast(array.shape()[i])); } self->Resize(framework::make_ddim(dims)); if (paddle::platform::is_cpu_place(place)) { if (zero_copy) { auto holder = std::make_shared>(array); auto type = framework::ToDataType(std::type_index(typeid(T))); self->ResetHolderWithType(holder, type); } else { auto dst = self->mutable_data(place); std::memcpy(dst, array.data(), array.nbytes()); } } else if (paddle::platform::is_xpu_place(place)) { #ifdef PADDLE_WITH_XPU // NOTE(wangxi): When copying data to the accelerator card, // we need set_device(dev_id) first. platform::Place tmp_place = place; platform::XPUDeviceGuard guard( BOOST_GET_CONST(platform::XPUPlace, tmp_place).device); auto dst = self->mutable_data(place); xpu_memcpy(dst, array.data(), array.nbytes(), XPUMemcpyKind::XPU_HOST_TO_DEVICE); #else PADDLE_THROW(platform::errors::PermissionDenied( "Cannot use XPUPlace in CPU/GPU version, " "Please recompile or reinstall Paddle with XPU support.")); #endif } else if (paddle::platform::is_ipu_place(place)) { #ifdef PADDLE_WITH_IPU if (zero_copy) { auto holder = std::make_shared>(array); auto type = framework::ToDataType(std::type_index(typeid(T))); self->ResetHolderWithType(holder, type); } else { auto dst = self->mutable_data(place); std::memcpy(dst, array.data(), array.nbytes()); } #else PADDLE_THROW(platform::errors::PermissionDenied( "Cannot use IPUPlace in CPU/GPU/XPU/NPU version, " "Please recompile or reinstall Paddle with IPU support.")); #endif } else if (paddle::platform::is_npu_place(place)) { #ifdef PADDLE_WITH_ASCEND_CL platform::Place tmp_place = place; platform::NPUDeviceGuard guard( BOOST_GET_CONST(platform::NPUPlace, tmp_place).device); auto dst = self->mutable_data(place); platform::NPUMemcpySync(dst, array.data(), array.nbytes(), ACL_MEMCPY_HOST_TO_DEVICE); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &ctx = *pool.Get(place); ctx.Wait(); #else PADDLE_THROW(platform::errors::PermissionDenied( "Cannot use NPUPlace in CPU/GPU/XPU version. " "Please recompile or reinstall Paddle with NPU support.")); #endif } else { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (paddle::platform::is_gpu_place(place)) { // NOTE(wangxi): When copying data to the accelerator card, // we need set_device(dev_id) first. platform::Place tmp_place = place; platform::CUDADeviceGuard guard( BOOST_GET_CONST(platform::CUDAPlace, tmp_place).device); auto dst = self->mutable_data(place); #ifdef PADDLE_WITH_HIP paddle::platform::GpuMemcpySync(dst, array.data(), array.nbytes(), hipMemcpyHostToDevice); #else paddle::platform::GpuMemcpySync(dst, array.data(), array.nbytes(), cudaMemcpyHostToDevice); #endif } else if (paddle::platform::is_cuda_pinned_place(place)) { auto dst = self->mutable_data(place); std::memcpy(dst, array.data(), array.nbytes()); } else { PADDLE_THROW(platform::errors::InvalidArgument( "Incompatible place type: Tensor.set() supports " "CPUPlace, CUDAPlace " "and CUDAPinnedPlace, but got %s!", place)); } #else PADDLE_THROW(platform::errors::PermissionDenied( "Cannot use CUDAPlace or CUDAPinnedPlace in CPU only version, " "Please recompile or reinstall Paddle with CUDA support.")); #endif } } template void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj, const P &place, bool zero_copy) { auto array = obj.cast(); if (py::isinstance>(array)) { SetTensorFromPyArrayT(self, array, place, zero_copy); } else if (py::isinstance>(array)) { SetTensorFromPyArrayT(self, array, place, zero_copy); } else if (py::isinstance>(array)) { SetTensorFromPyArrayT(self, array, place, zero_copy); } else if (py::isinstance>(array)) { SetTensorFromPyArrayT(self, array, place, zero_copy); } else if (py::isinstance>(array)) { SetTensorFromPyArrayT(self, array, place, zero_copy); } else if (py::isinstance>(array)) { SetTensorFromPyArrayT(self, array, place, zero_copy); } else if (py::isinstance>(array)) { SetTensorFromPyArrayT(self, array, place, zero_copy); } else if (py::isinstance>(array)) { SetTensorFromPyArrayT(self, array, place, zero_copy); } else if (py::isinstance>>( array)) { SetTensorFromPyArrayT, P>( self, array, place, zero_copy); } else if (py::isinstance>>( array)) { SetTensorFromPyArrayT, P>( self, array, place, zero_copy); } else if (py::isinstance>(array)) { // since there is still no support for bfloat16 in NumPy, // uint16 is used for casting bfloat16 SetTensorFromPyArrayT(self, array, place, zero_copy); } else if (py::isinstance>(array)) { SetTensorFromPyArrayT(self, array, place, zero_copy); } else { // obj may be any type, obj.cast() may be failed, // then the array.dtype will be string of unknown meaning, PADDLE_THROW(platform::errors::InvalidArgument( "Input object type error or incompatible array data type. " "tensor.set() supports array with bool, float16, float32, " "float64, int8, int16, int32, int64, uint8 or uint16, " "please check your input or input array data type.")); } } template void _sliceCompute(const framework::Tensor *in, framework::Tensor *out, const platform::CPUDeviceContext &ctx, const std::vector &axes, const std::vector &starts) { auto &eigen_place = *ctx.eigen_device(); auto place = in->place(); auto out_dims = out->dims(); auto in_dims = in->dims(); auto offsets = Eigen::DSizes(); auto extents = Eigen::DSizes(); for (size_t i = 0; i < D; ++i) { offsets[i] = 0; extents[i] = out_dims[i]; } int start; for (size_t i = 0; i < axes.size(); ++i) { start = starts[i]; if (start < 0) { start = (start + in_dims[axes[i]]); } start = std::max(start, 0); offsets[axes[i]] = start; } auto in_t = framework::EigenTensor::From( *in); auto out_t = framework::EigenTensor::From( *out); operators::EigenSlice, T, D>::Eval( eigen_place, out_t, in_t, offsets, extents); } template void _concatCompute(const std::vector &ins, paddle::framework::Tensor *out, const platform::CPUDeviceContext &ctx, int64_t axis) { if (axis == 0 && ins.size() < 10) { size_t output_offset = 0; for (auto &in : ins) { auto in_stride = framework::stride_numel(in.dims()); auto out_stride = framework::stride_numel(out->dims()); paddle::operators::StridedNumelCopyWithAxis( ctx, axis, out->data() + output_offset, out_stride, in.data(), in_stride, in_stride[axis]); output_offset += in_stride[axis]; } } else { paddle::operators::math::ConcatFunctor concat_functor; concat_functor(ctx, ins, static_cast(axis), out); } } inline void _getSliceinfo(const framework::Tensor &self, py::object obj, const int64_t dim, int64_t *pstart, int64_t *pstop, int64_t *pstep, int64_t *pslicelength) { auto &start = *pstart; auto &stop = *pstop; auto &step = *pstep; auto &slicelength = *pslicelength; const framework::DDim &srcDDim = self.dims(); if (dim < 0 || dim >= srcDDim.size()) { throw py::index_error(); } if (py::isinstance(obj)) { size_t lstart, lstop, lstep, lslicelength; py::slice s = static_cast(obj); if (!s.compute(srcDDim[dim], &lstart, &lstop, &lstep, &lslicelength)) { throw py::index_error(); } start = static_cast(lstart); stop = static_cast(lstop); step = static_cast(lstep); slicelength = static_cast(lslicelength); } else if (py::isinstance(obj)) { start = static_cast(static_cast(obj)); if (std::abs(start) >= srcDDim[dim]) { throw py::index_error(); } start = (start >= 0) ? start : srcDDim[dim] - start; stop = start + 1; step = 1; slicelength = 1; } else { throw py::index_error(); } } inline framework::Tensor *_getTensor(const framework::Tensor &self, const framework::DDim &ddim) { framework::Tensor *output = new framework::Tensor(); output->Resize(ddim); auto place = self.place(); if (platform::is_cpu_place(place)) { output->mutable_data(BOOST_GET_CONST(platform::CPUPlace, place), self.type()); } else if (platform::is_xpu_place(place)) { #ifdef PADDLE_WITH_XPU output->mutable_data(BOOST_GET_CONST(platform::XPUPlace, place), self.type()); #endif } else { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_cuda_pinned_place(place)) { output->mutable_data(BOOST_GET_CONST(platform::CUDAPinnedPlace, place), self.type()); } else if ((platform::is_gpu_place(place))) { output->mutable_data(BOOST_GET_CONST(platform::CUDAPlace, place), self.type()); } #endif } return output; } template void _sliceDapper(const framework::Tensor *in, framework::Tensor *out, const platform::CPUDeviceContext &ctx, const std::vector &axes, const std::vector &starts, int size) { switch (size) { case 1: _sliceCompute(in, out, ctx, axes, starts); break; case 2: _sliceCompute(in, out, ctx, axes, starts); break; case 3: _sliceCompute(in, out, ctx, axes, starts); break; case 4: _sliceCompute(in, out, ctx, axes, starts); break; case 5: _sliceCompute(in, out, ctx, axes, starts); break; case 6: _sliceCompute(in, out, ctx, axes, starts); break; case 7: _sliceCompute(in, out, ctx, axes, starts); break; case 8: _sliceCompute(in, out, ctx, axes, starts); break; case 9: _sliceCompute(in, out, ctx, axes, starts); break; default: PADDLE_THROW(platform::errors::InvalidArgument( "The dim size should be 1 to 9, current is %d", size)); break; } } template inline framework::Tensor *_sliceWrapper(const framework::Tensor &self, const platform::CPUDeviceContext &ctx, py::object obj, int dim, int64_t start, int64_t slicelength) { framework::DDim dstDDim = self.dims(); dstDDim[dim] = static_cast(slicelength); std::vector axes({dim}); std::vector starts({static_cast(start)}); framework::Tensor *output = _getTensor(self, dstDDim); _sliceDapper(&self, output, ctx, axes, starts, dstDDim.size()); return output; } template inline framework::Tensor *_sliceAndConcat(const framework::Tensor &self, py::object obj, int dim) { platform::CPUDeviceContext ctx; int64_t start, stop, step, slicelength; _getSliceinfo(self, obj, dim, &start, &stop, &step, &slicelength); if (step == 1 || slicelength == 1) { return _sliceWrapper(self, ctx, obj, dim, start, slicelength); } else { std::vector ins; for (auto i = 0; i < slicelength; ++i, start += step) { ins.emplace_back(*_sliceWrapper(self, ctx, obj, dim, start, 1)); } // do the concat operation framework::DDim dstDDim = self.dims(); dstDDim[dim] = static_cast(slicelength); framework::Tensor *output1 = _getTensor(self, dstDDim); _concatCompute(ins, output1, ctx, dim); return output1; } } inline framework::Tensor *_sliceTensor(const framework::Tensor &self, py::object obj, int dim) { auto src_type = self.type(); switch (src_type) { case framework::proto::VarType::FP16: return _sliceAndConcat(self, obj, dim); case framework::proto::VarType::BF16: return _sliceAndConcat(self, obj, dim); case framework::proto::VarType::COMPLEX64: return _sliceAndConcat>(self, obj, dim); case framework::proto::VarType::COMPLEX128: return _sliceAndConcat>(self, obj, dim); case framework::proto::VarType::FP32: return _sliceAndConcat(self, obj, dim); case framework::proto::VarType::FP64: return _sliceAndConcat(self, obj, dim); case framework::proto::VarType::INT8: return _sliceAndConcat(self, obj, dim); case framework::proto::VarType::INT16: return _sliceAndConcat(self, obj, dim); case framework::proto::VarType::INT32: return _sliceAndConcat(self, obj, dim); case framework::proto::VarType::INT64: return _sliceAndConcat(self, obj, dim); case framework::proto::VarType::BOOL: return _sliceAndConcat(self, obj, dim); case framework::proto::VarType::UINT8: return _sliceAndConcat(self, obj, dim); default: PADDLE_THROW(platform::errors::InvalidArgument( "Not support tensor type: %s", framework::DataTypeToString(src_type))); } } inline framework::Tensor *_pySliceTensor(const framework::Tensor &self, py::object obj) { if (py::isinstance(obj)) { py::list l = static_cast(obj); std::unique_ptr target; framework::Tensor *src = const_cast(&self); for (auto i = 0; i < static_cast(l.size()); ++i) { src = _sliceTensor(*src, l[i], i); if (i + 1 == static_cast(l.size())) { return src; } else { target.reset(src); } } return nullptr; } else { return _sliceTensor(self, obj, 0); } } inline framework::Tensor *PySliceTensor(const framework::Tensor &self, py::object obj) { if (platform::is_gpu_place(self.place())) { std::unique_ptr holder; framework::Tensor src; framework::TensorCopySync(self, platform::CPUPlace(), &src); framework::Tensor *output = _pySliceTensor(src, obj); holder.reset(output); framework::Tensor *dst = _getTensor(*output, output->dims()); framework::TensorCopySync(*output, self.place(), dst); return dst; } else { return _pySliceTensor(self, obj); } } inline py::array TensorToPyArray(const framework::Tensor &tensor, bool need_deep_copy = false) { if (!tensor.IsInitialized()) { return py::array(); } bool is_gpu_tensor = platform::is_gpu_place(tensor.place()); bool is_xpu_tensor = platform::is_xpu_place(tensor.place()); bool is_npu_tensor = platform::is_npu_place(tensor.place()); const auto &tensor_dims = tensor.dims(); auto tensor_dtype = tensor.type(); size_t sizeof_dtype = framework::SizeOfType(tensor_dtype); std::vector py_dims(tensor_dims.size()); std::vector py_strides(tensor_dims.size()); size_t numel = 1; for (int i = tensor_dims.size() - 1; i >= 0; --i) { py_dims[i] = static_cast(tensor_dims[i]); py_strides[i] = sizeof_dtype * numel; numel *= py_dims[i]; } const void *tensor_buf_ptr = tensor.data(); std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(tensor.type()); if (!is_gpu_tensor && !is_xpu_tensor && !is_npu_tensor) { if (!need_deep_copy) { auto base = py::cast(std::move(tensor)); return py::array(py::dtype(py_dtype_str.c_str()), py_dims, py_strides, const_cast(tensor_buf_ptr), base); } else { py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides); PADDLE_ENFORCE_EQ( py_arr.writeable(), true, platform::errors::InvalidArgument( "PyArray is not writable, in which case memory leak " "or double free would occur")); PADDLE_ENFORCE_EQ( py_arr.owndata(), true, platform::errors::InvalidArgument( "PyArray does not own data, in which case memory leak " "or double free would occur")); platform::CPUPlace place; size_t copy_bytes = sizeof_dtype * numel; paddle::memory::Copy(place, py_arr.mutable_data(), place, tensor_buf_ptr, copy_bytes); return py_arr; } } else if (is_xpu_tensor) { #ifdef PADDLE_WITH_XPU py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides); PADDLE_ENFORCE_EQ(py_arr.writeable(), true, platform::errors::InvalidArgument( "PyArray is not writable, in which case memory leak " "or double free would occur")); PADDLE_ENFORCE_EQ( py_arr.owndata(), true, platform::errors::InvalidArgument( "PyArray does not own data, in which case memory leak " "or double free would occur")); size_t copy_bytes = sizeof_dtype * numel; auto p = BOOST_GET_CONST(platform::XPUPlace, tensor.place()); paddle::memory::Copy(platform::CPUPlace(), py_arr.mutable_data(), p, tensor_buf_ptr, copy_bytes); return py_arr; #else PADDLE_THROW(platform::errors::PermissionDenied( "Cannot use XPUPlace in CPU/GPU version, " "Please recompile or reinstall Paddle with XPU support.")); #endif } else if (is_gpu_tensor) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides); PADDLE_ENFORCE_EQ(py_arr.writeable(), true, platform::errors::InvalidArgument( "PyArray is not writable, in which case memory leak " "or double free would occur")); PADDLE_ENFORCE_EQ( py_arr.owndata(), true, platform::errors::InvalidArgument( "PyArray does not own data, in which case memory leak " "or double free would occur")); size_t copy_bytes = sizeof_dtype * numel; auto p = BOOST_GET_CONST(platform::CUDAPlace, tensor.place()); paddle::memory::Copy(platform::CPUPlace(), py_arr.mutable_data(), p, tensor_buf_ptr, copy_bytes, nullptr); return py_arr; #else PADDLE_THROW(platform::errors::PermissionDenied( "Cannot use CUDAPlace in CPU only version, " "Please recompile or reinstall Paddle with CUDA support.")); #endif } else if (is_npu_tensor) { #ifdef PADDLE_WITH_ASCEND_CL py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides); PADDLE_ENFORCE_EQ(py_arr.writeable(), true, platform::errors::InvalidArgument( "PyArray is not writable, in which case memory leak " "or double free would occur")); PADDLE_ENFORCE_EQ( py_arr.owndata(), true, platform::errors::InvalidArgument( "PyArray does not own data, in which case memory leak " "or double free would occur")); size_t copy_bytes = sizeof_dtype * numel; auto p = BOOST_GET_CONST(platform::NPUPlace, tensor.place()); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &ctx = *pool.Get(tensor.place()); paddle::memory::Copy( platform::CPUPlace(), py_arr.mutable_data(), p, tensor_buf_ptr, copy_bytes, reinterpret_cast(ctx).stream()); ctx.Wait(); return py_arr; #else PADDLE_THROW(platform::errors::PermissionDenied( "Cannot use NPUPlace in CPU/GPU/XPU version, " "Please recompile or reinstall Paddle with NPU support.")); #endif } PADDLE_THROW(platform::errors::Unimplemented("Place is not supported")); return py::array(); } } // namespace pybind } // namespace paddle