elemwise.py 12.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
import functools

12
from ..core._imperative_rt.core2 import apply
13
from ..core.ops import builtin
14
from ..core.ops.builtin import Elemwise
15
from ..core.tensor import megbrain_graph, utils
16
from ..core.tensor.utils import isscalar, setscalar
17
from ..device import get_default_device
18
from ..jit.tracing import is_tracing
19 20 21 22 23 24 25 26 27 28 29 30 31
from ..tensor import Tensor

__all__ = [
    "abs",
    "add",
    "acos",
    "asin",
    "atan",
    "atan2",
    "asinh",
    "acosh",
    "atanh",
    "ceil",
32
    "clip",
33 34 35
    "cos",
    "cosh",
    "div",
36
    "equal",
37 38 39 40
    "exp",
    "expm1",
    "floor",
    "floor_div",
41 42
    "greater",
    "greater_equal",
43 44 45
    "hswish",
    "hsigmoid",
    "left_shift",
46 47
    "less",
    "less_equal",
48 49 50 51 52 53 54 55 56 57 58
    "log",
    "log1p",
    "logical_and",
    "logical_not",
    "logical_or",
    "logical_xor",
    "maximum",
    "minimum",
    "mod",
    "mul",
    "neg",
59
    "not_equal",
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
    "pow",
    "relu",
    "relu6",
    "right_shift",
    "round",
    "sigmoid",
    "sin",
    "sinh",
    "sqrt",
    "square",
    "sub",
    "tan",
    "tanh",
]


def _elwise(*args, mode):
77
    op = builtin.Elemwise(mode)
78 79 80 81 82 83 84 85 86
    tensor_args = list(
        filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args)
    )
    if len(tensor_args) == 0:
        dtype = utils.dtype_promotion(args)
        first_arg = Tensor(args[0], dtype=dtype, device=get_default_device())
        args = utils.convert_inputs(first_arg, *args[1:])
    else:
        args = utils.convert_inputs(*args)
87
    if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"):
88
        args = tuple(map(lambda x: x.astype("float32"), args))
89 90 91 92 93
    _isscalar = True
    for i in args:
        if isscalar(i) == False:
            _isscalar = False
            break
94
    (result,) = apply(op, *args)
95 96
    if _isscalar:
        setscalar(result)
97 98 99 100 101 102 103 104 105 106 107 108 109 110
    return result


def _elemwise_multi_type(*args, mode, **kwargs):
    op = builtin.ElemwiseMultiType(mode=mode, **kwargs)
    args = utils.convert_inputs(*args)
    (result,) = apply(op, *args)
    return result


# math operations


def add(x, y):
111 112
    """
    Element-wise `addition`.
113
    At least one operand should be tensor.
M
Megvii Engine Team 已提交
114

115
    Same for sub/mul/div/floor_div/pow/mod/atan2/equal/not_equal/less/less_equal/greater/greater_equal/maximum/minmium.
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139

    :param x: input tensor.
    :return: computed tensor.

    Examples:

    .. testcode::

        import numpy as np
        from megengine import tensor
        import megengine.functional as F

        x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
        y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
        out = F.add(x, y)
        print(out.numpy())

    Outputs:

    .. testoutput::

        [[ 0.  2.  4.]
         [ 6.  8. 10.]]

140
    """
141
    return _elwise(x, y, mode=Elemwise.Mode.ADD)
142 143 144


def sub(x, y):
M
Megvii Engine Team 已提交
145
    """Element-wise `subtraction`."""
146
    return _elwise(x, y, mode=Elemwise.Mode.SUB)
147 148 149


def mul(x, y):
M
Megvii Engine Team 已提交
150
    """Element-wise `multiplication`."""
151
    return _elwise(x, y, mode=Elemwise.Mode.MUL)
152 153 154


def div(x, y):
M
Megvii Engine Team 已提交
155
    """Element-wise `(x / y)`."""
156
    return _elwise(x, y, mode=Elemwise.Mode.TRUE_DIV)
157 158 159


def floor_div(x, y):
M
Megvii Engine Team 已提交
160
    """Element-wise `floor(x / y)`."""
161
    return _elwise(x, y, mode=Elemwise.Mode.FLOOR_DIV)
162 163 164


def neg(x):
M
Megvii Engine Team 已提交
165
    """Element-wise `negation`."""
166
    return _elwise(x, mode=Elemwise.Mode.NEGATE)
167 168 169


def pow(x, y):
M
Megvii Engine Team 已提交
170
    """Element-wise `power`."""
171
    return _elwise(x, y, mode=Elemwise.Mode.POW)
