tensor.py 39.4 KB
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
from typing import Iterable, Optional, Sequence, Tuple, Union
10 11 12 13

import numpy as np

from ..core._imperative_rt import CompNode
14
from ..core._imperative_rt.core2 import SymbolVar, apply, dtype_promotion
15
from ..core._wrap import as_device
16
from ..core.ops import builtin
17
from ..core.ops.builtin import Copy, Identity
18
from ..core.ops.special import Const
19
from ..core.tensor.array_method import _broadcast, _remove_axis
20
from ..core.tensor.utils import astensor1d, convert_inputs, get_device
21 22
from ..device import get_default_device
from ..tensor import Tensor
23
from .elemwise import ceil
24 25 26

__all__ = [
    "arange",
27
    "broadcast_to",
28 29
    "concat",
    "cond_take",
M
Megvii Engine Team 已提交
30
    "cumsum",
31
    "expand_dims",
32
    "eye",
33
    "flatten",
34 35 36 37 38 39
    "full",
    "full_like",
    "gather",
    "linspace",
    "ones",
    "ones_like",
40
    "repeat",
41
    "reshape",
42
    "roll",
43
    "split",
M
Megvii Engine Team 已提交
44
    "squeeze",
45 46
    "stack",
    "scatter",
47
    "tile",
48
    "copy",
49
    "transpose",
50 51 52 53 54 55
    "where",
    "zeros",
    "zeros_like",
]


56
def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor:
57
    r"""Returns a 2D tensor with ones on the diagonal and zeros elsewhere.
58

59 60 61 62 63
    Args:
        shape: a list, tuple or integer defining the shape of the output tensor.
        dtype: the desired data type of the output tensor. Default: ``float32``.
        device: the desired device of the output tensor. Default: if ``None``,
            use the default device (see :func:`~.megengine.get_default_device`).
64

65 66
    Returns:
        eye matrix.
67

68
    Examples:
69

70
        .. testcode::
71

72 73
            import numpy as np
            import megengine.functional as F
74

75 76
            out = F.eye(4, 6, dtype=np.float32)
            print(out.numpy())
77

78
        Outputs:
79

80
        .. testoutput::
81

82 83 84 85
            [[1. 0. 0. 0. 0. 0.]
             [0. 1. 0. 0. 0. 0.]
             [0. 0. 1. 0. 0. 0.]
             [0. 0. 0. 1. 0. 0.]]
86
    """
87 88 89 90 91 92 93 94 95
    if M is not None:
        if isinstance(N, Tensor) or isinstance(M, Tensor):
            shape = astensor1d((N, M))
        else:
            shape = Tensor([N, M], dtype="int32", device=device)
    elif isinstance(N, Tensor):
        shape = N
    else:
        shape = Tensor(N, dtype="int32", device=device)
96
    op = builtin.Eye(k=0, dtype=dtype, comp_node=device)
97
    (result,) = apply(op, shape)
98 99 100
    return result


101 102 103 104 105 106
def full(
    shape: Union[int, tuple, list],
    value: Union[bool, int, float, Tensor],
    dtype=None,
    device=None,
) -> Tensor:
107
    r"""Creates a tensor of shape ``shape`` filled with ``value``.
108

109
    Args:
110 111 112 113 114 115 116 117 118
        shape: output tensor shape.
        value: fill value.
        dtype: output tensor data type. If ``dtype`` is ``None``, the output tensor
            data type must be inferred from ``value``. If the value is an ``int``,
            the output tensor data type must be the default integer data type. If the
            value is a ``float``, the output tensor data type must be the default
            floating-point data type. If the value is a ``bool``, the output tensor 
            must have boolean data type. Default: ``None``.
        device: device on which to place the created tensor. Default: ``None``.
119

120
    Returns:
121
        a tensor where every element is equal to ``value``.
122

123
    Examples:
124

125
        .. testcode::
126

127 128
            import numpy as np
            import megengine.functional as F
129

130 131
            out = F.full([2,3], 1.5)
            print(out.numpy())
132

133
        Outputs:
134

135
        .. testoutput::
136

137 138
            [[1.5 1.5 1.5]
             [1.5 1.5 1.5]]
139
    """
140

141 142
    if isinstance(shape, int):
        shape = (shape,)
143 144
    if device is None:
        device = get_default_device()
145
    (x,) = Const(value, dtype=dtype, device=device)()
146
    if type(shape) in (list, tuple) and len(shape) == 0:
147
        return x
148
    return broadcast_to(x, shape)
149 150


151 152 153 154 155 156 157
def ones(
    shape: Union[int, Tuple[int, ...]],
    *,
    dtype="float32",
    device: Optional[CompNode] = None
) -> Tensor:
    r"""Returns a new tensor having a specified shape and filled with ones.
158

159
    Args:
160 161 162 163 164
        shape (int or sequence of ints): the shape of the output tensor.

    Keyword args:
        dtype (:attr:`.Tensor.dtype`): output tensor data type. Default: ``float32``.
        device (:attr:`.Tensor.device`): device on which to place the created tensor. Default: ``None``.
165

166
    Returns:
167
        a tensor containing ones.
168

169
    Examples:
170

171
        .. testcode::
172

173
            import megengine.functional as F
174

175 176 177 178 179 180 181
            out = F.ones(5)
            print(out.numpy())
            out = F.ones((5, ), dtype='int32')
            print(out.numpy())
            out = F.ones((2, 2))
            print(out.numpy())
            out = F.ones([2, 1])
182
            print(out.numpy())
183

184
        Outputs:
185

186
        .. testoutput::
187

188 189 190 191
            [1. 1. 1. 1. 1.]
            [1 1 1 1 1]
            [[1. 1.]
             [1. 1.]]
192 193
            [[1.]
             [1.]]
194
    """
