creation.py 11.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import numpy as np

17
import paddle
18
from paddle import _C_ops, in_dynamic_mode
19
from paddle.fluid.data_feeder import convert_dtype
20 21 22 23 24 25
from paddle.fluid.framework import (
    _current_expected_place,
    _get_paddle_place,
    core,
    dygraph_only,
)
26
from paddle.fluid.layer_helper import LayerHelper
27
from paddle.tensor import max, to_tensor
28

29 30 31 32 33 34 35 36 37 38 39 40 41
__all__ = [
    'sparse_coo_tensor',
    'sparse_csr_tensor',
]


def _handle_dtype(data, dtype):
    if dtype:
        if convert_dtype(dtype) != convert_dtype(data.dtype):
            return data.astype(convert_dtype(dtype))
    return data


42
def _infer_dense_shape(indices, values):
43 44 45
    assert len(indices.shape) == 2
    lens = max(indices, axis=1)
    lens = lens + 1
46 47 48 49
    lens = lens.numpy()
    if len(values.shape) > 1:
        lens = np.append(lens, values.shape[1:])
    return list(lens)
50 51


52 53 54 55
def _get_place(place):
    place = _get_paddle_place(place)
    if place is None:
        place = _current_expected_place()
56
    elif not isinstance(
57 58
        place, (core.Place, core.CPUPlace, core.CUDAPinnedPlace, core.CUDAPlace)
    ):
59 60 61 62 63 64
        raise ValueError(
            "'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace"
        )
    return place


65 66 67 68 69 70 71
def _check_indices_dtype(dtype):
    if dtype not in [paddle.int8, paddle.int16, paddle.int32, paddle.int64]:
        raise TypeError(
            "the dtype of indices must be 'int8' or 'int16' or 'int32' or 'int64'"
        )


72 73 74
def sparse_coo_tensor(
    indices, values, shape=None, dtype=None, place=None, stop_gradient=True
):
75
    r"""
76
    Constructs a sparse ``paddle.Tensor`` in coordinate format according to the indices
77 78 79 80 81 82 83 84
    and values of the specified non-zero elements.

    Args:
        indices(list|tuple|ndarray|Tensor): the indices of non-zero elements.
            Can be a list, tuple, numpy\.ndarray, paddle\.Tensor. The indices must be 2-D.
        values(list|tuple|ndarray|Tensor): Initial values for the tensor.
            Can be a scalar, list, tuple, numpy\.ndarray, paddle\.Tensor.
        shape(list|tuple, optional): The shape of the sparse tensor also represents the shape of
85
            original dense tensor. If not provided the smallest shape will be inferred to
86
            hold all elements.
87
        dtype(str|np.dtype, optional): The desired data type of returned tensor. Can be 'bool' , 'float16' ,
88
            'float32' , 'float64' , 'int8' , 'int16' , 'int32' , 'int64' , 'uint8',
89
            'complex64' , 'complex128'. Default: None, infers dtype from ``data``
90
            except for python float number which gets dtype from ``get_default_type`` .
91 92 93
        place(CPUPlace|CUDAPinnedPlace|CUDAPlace|str, optional): The place to allocate Tensor. Can be
            CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place. If ``place`` is
            string, It can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, where ``x`` is the index of the GPUs.
94 95 96 97 98 99 100 101 102 103
        stop_gradient(bool, optional): Whether to block the gradient propagation of Autograd. Default: True.

    Returns:
        Tensor: A Tensor constructed from ``indices`` and ``values`` .

    Examples:

    .. code-block:: python

        import paddle
Z
zhangkaihuo 已提交
104 105 106 107 108 109 110 111 112 113

        indices = [[0, 1, 2], [1, 2, 0]]
        values = [1.0, 2.0, 3.0]
        dense_shape = [3, 3]
        coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
        # print(coo)
        # Tensor(shape=[2, 3], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True,
        #       indices=[[0, 1, 2],
        #                [1, 2, 0]],
        #       values=[1., 2., 3.])
114 115
    """

