tensor.py 37.8 KB
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
from typing import Iterable, Optional, Sequence, Union
10 11 12 13

import numpy as np

from ..core._imperative_rt import CompNode
14
from ..core._imperative_rt.core2 import SymbolVar, apply, dtype_promotion
15
from ..core._wrap import as_device
16
from ..core.ops import builtin
17
from ..core.ops.builtin import Copy, Identity
18
from ..core.ops.special import Const
19
from ..core.tensor.array_method import _broadcast, _remove_axis
20
from ..core.tensor.utils import astensor1d, convert_inputs, get_device
21 22
from ..device import get_default_device
from ..tensor import Tensor
23
from .elemwise import ceil
24 25 26

__all__ = [
    "arange",
27
    "broadcast_to",
28 29
    "concat",
    "cond_take",
M
Megvii Engine Team 已提交
30
    "cumsum",
31
    "expand_dims",
32
    "eye",
33
    "flatten",
34 35 36 37 38 39
    "full",
    "full_like",
    "gather",
    "linspace",
    "ones",
    "ones_like",
40
    "repeat",
41
    "reshape",
42
    "roll",
43
    "split",
M
Megvii Engine Team 已提交
44
    "squeeze",
45 46
    "stack",
    "scatter",
47
    "tile",
48
    "copy",
49
    "transpose",
50 51 52 53 54 55
    "where",
    "zeros",
    "zeros_like",
]


56
def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor:
57
    r"""Returns a 2D tensor with ones on the diagonal and zeros elsewhere.
58

59 60 61 62 63
    Args:
        shape: a list, tuple or integer defining the shape of the output tensor.
        dtype: the desired data type of the output tensor. Default: ``float32``.
        device: the desired device of the output tensor. Default: if ``None``,
            use the default device (see :func:`~.megengine.get_default_device`).
64

65 66
    Returns:
        eye matrix.
67

68
    Examples:
69

70
        .. testcode::
71

72 73
            import numpy as np
            import megengine.functional as F
74

75 76
            out = F.eye(4, 6, dtype=np.float32)
            print(out.numpy())
77

78
        Outputs:
79

80
        .. testoutput::
81

82 83 84 85
            [[1. 0. 0. 0. 0. 0.]
             [0. 1. 0. 0. 0. 0.]
             [0. 0. 1. 0. 0. 0.]
             [0. 0. 0. 1. 0. 0.]]
86
    """
87 88 89 90 91 92 93 94 95
    if M is not None:
        if isinstance(N, Tensor) or isinstance(M, Tensor):
            shape = astensor1d((N, M))
        else:
            shape = Tensor([N, M], dtype="int32", device=device)
    elif isinstance(N, Tensor):
        shape = N
    else:
        shape = Tensor(N, dtype="int32", device=device)
96
    op = builtin.Eye(k=0, dtype=dtype, comp_node=device)
97
    (result,) = apply(op, shape)
98 99 100
    return result


101
def full(shape, value, dtype="float32", device=None) -> Tensor:
102
    r"""Creates a tensor of shape ``shape`` filled with ``value``.
103

104 105 106 107 108 109
    Args:
        shape: a list, tuple or integer defining the shape of the output tensor.
        value: the value to fill the output tensor with.
        dtype: the desired data type of the output tensor. Default: ``float32``.
        device: the desired device of the output tensor. Default: if ``None``,
            use the default device (see :func:`~.megengine.get_default_device`).
110

111 112
    Returns:
        output tensor.
113

114
    Examples:
115

116
        .. testcode::
117

118 119
            import numpy as np
            import megengine.functional as F
120

121 122
            out = F.full([2,3], 1.5)
            print(out.numpy())
123

124
        Outputs:
125

126
        .. testoutput::
127

128 129
            [[1.5 1.5 1.5]
             [1.5 1.5 1.5]]
130
    """
131

132 133
    if isinstance(shape, int):
        shape = (shape,)
134 135
    if device is None:
        device = get_default_device()
136
    (x,) = Const(value, dtype=dtype, device=device)()
137
    if shape is ():  # scalar.shape
138
        return x
139
    return broadcast_to(x, shape)
140 141


142
def ones(shape, dtype="float32", device=None) -> Tensor:
143
    r"""Returns a ones tensor with given shape.
144

145 146 147 148 149
    Args:
        shape: a list, tuple or integer defining the shape of the output tensor.
        dtype: the desired data type of the output tensor. Default: ``float32``.
        device: the desired device of the output tensor. Default: if ``None``,
            use the default device (see :func:`~.megengine.get_default_device`).
150

151 152
    Returns:
        output tensor.
153

154
    Examples:
155

156
        .. testcode::
157

158
            import megengine.functional as F
159

160 161
            out = F.ones((2, 1))
            print(out.numpy())
162

163
        Outputs:
164

165
        .. testoutput::
166

167 168
            [[1.]
             [1.]]
169
    """
170 171 172
    return full(shape, 1.0, dtype=dtype, device=device)