195 196 197
    return full(shape, 1.0, dtype=dtype, device=device)


198
def zeros(shape, dtype="float32", device=None) -> Tensor:
199 200 201 202 203 204 205
    r"""Returns a zero tensor with given shape.

    Args:
        shape: a list, tuple or integer defining the shape of the output tensor.
        dtype: the desired data type of the output tensor. Default: ``float32``.
        device: the desired device of the output tensor. Default: if ``None``,
            use the default device (see :func:`~.megengine.get_default_device`).
206
    """
207 208 209
    return full(shape, 0.0, dtype=dtype, device=device)


210
def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
211
    r"""Returns a tensor filled with zeros with the same shape and data type as input tensor.
212 213

    Args:
214
        inp (Tensor): input tensor.
215

216
    Return:
217
        a tensor containing zeros.
218 219

    Examples:
220 221 222 223 224
        >>> input = F.arange(9, dtype='int32').reshape(3,3)
        >>> F.ones_like(input)
        Tensor([[0 0 0]
         [0 0 0]
         [0 0 0]], dtype=int32, device=xpux:0)
225
    """
226
    return full_like(inp, 0.0)
227 228


229
def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
230
    r"""Returns a ones tensor with the same shape as input tensor.
231

232 233
    Args:
        inp: input tensor.
234

235 236
    Return:
        output tensor.
237

238
    Examples:
239

240
        .. testcode::
241

242 243 244
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
245

246 247 248
            inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
            out = F.ones_like(inp)
            print(out.numpy())
249

250
        Outputs:
251

252
        .. testoutput::
253

254 255
            [[1 1 1]
             [1 1 1]]
256
    """
257
    return full_like(inp, 1.0)
258 259


260 261 262
def full_like(
    inp: Union[Tensor, SymbolVar], value: Union[int, float]
) -> Union[Tensor, SymbolVar]:
263
    r"""Returns a tensor filled with given value with the same shape as input tensor.
264

265 266 267 268 269 270
    Args:
        inp: input tensor.
        value: target value.

    Return:
        output tensor.
271 272 273

    Examples:

274 275 276 277 278
        .. testcode::

            import numpy as np
            from megengine import tensor
            import megengine.functional as F
279

280 281 282
            inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
            out = F.full_like(inp, 2)
            print(out.numpy())
283

284
        Outputs:
285

286
        .. testoutput::
287

288 289
            [[2 2 2]
             [2 2 2]]
290

291
    """
292
    (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp)
293
    if inp.ndim == 0:
294 295
        return x
    return broadcast_to(x, inp.shape)
296 297


298
def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
299
    r"""Broadcasts a tensor to given shape.
300

301 302 303
    Args:
        inp: input tensor.
        shape: target shape.
304

305 306
    Returns:
        output tensor.
307

308
    Examples:
309

310
        .. testcode::
311

312 313 314
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
315

316 317 318
            data = tensor(np.arange(0, 3, dtype=np.float32).reshape(3))
            out = F.broadcast_to(data, (2, 3))
            print(out.numpy())
319

320
        Outputs:
321

322
        .. testoutput::
323

324 325
            [[0. 1. 2.]
             [0. 1. 2.]]
326
    """
327
    return _broadcast(inp, shape)
328 329


330
def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
331
    r"""Concat some tensors
332

333 334 335 336
    Args:
        inps: input tensors to concat.
        axis: over which dimension the tensors are concatenated. Default: 0
        device: which device output will be. Default: None
337

338 339
    Returns:
        output tensor.
340

341
    Examples:
342

343
        .. testcode::
344

345 346 347
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
348

349 350 351 352
            data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
            data2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
            out = F.concat([data1, data2])
            print(out.numpy())
353

354
        Outputs:
355

356
        .. testoutput::
357

358 359 360 361
            [[ 0.  1.  2.]
             [ 3.  4.  5.]
             [ 6.  7.  8.]
             [ 9. 10. 11.]]
362
    """
363 364 365
    if len(inps) == 1:
        return inps[0]

366
    # FIXME: remove this convert_inputs
367
    inps = convert_inputs(*inps, device=device)
368 369 370
    if device is None:
        device = get_device(inps)
    device = as_device(device)
371 372 373 374
    (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps)
    return result


375
def stack(inps, axis=0, device=None):
376
    r"""Concats a sequence of tensors along a new axis.
377 378
    The input tensors must have the same shape.

379 380 381 382
    Args:
        inps: input tensors.
        axis: which axis will be concatenated.
        device: the device output will be. Default: None
383

384 385
    Returns:
        output concatenated tensor.
386

387
    Examples:
388

389
        .. testcode::
390

391 392 393
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
394

395 396 397 398
            x1 = tensor(np.arange(0, 3, dtype=np.float32).reshape((3)))
            x2 = tensor(np.arange(6, 9, dtype=np.float32).reshape((3)))
            out = F.stack([x1, x2], axis=0)
            print(out.numpy())
399

400
        Outputs:
401

402
        .. testoutput::
403

404 405
            [[0. 1. 2.]
             [6. 7. 8.]]
406
    """