116 117
    if in_dynamic_mode():
        place = _get_place(place)
118

119
        if not isinstance(indices, core.eager.Tensor):
120 121 122
            indices = to_tensor(
                indices, dtype=None, place=place, stop_gradient=True
            )
123 124 125 126
        if not isinstance(values, core.eager.Tensor):
            values = to_tensor(values, dtype, place, stop_gradient)
        if len(indices.shape) != 2:
            raise ValueError("'indices' must be 2-D.")
127

128 129
        nnz = indices.shape[1]
        sparse_dim = indices.shape[0]
130

131
        _check_indices_dtype(indices.dtype)
132

133 134
        if nnz != values.shape[0]:
            raise ValueError(
135 136 137 138
                "the indices and values must have same number of non-zero, but get {} and {}".format(
                    nnz, values.shape[0]
                )
            )
139

140
        dense_dim = len(values.shape) - 1
141

142 143
        if not indices.place._equals(place):
            indices = indices._copy_to(place, False)
144

145 146 147 148
        if not values.place._equals(place):
            values = values._copy_to(place, False)
        values = _handle_dtype(values, dtype)
        values.stop_gradient = stop_gradient
149

150
        min_shape = _infer_dense_shape(indices, values)
151

152 153 154
        if shape is None:
            shape = min_shape
        else:
155
            shape = list(shape)
156 157 158
            if shape < min_shape:
                raise ValueError(
                    "the minimun shape required is {}, but get {}".format(
159 160 161
                        min_shape, shape
                    )
                )
162 163
            if len(shape) != sparse_dim + dense_dim:
                raise ValueError(
164 165 166 167
                    "the number of dimensions(len(shape) must be sparse_dim({}) + dense_dim({}), but get {}".format(
                        sparse_dim, dense_dim, len(shape)
                    )
                )
168

169 170 171 172 173 174 175
        return _C_ops.sparse_sparse_coo_tensor(values, indices, shape)

    else:
        op_type = 'sparse_sparse_coo_tensor'
        inputs = {'values': values, 'indices': indices}
        if shape[0] is None:
            shape[0] = -1
176
        attrs = {'shape': shape}
177 178
        helper = LayerHelper(op_type)
        out = helper.create_sparse_variable_for_type_inference(dtype)
179 180 181
        helper.append_op(
            type=op_type, inputs=inputs, outputs={'out': out}, attrs=attrs
        )
182
        return out
183 184


185
# TODO: need to support shape is None
186
@dygraph_only
187 188 189
def sparse_csr_tensor(
    crows, cols, values, shape, dtype=None, place=None, stop_gradient=True
):
190
    r"""
191
    Constructs a sparse ``paddle.Tensor`` in CSR(Compressed Sparse Row) format according to the
192
    ``crows``, ``cols`` and ``values``.
193
    Currently, the crows and cols of each batch must be incrementd.
194 195

    Args:
196 197 198
        crows(list|tuple|ndarray|Tensor): 1-D array, each element in the rows represents the
            starting position of the first non-zero element of each row in values.
            Can be a list, tuple, numpy\.ndarray, paddle\.Tensor.
199
        cols(list|tuple|ndarray|Tensor): 1-D array, the column of non-zero elements.
200
            Can be a list, tuple, numpy\.ndarray, paddle\.Tensor.
201 202 203
        values(list|tuple|ndarray|Tensor): 1-D array, the non-zero elements.
            Can be a scalar, list, tuple, numpy\.ndarray, paddle\.Tensor.
        shape(list|tuple, optional): The shape of the sparse tensor also represents the shape of
204
            original dense tensor.
205
            hold all elements.
206
        dtype(str|np.dtype, optional): The desired data type of returned tensor. Can be 'bool' , 'float16' ,
207
            'float32' , 'float64' , 'int8' , 'int16' , 'int32' , 'int64' , 'uint8',
208
            'complex64' , 'complex128'. Default: None, infers dtype from ``data``
209
            except for python float number which gets dtype from ``get_default_type`` .
210 211 212
        place(CPUPlace|CUDAPinnedPlace|CUDAPlace|str, optional): The place to allocate Tensor. Can be
            CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place. If ``place`` is
            string, It can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, where ``x`` is the index of the GPUs.
213 214 215 216 217 218 219
        stop_gradient(bool, optional): Whether to block the gradient propagation of Autograd. Default: True.

    Returns:
        Tensor: A Tensor constructed from ``crows``, ``cols`` and ``values`` .

    Examples:

220 221 222 223 224 225 226 227 228 229 230 231 232 233
        .. code-block:: python

            >>> import paddle

            >>> crows = [0, 2, 3, 5]
            >>> cols = [1, 3, 2, 0, 1]
            >>> values = [1, 2, 3, 4, 5]
            >>> dense_shape = [3, 4]
            >>> csr = paddle.sparse.sparse_csr_tensor(crows, cols, values, dense_shape)
            >>> print(csr)
            Tensor(shape=[3, 4], dtype=paddle.int64, place=Place(cpu), stop_gradient=True,
                  crows=[0, 2, 3, 5],
                  cols=[1, 3, 2, 0, 1],
                  values=[1, 2, 3, 4, 5])
234
    """