172 173 174


def mod(x, y):
M
Megvii Engine Team 已提交
175
    """Element-wise `remainder of division`."""
176
    return _elwise(x, y, mode=Elemwise.Mode.MOD)
177 178 179


def abs(x):
M
Megvii Engine Team 已提交
180
    """Element-wise `absolute value`."""
181
    return _elwise(x, mode=Elemwise.Mode.ABS)
182 183 184


def exp(x):
M
Megvii Engine Team 已提交
185
    """Element-wise `exponential`."""
186
    return _elwise(x, mode=Elemwise.Mode.EXP)
187 188 189


def expm1(x):
M
Megvii Engine Team 已提交
190
    """Element-wise `exp(x)-1`."""
191
    return _elwise(x, mode=Elemwise.Mode.EXPM1)
192 193 194


def log(x):
M
Megvii Engine Team 已提交
195
    """Element-wise `logarithm (base e)`."""
196
    return _elwise(x, mode=Elemwise.Mode.LOG)
197 198 199


def log1p(x):
M
Megvii Engine Team 已提交
200
    """Element-wise `log(x+1) (base e)`."""
201
    return _elwise(x, mode=Elemwise.Mode.LOG1P)
202 203


204
def sqrt(x: Tensor) -> Tensor:
205 206
    """
    Element-wise `sqrt`.
M
Megvii Engine Team 已提交
207
    Returns ``NaN`` for negative input value.
208

209 210
    :param x: input tensor.
    :return: computed tensor.
211 212 213 214 215 216

    Examples:

    .. testcode::

        import numpy as np
217
        from megengine import tensor
218 219
        import megengine.functional as F

220 221
        x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
        out = F.sqrt(x)
222
        print(out.numpy().round(decimals=4))
223 224 225 226 227

    Outputs:

    .. testoutput::

M
Megvii Engine Team 已提交
228 229
        [[0.     1.     1.4142]
         [1.7321 2.     2.2361]]
230 231

    """
232
    return x ** 0.5
233 234


235
def square(x: Tensor) -> Tensor:
236
    """
M
Megvii Engine Team 已提交
237
    Returns a new tensor with the square of the elements of input tensor.
238

239 240
    :param inp: input tensor.
    :return: computed tensor.
241 242 243 244 245 246 247 248 249 250 251

    Examples:

    .. testcode::

        import numpy as np
        import megengine as mge
        import megengine.functional as F

        data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
        out = F.square(data)
252
        print(out.numpy().round(decimals=4))
253 254 255 256 257

    Outputs:

    .. testoutput::

M
Megvii Engine Team 已提交
258 259
        [[ 0.  1.  4.]
         [ 9. 16. 25.]]
260 261

    """
262
    return x ** 2
263 264 265


def round(x):
M
Megvii Engine Team 已提交
266
    """Element-wise `rounding to int`."""
267
    return _elwise(x, mode=Elemwise.Mode.ROUND)
268 269 270


def ceil(x):
M
Megvii Engine Team 已提交
271
    """Element-wise `ceiling`."""
272
    return _elwise(x, mode=Elemwise.Mode.CEIL)
273 274 275


def floor(x):
M
Megvii Engine Team 已提交
276
    """Element-wise `floor`."""
277
    return _elwise(x, mode=Elemwise.Mode.FLOOR)
278 279


280
def maximum(x, y):
M
Megvii Engine Team 已提交
281
    """Element-wise `maximum of array elements`."""
282
    return _elwise(x, y, mode=Elemwise.Mode.MAX)
283 284 285


def minimum(x, y):
M
Megvii Engine Team 已提交
286
    """Element-wise `minimum of array elements`."""
287
    return _elwise(x, y, mode=Elemwise.Mode.MIN)
288 289


290 291 292 293
# trigonometric functions


def cos(x):
294 295
    """
    Element-wise `cosine`.
296 297 298 299 300 301 302 303 304 305 306 307 308 309

    :param x: input tensor.
    :return: computed tensor.

    Examples:

    .. testcode::

        import numpy as np
        from megengine import tensor
        import megengine.functional as F

        x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
        out = F.cos(x)
310
        print(out.numpy().round(decimals=4))
311 312 313 314 315 316 317 318 319

    Outputs:

    .. testoutput::

        [[ 1.      0.5403 -0.4161]
         [-0.99   -0.6536  0.2837]]

    """
320
    return _elwise(x, mode=Elemwise.Mode.COS)
321 322 323


def sin(x):
M
Megvii Engine Team 已提交
324
    """Element-wise `sine`."""