407 408 409 410
    if len(inps) > 0 and not isinstance(inps[0].shape, inps[0].__class__):
        shapes = {arr.shape for arr in inps}
        if len(shapes) != 1:
            raise ValueError("All input tensors must have the same shape")
411

412
    inps = [expand_dims(inp, axis=axis) for inp in inps]
413
    return concat(inps, axis=axis, device=device)
414 415 416


def split(inp, nsplits_or_sections, axis=0):
417
    r"""Splits the input tensor into several smaller tensors.
418 419
    When nsplits_or_sections is int, the last tensor may be smaller than others.

420 421 422 423
    Args:
        inp: input tensor.
        nsplits_or_sections: number of sub tensors or sections information list.
        axis: which axis will be splited.
424

425 426
    Returns:
        output tensor list.
427

428
    Examples:
429

430
        .. testcode::
431

432 433 434 435
            import os
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
436

437 438 439
            x = tensor(np.random.random((10, 20)), dtype=np.float32)
            y = F.split(x, 3)
            z = F.split(x, [6, 17], axis=1)
440

441 442
            print([i.numpy().shape for i in y])
            print([i.numpy().shape for i in z])
443

444
        Outputs:
445

446
        .. testoutput::
447

448 449
            [(4, 20), (3, 20), (3, 20)]
            [(10, 6), (10, 11), (10, 3)]
450
    """
451 452 453 454 455 456
    ndim = len(inp.shape)
    if axis >= ndim:
        raise ValueError("Invalid axis {}".format(axis))

    Ntotal = inp.shape[axis]

457
    if isinstance(nsplits_or_sections, Sequence):
458 459
        Nsections = len(nsplits_or_sections) + 1
        is_array = True
460
    else:
461 462 463 464
        Nsections = int(nsplits_or_sections)
        is_array = False

    if is_array:
465
        partitions = []
466 467
        div_points = [0] + list(nsplits_or_sections) + [Ntotal]
        for i in range(1, len(div_points)):
468
            if div_points[i - 1] > div_points[i]:
469 470 471
                raise ValueError(
                    "Invalid nsplits_or_secions: {}".format(nsplits_or_sections)
                )
472
            partitions.append(div_points[i] - div_points[i - 1])
473 474 475 476 477 478 479 480 481
    else:  # scalar
        if Nsections <= 0:
            raise ValueError("Number sections must be larger than 0")
        if Nsections > Ntotal:
            raise ValueError(
                "The size {} at dim {} cannot be split into {} sections".format(
                    Ntotal, axis, Nsections
                )
            )
482 483 484 485 486 487 488 489 490 491 492 493 494
        partitions = []
        for i in range(Nsections):
            section_size = (Ntotal + Nsections - i - 1) // Nsections
            partitions.append(section_size)

    partitions = [
        part
        if isinstance(part, (SymbolVar, Tensor))
        else Const(part, dtype="int32", device=inp.device)(inp)[0]
        for part in partitions
    ]
    op = builtin.Split(axis=axis)
    return apply(op, inp, *partitions)
495 496 497 498 499 500 501 502 503 504 505 506 507


def _get_idx(index, axis):
    index_dims = len(index.shape)
    idx = []
    for i in range(index_dims):
        if i != axis:
            shape = [1] * index_dims
            shape[i] = index.shape[i]
            arange = linspace(
                0, index.shape[i] - 1, index.shape[i], device=index.device,
            )
            arange = (
508
                broadcast_to(arange.reshape(*shape), index.shape)
509 510 511 512 513 514 515 516 517 518
                .reshape(-1)
                .astype(np.int32)
            )
            idx.append(arange)
        else:
            idx.append(index.reshape(-1))
    return tuple(idx)


def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
519
    # TODO: rewrite doc
520 521
    r"""
    Gathers data from input tensor on axis using index.
522

523
    For a 3-D tensor, the output is specified by:
524

525 526 527 528 529
    .. code-block::

       out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0
       out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1
       out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2
530

M
Megvii Engine Team 已提交
531
    if input tensor is a n-dimensional tensor with size
532
    :math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i,
M
Megvii Engine Team 已提交
533
    then index must be a n-dimensional tensor with size
534
    :math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and
535
    output will have the same size as index.
536

537 538 539 540
    Args:
        inp: input tensor.
        axis: along which axis to index.
        index: indices of elements to gather.
541

542 543
    Return:
        output tensor.
544

545
    Examples:
546

547
        .. testcode::
548

549 550
            import megengine.functional as F
            from megengine import tensor
551

552 553 554 555 556 557
            inp = tensor([
                [1,2], [3,4], [5,6],
            ])
            index = tensor([[0,2], [1,0]])
            oup = F.gather(inp, 0, index)
            print(oup.numpy())
558

559
        Outputs:
560

561
        .. testoutput::
562

563 564
            [[1 6]
             [3 2]]
565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595
    """
    input_shape = inp.shape
    index_shape = index.shape
    input_dims = len(input_shape)
    index_dims = len(index_shape)
    if input_dims != index_dims:
        raise ValueError(
            "The index tensor must have same dimensions as input tensor, "
            "But the input dims:{}, the index dims:{}".format(input_dims, index_dims)
        )

    if axis < 0 or axis >= input_dims:
        raise ValueError(
            "Index axis {} is output of bounds, should in range [0 {})".format(
                axis, input_dims
            )
        )

    for i in range(input_dims):
        if i != axis and input_shape[i] != index_shape[i]:
            raise ValueError(
                "The input {} and index {} must have the same size apart from axis {}".format(
                    input_shape, index_shape, axis
                )
            )

    idx = _get_idx(index, axis)
    return inp[idx].reshape(index.shape)  # pylint: disable=no-member