173
def zeros(shape, dtype="float32", device=None) -> Tensor:
174 175 176 177 178 179 180
    r"""Returns a zero tensor with given shape.

    Args:
        shape: a list, tuple or integer defining the shape of the output tensor.
        dtype: the desired data type of the output tensor. Default: ``float32``.
        device: the desired device of the output tensor. Default: if ``None``,
            use the default device (see :func:`~.megengine.get_default_device`).
181
    """
182 183 184
    return full(shape, 0.0, dtype=dtype, device=device)


185
def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
186 187 188 189
    r"""Returns a zero tensor with the same shape as input tensor.

    Args:
        inp: input tensor.
190

191 192
    Return:
        output tensor.
193 194 195

    Examples:

196
        .. testcode::
197

198 199 200
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
201

202 203 204
            inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
            out = F.zeros_like(inp)
            print(out.numpy())
205

206
        Outputs:
207

208
        .. testoutput::
209

210 211
            [[0 0 0]
             [0 0 0]]
212 213

    """
214
    return full_like(inp, 0.0)
215 216


217
def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
218
    r"""Returns a ones tensor with the same shape as input tensor.
219

220 221
    Args:
        inp: input tensor.
222

223 224
    Return:
        output tensor.
225

226
    Examples:
227

228
        .. testcode::
229

230 231 232
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
233

234 235 236
            inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
            out = F.ones_like(inp)
            print(out.numpy())
237

238
        Outputs:
239

240
        .. testoutput::
241

242 243
            [[1 1 1]
             [1 1 1]]
244
    """
245
    return full_like(inp, 1.0)
246 247


248 249 250
def full_like(
    inp: Union[Tensor, SymbolVar], value: Union[int, float]
) -> Union[Tensor, SymbolVar]:
251
    r"""Returns a tensor filled with given value with the same shape as input tensor.
252

253 254 255 256 257 258
    Args:
        inp: input tensor.
        value: target value.

    Return:
        output tensor.
259 260 261

    Examples:

262 263 264 265 266
        .. testcode::

            import numpy as np
            from megengine import tensor
            import megengine.functional as F
267

268 269 270
            inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
            out = F.full_like(inp, 2)
            print(out.numpy())
271

272
        Outputs:
273

274
        .. testoutput::
275

276 277
            [[2 2 2]
             [2 2 2]]
278

279
    """
280 281 282 283
    (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp)
    if inp.shape is ():
        return x
    return broadcast_to(x, inp.shape)
284 285


286
def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
287
    r"""Broadcasts a tensor to given shape.
288

289 290 291
    Args:
        inp: input tensor.
        shape: target shape.
292

293 294
    Returns:
        output tensor.
295

296
    Examples:
297

298
        .. testcode::
299

300 301 302
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
303

304 305 306
            data = tensor(np.arange(0, 3, dtype=np.float32).reshape(3))
            out = F.broadcast_to(data, (2, 3))
            print(out.numpy())
307

308
        Outputs:
309

310
        .. testoutput::
311

312 313
            [[0. 1. 2.]
             [0. 1. 2.]]
314
    """
315
    return _broadcast(inp, shape)
316 317


318
def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
319
    r"""Concat some tensors
320

321 322 323 324
    Args:
        inps: input tensors to concat.
        axis: over which dimension the tensors are concatenated. Default: 0
        device: which device output will be. Default: None
325

326 327
    Returns:
        output tensor.
328

329
    Examples:
330

331
        .. testcode::
332

333 334 335
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
336

337 338 339 340
            data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
            data2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
            out = F.concat([data1, data2])
            print(out.numpy())
341

342
        Outputs:
343

344
        .. testoutput::
345

346 347 348 349
            [[ 0.  1.  2.]
             [ 3.  4.  5.]
             [ 6.  7.  8.]
             [ 9. 10. 11.]]
350
    """
351 352 353
    if len(inps) == 1:
        return inps[0]

354
    # FIXME: remove this convert_inputs
355
    inps = convert_inputs(*inps, device=device)
356 357 358
    if device is None:
        device = get_device(inps)
    device = as_device(device)
359 360 361 362
    (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps)
    return result


363
def stack(inps, axis=0, device=None):
364
    r"""Concats a sequence of tensors along a new axis.
365 366
    The input tensors must have the same shape.

367 368 369 370
    Args:
        inps: input tensors.
        axis: which axis will be concatenated.
        device: the device output will be. Default: None
371

372 373
    Returns:
        output concatenated tensor.
374

375
    Examples:
376

377
        .. testcode::
378

379 380 381
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
382

383 384 385 386
            x1 = tensor(np.arange(0, 3, dtype=np.float32).reshape((3)))
            x2 = tensor(np.arange(6, 9, dtype=np.float32).reshape((3)))
            out = F.stack([x1, x2], axis=0)
            print(out.numpy())
387

388
        Outputs:
389

390
        .. testoutput::
391

392 393
            [[0. 1. 2.]
             [6. 7. 8.]]
394
    """
