Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
31a5829a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
31a5829a
编写于
10月 11, 2021
作者:
S
Siming Dai
提交者:
GitHub
10月 11, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dlpack fix (#35817) (#36177)
上级
21c65f66
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
120 addition
and
80 deletion
+120
-80
cmake/external/dlpack.cmake
cmake/external/dlpack.cmake
+1
-1
paddle/fluid/framework/dlpack_tensor.cc
paddle/fluid/framework/dlpack_tensor.cc
+35
-45
paddle/fluid/framework/dlpack_tensor.h
paddle/fluid/framework/dlpack_tensor.h
+1
-1
paddle/fluid/framework/dlpack_tensor_test.cc
paddle/fluid/framework/dlpack_tensor_test.cc
+16
-13
paddle/fluid/framework/tensor_util.cc
paddle/fluid/framework/tensor_util.cc
+17
-4
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+3
-4
python/paddle/tests/test_dlpack.py
python/paddle/tests/test_dlpack.py
+41
-0
python/paddle/utils/dlpack.py
python/paddle/utils/dlpack.py
+6
-12
未找到文件。
cmake/external/dlpack.cmake
浏览文件 @
31a5829a
...
...
@@ -18,7 +18,7 @@ set(DLPACK_PREFIX_DIR ${THIRD_PARTY_PATH}/dlpack)
set
(
DLPACK_SOURCE_DIR
${
THIRD_PARTY_PATH
}
/dlpack/src/extern_dlpack
)
set
(
DLPACK_REPOSITORY
${
GIT_URL
}
/dmlc/dlpack.git
)
set
(
DLPACK_TAG v0.
2
)
set
(
DLPACK_TAG v0.
4
)
cache_third_party
(
extern_dlpack
REPOSITORY
${
DLPACK_REPOSITORY
}
...
...
paddle/fluid/framework/dlpack_tensor.cc
浏览文件 @
31a5829a
...
...
@@ -30,14 +30,10 @@ static ::DLDataType GetDLDataTypeCode() {
::
DLDataType
dtype
;
if
(
std
::
is_same
<
T
,
platform
::
complex
<
float
>>::
value
||
std
::
is_same
<
T
,
platform
::
complex
<
double
>>::
value
)
{
// The current dlpack library version is v0.2, and does not define
// kDLComplex value. But kDLComplex is defined by 5U in v0.4, so we set
// dtype.code to 5U directly here. After the dlpack library version being
// upgraded to v0.4, it should be written as follow.
// dtype.code = kDLComplex;
dtype
.
code
=
5U
;
dtype
.
code
=
kDLComplex
;
}
else
if
(
std
::
is_same
<
T
,
platform
::
bfloat16
>::
value
)
{
dtype
.
code
=
kDLBfloat
;
}
else
if
(
std
::
is_same
<
T
,
platform
::
float16
>::
value
||
std
::
is_same
<
T
,
platform
::
bfloat16
>::
value
||
std
::
is_floating_point
<
T
>::
value
)
{
dtype
.
code
=
kDLFloat
;
}
else
if
(
std
::
is_unsigned
<
T
>::
value
)
{
...
...
@@ -77,47 +73,47 @@ static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
#undef REG_DL_DATA_TYPE
}
struct
DL
ContextVisitor
:
public
boost
::
static_visitor
<::
DLContext
>
{
inline
::
DL
Context
operator
()(
const
platform
::
CPUPlace
&
place
)
const
{
::
DL
Context
ctx
;
ctx
.
device_type
=
kDLCPU
;
ctx
.
device_id
=
0
;
return
ctx
;
struct
DL
DeviceVisitor
:
public
boost
::
static_visitor
<::
DLDevice
>
{
inline
::
DL
Device
operator
()(
const
platform
::
CPUPlace
&
place
)
const
{
::
DL
Device
device
;
device
.
device_type
=
kDLCPU
;
device
.
device_id
=
0
;
return
device
;
}
inline
::
DL
Context
operator
()(
const
platform
::
XPUPlace
&
place
)
const
{
inline
::
DL
Device
operator
()(
const
platform
::
XPUPlace
&
place
)
const
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"platform::XPUPlace is not supported"
));
}
inline
::
DL
Context
operator
()(
const
platform
::
NPUPlace
&
place
)
const
{
inline
::
DL
Device
operator
()(
const
platform
::
NPUPlace
&
place
)
const
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"platform::NPUPlace is not supported"
));
}
inline
::
DL
Context
operator
()(
const
platform
::
NPUPinnedPlace
&
place
)
const
{
inline
::
DL
Device
operator
()(
const
platform
::
NPUPinnedPlace
&
place
)
const
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"platform::NPUPinnedPlace is not supported"
));
}
inline
::
DL
Context
operator
()(
const
platform
::
CUDAPlace
&
place
)
const
{
inline
::
DL
Device
operator
()(
const
platform
::
CUDAPlace
&
place
)
const
{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
::
DL
Context
ctx
;
ctx
.
device_type
=
kDLGPU
;
ctx
.
device_id
=
place
.
device
;
return
ctx
;
::
DL
Device
device
;
device
.
device_type
=
kDLGPU
;
device
.
device_id
=
place
.
device
;
return
device
;
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"platform::CUDAPlace is not supported in CPU only version."
));
#endif
}
inline
::
DL
Context
operator
()(
const
platform
::
CUDAPinnedPlace
&
place
)
const
{
inline
::
DL
Device
operator
()(
const
platform
::
CUDAPinnedPlace
&
place
)
const
{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
::
DL
Context
ctx
;
ctx
.
device_type
=
kDLCPUPinned
;
ctx
.
device_id
=
0
;
return
ctx
;
::
DL
Device
device
;
device
.
device_type
=
kDLCPUPinned
;
device
.
device_id
=
0
;
return
device
;
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"platform::CUDAPinnedPlace is not supported in CPU only version."
));
...
...
@@ -130,9 +126,9 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) {
// init data, data buffer
t_
.
data
=
const_cast
<
void
*>
(
tensor
.
data
<
void
>
());
// init
ctx, DLContext
type with device_type and device_id
// init
device, DLDevice
type with device_type and device_id
auto
place
=
tensor
.
place
();
t_
.
ctx
=
boost
::
apply_visitor
(
internal
::
DLContext
Visitor
(),
place
);
t_
.
device
=
boost
::
apply_visitor
(
internal
::
DLDevice
Visitor
(),
place
);
// init dtype
t_
.
dtype
=
internal
::
GetDLDataTypeFromTypeIndex
(
tensor
.
type
());
...
...
@@ -156,10 +152,8 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) {
t_
.
byte_offset
=
0
;
}
::
DLManagedTensor
*
DLPackTensor
::
ToCudfCompatibleDLManagedTensor
()
{
// init shape, tensor dims
// for DLManagedTensor shape need to be compatible with ndim
// refer to cupy and cudf, we new int64[ndim]
::
DLManagedTensor
*
DLPackTensor
::
ToDLManagedTensor
()
{
// init shape
auto
shape
=
new
int64_t
[
t_
.
ndim
];
using
DimType
=
decltype
(
t_
.
ndim
);
// int
for
(
DimType
i
=
0
;
i
<
t_
.
ndim
;
++
i
)
{
...
...
@@ -167,19 +161,15 @@ DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) {
}
t_
.
shape
=
shape
;
// init strides, nullptr means the tensor is compact
// refer to cupy and cudf, the compact tensor first dim's strides need to be 1
// and second dim's strides need to be length of rows of cudf
// cudf now only support dim=2
PADDLE_ENFORCE_LE
(
t_
.
ndim
,
2
,
platform
::
errors
::
InvalidArgument
(
"cudf now only supports dimension is 2, "
"but received dimension is %d."
,
t_
.
ndim
));
if
(
t_
.
ndim
>
1
)
t_
.
strides
=
new
int64_t
[
2
]{
1
,
t_
.
shape
[
1
]};
else
t_
.
strides
=
new
int64_t
[
1
]{
1
};
// init strides
auto
strides
=
new
int64_t
[
t_
.
ndim
];
for
(
DimType
i
=
0
;
i
<
t_
.
ndim
;
++
i
)
{
strides
[
i
]
=
1
;
}
for
(
DimType
i
=
t_
.
ndim
-
2
;
i
>=
0
;
--
i
)
{
strides
[
i
]
=
t_
.
shape
[
i
+
1
]
*
strides
[
i
+
1
];
}
t_
.
strides
=
strides
;
auto
tensor
=
new
DLManagedTensor
;
tensor
->
dl_tensor
=
t_
;
...
...
paddle/fluid/framework/dlpack_tensor.h
浏览文件 @
31a5829a
...
...
@@ -36,7 +36,7 @@ class DLPackTensor {
inline
operator
::
DLTensor
&
()
{
return
t_
;
}
::
DLManagedTensor
*
To
CudfCompatible
DLManagedTensor
();
::
DLManagedTensor
*
ToDLManagedTensor
();
private:
::
DLTensor
t_
;
...
...
paddle/fluid/framework/dlpack_tensor_test.cc
浏览文件 @
31a5829a
...
...
@@ -30,7 +30,11 @@ template <typename T>
constexpr
uint8_t
GetDLDataTypeCode
()
{
if
(
std
::
is_same
<
T
,
platform
::
complex
<
float
>>::
value
||
std
::
is_same
<
T
,
platform
::
complex
<
double
>>::
value
)
{
return
static_cast
<
uint8_t
>
(
5
);
return
static_cast
<
uint8_t
>
(
kDLComplex
);
}
if
(
std
::
is_same
<
T
,
platform
::
bfloat16
>::
value
)
{
return
static_cast
<
uint8_t
>
(
kDLBfloat
);
}
return
std
::
is_same
<
platform
::
float16
,
T
>::
value
||
...
...
@@ -55,15 +59,15 @@ void TestMain(const platform::Place &place, uint16_t lanes) {
CHECK_EQ
(
p
,
dl_tensor
.
data
);
if
(
platform
::
is_cpu_place
(
place
))
{
CHECK_EQ
(
kDLCPU
,
dl_tensor
.
ctx
.
device_type
);
CHECK_EQ
(
0
,
dl_tensor
.
ctx
.
device_id
);
CHECK_EQ
(
kDLCPU
,
dl_tensor
.
device
.
device_type
);
CHECK_EQ
(
0
,
dl_tensor
.
device
.
device_id
);
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
CHECK_EQ
(
kDLGPU
,
dl_tensor
.
ctx
.
device_type
);
CHECK_EQ
(
kDLGPU
,
dl_tensor
.
device
.
device_type
);
CHECK_EQ
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place
).
device
,
dl_tensor
.
ctx
.
device_id
);
dl_tensor
.
device
.
device_id
);
}
else
if
(
platform
::
is_cuda_pinned_place
(
place
))
{
CHECK_EQ
(
kDLCPUPinned
,
dl_tensor
.
ctx
.
device_type
);
CHECK_EQ
(
0
,
dl_tensor
.
ctx
.
device_id
);
CHECK_EQ
(
kDLCPUPinned
,
dl_tensor
.
device
.
device_type
);
CHECK_EQ
(
0
,
dl_tensor
.
device
.
device_id
);
}
else
{
CHECK_EQ
(
false
,
true
);
}
...
...
@@ -83,8 +87,7 @@ void TestMain(const platform::Place &place, uint16_t lanes) {
}
template
<
typename
T
>
void
TestToCudfCompatibleDLManagedTensor
(
const
platform
::
Place
&
place
,
uint16_t
lanes
)
{
void
TestToDLManagedTensor
(
const
platform
::
Place
&
place
,
uint16_t
lanes
)
{
DDim
dims
{
6
,
7
};
Tensor
tensor
;
tensor
.
Resize
(
dims
);
...
...
@@ -92,8 +95,7 @@ void TestToCudfCompatibleDLManagedTensor(const platform::Place &place,
DLPackTensor
dlpack_tensor
(
tensor
,
lanes
);
::
DLManagedTensor
*
dl_managed_tensor
=
dlpack_tensor
.
ToCudfCompatibleDLManagedTensor
();
::
DLManagedTensor
*
dl_managed_tensor
=
dlpack_tensor
.
ToDLManagedTensor
();
CHECK_EQ
(
dl_managed_tensor
->
manager_ctx
==
nullptr
,
true
);
...
...
@@ -101,7 +103,8 @@ void TestToCudfCompatibleDLManagedTensor(const platform::Place &place,
CHECK_EQ
(
dims
[
i
],
dl_managed_tensor
->
dl_tensor
.
shape
[
i
]);
}
CHECK_EQ
(
dl_managed_tensor
->
dl_tensor
.
strides
[
0
]
==
1
,
true
);
CHECK_EQ
(
dl_managed_tensor
->
dl_tensor
.
strides
[
0
]
==
7
,
true
);
CHECK_EQ
(
dl_managed_tensor
->
dl_tensor
.
strides
[
1
]
==
1
,
true
);
dl_managed_tensor
->
deleter
(
dl_managed_tensor
);
}
...
...
@@ -122,7 +125,7 @@ void TestMainLoop() {
for
(
auto
&
p
:
places
)
{
for
(
auto
&
l
:
lanes
)
{
TestMain
<
T
>
(
p
,
l
);
TestTo
CudfCompatible
DLManagedTensor
<
T
>
(
p
,
l
);
TestToDLManagedTensor
<
T
>
(
p
,
l
);
}
}
}
...
...
paddle/fluid/framework/tensor_util.cc
浏览文件 @
31a5829a
...
...
@@ -1065,6 +1065,9 @@ void* GetDstPtrByDLDataType(DLDataType type, framework::Tensor* dst,
if
(
type
.
code
==
kDLFloat
)
return
static_cast
<
void
*>
(
dst
->
mutable_data
<
paddle
::
platform
::
float16
>
(
dst_place
));
if
(
type
.
code
==
kDLBfloat
)
return
static_cast
<
void
*>
(
dst
->
mutable_data
<
paddle
::
platform
::
bfloat16
>
(
dst_place
));
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"DLDataType code <%d> is illegal when DLDataType.bits is <%d>."
,
type
.
code
,
type
.
bits
));
...
...
@@ -1081,6 +1084,16 @@ void* GetDstPtrByDLDataType(DLDataType type, framework::Tensor* dst,
return
static_cast
<
void
*>
(
dst
->
mutable_data
<
int64_t
>
(
dst_place
));
if
(
type
.
code
==
kDLFloat
)
return
static_cast
<
void
*>
(
dst
->
mutable_data
<
double
>
(
dst_place
));
if
(
type
.
code
==
kDLComplex
)
return
static_cast
<
void
*>
(
dst
->
mutable_data
<
paddle
::
platform
::
complex
<
float
>>
(
dst_place
));
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"DLDataType code <%d> is illegal when DLDataType.bits is <%d>."
,
type
.
code
,
type
.
bits
));
case
128
:
if
(
type
.
code
==
kDLComplex
)
return
static_cast
<
void
*>
(
dst
->
mutable_data
<
paddle
::
platform
::
complex
<
double
>>
(
dst_place
));
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"DLDataType code <%d> is illegal when DLDataType.bits is <%d>."
,
type
.
code
,
type
.
bits
));
...
...
@@ -1107,15 +1120,15 @@ void TensorFromDLPack(const ::DLTensor& dl_tensor, framework::Tensor* dst) {
auto
src_ptr
=
static_cast
<
const
void
*>
(
dl_tensor
.
data
);
auto
size
=
paddle
::
framework
::
product
(
vddim
)
*
type
.
bits
/
8
;
if
(
dl_tensor
.
ctx
.
device_type
==
kDLCPU
)
{
if
(
dl_tensor
.
device
.
device_type
==
kDLCPU
)
{
memory
::
Copy
(
dst_place
,
dst_ptr
,
src_place
,
src_ptr
,
size
);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
dl_tensor
.
ctx
.
device_type
==
kDLGPU
)
{
if
(
dl_tensor
.
device
.
device_type
==
kDLGPU
)
{
platform
::
CUDAPlace
dst_place
=
platform
::
CUDAPlace
(
dl_tensor
.
ctx
.
device_id
);
platform
::
CUDAPlace
(
dl_tensor
.
device
.
device_id
);
platform
::
CUDAPlace
src_place
=
platform
::
CUDAPlace
(
dl_tensor
.
ctx
.
device_id
);
platform
::
CUDAPlace
(
dl_tensor
.
device
.
device_id
);
dst_ptr
=
GetDstPtrByDLDataType
(
type
,
dst
,
dst_place
);
auto
*
ctx
=
platform
::
DeviceContextPool
::
Instance
().
GetByPlace
(
dst_place
);
memory
::
Copy
(
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
31a5829a
...
...
@@ -537,11 +537,11 @@ PYBIND11_MODULE(core_noavx, m) {
DLTensor
dl
=
dmt
->
dl_tensor
;
framework
::
Tensor
tensor
;
if
(
dl
.
ctx
.
device_type
==
kDLCPU
)
{
if
(
dl
.
device
.
device_type
==
kDLCPU
)
{
paddle
::
framework
::
TensorFromDLPack
(
dl
,
&
tensor
);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
dl
.
ctx
.
device_type
==
kDLGPU
)
{
if
(
dl
.
device
.
device_type
==
kDLGPU
)
{
paddle
::
framework
::
TensorFromDLPack
(
dl
,
&
tensor
);
}
#endif
...
...
@@ -776,8 +776,7 @@ PYBIND11_MODULE(core_noavx, m) {
.
def
(
"_to_dlpack"
,
[](
framework
::
Tensor
&
self
)
{
DLPackTensor
dlpack_tensor
(
self
,
1
);
DLManagedTensor
*
dmt
=
dlpack_tensor
.
ToCudfCompatibleDLManagedTensor
();
DLManagedTensor
*
dmt
=
dlpack_tensor
.
ToDLManagedTensor
();
auto
capsule
=
py
::
capsule
(
static_cast
<
void
*>
(
dmt
),
"dltensor"
,
[](
PyObject
*
ptr
)
{
if
(
ptr
)
{
...
...
python/paddle/tests/test_dlpack.py
浏览文件 @
31a5829a
...
...
@@ -22,6 +22,7 @@ import paddle.fluid.core as core
class
TestDLPack
(
unittest
.
TestCase
):
def
test_dlpack_dygraph
(
self
):
paddle
.
disable_static
()
tensor
=
paddle
.
to_tensor
(
np
.
array
([
1
,
2
,
3
,
4
]).
astype
(
'int'
))
dlpack
=
paddle
.
utils
.
dlpack
.
to_dlpack
(
tensor
)
out_from_dlpack
=
paddle
.
utils
.
dlpack
.
from_dlpack
(
dlpack
)
...
...
@@ -31,6 +32,15 @@ class TestDLPack(unittest.TestCase):
np
.
array
(
out_from_dlpack
),
np
.
array
([
1
,
2
,
3
,
4
]).
astype
(
'int'
)))
def
test_dlpack_tensor_larger_than_2dim
(
self
):
paddle
.
disable_static
()
numpy_data
=
np
.
random
.
randn
(
4
,
5
,
6
)
t
=
paddle
.
to_tensor
(
numpy_data
)
# TODO: There may be a reference count problem of to_dlpack.
dlpack
=
paddle
.
utils
.
dlpack
.
to_dlpack
(
t
)
out
=
paddle
.
utils
.
dlpack
.
from_dlpack
(
dlpack
)
self
.
assertTrue
(
np
.
allclose
(
numpy_data
,
out
.
numpy
()))
def
test_dlpack_static
(
self
):
paddle
.
enable_static
()
tensor
=
fluid
.
create_lod_tensor
(
...
...
@@ -57,6 +67,37 @@ class TestDLPack(unittest.TestCase):
np
.
array
(
gout_from_dlpack
),
np
.
array
([[
1
],
[
2
],
[
3
],
[
4
]]).
astype
(
'int'
)))
def
test_dlpack_dtype_conversion
(
self
):
paddle
.
disable_static
()
# DLpack does not explicitly support bool data type.
dtypes
=
[
"float16"
,
"float32"
,
"float64"
,
"int8"
,
"int16"
,
"int32"
,
"int64"
,
"uint8"
,
]
data
=
np
.
ones
((
2
,
3
,
4
))
for
dtype
in
dtypes
:
x
=
paddle
.
to_tensor
(
data
,
dtype
=
dtype
)
dlpack
=
paddle
.
utils
.
dlpack
.
to_dlpack
(
x
)
o
=
paddle
.
utils
.
dlpack
.
from_dlpack
(
dlpack
)
self
.
assertEqual
(
x
.
dtype
,
o
.
dtype
)
self
.
assertTrue
(
np
.
allclose
(
x
.
numpy
(),
o
.
numpy
()))
complex_dtypes
=
[
"complex64"
,
"complex128"
]
for
dtype
in
complex_dtypes
:
x
=
paddle
.
to_tensor
(
[[
1
+
6j
,
2
+
5j
,
3
+
4j
],
[
4
+
3j
,
5
+
2j
,
6
+
1j
]],
dtype
=
dtype
)
dlpack
=
paddle
.
utils
.
dlpack
.
to_dlpack
(
x
)
o
=
paddle
.
utils
.
dlpack
.
from_dlpack
(
dlpack
)
self
.
assertEqual
(
x
.
dtype
,
o
.
dtype
)
self
.
assertTrue
(
np
.
allclose
(
x
.
numpy
(),
o
.
numpy
()))
class
TestRaiseError
(
unittest
.
TestCase
):
def
test_from_dlpack_raise_type_error
(
self
):
...
...
python/paddle/utils/dlpack.py
浏览文件 @
31a5829a
...
...
@@ -28,7 +28,9 @@ def to_dlpack(x):
Encodes a tensor to DLPack.
Args:
x (Tensor): A tensor, and the data type is bool, float32, float64, int32, int64.
x (Tensor): The input tensor, and the data type can be `bool`, `float16`, `float32`,
`float64`, `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`,
`complex128`.
Returns:
dltensor, and the data type is PyCapsule.
...
...
@@ -51,19 +53,9 @@ def to_dlpack(x):
"The type of 'x' in to_dlpack must be paddle.Tensor,"
" but received {}."
.
format
(
type
(
x
)))
dtype
=
convert_dtype
(
x
.
dtype
)
if
dtype
not
in
[
'bool'
,
'int32'
,
'int64'
,
'float32'
,
'float64'
]:
raise
TypeError
(
"the dtype of 'x' in to_dlpack must be any of [bool, int32, int64, "
"float32, float64], but received {}."
.
format
(
dtype
))
return
x
.
value
().
get_tensor
().
_to_dlpack
()
check_type
(
x
,
'x'
,
(
LoDTensor
),
'to_dlpack'
)
check_dtype
(
x
.
_dtype
(),
'x'
,
[
'bool'
,
'int32'
,
'int64'
,
'float32'
,
'float64'
],
'to_dlpack'
)
return
x
.
_to_dlpack
()
...
...
@@ -75,7 +67,9 @@ def from_dlpack(dlpack):
dlpack (PyCapsule): a PyCapsule object with the dltensor.
Returns:
out (Tensor): a tensor decoded from DLPack.
out (Tensor): a tensor decoded from DLPack. One thing to be noted, if we get
an input dltensor with data type as `bool`, we return the decoded
tensor as `uint8`.
Examples:
.. code-block:: python
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录