def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
596
    # TODO: rewrite doc
597 598
    r"""
    Writes all values from the tensor source into input tensor
599
    at the indices specified in the index tensor.
600

601 602 603
    For each value in source, its output index is specified by its index
    in source for ``axis != dimension`` and by the corresponding value in
    index for ``axis = dimension``.
604

605 606 607
    For a 3-D tensor, input tensor is updated as:

    .. code-block::
608

609 610 611
       inp[index[i][j][k]][j][k] = source[i][j][k]  # if axis == 0
       inp[i][index[i][j][k]][k] = source[i][j][k]  # if axis == 1
       inp[i][j][index[i][j][k]] = source[i][j][k]  # if axis == 2
612

M
Megvii Engine Team 已提交
613
    ``inp``, ``index`` and ``source`` should have same number of dimensions.
614 615 616 617

    It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)``
    for all dimensions ``d``.

618
    Moreover, the values of index must be between ``0`` and ``inp.shape(axis) - 1`` inclusive.
619

620
    Note:
621
        Please notice that, due to performance issues, the result is uncertain on the GPU device
M
Megvii Engine Team 已提交
622
        if scattering different positions from source to the same destination position
623 624
        regard to index tensor.

M
Megvii Engine Team 已提交
625
        Check the following examples, the oup[0][2] is maybe
626 627 628
        from source[0][2] which value is 0.2256 or source[1][2] which value is 0.5339
        if set the index[1][2] from 1 to 0.

629 630 631 632 633
    Args:
        inp: inp tensor which to be scattered.
        axis: axis along which to index.
        index: indices of elements to scatter.
        source: source element(s) to scatter.
634

635 636
    Return:
        output tensor.
637

638
    Examples:
639

640
        .. testcode::
641

642 643 644
            import numpy as np
            import megengine.functional as F
            from megengine import tensor
645

646 647 648 649 650
            inp = tensor(np.zeros(shape=(3,5),dtype=np.float32))
            source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]])
            index = tensor([[0,2,0,2,1],[2,0,1,1,2]])
            oup = F.scatter(inp, 0, index,source)
            print(oup.numpy())
651

652
        Outputs:
653

654
        .. testoutput::
655

656 657 658
            [[0.9935 0.0718 0.2256 0.     0.    ]
             [0.     0.     0.5939 0.357  0.4396]
             [0.7723 0.9465 0.     0.8926 0.4576]]
659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706
    """
    input_shape = inp.shape
    index_shape = index.shape
    source_shape = source.shape
    input_dims = len(input_shape)
    index_dims = len(index_shape)
    source_dims = len(source_shape)

    if input_dims != index_dims or input_dims != source_dims:
        raise ValueError("The input, source and index tensor must have same dimensions")

    if axis < 0 or axis >= input_dims:
        raise ValueError(
            "Index axis {} is output of bounds, should in range [0 {})".format(
                axis, input_dims
            )
        )

    for i in range(source_dims):
        if source_shape[i] > input_shape[i]:
            raise ValueError(
                "The each shape size for source {} must be less than or equal to input {} ".format(
                    source_shape, input_shape
                )
            )

    for i in range(index_dims):
        if index_shape[i] != source_shape[i]:
            raise ValueError(
                "The each shape size for index {} must be equal to source {} ".format(
                    index_shape, source_shape
                )
            )

    for i in range(index_dims):
        if i != axis and index_shape[i] > input_shape[i]:
            raise ValueError(
                "The index {} must be less than or equal to input {} size apart from axis {}".format(
                    index_shape, input_shape, axis
                )
            )

    idx = _get_idx(index, axis)
    inp[idx] = source.flatten()
    return inp


def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
707
    r"""Selects elements either from Tensor x or Tensor y, according to mask.
708 709 710 711 712

    .. math::

        \textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i

713 714 715 716 717 718 719
    Args:
        mask: a mask used for choosing ``x`` or ``y``.
        x: first choice.
        y: second choice.

    Returns:
        output tensor.
720 721 722

    Examples:

723
        .. testcode::
724

725 726 727 728 729 730 731 732 733
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
            mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool))
            x = tensor(np.array([[1, np.inf], [np.nan, 4]],
                dtype=np.float32))
            y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32))
            out = F.where(mask, x, y)
            print(out.numpy())
734

735
        Outputs:
736

737
        .. testoutput::
738

739 740
            [[1. 6.]
             [7. 4.]]
741
    """
742

743
    if not isinstance(x, Tensor):
744
        raise TypeError("input x must be a tensor")