395 396 397 398
    if len(inps) > 0 and not isinstance(inps[0].shape, inps[0].__class__):
        shapes = {arr.shape for arr in inps}
        if len(shapes) != 1:
            raise ValueError("All input tensors must have the same shape")
399

400
    inps = [expand_dims(inp, axis=axis) for inp in inps]
401
    return concat(inps, axis=axis, device=device)
402 403 404


def split(inp, nsplits_or_sections, axis=0):
405
    r"""Splits the input tensor into several smaller tensors.
406 407
    When nsplits_or_sections is int, the last tensor may be smaller than others.

408 409 410 411
    Args:
        inp: input tensor.
        nsplits_or_sections: number of sub tensors or sections information list.
        axis: which axis will be splited.
412

413 414
    Returns:
        output tensor list.
415

416
    Examples:
417

418
        .. testcode::
419

420 421 422 423
            import os
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
424

425 426 427
            x = tensor(np.random.random((10, 20)), dtype=np.float32)
            y = F.split(x, 3)
            z = F.split(x, [6, 17], axis=1)
428

429 430
            print([i.numpy().shape for i in y])
            print([i.numpy().shape for i in z])
431

432
        Outputs:
433

434
        .. testoutput::
435

436 437
            [(4, 20), (3, 20), (3, 20)]
            [(10, 6), (10, 11), (10, 3)]
438
    """
439 440 441 442 443 444
    ndim = len(inp.shape)
    if axis >= ndim:
        raise ValueError("Invalid axis {}".format(axis))

    Ntotal = inp.shape[axis]

445
    if isinstance(nsplits_or_sections, Sequence):
446 447
        Nsections = len(nsplits_or_sections) + 1
        is_array = True
448
    else:
449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467
        Nsections = int(nsplits_or_sections)
        is_array = False

    if is_array:
        div_points = [0] + list(nsplits_or_sections) + [Ntotal]
        for i in range(1, len(div_points)):
            if div_points[i - 1] >= div_points[i]:
                raise ValueError(
                    "Invalid nsplits_or_secions: {}".format(nsplits_or_sections)
                )
    else:  # scalar
        if Nsections <= 0:
            raise ValueError("Number sections must be larger than 0")
        if Nsections > Ntotal:
            raise ValueError(
                "The size {} at dim {} cannot be split into {} sections".format(
                    Ntotal, axis, Nsections
                )
            )
468 469 470 471 472 473 474 475 476 477 478 479 480
        partitions = []
        for i in range(Nsections):
            section_size = (Ntotal + Nsections - i - 1) // Nsections
            partitions.append(section_size)

    partitions = [
        part
        if isinstance(part, (SymbolVar, Tensor))
        else Const(part, dtype="int32", device=inp.device)(inp)[0]
        for part in partitions
    ]
    op = builtin.Split(axis=axis)
    return apply(op, inp, *partitions)
481 482 483 484 485 486 487 488 489 490 491 492 493


def _get_idx(index, axis):
    index_dims = len(index.shape)
    idx = []
    for i in range(index_dims):
        if i != axis:
            shape = [1] * index_dims
            shape[i] = index.shape[i]
            arange = linspace(
                0, index.shape[i] - 1, index.shape[i], device=index.device,
            )
            arange = (
494
                broadcast_to(arange.reshape(*shape), index.shape)
495 496 497 498 499 500 501 502 503 504
                .reshape(-1)
                .astype(np.int32)
            )
            idx.append(arange)
        else:
            idx.append(index.reshape(-1))
    return tuple(idx)


def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
505
    # TODO: rewrite doc
506 507
    r"""
    Gathers data from input tensor on axis using index.
508

509
    For a 3-D tensor, the output is specified by:
510

511 512 513 514 515
    .. code-block::

       out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0
       out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1
       out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2
516

M
Megvii Engine Team 已提交
517
    if input tensor is a n-dimensional tensor with size
518
    :math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i,
M
Megvii Engine Team 已提交
519
    then index must be a n-dimensional tensor with size
520
    :math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and
521
    output will have the same size as index.
522

523 524 525 526
    Args:
        inp: input tensor.
        axis: along which axis to index.
        index: indices of elements to gather.
527

528 529
    Return:
        output tensor.
530

531
    Examples:
532

533
        .. testcode::
534

535 536
            import megengine.functional as F
            from megengine import tensor
537

538 539 540 541 542 543
            inp = tensor([
                [1,2], [3,4], [5,6],
            ])
            index = tensor([[0,2], [1,0]])
            oup = F.gather(inp, 0, index)
            print(oup.numpy())
544

545
        Outputs:
546

547
        .. testoutput::
548

549 550
            [[1 6]
             [3 2]]
551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581
    """
    input_shape = inp.shape
    index_shape = index.shape
    input_dims = len(input_shape)
    index_dims = len(index_shape)
    if input_dims != index_dims:
        raise ValueError(
            "The index tensor must have same dimensions as input tensor, "
            "But the input dims:{}, the index dims:{}".format(input_dims, index_dims)
        )

    if axis < 0 or axis >= input_dims:
        raise ValueError(
            "Index axis {} is output of bounds, should in range [0 {})".format(
                axis, input_dims
            )
        )

    for i in range(input_dims):
        if i != axis and input_shape[i] != index_shape[i]:
            raise ValueError(
                "The input {} and index {} must have the same size apart from axis {}".format(
                    input_shape, index_shape, axis
                )
            )

    idx = _get_idx(index, axis)
    return inp[idx].reshape(index.shape)  # pylint: disable=no-member


