未验证 提交 48ec02f1 编写于 作者: S Siming Dai 提交者: GitHub

Add public api for dlpack. (#35620)

上级 787209f7
...@@ -17,6 +17,11 @@ from ..fluid.core import LoDTensor ...@@ -17,6 +17,11 @@ from ..fluid.core import LoDTensor
from ..fluid.framework import in_dygraph_mode from ..fluid.framework import in_dygraph_mode
from ..fluid.data_feeder import check_type, check_dtype, convert_dtype from ..fluid.data_feeder import check_type, check_dtype, convert_dtype
__all__ = [
'to_dlpack',
'from_dlpack',
]
def to_dlpack(x): def to_dlpack(x):
""" """
...@@ -63,7 +68,8 @@ def to_dlpack(x): ...@@ -63,7 +68,8 @@ def to_dlpack(x):
def from_dlpack(dlpack): def from_dlpack(dlpack):
"""Decodes a DLPack to a tensor. """
Decodes a DLPack to a tensor.
Args: Args:
dlpack (PyCapsule): a PyCapsule object with the dltensor. dlpack (PyCapsule): a PyCapsule object with the dltensor.
...@@ -82,8 +88,8 @@ def from_dlpack(dlpack): ...@@ -82,8 +88,8 @@ def from_dlpack(dlpack):
x = paddle.utils.dlpack.from_dlpack(dlpack) x = paddle.utils.dlpack.from_dlpack(dlpack)
print(x) print(x)
# Tensor(shape=[2, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True, # Tensor(shape=[2, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
[[0.20000000, 0.30000001, 0.50000000, 0.89999998], # [[0.20000000, 0.30000001, 0.50000000, 0.89999998],
[0.10000000, 0.20000000, 0.60000002, 0.69999999]]) # [0.10000000, 0.20000000, 0.60000002, 0.69999999]])
""" """
t = type(dlpack) t = type(dlpack)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册