745
    if not isinstance(y, Tensor):
746
        raise TypeError("input y must be a tensor")
747
    if not isinstance(mask, Tensor):
748 749 750 751 752 753
        raise TypeError("mask must be a tensor")
    if mask.dtype != np.bool_:
        raise ValueError("mask must be bool")
    if x.device != mask.device:
        raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))

754 755 756 757 758 759
    dtype = dtype_promotion(x, y)
    if x.dtype != dtype:
        x = x.astype(dtype)
    if y.dtype != dtype:
        y = y.astype(dtype)

760 761 762
    v0, index0 = cond_take(mask, x)
    v1, index1 = cond_take(~mask, y)

763
    out = concat([v0, v1])
764 765 766 767 768

    out[index0] = v0
    out[index1] = v1
    out = out.reshape(x.shape)
    return out
769 770 771


def cond_take(mask: Tensor, x: Tensor) -> Tensor:
772
    r"""Takes elements from data if specific condition is satisfied on mask.
773 774 775
    This operator has two outputs: the first is the elements taken,
    and the second is the indices corresponding to those elements;
    they are both 1-dimensional. High-dimension input would first be flattened.
776

777 778 779
    Args:
        mask: condition param; must be the same shape with data.
        x: input tensor from which to take elements.
780 781 782

    Examples:

783
        .. testcode::
784

785 786 787 788 789 790 791 792
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
            mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool_))
            x = tensor(np.array([[1, np.inf], [np.nan, 4]],
                dtype=np.float32))
            v, index = F.cond_take(mask, x)
            print(v.numpy(), index.numpy())
793

794
        Outputs:
795

796
        .. testoutput::
797

798
            [1. 4.] [0 3]
799
    """
800
    if not isinstance(x, (Tensor, SymbolVar)):
801
        raise TypeError("input must be a tensor")
802
    if not isinstance(mask, (Tensor, SymbolVar)):
803 804 805 806 807 808 809 810 811 812 813
        raise TypeError("mask must be a tensor")
    if mask.dtype != np.bool_:
        raise ValueError("mask must be bool")
    if x.device != mask.device:
        raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))

    op = builtin.CondTake()
    v, index = apply(op, x, mask)
    return v, index


814
def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834
    r"""Swaps shapes and strides according to given pattern.

    Args:
        inp: input tensor.
        pattern: a list of integers including 0, 1, ... , ``ndim``-1,
            and any number of ``'x'`` char in dimensions where this tensor should be broadcasted.
            For examples:

            * (``'x'``) -> make a 0d (scalar) into a 1d vector
            * (0, 1) -> identity for 2d vectors
            * (1, 0) -> inverts the first and second dimensions
            * (``'x'``, 0) -> make a row out of a 1d vector (N to 1xN)
            * (0, ``'x'``) -> make a column out of a 1d vector (N to Nx1)
            * (2, 0, 1) -> AxBxC to CxAxB
            * (0, ``'x'``, 1) -> AxB to Ax1xB
            * (1, ``'x'``, 0) -> AxB to Bx1xA
            * (1,) -> this removes dimensions 0. It must be a broadcastable dimension (1xA to A)

    Returns:
        output tensor.
835 836 837

    Examples:

838
        .. testcode::
839

840 841 842 843 844 845
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
            x = tensor(np.array([[1, 1], [0, 0]], dtype=np.int32))
            out = F.transpose(x, (1, 0))
            print(out.numpy())
846

847
        Outputs:
848

849
        .. testoutput::
850

851 852
            [[1 0]
            [1 0]]
853
    """
854
    return inp.transpose(list(-1 if _ == "x" else _ for _ in pattern))
855 856 857


def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
858
    r"""Reshapes a tensor to given target shape; total number of logical elements must
859 860
    remain unchanged

861 862 863
    Args:
        inp: input tensor.
        target_shape: target shape, it can contain an element of -1 representing ``unspec_axis``.
864 865 866

    Examples:

867
        .. testcode::
868

869 870 871 872 873 874
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
            x = tensor(np.arange(12, dtype=np.int32))
            out = F.reshape(x, (3, 4))
            print(out.numpy())
875

876
        Outputs:
877

878
        .. testoutput::
879

880 881 882
            [[ 0  1  2  3]
             [ 4  5  6  7]
             [ 8  9 10 11]]
883
    """
884
    return inp.reshape(target_shape)
885 886


887
def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
888
    r"""Reshapes the tensor by flattening the sub-tensor from dimension ``start_axis`` to dimension ``end_axis``.
889

890 891 892 893
    Args:
        inp: input tensor.
        start_axis: start dimension that the sub-tensor to be flattened. Default: 0
        end_axis: end dimension that the sub-tensor to be flattened. Default: -1
894

895 896
    Returns:
        output tensor.
897

898
    Examples:
899

900
        .. testcode::
901

902 903 904
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
905

906 907 908 909 910 911 912
            inp_shape = (2, 2, 3, 3)
            x = tensor(
                np.arange(36, dtype=np.int32).reshape(inp_shape),
            )
            out = F.flatten(x, 2)
            print(x.numpy().shape)
            print(out.numpy().shape)
913

914
        Outputs:
915

916
        .. testoutput::
917

918 919
            (2, 2, 3, 3)
            (2, 2, 9)
920 921 922 923 924 925 926
    """
    target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,)
    if end_axis != -1:
        target_shape += (*inp.shape[end_axis + 1 :],)
    return inp.reshape(*target_shape)