def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
582
    # TODO: rewrite doc
583 584
    r"""
    Writes all values from the tensor source into input tensor
585
    at the indices specified in the index tensor.
586

587 588 589
    For each value in source, its output index is specified by its index
    in source for ``axis != dimension`` and by the corresponding value in
    index for ``axis = dimension``.
590

591 592 593
    For a 3-D tensor, input tensor is updated as:

    .. code-block::
594

595 596 597
       inp[index[i][j][k]][j][k] = source[i][j][k]  # if axis == 0
       inp[i][index[i][j][k]][k] = source[i][j][k]  # if axis == 1
       inp[i][j][index[i][j][k]] = source[i][j][k]  # if axis == 2
598

M
Megvii Engine Team 已提交
599
    ``inp``, ``index`` and ``source`` should have same number of dimensions.
600 601 602 603

    It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)``
    for all dimensions ``d``.

604
    Moreover, the values of index must be between ``0`` and ``inp.shape(axis) - 1`` inclusive.
605

606
    Note:
607
        Please notice that, due to performance issues, the result is uncertain on the GPU device
M
Megvii Engine Team 已提交
608
        if scattering different positions from source to the same destination position
609 610
        regard to index tensor.

M
Megvii Engine Team 已提交
611
        Check the following examples, the oup[0][2] is maybe
612 613 614
        from source[0][2] which value is 0.2256 or source[1][2] which value is 0.5339
        if set the index[1][2] from 1 to 0.

615 616 617 618 619
    Args:
        inp: inp tensor which to be scattered.
        axis: axis along which to index.
        index: indices of elements to scatter.
        source: source element(s) to scatter.
620

621 622
    Return:
        output tensor.
623

624
    Examples:
625

626
        .. testcode::
627

628 629 630
            import numpy as np
            import megengine.functional as F
            from megengine import tensor
631

632 633 634 635 636
            inp = tensor(np.zeros(shape=(3,5),dtype=np.float32))
            source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]])
            index = tensor([[0,2,0,2,1],[2,0,1,1,2]])
            oup = F.scatter(inp, 0, index,source)
            print(oup.numpy())
637

638
        Outputs:
639

640
        .. testoutput::
641

642 643 644
            [[0.9935 0.0718 0.2256 0.     0.    ]
             [0.     0.     0.5939 0.357  0.4396]
             [0.7723 0.9465 0.     0.8926 0.4576]]
645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692
    """
    input_shape = inp.shape
    index_shape = index.shape
    source_shape = source.shape
    input_dims = len(input_shape)
    index_dims = len(index_shape)
    source_dims = len(source_shape)

    if input_dims != index_dims or input_dims != source_dims:
        raise ValueError("The input, source and index tensor must have same dimensions")

    if axis < 0 or axis >= input_dims:
        raise ValueError(
            "Index axis {} is output of bounds, should in range [0 {})".format(
                axis, input_dims
            )
        )

    for i in range(source_dims):
        if source_shape[i] > input_shape[i]:
            raise ValueError(
                "The each shape size for source {} must be less than or equal to input {} ".format(
                    source_shape, input_shape
                )
            )

    for i in range(index_dims):
        if index_shape[i] != source_shape[i]:
            raise ValueError(
                "The each shape size for index {} must be equal to source {} ".format(
                    index_shape, source_shape
                )
            )

    for i in range(index_dims):
        if i != axis and index_shape[i] > input_shape[i]:
            raise ValueError(
                "The index {} must be less than or equal to input {} size apart from axis {}".format(
                    index_shape, input_shape, axis
                )
            )

    idx = _get_idx(index, axis)
    inp[idx] = source.flatten()
    return inp


def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
693
    r"""Selects elements either from Tensor x or Tensor y, according to mask.
694 695 696 697 698

    .. math::

        \textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i

699 700 701 702 703 704 705
    Args:
        mask: a mask used for choosing ``x`` or ``y``.
        x: first choice.
        y: second choice.

    Returns:
        output tensor.
706 707 708

    Examples:

709
        .. testcode::
710

711 712 713 714 715 716 717 718 719
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
            mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool))
            x = tensor(np.array([[1, np.inf], [np.nan, 4]],
                dtype=np.float32))
            y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32))
            out = F.where(mask, x, y)
            print(out.numpy())
720

721
        Outputs:
722

723
        .. testoutput::
724

725 726
            [[1. 6.]
             [7. 4.]]
727
    """
728

729
    if not isinstance(x, Tensor):
730
        raise TypeError("input x must be a tensor")
731
    if not isinstance(y, Tensor):
