dlpack.py 3.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from ..fluid.core import LoDTensor
from ..fluid.framework import in_dygraph_mode
from ..fluid.data_feeder import check_type, check_dtype, convert_dtype


def to_dlpack(x):
    """
    Encodes a tensor to DLPack.

    Args:
        x (Tensor): A tensor, and the data type is bool, float32, float64, int32, int64.

    Returns:
        dltensor, and the data type is PyCapsule.
    
    Examples:
        .. code-block:: python

            import paddle
            # x is a tensor with shape [2, 4]
            x = paddle.to_tensor([[0.2, 0.3, 0.5, 0.9],
                                  [0.1, 0.2, 0.6, 0.7]])
            dlpack = paddle.utils.dlpack.to_dlpack(x)
            print(dlpack)
            # <capsule object "dltensor" at 0x7f6103c681b0>
    """

    if in_dygraph_mode():
        if not isinstance(x, paddle.Tensor):
            raise TypeError(
                "The type of 'x' in to_dlpack must be paddle.Tensor,"
                " but received {}.".format(type(x)))

        dtype = convert_dtype(x.dtype)

        if dtype not in ['bool', 'int32', 'int64', 'float32', 'float64']:
            raise TypeError(
                "the dtype of 'x' in to_dlpack must be any of [bool, int32, int64, "
                "float32, float64], but received {}.".format(dtype))

        return x.value().get_tensor()._to_dlpack()

    check_type(x, 'x', (LoDTensor), 'to_dlpack')
    check_dtype(x._dtype(), 'x',
                ['bool', 'int32', 'int64', 'float32', 'float64'], 'to_dlpack')

    return x._to_dlpack()


def from_dlpack(dlpack):
    """Decodes a DLPack to a tensor.
    
    Args:
        dlpack (PyCapsule): a PyCapsule object with the dltensor.

    Returns:
        out (Tensor): a tensor decoded from DLPack.

    Examples:
        .. code-block:: python

            import paddle
            # x is a tensor with shape [2, 4]
            x = paddle.to_tensor([[0.2, 0.3, 0.5, 0.9],
                                  [0.1, 0.2, 0.6, 0.7]])
            dlpack = paddle.utils.dlpack.to_dlpack(x)
            x = paddle.utils.dlpack.from_dlpack(dlpack)
            print(x)
            # Tensor(shape=[2, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
              [[0.20000000, 0.30000001, 0.50000000, 0.89999998],
              [0.10000000, 0.20000000, 0.60000002, 0.69999999]]) 
    """

    t = type(dlpack)
    dlpack_flag = (t.__module__ == 'builtins' and t.__name__ == 'PyCapsule')
    if not dlpack_flag:
        raise TypeError(
            "The type of 'dlpack' in from_dlpack must be PyCapsule object,"
            " but received {}.".format(type(dlpack)))

    if in_dygraph_mode():
        out = paddle.fluid.core.from_dlpack(dlpack)
        out = paddle.to_tensor(out)
        return out

    out = paddle.fluid.core.from_dlpack(dlpack)
    return out