927
def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
928
    r"""Adds dimension before given axis.
929

930 931 932
    Args:
        inp: input tensor.
        axis: place of new axes.
933

934 935
    Returns:
        output tensor.
936

937
    Examples:
938

939
        .. testcode::
940

941 942 943
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
944

945 946 947
            x = tensor([1, 2])
            out = F.expand_dims(x, 0)
            print(out.numpy().shape)
948

949
        Outputs:
950

951
        .. testoutput::
952

953
            (1, 2)
954 955 956 957 958 959 960 961 962 963
    """

    def get_axes():
        try:
            return [int(axis)]
        except (TypeError, ValueError):
            pass
        return list(map(int, axis))

    axis = get_axes()
964 965 966 967 968 969 970 971 972
    try:
        ndim = inp.ndim + len(axis)
        axis = sorted(i + ndim if i < 0 else i for i in axis)
    except ValueError:
        if any([ind < 0 for ind in axis]):
            raise IndexError(
                "Does not support negative index when tensor's ndim is unknown"
            )
        axis = sorted(axis)
973 974 975 976 977 978 979 980
    assert axis, "axis could not be empty"
    if inp._isscalar():
        assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0])
        if len(axis) == 1:
            inp = copy(inp, device=None)
            inp._unsetscalar()
            return inp
        axis = axis[1:]
981
    op = builtin.AddAxis(axis=axis)
982 983 984 985
    (result,) = apply(op, inp)
    return result


986
def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Tensor:
987
    r"""Removes dimension of shape 1.
988

989 990 991
    Args:
        inp: input tensor.
        axis: place of axis to be removed.
992

993 994
    Returns:
        output tensor.
995

996
    Examples:
997

998
        .. testcode::
999

1000 1001 1002
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
1003

1004 1005 1006
            x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
            out = F.squeeze(x, 3)
            print(out.numpy().shape)
1007

1008
        Outputs:
1009

1010
        .. testoutput::
1011

1012
            (1, 1, 2)
1013
    """
1014
    return _remove_axis(inp, axis)
1015 1016 1017 1018 1019 1020 1021 1022 1023


def linspace(
    start: Union[int, float, Tensor],
    stop: Union[int, float, Tensor],
    num: Union[int, Tensor],
    dtype="float32",
    device: Optional[CompNode] = None,
) -> Tensor:
1024
    r"""Returns equally spaced numbers over a specified interval.
1025

1026 1027 1028 1029 1030
    Args:
        start: starting value of the squence, shoule be scalar.
        stop: last value of the squence, shoule be scalar.
        num: number of values to generate.
        dtype: result data type.
1031

1032 1033
    Returns:
        generated tensor.
1034

1035
    Examples:
1036

1037
        .. testcode::
1038

1039 1040
            import numpy as np
            import megengine.functional as F
1041

1042 1043
            a = F.linspace(3, 10, 5)
            print(a.numpy())
1044

1045
        Outputs:
1046

1047
        .. testoutput::
1048

1049
            [ 3.    4.75  6.5   8.25 10.  ]
1050
    """
1051 1052 1053 1054 1055 1056 1057 1058
    for item in (start, stop, num):
        cur_device = getattr(item, "device", None)
        if device is None:
            device = cur_device
        else:
            if not (cur_device is None or device == cur_device):
                raise ("ambiguous device for linspace opr")

1059 1060 1061 1062 1063
    is_symbolvar = list(isinstance(x, SymbolVar) for x in [start, stop, num])
    if any(is_symbolvar) and not all(is_symbolvar):
        raise TypeError("start, stop and num should all be VarNode or none of them")

    if not isinstance(start, (Tensor, SymbolVar)):
1064
        start = Tensor(start, device=device)
1065
    if not isinstance(stop, (Tensor, SymbolVar)):
1066
        stop = Tensor(stop, device=device)
1067
    if not isinstance(num, (Tensor, SymbolVar)):
1068
        num = Tensor(num, device=device)
1069 1070 1071

    op = builtin.Linspace(comp_node=device)
    (result,) = apply(op, start, stop, num)
1072
    if np.dtype(dtype) != np.float32:
1073 1074 1075 1076 1077
        return result.astype(dtype)
    return result