732
        raise TypeError("input y must be a tensor")
733
    if not isinstance(mask, Tensor):
734 735 736 737 738 739
        raise TypeError("mask must be a tensor")
    if mask.dtype != np.bool_:
        raise ValueError("mask must be bool")
    if x.device != mask.device:
        raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))

740 741 742 743 744 745
    dtype = dtype_promotion(x, y)
    if x.dtype != dtype:
        x = x.astype(dtype)
    if y.dtype != dtype:
        y = y.astype(dtype)

746 747 748
    v0, index0 = cond_take(mask, x)
    v1, index1 = cond_take(~mask, y)

749
    out = concat([v0, v1])
750 751 752 753 754

    out[index0] = v0
    out[index1] = v1
    out = out.reshape(x.shape)
    return out
755 756 757


def cond_take(mask: Tensor, x: Tensor) -> Tensor:
758
    r"""Takes elements from data if specific condition is satisfied on mask.
759 760 761
    This operator has two outputs: the first is the elements taken,
    and the second is the indices corresponding to those elements;
    they are both 1-dimensional. High-dimension input would first be flattened.
762

763 764 765
    Args:
        mask: condition param; must be the same shape with data.
        x: input tensor from which to take elements.
766 767 768

    Examples:

769
        .. testcode::
770

771 772 773 774 775 776 777 778
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
            mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool_))
            x = tensor(np.array([[1, np.inf], [np.nan, 4]],
                dtype=np.float32))
            v, index = F.cond_take(mask, x)
            print(v.numpy(), index.numpy())
779

780
        Outputs:
781

782
        .. testoutput::
783

784
            [1. 4.] [0 3]
785
    """
786
    if not isinstance(x, (Tensor, SymbolVar)):
787
        raise TypeError("input must be a tensor")
788
    if not isinstance(mask, (Tensor, SymbolVar)):
789 790 791 792 793 794 795 796 797 798 799
        raise TypeError("mask must be a tensor")
    if mask.dtype != np.bool_:
        raise ValueError("mask must be bool")
    if x.device != mask.device:
        raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))

    op = builtin.CondTake()
    v, index = apply(op, x, mask)
    return v, index


800
def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820
    r"""Swaps shapes and strides according to given pattern.

    Args:
        inp: input tensor.
        pattern: a list of integers including 0, 1, ... , ``ndim``-1,
            and any number of ``'x'`` char in dimensions where this tensor should be broadcasted.
            For examples:

            * (``'x'``) -> make a 0d (scalar) into a 1d vector
            * (0, 1) -> identity for 2d vectors
            * (1, 0) -> inverts the first and second dimensions
            * (``'x'``, 0) -> make a row out of a 1d vector (N to 1xN)
            * (0, ``'x'``) -> make a column out of a 1d vector (N to Nx1)
            * (2, 0, 1) -> AxBxC to CxAxB
            * (0, ``'x'``, 1) -> AxB to Ax1xB
            * (1, ``'x'``, 0) -> AxB to Bx1xA
            * (1,) -> this removes dimensions 0. It must be a broadcastable dimension (1xA to A)

    Returns:
        output tensor.
821 822 823

    Examples:

824
        .. testcode::
825

826 827 828 829 830 831
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
            x = tensor(np.array([[1, 1], [0, 0]], dtype=np.int32))
            out = F.transpose(x, (1, 0))
            print(out.numpy())
832

833
        Outputs:
834

835
        .. testoutput::
836

837 838
            [[1 0]
            [1 0]]
839
    """
840
    return inp.transpose(list(-1 if _ == "x" else _ for _ in pattern))
841 842 843


def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
844
    r"""Reshapes a tensor to given target shape; total number of logical elements must
845 846
    remain unchanged

847 848 849
    Args:
        inp: input tensor.
        target_shape: target shape, it can contain an element of -1 representing ``unspec_axis``.
850 851 852

    Examples:

853
        .. testcode::
854

855 856 857 858 859 860
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
            x = tensor(np.arange(12, dtype=np.int32))
            out = F.reshape(x, (3, 4))
            print(out.numpy())
861

862
        Outputs:
863

864
        .. testoutput::
865

866 867 868
            [[ 0  1  2  3]
             [ 4  5  6  7]
             [ 8  9 10 11]]
869
    """
870
    return inp.reshape(target_shape)
871 872


873
def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
874
    r"""Reshapes the tensor by flattening the sub-tensor from dimension ``start_axis`` to dimension ``end_axis``.
875

876 877 878 879
    Args:
        inp: input tensor.
        start_axis: start dimension that the sub-tensor to be flattened. Default: 0
        end_axis: end dimension that the sub-tensor to be flattened. Default: -1
880

881 882
    Returns:
        output tensor.
883

884
    Examples:
885

886
        .. testcode::
887

888 889 890
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
891

892 893 894 895 896 897 898
            inp_shape = (2, 2, 3, 3)
            x = tensor(
                np.arange(36, dtype=np.int32).reshape(inp_shape),
            )
            out = F.flatten(x, 2)
            print(x.numpy().shape)
            print(out.numpy().shape)
899

900
        Outputs:
901

902
        .. testoutput::
903

904 905
            (2, 2, 3, 3)
            (2, 2, 9)
906 907 908 909 910 911 912
    """
    target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,)
    if end_axis != -1:
        target_shape += (*inp.shape[end_axis + 1 :],)
    return inp.reshape(*target_shape)


