未验证 提交 31a5829a 编写于 作者: S Siming Dai 提交者: GitHub

dlpack fix (#35817) (#36177)

上级 21c65f66
...@@ -18,7 +18,7 @@ set(DLPACK_PREFIX_DIR ${THIRD_PARTY_PATH}/dlpack) ...@@ -18,7 +18,7 @@ set(DLPACK_PREFIX_DIR ${THIRD_PARTY_PATH}/dlpack)
set(DLPACK_SOURCE_DIR ${THIRD_PARTY_PATH}/dlpack/src/extern_dlpack) set(DLPACK_SOURCE_DIR ${THIRD_PARTY_PATH}/dlpack/src/extern_dlpack)
set(DLPACK_REPOSITORY ${GIT_URL}/dmlc/dlpack.git) set(DLPACK_REPOSITORY ${GIT_URL}/dmlc/dlpack.git)
set(DLPACK_TAG v0.2) set(DLPACK_TAG v0.4)
cache_third_party(extern_dlpack cache_third_party(extern_dlpack
REPOSITORY ${DLPACK_REPOSITORY} REPOSITORY ${DLPACK_REPOSITORY}
......
...@@ -30,14 +30,10 @@ static ::DLDataType GetDLDataTypeCode() { ...@@ -30,14 +30,10 @@ static ::DLDataType GetDLDataTypeCode() {
::DLDataType dtype; ::DLDataType dtype;
if (std::is_same<T, platform::complex<float>>::value || if (std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value) { std::is_same<T, platform::complex<double>>::value) {
// The current dlpack library version is v0.2, and does not define dtype.code = kDLComplex;
// kDLComplex value. But kDLComplex is defined by 5U in v0.4, so we set } else if (std::is_same<T, platform::bfloat16>::value) {
// dtype.code to 5U directly here. After the dlpack library version being dtype.code = kDLBfloat;
// upgraded to v0.4, it should be written as follow.
// dtype.code = kDLComplex;
dtype.code = 5U;
} else if (std::is_same<T, platform::float16>::value || } else if (std::is_same<T, platform::float16>::value ||
std::is_same<T, platform::bfloat16>::value ||
std::is_floating_point<T>::value) { std::is_floating_point<T>::value) {
dtype.code = kDLFloat; dtype.code = kDLFloat;
} else if (std::is_unsigned<T>::value) { } else if (std::is_unsigned<T>::value) {
...@@ -77,47 +73,47 @@ static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) { ...@@ -77,47 +73,47 @@ static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
#undef REG_DL_DATA_TYPE #undef REG_DL_DATA_TYPE
} }
struct DLContextVisitor : public boost::static_visitor<::DLContext> { struct DLDeviceVisitor : public boost::static_visitor<::DLDevice> {
inline ::DLContext operator()(const platform::CPUPlace &place) const { inline ::DLDevice operator()(const platform::CPUPlace &place) const {
::DLContext ctx; ::DLDevice device;
ctx.device_type = kDLCPU; device.device_type = kDLCPU;
ctx.device_id = 0; device.device_id = 0;
return ctx; return device;
} }
inline ::DLContext operator()(const platform::XPUPlace &place) const { inline ::DLDevice operator()(const platform::XPUPlace &place) const {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unimplemented("platform::XPUPlace is not supported")); platform::errors::Unimplemented("platform::XPUPlace is not supported"));
} }
inline ::DLContext operator()(const platform::NPUPlace &place) const { inline ::DLDevice operator()(const platform::NPUPlace &place) const {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unimplemented("platform::NPUPlace is not supported")); platform::errors::Unimplemented("platform::NPUPlace is not supported"));
} }
inline ::DLContext operator()(const platform::NPUPinnedPlace &place) const { inline ::DLDevice operator()(const platform::NPUPinnedPlace &place) const {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"platform::NPUPinnedPlace is not supported")); "platform::NPUPinnedPlace is not supported"));
} }
inline ::DLContext operator()(const platform::CUDAPlace &place) const { inline ::DLDevice operator()(const platform::CUDAPlace &place) const {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
::DLContext ctx; ::DLDevice device;
ctx.device_type = kDLGPU; device.device_type = kDLGPU;
ctx.device_id = place.device; device.device_id = place.device;
return ctx; return device;
#else #else
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"platform::CUDAPlace is not supported in CPU only version.")); "platform::CUDAPlace is not supported in CPU only version."));
#endif #endif
} }
inline ::DLContext operator()(const platform::CUDAPinnedPlace &place) const { inline ::DLDevice operator()(const platform::CUDAPinnedPlace &place) const {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
::DLContext ctx; ::DLDevice device;
ctx.device_type = kDLCPUPinned; device.device_type = kDLCPUPinned;
ctx.device_id = 0; device.device_id = 0;
return ctx; return device;
#else #else
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"platform::CUDAPinnedPlace is not supported in CPU only version.")); "platform::CUDAPinnedPlace is not supported in CPU only version."));
...@@ -130,9 +126,9 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) { ...@@ -130,9 +126,9 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) {
// init data, data buffer // init data, data buffer
t_.data = const_cast<void *>(tensor.data<void>()); t_.data = const_cast<void *>(tensor.data<void>());
// init ctx, DLContext type with device_type and device_id // init device, DLDevice type with device_type and device_id
auto place = tensor.place(); auto place = tensor.place();
t_.ctx = boost::apply_visitor(internal::DLContextVisitor(), place); t_.device = boost::apply_visitor(internal::DLDeviceVisitor(), place);
// init dtype // init dtype
t_.dtype = internal::GetDLDataTypeFromTypeIndex(tensor.type()); t_.dtype = internal::GetDLDataTypeFromTypeIndex(tensor.type());
...@@ -156,10 +152,8 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) { ...@@ -156,10 +152,8 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) {
t_.byte_offset = 0; t_.byte_offset = 0;
} }
::DLManagedTensor *DLPackTensor::ToCudfCompatibleDLManagedTensor() { ::DLManagedTensor *DLPackTensor::ToDLManagedTensor() {
// init shape, tensor dims // init shape
// for DLManagedTensor shape need to be compatible with ndim
// refer to cupy and cudf, we new int64[ndim]
auto shape = new int64_t[t_.ndim]; auto shape = new int64_t[t_.ndim];
using DimType = decltype(t_.ndim); // int using DimType = decltype(t_.ndim); // int
for (DimType i = 0; i < t_.ndim; ++i) { for (DimType i = 0; i < t_.ndim; ++i) {
...@@ -167,19 +161,15 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) { ...@@ -167,19 +161,15 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) {
} }
t_.shape = shape; t_.shape = shape;
// init strides, nullptr means the tensor is compact // init strides
// refer to cupy and cudf, the compact tensor first dim's strides need to be 1 auto strides = new int64_t[t_.ndim];
// and second dim's strides need to be length of rows of cudf for (DimType i = 0; i < t_.ndim; ++i) {
// cudf now only support dim=2 strides[i] = 1;
PADDLE_ENFORCE_LE(t_.ndim, 2, platform::errors::InvalidArgument( }
"cudf now only supports dimension is 2, " for (DimType i = t_.ndim - 2; i >= 0; --i) {
"but received dimension is %d.", strides[i] = t_.shape[i + 1] * strides[i + 1];
t_.ndim)); }
t_.strides = strides;
if (t_.ndim > 1)
t_.strides = new int64_t[2]{1, t_.shape[1]};
else
t_.strides = new int64_t[1]{1};
auto tensor = new DLManagedTensor; auto tensor = new DLManagedTensor;
tensor->dl_tensor = t_; tensor->dl_tensor = t_;
......
...@@ -36,7 +36,7 @@ class DLPackTensor { ...@@ -36,7 +36,7 @@ class DLPackTensor {
inline operator ::DLTensor&() { return t_; } inline operator ::DLTensor&() { return t_; }
::DLManagedTensor* ToCudfCompatibleDLManagedTensor(); ::DLManagedTensor* ToDLManagedTensor();
private: private:
::DLTensor t_; ::DLTensor t_;
......
...@@ -30,7 +30,11 @@ template <typename T> ...@@ -30,7 +30,11 @@ template <typename T>
constexpr uint8_t GetDLDataTypeCode() { constexpr uint8_t GetDLDataTypeCode() {
if (std::is_same<T, platform::complex<float>>::value || if (std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value) { std::is_same<T, platform::complex<double>>::value) {
return static_cast<uint8_t>(5); return static_cast<uint8_t>(kDLComplex);
}
if (std::is_same<T, platform::bfloat16>::value) {
return static_cast<uint8_t>(kDLBfloat);
} }
return std::is_same<platform::float16, T>::value || return std::is_same<platform::float16, T>::value ||
...@@ -55,15 +59,15 @@ void TestMain(const platform::Place &place, uint16_t lanes) { ...@@ -55,15 +59,15 @@ void TestMain(const platform::Place &place, uint16_t lanes) {
CHECK_EQ(p, dl_tensor.data); CHECK_EQ(p, dl_tensor.data);
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
CHECK_EQ(kDLCPU, dl_tensor.ctx.device_type); CHECK_EQ(kDLCPU, dl_tensor.device.device_type);
CHECK_EQ(0, dl_tensor.ctx.device_id); CHECK_EQ(0, dl_tensor.device.device_id);
} else if (platform::is_gpu_place(place)) { } else if (platform::is_gpu_place(place)) {
CHECK_EQ(kDLGPU, dl_tensor.ctx.device_type); CHECK_EQ(kDLGPU, dl_tensor.device.device_type);
CHECK_EQ(BOOST_GET_CONST(platform::CUDAPlace, place).device, CHECK_EQ(BOOST_GET_CONST(platform::CUDAPlace, place).device,
dl_tensor.ctx.device_id); dl_tensor.device.device_id);
} else if (platform::is_cuda_pinned_place(place)) { } else if (platform::is_cuda_pinned_place(place)) {
CHECK_EQ(kDLCPUPinned, dl_tensor.ctx.device_type); CHECK_EQ(kDLCPUPinned, dl_tensor.device.device_type);
CHECK_EQ(0, dl_tensor.ctx.device_id); CHECK_EQ(0, dl_tensor.device.device_id);
} else { } else {
CHECK_EQ(false, true); CHECK_EQ(false, true);
} }
...@@ -83,8 +87,7 @@ void TestMain(const platform::Place &place, uint16_t lanes) { ...@@ -83,8 +87,7 @@ void TestMain(const platform::Place &place, uint16_t lanes) {
} }
template <typename T> template <typename T>
void TestToCudfCompatibleDLManagedTensor(const platform::Place &place, void TestToDLManagedTensor(const platform::Place &place, uint16_t lanes) {
uint16_t lanes) {
DDim dims{6, 7}; DDim dims{6, 7};
Tensor tensor; Tensor tensor;
tensor.Resize(dims); tensor.Resize(dims);
...@@ -92,8 +95,7 @@ void TestToCudfCompatibleDLManagedTensor(const platform::Place &place, ...@@ -92,8 +95,7 @@ void TestToCudfCompatibleDLManagedTensor(const platform::Place &place,
DLPackTensor dlpack_tensor(tensor, lanes); DLPackTensor dlpack_tensor(tensor, lanes);
::DLManagedTensor *dl_managed_tensor = ::DLManagedTensor *dl_managed_tensor = dlpack_tensor.ToDLManagedTensor();
dlpack_tensor.ToCudfCompatibleDLManagedTensor();
CHECK_EQ(dl_managed_tensor->manager_ctx == nullptr, true); CHECK_EQ(dl_managed_tensor->manager_ctx == nullptr, true);
...@@ -101,7 +103,8 @@ void TestToCudfCompatibleDLManagedTensor(const platform::Place &place, ...@@ -101,7 +103,8 @@ void TestToCudfCompatibleDLManagedTensor(const platform::Place &place,
CHECK_EQ(dims[i], dl_managed_tensor->dl_tensor.shape[i]); CHECK_EQ(dims[i], dl_managed_tensor->dl_tensor.shape[i]);
} }
CHECK_EQ(dl_managed_tensor->dl_tensor.strides[0] == 1, true); CHECK_EQ(dl_managed_tensor->dl_tensor.strides[0] == 7, true);
CHECK_EQ(dl_managed_tensor->dl_tensor.strides[1] == 1, true);
dl_managed_tensor->deleter(dl_managed_tensor); dl_managed_tensor->deleter(dl_managed_tensor);
} }
...@@ -122,7 +125,7 @@ void TestMainLoop() { ...@@ -122,7 +125,7 @@ void TestMainLoop() {
for (auto &p : places) { for (auto &p : places) {
for (auto &l : lanes) { for (auto &l : lanes) {
TestMain<T>(p, l); TestMain<T>(p, l);
TestToCudfCompatibleDLManagedTensor<T>(p, l); TestToDLManagedTensor<T>(p, l);
} }
} }
} }
......
...@@ -1065,6 +1065,9 @@ void* GetDstPtrByDLDataType(DLDataType type, framework::Tensor* dst, ...@@ -1065,6 +1065,9 @@ void* GetDstPtrByDLDataType(DLDataType type, framework::Tensor* dst,
if (type.code == kDLFloat) if (type.code == kDLFloat)
return static_cast<void*>( return static_cast<void*>(
dst->mutable_data<paddle::platform::float16>(dst_place)); dst->mutable_data<paddle::platform::float16>(dst_place));
if (type.code == kDLBfloat)
return static_cast<void*>(
dst->mutable_data<paddle::platform::bfloat16>(dst_place));
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"DLDataType code <%d> is illegal when DLDataType.bits is <%d>.", "DLDataType code <%d> is illegal when DLDataType.bits is <%d>.",
type.code, type.bits)); type.code, type.bits));
...@@ -1081,6 +1084,16 @@ void* GetDstPtrByDLDataType(DLDataType type, framework::Tensor* dst, ...@@ -1081,6 +1084,16 @@ void* GetDstPtrByDLDataType(DLDataType type, framework::Tensor* dst,
return static_cast<void*>(dst->mutable_data<int64_t>(dst_place)); return static_cast<void*>(dst->mutable_data<int64_t>(dst_place));
if (type.code == kDLFloat) if (type.code == kDLFloat)
return static_cast<void*>(dst->mutable_data<double>(dst_place)); return static_cast<void*>(dst->mutable_data<double>(dst_place));
if (type.code == kDLComplex)
return static_cast<void*>(
dst->mutable_data<paddle::platform::complex<float>>(dst_place));
PADDLE_THROW(platform::errors::Unimplemented(
"DLDataType code <%d> is illegal when DLDataType.bits is <%d>.",
type.code, type.bits));
case 128:
if (type.code == kDLComplex)
return static_cast<void*>(
dst->mutable_data<paddle::platform::complex<double>>(dst_place));
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"DLDataType code <%d> is illegal when DLDataType.bits is <%d>.", "DLDataType code <%d> is illegal when DLDataType.bits is <%d>.",
type.code, type.bits)); type.code, type.bits));
...@@ -1107,15 +1120,15 @@ void TensorFromDLPack(const ::DLTensor& dl_tensor, framework::Tensor* dst) { ...@@ -1107,15 +1120,15 @@ void TensorFromDLPack(const ::DLTensor& dl_tensor, framework::Tensor* dst) {
auto src_ptr = static_cast<const void*>(dl_tensor.data); auto src_ptr = static_cast<const void*>(dl_tensor.data);
auto size = paddle::framework::product(vddim) * type.bits / 8; auto size = paddle::framework::product(vddim) * type.bits / 8;
if (dl_tensor.ctx.device_type == kDLCPU) { if (dl_tensor.device.device_type == kDLCPU) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (dl_tensor.ctx.device_type == kDLGPU) { if (dl_tensor.device.device_type == kDLGPU) {
platform::CUDAPlace dst_place = platform::CUDAPlace dst_place =
platform::CUDAPlace(dl_tensor.ctx.device_id); platform::CUDAPlace(dl_tensor.device.device_id);
platform::CUDAPlace src_place = platform::CUDAPlace src_place =
platform::CUDAPlace(dl_tensor.ctx.device_id); platform::CUDAPlace(dl_tensor.device.device_id);
dst_ptr = GetDstPtrByDLDataType(type, dst, dst_place); dst_ptr = GetDstPtrByDLDataType(type, dst, dst_place);
auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(dst_place); auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(dst_place);
memory::Copy( memory::Copy(
......
...@@ -537,11 +537,11 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -537,11 +537,11 @@ PYBIND11_MODULE(core_noavx, m) {
DLTensor dl = dmt->dl_tensor; DLTensor dl = dmt->dl_tensor;
framework::Tensor tensor; framework::Tensor tensor;
if (dl.ctx.device_type == kDLCPU) { if (dl.device.device_type == kDLCPU) {
paddle::framework::TensorFromDLPack(dl, &tensor); paddle::framework::TensorFromDLPack(dl, &tensor);
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (dl.ctx.device_type == kDLGPU) { if (dl.device.device_type == kDLGPU) {
paddle::framework::TensorFromDLPack(dl, &tensor); paddle::framework::TensorFromDLPack(dl, &tensor);
} }
#endif #endif
...@@ -776,8 +776,7 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -776,8 +776,7 @@ PYBIND11_MODULE(core_noavx, m) {
.def("_to_dlpack", .def("_to_dlpack",
[](framework::Tensor &self) { [](framework::Tensor &self) {
DLPackTensor dlpack_tensor(self, 1); DLPackTensor dlpack_tensor(self, 1);
DLManagedTensor *dmt = DLManagedTensor *dmt = dlpack_tensor.ToDLManagedTensor();
dlpack_tensor.ToCudfCompatibleDLManagedTensor();
auto capsule = py::capsule( auto capsule = py::capsule(
static_cast<void *>(dmt), "dltensor", [](PyObject *ptr) { static_cast<void *>(dmt), "dltensor", [](PyObject *ptr) {
if (ptr) { if (ptr) {
......
...@@ -22,6 +22,7 @@ import paddle.fluid.core as core ...@@ -22,6 +22,7 @@ import paddle.fluid.core as core
class TestDLPack(unittest.TestCase): class TestDLPack(unittest.TestCase):
def test_dlpack_dygraph(self): def test_dlpack_dygraph(self):
paddle.disable_static()
tensor = paddle.to_tensor(np.array([1, 2, 3, 4]).astype('int')) tensor = paddle.to_tensor(np.array([1, 2, 3, 4]).astype('int'))
dlpack = paddle.utils.dlpack.to_dlpack(tensor) dlpack = paddle.utils.dlpack.to_dlpack(tensor)
out_from_dlpack = paddle.utils.dlpack.from_dlpack(dlpack) out_from_dlpack = paddle.utils.dlpack.from_dlpack(dlpack)
...@@ -31,6 +32,15 @@ class TestDLPack(unittest.TestCase): ...@@ -31,6 +32,15 @@ class TestDLPack(unittest.TestCase):
np.array(out_from_dlpack), np.array([1, 2, 3, 4]).astype( np.array(out_from_dlpack), np.array([1, 2, 3, 4]).astype(
'int'))) 'int')))
def test_dlpack_tensor_larger_than_2dim(self):
paddle.disable_static()
numpy_data = np.random.randn(4, 5, 6)
t = paddle.to_tensor(numpy_data)
# TODO: There may be a reference count problem of to_dlpack.
dlpack = paddle.utils.dlpack.to_dlpack(t)
out = paddle.utils.dlpack.from_dlpack(dlpack)
self.assertTrue(np.allclose(numpy_data, out.numpy()))
def test_dlpack_static(self): def test_dlpack_static(self):
paddle.enable_static() paddle.enable_static()
tensor = fluid.create_lod_tensor( tensor = fluid.create_lod_tensor(
...@@ -57,6 +67,37 @@ class TestDLPack(unittest.TestCase): ...@@ -57,6 +67,37 @@ class TestDLPack(unittest.TestCase):
np.array(gout_from_dlpack), np.array(gout_from_dlpack),
np.array([[1], [2], [3], [4]]).astype('int'))) np.array([[1], [2], [3], [4]]).astype('int')))
def test_dlpack_dtype_conversion(self):
paddle.disable_static()
# DLpack does not explicitly support bool data type.
dtypes = [
"float16",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint8",
]
data = np.ones((2, 3, 4))
for dtype in dtypes:
x = paddle.to_tensor(data, dtype=dtype)
dlpack = paddle.utils.dlpack.to_dlpack(x)
o = paddle.utils.dlpack.from_dlpack(dlpack)
self.assertEqual(x.dtype, o.dtype)
self.assertTrue(np.allclose(x.numpy(), o.numpy()))
complex_dtypes = ["complex64", "complex128"]
for dtype in complex_dtypes:
x = paddle.to_tensor(
[[1 + 6j, 2 + 5j, 3 + 4j], [4 + 3j, 5 + 2j, 6 + 1j]],
dtype=dtype)
dlpack = paddle.utils.dlpack.to_dlpack(x)
o = paddle.utils.dlpack.from_dlpack(dlpack)
self.assertEqual(x.dtype, o.dtype)
self.assertTrue(np.allclose(x.numpy(), o.numpy()))
class TestRaiseError(unittest.TestCase): class TestRaiseError(unittest.TestCase):
def test_from_dlpack_raise_type_error(self): def test_from_dlpack_raise_type_error(self):
......
...@@ -28,7 +28,9 @@ def to_dlpack(x): ...@@ -28,7 +28,9 @@ def to_dlpack(x):
Encodes a tensor to DLPack. Encodes a tensor to DLPack.
Args: Args:
x (Tensor): A tensor, and the data type is bool, float32, float64, int32, int64. x (Tensor): The input tensor, and the data type can be `bool`, `float16`, `float32`,
`float64`, `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`,
`complex128`.
Returns: Returns:
dltensor, and the data type is PyCapsule. dltensor, and the data type is PyCapsule.
...@@ -51,19 +53,9 @@ def to_dlpack(x): ...@@ -51,19 +53,9 @@ def to_dlpack(x):
"The type of 'x' in to_dlpack must be paddle.Tensor," "The type of 'x' in to_dlpack must be paddle.Tensor,"
" but received {}.".format(type(x))) " but received {}.".format(type(x)))
dtype = convert_dtype(x.dtype)
if dtype not in ['bool', 'int32', 'int64', 'float32', 'float64']:
raise TypeError(
"the dtype of 'x' in to_dlpack must be any of [bool, int32, int64, "
"float32, float64], but received {}.".format(dtype))
return x.value().get_tensor()._to_dlpack() return x.value().get_tensor()._to_dlpack()
check_type(x, 'x', (LoDTensor), 'to_dlpack') check_type(x, 'x', (LoDTensor), 'to_dlpack')
check_dtype(x._dtype(), 'x',
['bool', 'int32', 'int64', 'float32', 'float64'], 'to_dlpack')
return x._to_dlpack() return x._to_dlpack()
...@@ -75,7 +67,9 @@ def from_dlpack(dlpack): ...@@ -75,7 +67,9 @@ def from_dlpack(dlpack):
dlpack (PyCapsule): a PyCapsule object with the dltensor. dlpack (PyCapsule): a PyCapsule object with the dltensor.
Returns: Returns:
out (Tensor): a tensor decoded from DLPack. out (Tensor): a tensor decoded from DLPack. One thing to be noted, if we get
an input dltensor with data type as `bool`, we return the decoded
tensor as `uint8`.
Examples: Examples:
.. code-block:: python .. code-block:: python
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册