325
    return _elwise(x, mode=Elemwise.Mode.SIN)
326 327 328


def tan(x):
M
Megvii Engine Team 已提交
329
    """Element-wise `tangent`."""
330 331 332 333
    return sin(x) / cos(x)


def acos(x):
M
Megvii Engine Team 已提交
334
    """Element-wise `inverse cosine`."""
335
    return _elwise(x, mode=Elemwise.Mode.ACOS)
336 337 338


def asin(x):
M
Megvii Engine Team 已提交
339
    """Element-wise `inverse sine`."""
340
    return _elwise(x, mode=Elemwise.Mode.ASIN)
341 342 343


def atan(x):
M
Megvii Engine Team 已提交
344
    """Element-wise `inverse tangent`."""
345
    return _elwise(x, 1, mode=Elemwise.Mode.ATAN2)
346 347 348


def atan2(y, x):
M
Megvii Engine Team 已提交
349
    """Element-wise `2-argument arctangent`."""
350
    return _elwise(y, x, mode=Elemwise.Mode.ATAN2)
351 352 353


def cosh(x):
M
Megvii Engine Team 已提交
354
    r"""Element-wise `hyperbolic cosine`."""
355 356 357 358
    return 0.5 * (exp(x) + exp(-x))


def sinh(x):
M
Megvii Engine Team 已提交
359
    r"""Element-wise `hyperbolic sine`."""
360 361 362 363 364
    u = expm1(x)
    return 0.5 * u / (u + 1) * (u + 2)


def tanh(x):
M
Megvii Engine Team 已提交
365
    r"""Element-wise `hyperbolic tangent`."""
366
    return _elwise(x, mode=Elemwise.Mode.TANH)
367 368 369


def asinh(x):
M
Megvii Engine Team 已提交
370
    r"""Element-wise `inverse hyperbolic sine`."""
371 372 373 374
    return log(x + (x ** 2 + 1) ** 0.5)


def acosh(x):
M
Megvii Engine Team 已提交
375
    r"""Element-wise `inverse hyperbolic cosine`."""
376 377 378 379
    return log(x + (x ** 2 - 1) ** 0.5)


def atanh(x):
M
Megvii Engine Team 已提交
380
    r"""Element-wise `inverse hyperbolic tangent`."""
381 382 383 384 385 386 387
    return log1p(2 * x / (1 - x)) / 2


# bit-twiddling functions


def left_shift(x, y):
388 389
    """
    Element-wise `bitwise binary: x << y`.
390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414

    :param x: input tensor, should be int.
    :param y: how many bits to be left-shifted.
    :return: computed tensor.

    Examples:

    .. testcode::

        import numpy as np
        from megengine import tensor
        import megengine.functional as F

        x = tensor(np.arange(0, 6, dtype=np.int32).reshape(2, 3))
        out = F.left_shift(x, 2)
        print(out.numpy())

    Outputs:

    .. testoutput::

        [[ 0  4  8]
         [12 16 20]]

    """
415
    return _elwise(x, y, mode=Elemwise.Mode.SHL)
416 417 418


def right_shift(x, y):
M
Megvii Engine Team 已提交
419
    """Element-wise `bitwise binary: x >> y`."""
420
    return _elwise(x, y, mode=Elemwise.Mode.SHR)
421 422 423 424 425 426


# logical functions


def logical_and(x, y):
M
Megvii Engine Team 已提交
427
    """Element-wise `logical and: x && y`."""
428
    return _elwise(x, y, mode=Elemwise.Mode.AND)
429 430 431


def logical_not(x):
M
Megvii Engine Team 已提交
432
    """Element-wise `logical not: ~x`."""
433
    return _elwise(x, mode=Elemwise.Mode.NOT)
434 435 436


def logical_or(x, y):
M
Megvii Engine Team 已提交
437
    """Element-wise `logical or: x || y`."""
438
    return _elwise(x, y, mode=Elemwise.Mode.OR)
439 440 441


def logical_xor(x, y):
M
Megvii Engine Team 已提交
442
    """Element-wise `logical xor: x ^ y`."""
443
    return _elwise(x, y, mode=Elemwise.Mode.XOR)
444 445 446 447 448


# comparison functions


449
def equal(x, y):
450 451
    """
    Element-wise `(x == y)`.
452 453 454 455 456 457 458 459 460 461 462 463 464 465 466

    :param x: input tensor 1.
    :param y: input tensor 2.
    :return: computed tensor.

    Examples:

    .. testcode::

        import numpy as np
        from megengine import tensor
        import megengine.functional as F

        x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
        y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
467
        out = F.equal(x, y)
468 469 470 471 472 473 474 475 476 477
        print(out.numpy())

    Outputs:

    .. testoutput::

        [[1. 1. 1.]
         [1. 1. 1.]]

    """