913
def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
914
    r"""Adds dimension before given axis.
915

916 917 918
    Args:
        inp: input tensor.
        axis: place of new axes.
919

920 921
    Returns:
        output tensor.
922

923
    Examples:
924

925
        .. testcode::
926

927 928 929
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
930

931 932 933
            x = tensor([1, 2])
            out = F.expand_dims(x, 0)
            print(out.numpy().shape)
934

935
        Outputs:
936

937
        .. testoutput::
938

939
            (1, 2)
940 941 942 943 944 945 946 947 948 949
    """

    def get_axes():
        try:
            return [int(axis)]
        except (TypeError, ValueError):
            pass
        return list(map(int, axis))

    axis = get_axes()
950 951 952 953 954 955 956 957 958
    try:
        ndim = inp.ndim + len(axis)
        axis = sorted(i + ndim if i < 0 else i for i in axis)
    except ValueError:
        if any([ind < 0 for ind in axis]):
            raise IndexError(
                "Does not support negative index when tensor's ndim is unknown"
            )
        axis = sorted(axis)
959 960 961 962 963 964 965 966
    assert axis, "axis could not be empty"
    if inp._isscalar():
        assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0])
        if len(axis) == 1:
            inp = copy(inp, device=None)
            inp._unsetscalar()
            return inp
        axis = axis[1:]
967
    op = builtin.AddAxis(axis=axis)
968 969 970 971
    (result,) = apply(op, inp)
    return result


972
def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Tensor:
973
    r"""Removes dimension of shape 1.
974

975 976 977
    Args:
        inp: input tensor.
        axis: place of axis to be removed.
978

979 980
    Returns:
        output tensor.
981

982
    Examples:
983

984
        .. testcode::
985

986 987 988
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
989

990 991 992
            x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
            out = F.squeeze(x, 3)
            print(out.numpy().shape)
993

994
        Outputs:
995

996
        .. testoutput::
997

998
            (1, 1, 2)
999
    """
1000
    return _remove_axis(inp, axis)
1001 1002 1003 1004 1005 1006 1007 1008 1009


def linspace(
    start: Union[int, float, Tensor],
    stop: Union[int, float, Tensor],
    num: Union[int, Tensor],
    dtype="float32",
    device: Optional[CompNode] = None,
) -> Tensor:
1010
    r"""Returns equally spaced numbers over a specified interval.
1011

1012 1013 1014 1015 1016
    Args:
        start: starting value of the squence, shoule be scalar.
        stop: last value of the squence, shoule be scalar.
        num: number of values to generate.
        dtype: result data type.
1017

1018 1019
    Returns:
        generated tensor.
1020

1021
    Examples:
1022

1023
        .. testcode::
1024

1025 1026
            import numpy as np
            import megengine.functional as F
1027

1028 1029
            a = F.linspace(3, 10, 5)
            print(a.numpy())
1030

1031
        Outputs:
1032

1033
        .. testoutput::
1034

1035
            [ 3.    4.75  6.5   8.25 10.  ]
1036
    """
1037 1038 1039 1040 1041 1042 1043 1044
    for item in (start, stop, num):
        cur_device = getattr(item, "device", None)
        if device is None:
            device = cur_device
        else:
            if not (cur_device is None or device == cur_device):
                raise ("ambiguous device for linspace opr")

1045 1046 1047 1048 1049
    is_symbolvar = list(isinstance(x, SymbolVar) for x in [start, stop, num])
    if any(is_symbolvar) and not all(is_symbolvar):
        raise TypeError("start, stop and num should all be VarNode or none of them")

    if not isinstance(start, (Tensor, SymbolVar)):
1050
        start = Tensor(start, device=device)
1051
    if not isinstance(stop, (Tensor, SymbolVar)):
1052
        stop = Tensor(stop, device=device)
1053
    if not isinstance(num, (Tensor, SymbolVar)):
1054
        num = Tensor(num, device=device)
1055 1056 1057

    op = builtin.Linspace(comp_node=device)
    (result,) = apply(op, start, stop, num)
1058
    if np.dtype(dtype) != np.float32:
1059 1060 1061 1062 1063
        return result.astype(dtype)
    return result