def arange(
1078
    start: Union[int, float, Tensor] = 0,
1079
    stop: Optional[Union[int, float, Tensor]] = None,
1080 1081 1082 1083
    step: Union[int, float, Tensor] = 1,
    dtype="float32",
    device: Optional[CompNode] = None,
) -> Tensor:
1084
    r"""Returns evenly spaced values within the half-open interval ``[start, stop)`` as a one-dimensional tensor. 
1085

1086 1087 1088
    Note:
        This function cannot guarantee that the interval does not include the stop value in those cases 
        where step is not an integer and floating-point rounding errors affect the length of the output tensor.
1089

1090 1091 1092 1093 1094 1095
    Args:
        start: if ``stop`` is specified, the start of interval (inclusive); otherwise, 
            the end of the interval (exclusive). If ``stop`` is not specified, the default starting value is ``0``. 
        stop: the end of the interval. Default: ``None``.
        step: the distance between two adjacent elements ( ``out[i+1] - out[i]`` ). Must not be 0 ; 
            may be negative, this results i an empty tensor if stop >= start . Default: 1 . 
1096

1097 1098 1099
    Keyword args:
        dtype( :attr:`.Tensor.dtype` ): output tensor data type. Default: ``float32``.
        device( :attr:`.Tensor.device` ): device on which to place the created tensor. Default: ``None``.
1100

1101 1102
    Returns:
        A one-dimensional tensor containing evenly spaced values.
1103

1104 1105
        The length of the output tensor must be ``ceil((stop-start)/step)`` 
        if ``stop - start`` and ``step`` have the same sign, and length 0 otherwise.
1106

1107
    Examples:
1108

1109 1110 1111 1112
        >>> F.arange(5)
        Tensor([0. 1. 2. 3. 4.], device=xpux:0)
        >>> F.arange(1, 4)
        Tensor([1. 2. 3.], device=xpux:0)
1113 1114

    """
1115 1116
    if stop is None:
        start, stop = 0, start
1117

1118 1119 1120 1121
    start = Tensor(start, dtype="float32")
    stop = Tensor(stop, dtype="float32")
    step = Tensor(step, dtype="float32")

1122
    num = ceil((stop - start) / step)
1123 1124
    stop = start + step * (num - 1)
    result = linspace(start, stop, num, device=device)
1125
    if np.dtype(dtype) != np.float32:
1126 1127
        return result.astype(dtype)
    return result
1128 1129 1130


def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None):
1131
    r"""Repeat elements of an array.
1132

1133 1134 1135 1136 1137
    Args:
        inp: input tensor.
        repeats: the number of repetitions for each element.
        axis: the axis along which to repeat values. By default, use the
            flattened input array, and return a flat output array.
1138

1139 1140
    Returns:
        output tensor.
1141

1142
    Examples:
1143

1144
        .. testcode::
1145

1146 1147 1148
            import numpy as np
            import megengine.functional as F
            from megengine import tensor
1149

1150 1151 1152
            x = tensor([[1, 2], [3, 4]], np.int32)
            y = F.repeat(x, 2, axis=0)
            print(y.numpy())
1153

1154
        Outputs:
1155

1156
        .. testoutput::
1157

1158 1159 1160 1161
            [[1 2]
             [1 2]
             [3 4]
             [3 4]]
1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216
    """
    if axis is None:
        inp = inp.reshape(-1)  # flatten
        axis = 0
    if inp._isscalar():
        inp._unsetscalar()
    shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
    # assume inp.ndim is not changed during trace
    max_axis = len(shape) - 1
    assert axis >= 0 and axis <= max_axis
    assert repeats >= 1

    base_shape, bcast_shape, target_shape = [], [], []
    if axis != 0:
        target_shape.append(shape[:axis])
    base_shape.extend([shape[: axis + 1], [1,]])
    bcast_shape.extend([shape[: axis + 1], [repeats,]])
    target_shape.extend(
        [shape[axis] * repeats,]
    )
    if axis + 1 <= max_axis:
        base_shape.append(shape[axis + 1 :])
        bcast_shape.append(shape[axis + 1 :])
        target_shape.append(shape[axis + 1 :])

    out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
        concat(target_shape)
    )
    return out


def _tile_one_dim(inp, rep, axis):
    shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
    # assume inp.ndim is not changed during trace
    max_axis = len(shape) - 1

    base_shape, bcast_shape, target_shape = [], [], []

    if axis != 0:
        base_shape.append(shape[:axis])
        bcast_shape.append(shape[:axis])
        target_shape.append(shape[:axis])
    base_shape.extend([[1,], shape[axis:]])
    bcast_shape.extend([rep, shape[axis:]])
    target_shape.append(shape[axis] * rep)
    if axis + 1 <= max_axis:
        target_shape.append(shape[axis + 1 :])

    out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
        concat(target_shape)
    )
    return out


def tile(inp: Tensor, reps: Iterable[int]):
1217
    r"""Construct an array by repeating ``inp`` the number of times given by ``reps``. If reps has length d,
1218 1219 1220
    the result will have dimension of ``max(d, inp.ndim)``. It is required that ``d >= inp.dim``. If ``inp.ndim < d``,
    ``inp`` is promoted to be ``d``-dimensional by prepending new axis.

1221 1222 1223
    Args:
        inp: input tensor.
        reps: The number of repetitions of inp along each axis.
1224

1225 1226
    Returns:
        output tensor.
1227 1228


1229
    Examples:
1230

1231
        .. testcode::
1232

1233 1234 1235 1236 1237 1238 1239
            import numpy as np
            import megengine.functional as F
            from megengine import tensor

            x = tensor([[1, 2], [3, 4]], np.int32)
            y = F.tile(x, (2,1))
            print(y.numpy())
1240

1241
        Outputs:
1242

1243
        .. testoutput::
1244

1245 1246 1247 1248
            [[1 2]
             [3 4]
             [1 2]
             [3 4]]
1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271
    """
    shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
    reps = astensor1d(reps, inp, dtype="int32", device=inp.device)
    l_shape = len(shape)
    l_reps = len(reps)
    assert (
        l_reps >= l_shape
    ), "Number of dimensions of tiled dims can not be smaller than number of dimensions of tensor"

    for i in range(l_shape):
        rep = reps[i + (l_reps - l_shape)]
        inp = _tile_one_dim(inp, rep, i)

    if l_reps > l_shape:
        shape = inp.shape
        extra = reps[:-l_shape]
        extra_ones = ones_like(extra)
        base_shape = concat([extra_ones, shape])
        bcast_shape = concat([extra, shape])
        target_shape = concat([extra, shape])
        inp = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape)

    return inp