478
    return _elwise(x, y, mode=Elemwise.Mode.EQ)
479 480


481
def not_equal(x, y):
M
Megvii Engine Team 已提交
482
    """Element-wise `(x != y)`."""
483 484 485
    return x != y


486
def less(x, y):
M
Megvii Engine Team 已提交
487
    """Element-wise `(x < y)`."""
488
    return _elwise(x, y, mode=Elemwise.Mode.LT)
489 490


491
def less_equal(x, y):
M
Megvii Engine Team 已提交
492
    """Element-wise `(x <= y)`."""
493
    return _elwise(x, y, mode=Elemwise.Mode.LEQ)
494 495


496
def greater(x, y):
M
Megvii Engine Team 已提交
497
    """Element-wise `(x > y)`."""
498
    return _elwise(y, x, mode=Elemwise.Mode.LT)
499 500


501
def greater_equal(x, y):
M
Megvii Engine Team 已提交
502
    """Element-wise `(x >= y)`."""
503
    return _elwise(y, x, mode=Elemwise.Mode.LEQ)
504 505


506 507 508
# other functions


509
def hswish(x):
510 511
    """
    Element-wise `x * relu6(x + 3) / 6`.
512 513 514 515 516 517 518 519 520 521 522 523 524 525

    :param x: input tensor.
    :return: computed tensor.

    Example:

    .. testcode::

        import numpy as np
        from megengine import tensor
        import megengine.functional as F

        x = tensor(np.arange(5).astype(np.float32))
        out = F.hswish(x)
526
        print(out.numpy().round(decimals=4))
527 528 529 530 531 532

    .. testoutput::

        [0.     0.6667 1.6667 3.     4.    ]

    """
533
    return _elwise(x, mode=Elemwise.Mode.H_SWISH)
534 535 536


def hsigmoid(x):
M
Megvii Engine Team 已提交
537
    """Element-wise `relu6(x + 3) / 6`."""
538 539 540 541
    return relu6(x + 3) / 6


def relu(x):
542
    """Element-wise `max(x, 0)`."""
543
    return _elwise(x, mode=Elemwise.Mode.RELU)
544 545 546


def relu6(x):
M
Megvii Engine Team 已提交
547
    """Element-wise `min(max(x, 0), 6)`."""
548 549 550 551
    return minimum(maximum(x, 0), 6)


def sigmoid(x):
M
Megvii Engine Team 已提交
552
    """Element-wise `1 / ( 1 + exp( -x ) )`."""
553
    return _elwise(x, mode=Elemwise.Mode.SIGMOID)
554 555


556
def clip(x: Tensor, lower=None, upper=None) -> Tensor:
557 558
    r"""
    Clamps all elements in input tensor into the range `[` :attr:`lower`, :attr:`upper` `]` and returns
559 560 561 562 563 564 565 566 567
    a resulting tensor:

    .. math::
        y_i = \begin{cases}
            \text{lower} & \text{if } x_i < \text{lower} \\
            x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\
            \text{upper} & \text{if } x_i > \text{upper}
        \end{cases}

568 569 570 571
    :param x: input tensor.
    :param lower: lower-bound of the range to be clamped to.
    :param upper: upper-bound of the range to be clamped to.
    :return: output clamped tensor.
572

M
Megvii Engine Team 已提交
573
    Examples:
574 575 576 577 578 579 580

    .. testcode::

        import numpy as np
        from megengine import tensor
        import megengine.functional as F

581
        a = tensor(np.arange(5).astype(np.int32))
582 583 584
        print(F.clip(a, 2, 4).numpy())
        print(F.clip(a, lower=3).numpy())
        print(F.clip(a, upper=3).numpy())
585

M
Megvii Engine Team 已提交
586 587
    Outputs:

588 589 590 591 592 593 594 595 596 597 598 599
    .. testoutput::

        [2 2 2 3 4]
        [3 3 3 3 4]
        [0 1 2 3 3]

    """
    assert (
        lower is not None or upper is not None
    ), "At least one of 'lower' or 'upper' must not be None"
    if lower is not None:
        if upper is not None:
600 601
            if not is_tracing():
                assert lower <= upper, "clip lower bound is bigger that upper bound"
602
            return minimum(maximum(x, lower), upper)
603
        else:
604
            return maximum(x, lower)
605
    else:
606
        return minimum(x, upper)