def arange(
1064
    start: Union[int, float, Tensor] = 0,
1065
    stop: Optional[Union[int, float, Tensor]] = None,
1066 1067 1068 1069
    step: Union[int, float, Tensor] = 1,
    dtype="float32",
    device: Optional[CompNode] = None,
) -> Tensor:
1070
    r"""Returns a tensor with values from start to stop with adjacent interval step.
1071

1072 1073 1074 1075 1076
    Args:
        start: starting value of the squence, shoule be scalar.
        stop: ending value of the squence, shoule be scalar.
        step: gap between each pair of adjacent values. Default: 1
        dtype: result data type.
1077

1078 1079
    Returns:
        generated tensor.
1080

1081
    Examples:
1082

1083
        .. testcode::
1084

1085 1086
            import numpy as np
            import megengine.functional as F
1087

1088 1089
            a = F.arange(5)
            print(a.numpy())
1090

1091
        Outputs:
1092

1093
        .. testoutput::
1094

1095
            [0. 1. 2. 3. 4.]
1096
    """
1097 1098
    if stop is None:
        start, stop = 0, start
1099

1100 1101 1102 1103
    start = Tensor(start, dtype="float32")
    stop = Tensor(stop, dtype="float32")
    step = Tensor(step, dtype="float32")

1104
    num = ceil((stop - start) / step)
1105 1106
    stop = start + step * (num - 1)
    result = linspace(start, stop, num, device=device)
1107
    if np.dtype(dtype) != np.float32:
1108 1109
        return result.astype(dtype)
    return result
1110 1111 1112


def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None):
1113
    r"""Repeat elements of an array.
1114

1115 1116 1117 1118 1119
    Args:
        inp: input tensor.
        repeats: the number of repetitions for each element.
        axis: the axis along which to repeat values. By default, use the
            flattened input array, and return a flat output array.
1120

1121 1122
    Returns:
        output tensor.
1123

1124
    Examples:
1125

1126
        .. testcode::
1127

1128 1129 1130
            import numpy as np
            import megengine.functional as F
            from megengine import tensor
1131

1132 1133 1134
            x = tensor([[1, 2], [3, 4]], np.int32)
            y = F.repeat(x, 2, axis=0)
            print(y.numpy())
1135

1136
        Outputs:
1137

1138
        .. testoutput::
1139

1140 1141 1142 1143
            [[1 2]
             [1 2]
             [3 4]
             [3 4]]
1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198
    """
    if axis is None:
        inp = inp.reshape(-1)  # flatten
        axis = 0
    if inp._isscalar():
        inp._unsetscalar()
    shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
    # assume inp.ndim is not changed during trace
    max_axis = len(shape) - 1
    assert axis >= 0 and axis <= max_axis
    assert repeats >= 1

    base_shape, bcast_shape, target_shape = [], [], []
    if axis != 0:
        target_shape.append(shape[:axis])
    base_shape.extend([shape[: axis + 1], [1,]])
    bcast_shape.extend([shape[: axis + 1], [repeats,]])
    target_shape.extend(
        [shape[axis] * repeats,]
    )
    if axis + 1 <= max_axis:
        base_shape.append(shape[axis + 1 :])
        bcast_shape.append(shape[axis + 1 :])
        target_shape.append(shape[axis + 1 :])

    out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
        concat(target_shape)
    )
    return out


def _tile_one_dim(inp, rep, axis):
    shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
    # assume inp.ndim is not changed during trace
    max_axis = len(shape) - 1

    base_shape, bcast_shape, target_shape = [], [], []

    if axis != 0:
        base_shape.append(shape[:axis])
        bcast_shape.append(shape[:axis])
        target_shape.append(shape[:axis])
    base_shape.extend([[1,], shape[axis:]])
    bcast_shape.extend([rep, shape[axis:]])
    target_shape.append(shape[axis] * rep)
    if axis + 1 <= max_axis:
        target_shape.append(shape[axis + 1 :])

    out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
        concat(target_shape)
    )
    return out


def tile(inp: Tensor, reps: Iterable[int]):
1199
    r"""Construct an array by repeating ``inp`` the number of times given by ``reps``. If reps has length d,
1200 1201 1202
    the result will have dimension of ``max(d, inp.ndim)``. It is required that ``d >= inp.dim``. If ``inp.ndim < d``,
    ``inp`` is promoted to be ``d``-dimensional by prepending new axis.

1203 1204 1205
    Args:
        inp: input tensor.
        reps: The number of repetitions of inp along each axis.
1206

1207 1208
    Returns:
        output tensor.
1209 1210


1211
    Examples:
1212

1213
        .. testcode::
1214

1215 1216 1217 1218 1219 1220 1221
            import numpy as np
            import megengine.functional as F
            from megengine import tensor

            x = tensor([[1, 2], [3, 4]], np.int32)
            y = F.tile(x, (2,1))
            print(y.numpy())
1222

1223
        Outputs:
1224

1225
        .. testoutput::
1226

1227 1228 1229 1230
            [[1 2]
             [3 4]
             [1 2]
             [3 4]]
1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253
    """
    shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
    reps = astensor1d(reps, inp, dtype="int32", device=inp.device)
    l_shape = len(shape)
    l_reps = len(reps)
    assert (
        l_reps >= l_shape
    ), "Number of dimensions of tiled dims can not be smaller than number of dimensions of tensor"

    for i in range(l_shape):
        rep = reps[i + (l_reps - l_shape)]
        inp = _tile_one_dim(inp, rep, i)

    if l_reps > l_shape:
        shape = inp.shape
        extra = reps[:-l_shape]
        extra_ones = ones_like(extra)
        base_shape = concat([extra_ones, shape])
        bcast_shape = concat([extra, shape])
        target_shape = concat([extra, shape])
        inp = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape)

    return inp