235 236 237

    place = _get_place(place)

238 239 240 241 242 243
    if not isinstance(crows, core.eager.Tensor):
        crows = to_tensor(crows, dtype=None, place=place, stop_gradient=True)
    if not isinstance(cols, core.eager.Tensor):
        cols = to_tensor(cols, dtype=None, place=place, stop_gradient=True)
    if not isinstance(values, core.eager.Tensor):
        values = to_tensor(values, dtype, place, stop_gradient)
244 245 246 247 248

    _check_indices_dtype(crows.dtype)
    _check_indices_dtype(cols.dtype)

    if len(shape) != 2 and len(shape) != 3:
249
        raise ValueError(
250 251 252 253
            "SparseCsrTensor only support 2-D or 3-D matrix. but get shape {}".format(
                shape
            )
        )
Z
zhangkaihuo 已提交
254
    rows = shape[len(shape) - 2]
255

256
    if not crows.place._equals(place):
257
        crows = crows._copy_to(place, False)
258 259

    if not cols.place._equals(place):
260
        cols = cols._copy_to(place, False)
261 262

    if not values.place._equals(place):
263 264
        values = values._copy_to(place, False)
    values = _handle_dtype(values, dtype)
265
    values.stop_gradient = stop_gradient
266 267 268 269

    if len(crows.shape) != 1 or len(cols.shape) != 1 or len(values.shape) != 1:
        raise ValueError("The 'crows', 'cols' and 'values' must be 1-D.")

270
    if len(cols) != len(values):
271 272 273
        raise ValueError("the length of cols must be same as length of values")

    if len(shape) == 2:
Z
zhangkaihuo 已提交
274
        if crows.shape[0] != rows + 1:
275
            raise ValueError(
276 277 278 279
                "The length({}) of crows must be equal to the rows({})+1 of matrix.".format(
                    crows.shape[0], rows
                )
            )
280 281 282 283 284
        if crows[0] != 0:
            raise ValueError("the 0th value of crows must be 0")

        if crows[-1] != values.shape[0]:
            raise ValueError(
285 286
                "the last value of crows must be equal the number of non-zero"
            )
287
    else:
Z
zhangkaihuo 已提交
288
        if crows.shape[0] % (rows + 1) != 0:
289
            raise ValueError(
290 291 292 293
                "The length({}) of crows must be divisible the rows({})+1 of matrix.".format(
                    crows.shape[0], rows
                )
            )
294
    # TODO(zkh2016): check whether the value in crows and cols is legal
295

296 297 298
    return core.eager.sparse_csr_tensor(
        crows, cols, values, shape, stop_gradient
    )