1272 1273 1274


def copy(inp, device=None):
1275
    r"""Copies tensor to another device.
1276

1277 1278 1279
    Args:
        inp: input tensor.
        device: destination device.
1280 1281 1282

    Examples:

1283
        .. testcode::
1284

1285 1286 1287 1288 1289
            import numpy as np
            import platform
            from megengine import tensor
            from megengine.device import get_device_count
            import megengine.functional as F
1290

1291 1292 1293 1294 1295 1296 1297
            x = tensor([1, 2, 3], np.int32)
            if 1 == get_device_count("gpu"):
                y = F.copy(x, "cpu1")
                print(y.numpy())
            else:
                y = F.copy(x, "xpu1")
                print(y.numpy())
1298

1299
        Outputs:
1300

1301
        .. testoutput::
1302

1303
            [1 2 3]
1304 1305 1306 1307
    """
    if device is None:
        return apply(Identity(), inp)[0]
    return apply(Copy(comp_node=as_device(device).to_c()), inp)[0]
1308 1309 1310 1311 1312 1313 1314


def roll(
    inp: Tensor,
    shift: Union[int, Iterable[int]],
    axis: Optional[Union[int, Iterable[int]]] = None,
):
1315
    r"""Roll the tensor along the given axis(or axes). Elements that are shifted
1316 1317
    beyond the last position are re-introduced at the first position.

1318 1319 1320 1321 1322 1323 1324 1325
    Args:
        inp: input tensor.
        shift: the number of places by which the elements of the tensor are
            shifted. If shift is a tuple, axis must be a tuple of the same size,
            and each axis will be rolled by the corresponding shift value.
        axis: axis along which to roll. If axis is not specified, the tensor
            will be flattened before rolling and then restored to the original shape.
            Duplicate axes is allowed if it is a tuple. Default: None.
1326 1327 1328

    Examples:

1329
        .. testcode::
1330

1331 1332 1333
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
1334

1335 1336 1337
            x = tensor([[1,2],[3,4],[5,6]], np.int32)
            y = F.roll(x, 1, 0)
            print(y.numpy())
1338

1339
        Outputs:
1340

1341
        .. testoutput::
1342

1343 1344 1345
            [[5 6]
            [1 2]
            [3 4]]
1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370
    """
    shp_bak = None
    if axis is None:
        shp_bak = inp.shape
        inp = inp.flatten()
        axis = 0
    shp = inp.shape
    dim = len(shp)
    if isinstance(shift, int):
        assert isinstance(axis, int)
        shift, axis = [shift,], [axis,]
    assert len(shift) == len(axis)
    out = inp
    for i in range(len(shift)):
        axis_ = axis[i]
        shift_ = shift[i]
        axis_normalized_ = axis_ + dim if axis_ < 0 else axis_
        assert (
            dim > axis_normalized_ >= 0
        ), "axis out of range (expected to be in range of [{}, {}], but got {})".format(
            -dim, dim - 1, axis_
        )
        if shift_ == 0:
            continue
        size = shp[axis_normalized_]
1371 1372 1373
        shift_normalized_ = 0 if size == 0 else shift_ % size
        if shift_normalized_ > 0:
            a, b = split(out, [size - shift_normalized_,], axis=axis_normalized_)
1374
        else:
1375
            a, b = split(out, [-shift_normalized_,], axis=axis_normalized_)
1376 1377 1378 1379
        out = concat((b, a), axis=axis_normalized_)
    if shp_bak is not None:
        out = out.reshape(shp_bak)
    return out
M
Megvii Engine Team 已提交
1380 1381 1382


def cumsum(inp: Tensor, axis: int):
1383
    r"""Computes the cumulative sum of elements along given axis.
M
Megvii Engine Team 已提交
1384

1385 1386 1387
    Args:
        inp: input tensor.
        axis: axis along which cumsum is performed.
M
Megvii Engine Team 已提交
1388 1389 1390

    Examples:

1391
        .. testcode::
M
Megvii Engine Team 已提交
1392

1393 1394
            from megengine import tensor
            import megengine.functional as F
M
Megvii Engine Team 已提交
1395

1396 1397 1398
            x = tensor([[1, 2, 3], [4, 5, 6]], "int32")
            y = F.cumsum(x, 1)
            print(y.numpy())
M
Megvii Engine Team 已提交
1399

1400
        Outputs:
M
Megvii Engine Team 已提交
1401

1402
        .. testoutput::
M
Megvii Engine Team 已提交
1403

1404 1405
            [[ 1  3  6]
            [ 4  9 15]]
M
Megvii Engine Team 已提交
1406 1407 1408 1409 1410
    """
    assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor"
    assert axis >= 0 and axis < inp.ndim, "input axis {} out of bound".format(axis)
    op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False)
    return apply(op, inp)[0]