1254 1255 1256


def copy(inp, device=None):
1257
    r"""Copies tensor to another device.
1258

1259 1260 1261
    Args:
        inp: input tensor.
        device: destination device.
1262 1263 1264

    Examples:

1265
        .. testcode::
1266

1267 1268 1269 1270 1271
            import numpy as np
            import platform
            from megengine import tensor
            from megengine.device import get_device_count
            import megengine.functional as F
1272

1273 1274 1275 1276 1277 1278 1279
            x = tensor([1, 2, 3], np.int32)
            if 1 == get_device_count("gpu"):
                y = F.copy(x, "cpu1")
                print(y.numpy())
            else:
                y = F.copy(x, "xpu1")
                print(y.numpy())
1280

1281
        Outputs:
1282

1283
        .. testoutput::
1284

1285
            [1 2 3]
1286 1287 1288 1289
    """
    if device is None:
        return apply(Identity(), inp)[0]
    return apply(Copy(comp_node=as_device(device).to_c()), inp)[0]
1290 1291 1292 1293 1294 1295 1296


def roll(
    inp: Tensor,
    shift: Union[int, Iterable[int]],
    axis: Optional[Union[int, Iterable[int]]] = None,
):
1297
    r"""Roll the tensor along the given axis(or axes). Elements that are shifted
1298 1299
    beyond the last position are re-introduced at the first position.

1300 1301 1302 1303 1304 1305 1306 1307
    Args:
        inp: input tensor.
        shift: the number of places by which the elements of the tensor are
            shifted. If shift is a tuple, axis must be a tuple of the same size,
            and each axis will be rolled by the corresponding shift value.
        axis: axis along which to roll. If axis is not specified, the tensor
            will be flattened before rolling and then restored to the original shape.
            Duplicate axes is allowed if it is a tuple. Default: None.
1308 1309 1310

    Examples:

1311
        .. testcode::
1312

1313 1314 1315
            import numpy as np
            from megengine import tensor
            import megengine.functional as F
1316

1317 1318 1319
            x = tensor([[1,2],[3,4],[5,6]], np.int32)
            y = F.roll(x, 1, 0)
            print(y.numpy())
1320

1321
        Outputs:
1322

1323
        .. testoutput::
1324

1325 1326 1327
            [[5 6]
            [1 2]
            [3 4]]
1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360
    """
    shp_bak = None
    if axis is None:
        shp_bak = inp.shape
        inp = inp.flatten()
        axis = 0
    shp = inp.shape
    dim = len(shp)
    if isinstance(shift, int):
        assert isinstance(axis, int)
        shift, axis = [shift,], [axis,]
    assert len(shift) == len(axis)
    out = inp
    for i in range(len(shift)):
        axis_ = axis[i]
        shift_ = shift[i]
        axis_normalized_ = axis_ + dim if axis_ < 0 else axis_
        assert (
            dim > axis_normalized_ >= 0
        ), "axis out of range (expected to be in range of [{}, {}], but got {})".format(
            -dim, dim - 1, axis_
        )
        if shift_ == 0:
            continue
        size = shp[axis_normalized_]
        if shift_ > 0:
            a, b = split(out, [size - shift_,], axis=axis_normalized_)
        else:
            a, b = split(out, [-shift_,], axis=axis_normalized_)
        out = concat((b, a), axis=axis_normalized_)
    if shp_bak is not None:
        out = out.reshape(shp_bak)
    return out
M
Megvii Engine Team 已提交
1361 1362 1363


def cumsum(inp: Tensor, axis: int):
1364
    r"""Computes the cumulative sum of elements along given axis.
M
Megvii Engine Team 已提交
1365

1366 1367 1368
    Args:
        inp: input tensor.
        axis: axis along which cumsum is performed.
M
Megvii Engine Team 已提交
1369 1370 1371

    Examples:

1372
        .. testcode::
M
Megvii Engine Team 已提交
1373

1374 1375
            from megengine import tensor
            import megengine.functional as F
M
Megvii Engine Team 已提交
1376

1377 1378 1379
            x = tensor([[1, 2, 3], [4, 5, 6]], "int32")
            y = F.cumsum(x, 1)
            print(y.numpy())
M
Megvii Engine Team 已提交
1380

1381
        Outputs:
M
Megvii Engine Team 已提交
1382

1383
        .. testoutput::
M
Megvii Engine Team 已提交
1384

1385 1386
            [[ 1  3  6]
            [ 4  9 15]]
M
Megvii Engine Team 已提交
1387 1388 1389 1390 1391
    """
    assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor"
    assert axis >= 0 and axis < inp.ndim, "input axis {} out of bound".format(axis)
    op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False)
    return apply(op, inp)[0]