nn.py 42.4 KB
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8 9
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=too-many-lines
10
from typing import Optional, Sequence, Tuple, Union
11

12
from ..core._imperative_rt.core2 import apply
13
from ..core._imperative_rt.graph import VarNode
14
from ..core._trace_option import use_symbolic_shape
15
from ..core.ops import builtin
16
from ..core.ops.builtin import BatchNorm, Elemwise
17
from ..core.ops.special import Const
18 19 20 21
from ..core.tensor import megbrain_graph, utils
from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import astensor1d, astype, setscalar
from ..device import get_default_device
22 23 24
from ..distributed import WORLD, is_distributed
from ..random import uniform
from ..tensor import Tensor
25
from ..utils.deprecation import deprecated_func
26 27
from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero
from .debug_param import get_execution_strategy
28
from .distributed import all_reduce_sum
29
from .elemwise import _elwise, exp, floor, log, log1p, maximum, minimum
30
from .math import argsort, matmul, max, prod, sum
31 32 33 34 35 36 37 38 39 40
from .tensor import (
    broadcast_to,
    concat,
    expand_dims,
    full,
    ones,
    reshape,
    squeeze,
    zeros,
)
41 42

__all__ = [
43 44
    "adaptive_avg_pool2d",
    "adaptive_max_pool2d",
45
    "avg_pool2d",
46
    "batch_norm",
47
    "conv1d",
48
    "conv2d",
49
    "conv3d",
50
    "conv_transpose2d",
51
    "conv_transpose3d",
52 53
    "deformable_conv2d",
    "deformable_psroi_pooling",
54
    "dropout",
55
    "embedding",
56 57
    "hsigmoid",
    "hswish",
58
    "indexing_one_hot",
59
    "leaky_relu",
60
    "linear",
61
    "local_conv2d",
62 63
    "logsigmoid",
    "logsumexp",
64
    "logsoftmax",
65 66 67
    "max_pool2d",
    "one_hot",
    "prelu",
68 69
    "relu",
    "relu6",
70
    "remap",
71 72 73 74 75
    "resize",
    "sigmoid",
    "softmax",
    "softplus",
    "sync_batch_norm",
76 77
    "warp_affine",
    "warp_perspective",
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
]


def expand_hw(x):
    # NOTE: >1d array is accepted, as long as 1 <= size <= 2
    try:
        x = int(x)
        return [x, x]
    except (TypeError, ValueError):
        pass
    h, w = x
    return int(h), int(w)


def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
93 94
    """
    Applies a linear transformation to the input tensor.
95 96 97

    Refer to :class:`~.module.linear.Linear` for more information.

98 99 100 101
    :param inp: input tensor with shape `(N, in_features)`.
    :param weight: weight with shape `(out_features, in_features)`.
    :param bias: bias with shape `(out_features,)`.
        Default: None
102 103 104 105 106 107 108
    """
    ret = matmul(inp, weight, transpose_b=True)
    if bias is not None:
        ret += bias
    return ret


109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
def conv1d(
    inp: Tensor,
    weight: Tensor,
    bias: Optional[Tensor] = None,
    stride: int = 1,
    padding: int = 0,
    dilation: int = 1,
    groups: int = 1,
    conv_mode="cross_correlation",
    compute_mode="default",
) -> Tensor:
    """1D convolution operation.

    Refer to :class:`~.Conv1d` for more information.

    :param inp: The feature map of the convolution operation
125
    :param weight: The convolution kernel.
126 127 128 129 130 131 132 133 134
    :param bias: The bias added to the result of convolution (if given)
    :param stride: Stride of the 1D convolution operation. Default: 1
    :param padding: Size of the paddings added to the input on both sides of its
        spatial dimensions. Only zero-padding is supported. Default: 0
    :param dilation: Dilation of the 1D convolution operation. Default: 1
    :param groups: number of groups to divide input and output channels into,
        so as to perform a "grouped convolution". When ``groups`` is not 1,
        ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
        and the shape of weight should be ``(groups, out_channel // groups,
135
        in_channels // groups, kernel_size)``. Default: 1
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
    :type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode`
    :param conv_mode: Supports 'cross_correlation'. Default:
        'cross_correlation'.
    :type compute_mode: string or
        :class:`mgb.opr_param_defs.Convolution.ComputeMode`
    :param compute_mode: When set to 'default', no special requirements will be
        placed on the precision of intermediate results. When set to 'float32',
        float32 would be used for accumulator and intermediate result, but only
        effective when input and output are of float16 dtype.

    """
    assert (
        conv_mode.lower() == "cross_correlation"
        or conv_mode.name == "CROSS_CORRELATION"
    )
    assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
    assert inp.ndim == 3, "the input dimension of conv1d should be 3"
    assert weight.ndim == 3, "the weight dimension of conv1d should be 3"

    inp = expand_dims(inp, 3)
    weight = expand_dims(weight, 3)
    if bias is not None:
        assert bias.ndim == 3, "the bias dimension of conv1d should be 3"
        bias = expand_dims(bias, 3)

    stride_h = stride
    pad_h = padding
    dilate_h = dilation

    sparse_type = "dense" if groups == 1 else "group"
    op = builtin.Convolution(
        stride_h=stride_h,
        stride_w=1,
        pad_h=pad_h,
        pad_w=0,
        dilate_h=dilate_h,
        dilate_w=1,
        strategy=get_execution_strategy(),
        mode=conv_mode,
        compute_mode=compute_mode,
        sparse=sparse_type,
    )
    inp, weight = utils.convert_inputs(inp, weight)
    (output,) = apply(op, inp, weight)
    if bias is not None:
        output += bias
    output = squeeze(output, 3)
    return output


186 187 188 189 190 191 192 193
def conv2d(
    inp: Tensor,
    weight: Tensor,
    bias: Optional[Tensor] = None,
    stride: Union[int, Tuple[int, int]] = 1,
    padding: Union[int, Tuple[int, int]] = 0,
    dilation: Union[int, Tuple[int, int]] = 1,
    groups: int = 1,
194 195
    conv_mode="cross_correlation",
    compute_mode="default",
196
) -> Tensor:
197 198
    """
    2D convolution operation.
199

200
    Refer to :class:`~.module.Conv2d` for more information.
201

202 203 204 205 206
    :param inp: feature map of the convolution operation.
    :param weight: convolution kernel.
    :param bias: bias added to the result of convolution (if given).
    :param stride: stride of the 2D convolution operation. Default: 1
    :param padding: size of the paddings added to the input on both sides of its
207
        spatial dimensions. Only zero-padding is supported. Default: 0
208
    :param dilation: dilation of the 2D convolution operation. Default: 1
209 210
    :param groups: number of groups into which the input and output channels are divided, 
        so as to perform a ``grouped convolution``. When ``groups`` is not 1,
M
Megvii Engine Team 已提交
211
        ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
212 213
        and the shape of weight should be ``(groups, out_channel // groups,
        in_channels // groups, height, width)``. Default: 1
214
    :type conv_mode: string or :class:`Convolution.Mode`
215 216
    :param conv_mode: supports "cross_correlation". Default:
        "cross_correlation"
217
    :type compute_mode: string or
218
        :class:`Convolution.ComputeMode`
219 220 221 222
    :param compute_mode: when set to "default", no special requirements will be
        placed on the precision of intermediate results. When set to "float32",
        "float32" would be used for accumulator and intermediate result, but only
        effective when input and output are of float16 dtype.
223
    :return: output tensor.
224
    """
225 226 227 228 229
    assert (
        conv_mode.lower() == "cross_correlation"
        or conv_mode.name == "CROSS_CORRELATION"
    )
    assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
230 231 232 233 234

    stride_h, stride_w = expand_hw(stride)
    pad_h, pad_w = expand_hw(padding)
    dilate_h, dilate_w = expand_hw(dilation)

235
    sparse_type = "dense" if groups == 1 else "group"
236 237 238 239 240 241 242
    op = builtin.Convolution(
        stride_h=stride_h,
        stride_w=stride_w,
        pad_h=pad_h,
        pad_w=pad_w,
        dilate_h=dilate_h,
        dilate_w=dilate_w,
243
        strategy=get_execution_strategy(),
244 245 246 247
        mode=conv_mode,
        compute_mode=compute_mode,
        sparse=sparse_type,
    )
248
    inp, weight = utils.convert_inputs(inp, weight)
249 250 251 252 253 254
    (output,) = apply(op, inp, weight)
    if bias is not None:
        output += bias
    return output


255 256 257 258 259 260 261 262
def conv3d(
    inp: Tensor,
    weight: Tensor,
    bias: Optional[Tensor] = None,
    stride: Union[int, Tuple[int, int, int]] = 1,
    padding: Union[int, Tuple[int, int, int]] = 0,
    dilation: Union[int, Tuple[int, int, int]] = 1,
    groups: int = 1,
263
    conv_mode: str = "cross_correlation",
264 265 266 267 268 269 270 271 272 273 274 275 276
) -> Tensor:
    """
    3D convolution operation.

    Refer to :class:`~.Conv3d` for more information.

    :param inp: feature map of the convolution operation.
    :param weight: convolution kernel.
    :param bias: bias added to the result of convolution (if given).
    :param stride: stride of the 3D convolution operation. Default: 1
    :param padding: size of the paddings added to the input on both sides of its
        spatial dimensions. Only zero-padding is supported. Default: 0
    :param dilation: dilation of the 3D convolution operation. Default: 1
277 278
    :param groups: number of groups into which the input and output channels are divided,
        so as to perform a ``grouped convolution``. When ``groups`` is not 1,
279
        ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
280 281
        and the shape of weight should be ``(groups, out_channel // groups,
        in_channels // groups, depth, height, width)``. Default: 1
282 283
    :param conv_mode: supports "cross_correlation". Default:
        "cross_correlation"
284 285
    :return: output tensor.
    """
286
    assert conv_mode.lower() == "cross_correlation"
287 288 289 290 291 292 293

    D, H, W = 0, 1, 2

    pad = _triple(padding)
    stride = _triple_nonzero(stride)
    dilate = _triple_nonzero(dilation)

294
    sparse_type = "dense" if groups == 1 else "group"
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
    op = builtin.Convolution3D(
        pad_d=pad[D],
        pad_h=pad[H],
        pad_w=pad[W],
        stride_d=stride[D],
        stride_h=stride[H],
        stride_w=stride[W],
        dilate_d=dilate[D],
        dilate_h=dilate[H],
        dilate_w=dilate[W],
        strategy=get_execution_strategy(),
        mode=conv_mode,
        sparse=sparse_type,
    )
    inp, weight = utils.convert_inputs(inp, weight)
    (output,) = apply(op, inp, weight)
    if bias is not None:
        output += bias
    return output


316 317 318 319 320 321 322 323
def conv_transpose2d(
    inp: Tensor,
    weight: Tensor,
    bias: Optional[Tensor] = None,
    stride: Union[int, Tuple[int, int]] = 1,
    padding: Union[int, Tuple[int, int]] = 0,
    dilation: Union[int, Tuple[int, int]] = 1,
    groups: int = 1,
324 325
    conv_mode="cross_correlation",
    compute_mode="default",
326
) -> Tensor:
327 328
    """
    2D transposed convolution operation.
329 330 331

    Refer to :class:`~.ConvTranspose2d` for more information.

332 333
    :param inp: feature map of the convolution operation.
    :param weight: convolution kernel.
M
Megvii Engine Team 已提交
334
    :param bias: bias added to the result of convolution (if given).
335 336
    :param stride: stride of the 2D convolution operation. Default: 1
    :param padding: size of the paddings added to the input on both sides of its
337
        spatial dimensions. Only zero-padding is supported. Default: 0
338
    :param dilation: dilation of the 2D convolution operation. Default: 1
339 340
    :param groups: number of groups into which the input and output channels are divided, 
        so as to perform a ``grouped convolution``. When ``groups`` is not 1,
M
Megvii Engine Team 已提交
341
        ``in_channels`` and ``out_channels`` must be divisible by groups,
342 343
        and the shape of weight should be ``(groups, in_channels // groups,
        out_channels // groups, height, width)``. Default: 1
344
    :type conv_mode: string or :class:`Convolution.Mode`
345 346
    :param conv_mode: supports "cross_correlation". Default:
        "cross_correlation"
347
    :type compute_mode: string or
348
        :class:`Convolution.ComputeMode`
349 350 351 352
    :param compute_mode: when set to "default", no special requirements will be
        placed on the precision of intermediate results. When set to "float32",
        "float32" would be used for accumulator and intermediate result, but only
        effective when input and output are of float16 dtype.
353
    :return: output tensor.
354
    """
355 356 357 358 359
    assert (
        conv_mode.lower() == "cross_correlation"
        or conv_mode.name == "CROSS_CORRELATION"
    )
    assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
360 361

    if groups != 1:
362
        raise NotImplementedError("group transposed conv2d is not supported yet.")
363 364 365 366 367 368 369 370 371 372 373 374

    stride_h, stride_w = expand_hw(stride)
    pad_h, pad_w = expand_hw(padding)
    dilate_h, dilate_w = expand_hw(dilation)

    op = builtin.ConvolutionBackwardData(
        stride_h=stride_h,
        stride_w=stride_w,
        pad_h=pad_h,
        pad_w=pad_w,
        dilate_h=dilate_h,
        dilate_w=dilate_w,
375
        strategy=get_execution_strategy(),
376
    )
377
    weight, inp = utils.convert_inputs(weight, inp)
378
    (output,) = apply(op, weight, inp)
379 380 381 382 383
    if bias is not None:
        output += bias
    return output


384 385 386 387 388 389 390 391 392 393
def deformable_conv2d(
    inp: Tensor,
    weight: Tensor,
    offset: Tensor,
    mask: Tensor,
    bias: Optional[Tensor] = None,
    stride: Union[int, Tuple[int, int]] = 1,
    padding: Union[int, Tuple[int, int]] = 0,
    dilation: Union[int, Tuple[int, int]] = 1,
    groups: int = 1,
394 395
    conv_mode="cross_correlation",
    compute_mode="default",
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
) -> Tensor:
    """
    Deformable Convolution.

    :param inp: input feature map.
    :param weight: convolution kernel.
    :param offset: input offset to kernel, channel of this tensor should match the deformable settings.
    :param mask: input mask to kernel, channel of this tensor should match the deformable settings.
    :param bias: bias added to the result of convolution (if given).
    :param stride: stride of the 2D convolution operation. Default: 1
    :param padding: size of the paddings added to the input on both sides of its
        spatial dimensions. Only zero-padding is supported. Default: 0
    :param dilation: dilation of the 2D convolution operation. Default: 1
    :param groups: number of groups into which the input and output channels are divided, 
        so as to perform a ``grouped convolution``. When ``groups`` is not 1,
        ``in_channels`` and ``out_channels`` must be divisible by groups,
412 413
        and the shape of weight should be ``(groups, out_channel // groups,
        in_channels // groups, height, width)``. Default: 1
414
    :type conv_mode: string or :class:`Convolution.Mode`
415 416
    :param conv_mode: supports "cross_correlation". Default:
        "cross_correlation"
417 418
    :type compute_mode: string or
        :class:`Convolution.ComputeMode`
419 420 421 422
    :param compute_mode: when set to "default", no special requirements will be
        placed on the precision of intermediate results. When set to "float32",
        "float32" would be used for accumulator and intermediate result, but only
        effective when input and output are of float16 dtype.
423 424
    :return: output tensor.
    """
425 426 427 428 429
    assert (
        conv_mode.lower() == "cross_correlation"
        or conv_mode.name == "CROSS_CORRELATION"
    )
    assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
430 431 432 433 434

    stride_h, stride_w = expand_hw(stride)
    pad_h, pad_w = expand_hw(padding)
    dilate_h, dilate_w = expand_hw(dilation)

435
    sparse_type = "dense" if groups == 1 else "group"
436 437 438 439 440 441 442
    op = builtin.DeformableConv(
        stride_h=stride_h,
        stride_w=stride_w,
        pad_h=pad_h,
        pad_w=pad_w,
        dilate_h=dilate_h,
        dilate_w=dilate_w,
443
        strategy=get_execution_strategy(),
444 445 446 447 448 449 450 451 452 453 454
        mode=conv_mode,
        compute_mode=compute_mode,
        sparse=sparse_type,
    )
    inp, weight, offset, mask = utils.convert_inputs(inp, weight, offset, mask)
    (output,) = apply(op, inp, weight, offset, mask)
    if bias is not None:
        output += bias
    return output


455 456 457 458 459 460 461
def local_conv2d(
    inp: Tensor,
    weight: Tensor,
    bias: Optional[Tensor] = None,
    stride: Union[int, Tuple[int, int]] = 1,
    padding: Union[int, Tuple[int, int]] = 0,
    dilation: Union[int, Tuple[int, int]] = 1,
462
    conv_mode="cross_correlation",
463
):
464
    """Applies spatial 2D convolution over an groupped channeled image with untied kernels."""
465 466 467 468
    assert (
        conv_mode.lower() == "cross_correlation"
        or conv_mode.name == "CROSS_CORRELATION"
    )
469 470 471 472 473 474 475 476 477 478 479 480

    stride_h, stride_w = expand_hw(stride)
    pad_h, pad_w = expand_hw(padding)
    dilate_h, dilate_w = expand_hw(dilation)

    op = builtin.GroupLocal(
        stride_h=stride_h,
        stride_w=stride_w,
        pad_h=pad_h,
        pad_w=pad_w,
        dilate_h=dilate_h,
        dilate_w=dilate_w,
481
        mode=conv_mode,
482 483
        compute_mode="default",
        sparse="dense",
484
    )
485
    inp, weight = utils.convert_inputs(inp, weight)
486 487 488
    (output,) = apply(op, inp, weight)
    if bias is not None:
        output += bias
489 490 491 492 493 494 495 496 497 498 499 500
    return output


def conv_transpose3d(
    inp: Tensor,
    weight: Tensor,
    bias: Optional[Tensor] = None,
    stride: Union[int, Tuple[int, int, int]] = 1,
    padding: Union[int, Tuple[int, int, int]] = 0,
    dilation: Union[int, Tuple[int, int, int]] = 1,
) -> Tensor:
    """
501
    3D transposed convolution operation. Only support the case that groups = 1 
502 503 504 505 506 507
    and conv_mode = "cross_correlation".

    Refer to :class:`~.ConvTranspose3d` for more information.

    :param inp: feature map of the convolution operation.
    :param weight: convolution kernel.
508
        weight usually has shape ``(in_channels, out_channels, depth, height, width)``.
509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537
    :param bias: bias added to the result of convolution (if given).
    :param stride: stride of the 3D convolution operation. Default: 1
    :param padding: size of the paddings added to the input on all sides of its
        spatial dimensions. Only zero-padding is supported. Default: 0
    :param dilation: dilation of the 3D convolution operation. Default: 1
    :return: output tensor.
    """
    D, H, W = 0, 1, 2

    pad = _triple(padding)
    stride = _triple_nonzero(stride)
    dilate = _triple_nonzero(dilation)

    op = builtin.Convolution3DBackwardData(
        pad_d=pad[D],
        pad_h=pad[H],
        pad_w=pad[W],
        stride_d=stride[D],
        stride_h=stride[H],
        stride_w=stride[W],
        dilate_d=dilate[D],
        dilate_h=dilate[H],
        dilate_w=dilate[W],
        strategy=get_execution_strategy(),
    )
    weight, inp = utils.convert_inputs(weight, inp)
    (output,) = apply(op, weight, inp)
    if bias is not None:
        output += bias
538 539 540 541 542 543 544 545 546
    return output


def max_pool2d(
    inp: Tensor,
    kernel_size: Union[int, Tuple[int, int]],
    stride: Optional[Union[int, Tuple[int, int]]] = None,
    padding: Union[int, Tuple[int, int]] = 0,
) -> Tensor:
547 548
    """
    Applies a 2D max pooling over an input tensor.
549 550 551

    Refer to :class:`~.MaxPool2d` for more information.

552 553 554
    :param inp: input tensor.
    :param kernel_size: size of the window.
    :param stride: stride of the window. If not provided, its value is set to kernel_size.
555
        Default: None
M
Megvii Engine Team 已提交
556
    :param padding: implicit zero padding added on both sides. Default: 0
557
    :return: output tensor.
558 559 560 561 562 563 564 565 566 567 568 569 570 571
    """
    if stride is None:
        stride = kernel_size
    window_h, window_w = _pair_nonzero(kernel_size)
    stride_h, stride_w = _pair_nonzero(stride)
    padding_h, padding_w = _pair(padding)

    op = builtin.Pooling(
        window_h=window_h,
        window_w=window_w,
        stride_h=stride_h,
        stride_w=stride_w,
        pad_h=padding_h,
        pad_w=padding_w,
572
        mode="max",
573 574 575 576 577 578 579 580 581 582
    )
    (output,) = apply(op, inp)
    return output


def avg_pool2d(
    inp: Tensor,
    kernel_size: Union[int, Tuple[int, int]],
    stride: Optional[Union[int, Tuple[int, int]]] = None,
    padding: Union[int, Tuple[int, int]] = 0,
583
    mode: str = "average_count_exclude_padding",
584
) -> Tensor:
585 586
    """
    Applies 2D average pooling over an input tensor.
587 588 589

    Refer to :class:`~.AvgPool2d` for more information.

590 591
    :param inp: input tensor.
    :param kernel_size: size of the window.
M
Megvii Engine Team 已提交
592
    :param stride: stride of the window. If not provided, its value is set to ``kernel_size``.
593
        Default: None
M
Megvii Engine Team 已提交
594
    :param padding: implicit zero padding added on both sides. Default: 0
595 596
    :param mode: whether to count padding values, set to "average" will do counting.
        Default: "average_count_exclude_padding"
597
    :return: output tensor.
598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617
    """
    if stride is None:
        stride = kernel_size
    window_h, window_w = _pair_nonzero(kernel_size)
    stride_h, stride_w = _pair_nonzero(stride)
    padding_h, padding_w = _pair(padding)

    op = builtin.Pooling(
        window_h=window_h,
        window_w=window_w,
        stride_h=stride_h,
        stride_w=stride_w,
        pad_h=padding_h,
        pad_w=padding_w,
        mode=mode,
    )
    (output,) = apply(op, inp)
    return output


618 619 620
def adaptive_max_pool2d(
    inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
) -> Tensor:
621 622
    """
    Applies a 2D max adaptive pooling over an input.
623 624 625

    Refer to :class:`~.MaxAdaptivePool2d` for more information.

626 627
    :param inp: input tensor.
    :param oshp: `(OH, OW)` size of the output shape.
628 629 630 631 632
    :return: output tensor.
    """
    if isinstance(oshp, int):
        oshp = (oshp, oshp)

633
    op = builtin.AdaptivePooling(mode="max", format="NCHW",)
634 635 636 637 638 639 640 641
    oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device)
    (output,) = apply(op, inp, oshp)
    return output


def adaptive_avg_pool2d(
    inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
) -> Tensor:
642 643
    """
    Applies a 2D average adaptive pooling over an input.
644 645 646

    Refer to :class:`~.AvgAdaptivePool2d` for more information.

647 648
    :param inp: input tensor.
    :param oshp: `(OH, OW)` size of the output shape.
649 650 651 652 653
    :return: output tensor.
    """
    if isinstance(oshp, int):
        oshp = (oshp, oshp)

654
    op = builtin.AdaptivePooling(mode="average", format="NCHW",)
655 656 657 658 659
    oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device)
    (output,) = apply(op, inp, oshp)
    return output


660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698
def deformable_psroi_pooling(
    inp: Tensor,
    rois: Tensor,
    trans: Tensor,
    no_trans: bool,
    part_size: int,
    pooled_h: int,
    pooled_w: int,
    sample_per_part: int,
    spatial_scale: float,
    trans_std: float = 0.1,
):
    """
    Deformable PSROI(Position Sensitive Region of Interest) Pooling.

    :param inp: input feature map.
    :param rois: the rois for feature pooling.
    :param trans: input offset to psroi_pooling.
    :param no_trans: check the phase of DeformablePSROIPooling. False to the
                        1st phase, True to the 2nd phase.
    :param part_size: part size.
    :param sample_per_part: sample points of each part.
    :param pooled_shape: kernel shape of convolution.
    :param spatial_scale: the spatial_scale w.r.t input image.
    :param trans_std: multiplier used in 2nd phase.
    """
    op = builtin.DeformablePSROIPooling(
        no_trans=no_trans,
        part_size=part_size,
        pooled_h=pooled_h,
        pooled_w=pooled_w,
        sample_per_part=sample_per_part,
        spatial_scale=spatial_scale,
        trans_std=trans_std,
    )
    output, _ = apply(op, inp, rois, trans)
    return output


699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745
def hswish(x):
    """
    Element-wise `x * relu6(x + 3) / 6`.

    :param x: input tensor.
    :return: computed tensor.

    Example:

    .. testcode::

        import numpy as np
        from megengine import tensor
        import megengine.functional as F

        x = tensor(np.arange(5).astype(np.float32))
        out = F.hswish(x)
        print(out.numpy().round(decimals=4))

    .. testoutput::

        [0.     0.6667 1.6667 3.     4.    ]

    """
    return _elwise(x, mode=Elemwise.Mode.H_SWISH)


def sigmoid(x):
    """Element-wise `1 / ( 1 + exp( -x ) )`."""
    return _elwise(x, mode=Elemwise.Mode.SIGMOID)


def hsigmoid(x):
    """Element-wise `relu6(x + 3) / 6`."""
    return relu6(x + 3) / 6


def relu(x):
    """Element-wise `max(x, 0)`."""
    return _elwise(x, mode=Elemwise.Mode.RELU)


def relu6(x):
    """Element-wise `min(max(x, 0), 6)`."""
    return minimum(maximum(x, 0), 6)


746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764
def prelu(inp: Tensor, weight: Tensor) -> Tensor:
    r"""
    Applies the element-wise PReLU function.

    Refer to :class:`~.PReLU` for more information.
    """
    return maximum(inp, 0) + weight * minimum(inp, 0)


def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
    r"""
    Applies the element-wise leaky_relu function

    Refer to :class:`~.LeakyReLU` for more information.
    """
    return maximum(inp, 0) + negative_slope * minimum(inp, 0)


def softplus(inp: Tensor) -> Tensor:
765 766
    r"""
    Applies the element-wise function:
767 768 769

    .. math::
        \text{softplus}(x) = \log(1 + \exp(x))
M
Megvii Engine Team 已提交
770

771
    softplus is a smooth approximation to the ReLU function and can be used
M
Megvii Engine Team 已提交
772
    to constrain the output to be always positive.
773 774 775
    For numerical stability the implementation follows this transformation:

    .. math::
M
Megvii Engine Team 已提交
776 777
        \text{softplus}(x) = \log(1 + \exp(x))
                           = \log(1 + \exp(-\text{abs}(x))) + \max(x, 0)
778 779
                           = \log1p(\exp(-\text{abs}(x))) + \text{relu}(x)

M
Megvii Engine Team 已提交
780
    :param inp: input tensor.
781 782 783 784 785 786 787 788 789 790 791

    Examples:

    .. testcode::

        import numpy as np
        from megengine import tensor
        import megengine.functional as F

        x = tensor(np.arange(-3, 3, dtype=np.float32))
        y = F.softplus(x)
792
        print(y.numpy().round(decimals=4))
M
Megvii Engine Team 已提交
793

M
Megvii Engine Team 已提交
794
    Outputs:
M
Megvii Engine Team 已提交
795

M
Megvii Engine Team 已提交
796
    .. testoutput::
797

M
Megvii Engine Team 已提交
798
        [0.0486 0.1269 0.3133 0.6931 1.3133 2.1269]
799 800 801 802 803

    """
    return log1p(exp(-abs(inp))) + relu(inp)


804
def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
805
    r"""
806 807
    Applies the :math:`\log(\text{softmax}(x))` function to an n-dimensional
    input tensor. The :math:`\text{logsoftmax}(x)` formulation can be simplified as:
808 809

    .. math::
810
        \text{logsoftmax}(x_{i}) = \log(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} )
811 812 813 814

    For numerical stability the implementation follows this transformation:

    .. math::
815
        \text{logsoftmax}(x)
816 817
        = \log (\frac{\exp (x)}{\sum_{i}(\exp (x_{i}))})
        = x - \log (\sum_{i}(\exp (x_{i})))
818
        = x - \text{logsumexp}(x)
M
Megvii Engine Team 已提交
819

M
Megvii Engine Team 已提交
820
    :param inp: input tensor.
821
    :param axis: axis along which :math:`\text{logsoftmax}(x)` will be applied.
822 823 824 825 826 827 828 829 830 831

    Examples:

    .. testcode::

        import numpy as np
        from megengine import tensor
        import megengine.functional as F

        x = tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
832
        y = F.logsoftmax(x, axis=1)
833
        print(y.numpy().round(decimals=4))
834

M
Megvii Engine Team 已提交
835
    Outputs:
M
Megvii Engine Team 已提交
836

M
Megvii Engine Team 已提交
837
    .. testoutput::
838

M
Megvii Engine Team 已提交
839 840
        [[-4.4519 -3.4519 -2.4519 -1.4519 -0.4519]
         [-4.4519 -3.4519 -2.4519 -1.4519 -0.4519]]
841 842 843 844 845 846

    """
    return inp - logsumexp(inp, axis, keepdims=True)


def logsigmoid(inp: Tensor) -> Tensor:
847 848
    r"""
    Applies the element-wise function:
849 850 851

    .. math::
        \text{logsigmoid}(x) = \log(\frac{ 1 }{ 1 + \exp(-x)})
852 853
        = \log(1/(1 + \exp(-x)))
        = - \log(1 + \exp(-x))
854 855
        = - \text{softplus}(-x)

M
Megvii Engine Team 已提交
856
    :param inp: input tensor.
857 858

    Examples:
M
Megvii Engine Team 已提交
859

860 861 862 863 864 865 866 867
    .. testcode::

        import numpy as np
        from megengine import tensor
        import megengine.functional as F

        x = tensor(np.arange(-5, 5, dtype=np.float32))
        y = F.logsigmoid(x)
868
        print(y.numpy().round(decimals=4))
869

M
Megvii Engine Team 已提交
870 871 872
    Outputs:

    .. testoutput::
873

874
        [-5.0067 -4.0182 -3.0486 -2.1269 -1.3133 -0.6931 -0.3133 -0.1269 -0.0486
M
Megvii Engine Team 已提交
875
         -0.0181]
876 877 878 879 880 881 882 883 884

    """
    return -softplus(-inp)


def logsumexp(
    inp: Tensor, axis: Union[int, Sequence[int]], keepdims: bool = False
) -> Tensor:
    r"""
M
Megvii Engine Team 已提交
885
    Calculates the logarithm of the inputs' exponential sum along the given :attr:`axis`.
M
Megvii Engine Team 已提交
886

887
    .. math::
M
Megvii Engine Team 已提交
888

889
        \text{logsumexp}(x)= \log \sum_{j=1}^{n} \exp \left(x_{j}\right)
890 891 892 893 894

    For numerical stability, the implementation follows this transformation:

    .. math::

895 896
        \text{logsumexp}(x)= \log \sum_{j=1}^{n} \exp \left(x_{j}\right)
        = \text{logsumexp}(x)=b+\log \sum_{j=1}^{n} \exp \left(x_{j}-b\right)
M
Megvii Engine Team 已提交
897

898 899 900 901 902
    where

    .. math::
        b = \max(x_j)

M
Megvii Engine Team 已提交
903 904
    :param inp: input tensor.
    :param axis: axis over which the sum is taken. It could be single axis or list of axes.
905 906 907
    :param keepdims: whether to retain :attr:`axis` or not for the output tensor.

    Examples:
M
Megvii Engine Team 已提交
908

909 910 911 912 913 914 915 916
    .. testcode::

        import numpy as np
        from megengine import tensor
        import megengine.functional as F

        x = tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
        y = F.logsumexp(x, axis=1, keepdims=False)
917
        print(y.numpy().round(decimals=4))
918

M
Megvii Engine Team 已提交
919 920 921
    Outputs:

    .. testoutput::
922

M
Megvii Engine Team 已提交
923
        [-0.5481  4.4519]
924 925

    """
926
    max_value = max(inp.detach(), axis, keepdims=True)
927 928 929
    if keepdims:
        return max_value + log(sum(exp(inp - max_value), axis, keepdims))
    else:
930
        return squeeze(max_value, axis=None) + log(
931 932 933 934 935 936 937 938 939 940 941 942
            sum(exp(inp - max_value), axis, keepdims)
        )


def _get_softmax_axis(ndim: int) -> int:
    if ndim in (0, 1, 3):
        return 0
    return 1


def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
    r"""
943
    Applies a :math:`\text{softmax}(x)` function. :math:`\text{softmax}(x)` is defined as:
944 945

    .. math::
946
            \text{softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
947

M
Megvii Engine Team 已提交
948 949
    It is applied to all elements along axis, and rescales elements so that
    they stay in the range `[0, 1]` and sum to 1.
950 951 952

    See :class:`~megengine.module.activation.Softmax` for more details.

M
Megvii Engine Team 已提交
953
    :param inp: input tensor.
954 955
    :param axis: an axis along which :math:`\text{softmax}(x)` will be applied. By default,
        :math:`\text{softmax}(x)` will apply along the highest ranked axis.
956 957 958 959 960 961 962 963 964 965 966

    Examples:

    .. testcode::

        import numpy as np
        from megengine import tensor
        import megengine.functional as F

        x = tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
        out = F.softmax(x)
967
        print(out.numpy().round(decimals=4))
968 969 970 971

    Outputs:

    .. testoutput::
M
Megvii Engine Team 已提交
972 973 974

        [[0.0117 0.0317 0.0861 0.2341 0.6364]
         [0.0117 0.0317 0.0861 0.2341 0.6364]]
975 976 977 978

    """
    if axis is None:
        axis = _get_softmax_axis(len(inp.shape))
979
    offset = inp.max(axis=axis, keepdims=True).detach()
980 981 982 983 984
    cached = exp(inp - offset)
    down = sum(cached, axis=axis, keepdims=True)
    return cached / down


985
def batch_norm(
986
    inp: Tensor,
987 988 989 990 991 992 993 994 995 996
    running_mean: Tensor = None,
    running_var: Tensor = None,
    weight: Optional[Tensor] = None,
    bias: Optional[Tensor] = None,
    *,
    training: bool = False,
    momentum: float = 0.9,
    eps: float = 1e-5,
    inplace: bool = True
):
997 998
    r"""
    Applies batch normalization to the input.
999 1000 1001 1002 1003 1004 1005

    Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.

    :param inp: input tensor.
    :param running_mean: tensor to store running mean.
    :param running_var: tensor to store running variance.
    :param weight: scaling tensor in the learnable affine parameters.
1006
        See :math:`\gamma` in :class:`~.BatchNorm2d`.
1007
    :param bias: bias tensor in the learnable affine parameters.
1008
        See :math:`\beta` in :class:`~.BatchNorm2d`.
1009
    :param training: a boolean value to indicate whether batch norm is performed
M
Megvii Engine Team 已提交
1010
        in training mode. Default: False
1011
    :param momentum: value used for the ``running_mean`` and ``running_var``
1012 1013 1014
        computation.
        Default: 0.9
    :param eps: a value added to the denominator for numerical stability.
1015
        Default: 1e-5
M
Megvii Engine Team 已提交
1016
    :param inplace: whether to update ``running_mean`` and ``running_var`` inplace or return new tensors
1017
        Default: True
1018
    :return: output tensor.
1019
    """
1020 1021
    if inp.ndim != 4:
        raise NotImplementedError("batch_norm for ndim != 4")
1022

1023
    C = inp.shape[1]
1024 1025 1026

    def make_full_if_none(x, value):
        if x is None:
1027
            (x,) = Const(value, dtype=inp.dtype, device=inp.device)()
1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038
            shape = utils.astensor1d(
                (1, C, 1, 1), inp, dtype="int32", device=inp.device
            )
            (result,) = apply(builtin.Broadcast(), x, shape)
            return result
        elif x.ndim == 1:
            shape = utils.astensor1d(
                (1, C, 1, 1), inp, dtype="int32", device=inp.device
            )
            (result,) = apply(builtin.Reshape(), x, shape)
            return result
1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052
        return x

    has_mean = running_mean is not None
    has_var = running_var is not None

    if not training:
        assert has_mean, "running_mean must be provided in inference mode"
        assert has_var, "running_var must be provided in inference mode"

    if has_mean and running_mean.ndim != 4:
        raise ValueError
    if has_var and running_var.ndim != 4:
        raise ValueError

1053 1054
    inp, weight, bias, running_mean, running_var = utils.convert_inputs(
        inp, weight, bias, running_mean, running_var
1055 1056
    )

1057 1058
    weight = make_full_if_none(weight, 1)
    bias = make_full_if_none(bias, 0)
1059 1060

    if not training:
1061
        op = builtin.BatchNorm(
1062
            fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="dim_1c11"
1063
        )
1064
        ret = apply(op, inp, weight, bias, running_mean, running_var)[-1]
1065 1066 1067 1068
        return ret

    else:
        op = builtin.BatchNorm(
1069
            avg_factor=1 - momentum, epsilon=eps, param_dim="dim_1c11"
1070 1071 1072 1073
        )
        if has_mean or has_var:
            running_mean = make_full_if_none(running_mean, 0)
            running_var = make_full_if_none(running_var, 1)
1074 1075
            new_mean, new_var, _, _, inp = apply(
                op, inp, weight, bias, running_mean, running_var
1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087
            )
            if not has_mean:
                new_mean = None
            if not has_var:
                new_var = None

            if inplace:
                if has_mean:
                    running_mean[...] = new_mean
                if has_var:
                    running_var[...] = new_var

1088
                return inp
1089
            else:
1090
                return inp, new_mean, new_var
1091
        else:
1092
            (_, _, inp,) = apply(op, inp, weight, bias)
1093
            return inp
1094 1095 1096


def sync_batch_norm(
1097
    inp: Tensor,
1098 1099 1100 1101 1102 1103 1104
    running_mean: Tensor,
    running_var: Tensor,
    weight: Optional[Tensor] = None,
    bias: Optional[Tensor] = None,
    training: bool = False,
    momentum: Union[float, Tensor] = 0.9,
    eps: float = 1e-5,
1105
    eps_mode="additive",
1106 1107
    group=WORLD,
) -> Tensor:
1108 1109
    r"""
    Applies synchronized batch normalization to the input.
1110 1111 1112 1113 1114 1115 1116

    Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.

    :param inp: input tensor.
    :param running_mean: tensor to store running mean.
    :param running_var: tensor to store running variance.
    :param weight: scaling tensor in the learnable affine parameters.
1117
        See :math:`\gamma` in :class:`~.BatchNorm2d`.
1118
    :param bias: bias tensor in the learnable affine parameters.
1119
        See :math:`\beta` in :class:`~.BatchNorm2d`.
1120
    :param training: a boolean value to indicate whether batch norm is performed
1121 1122
        in traning mode. Default: False
    :param momentum: value used for the ``running_mean`` and ``running_var``
1123 1124 1125
        computation.
        Default: 0.9
    :param eps: a value added to the denominator for numerical stability.
1126 1127
        Default: 1e-5
    :return: output tensor.
1128
    """
1129 1130 1131
    assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format(
        eps_mode
    )
1132 1133 1134 1135
    _channels = inp.shape[1]
    _ndim = inp.ndim
    _device = inp.device
    _dtype = inp.dtype
1136
    _param_shape = (1, _channels) + (1,) * (_ndim - 2)
1137
    _reduce_axis = [0] + [i for i in range(2, _ndim)]
1138 1139 1140

    if training:

1141 1142
        def _sum_on_channel(inp):
            return inp.sum(axis=_reduce_axis, keepdims=True)
1143

1144
        reduce_size = inp.shape[0]
1145
        for i in range(2, _ndim):
1146 1147 1148
            reduce_size = reduce_size * inp.shape[i]
        channel_x1s = _sum_on_channel(inp)
        channel_x2s = _sum_on_channel(inp ** 2)
1149 1150 1151

        if is_distributed():
            # reduce all nodes' data to calculate mean and variance
1152 1153
            reduce_size = broadcast_to(
                Tensor(reduce_size).astype(dtype=_dtype), [1] * _ndim
1154
            )
1155
            stat = concat([reduce_size, channel_x1s, channel_x2s], axis=1)
1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170
            stat = all_reduce_sum(stat, group)
            reduce_size = stat[:, :1].reshape(1)
            channel_x1s = stat[:, 1 : 1 + _channels]
            channel_x2s = stat[:, 1 + _channels :]

        channel_mean = channel_x1s / reduce_size
        channel_variance = (
            channel_x1s ** 2 / (-reduce_size * reduce_size) + channel_x2s / reduce_size
        )
    else:
        assert running_var is not None and running_mean is not None
        channel_variance = running_var.reshape(*_param_shape)
        channel_mean = running_mean.reshape(*_param_shape)

    invsqrt_channel_variance = (
1171
        maximum(channel_variance, eps) if eps_mode == "max" else channel_variance + eps
1172 1173 1174 1175 1176 1177 1178 1179
    ) ** -0.5

    if weight is not None:
        weight = weight.reshape(*_param_shape)
    if bias is not None:
        bias = bias.reshape(*_param_shape)

    # outvar = output * weight + bias
1180
    # where output = inp * invsqrt_channel_variance + (
1181 1182 1183 1184 1185 1186 1187 1188
    #    -channel_mean * invsqrt_channel_variance
    # )
    # Manually expand output for gopt

    if weight is not None:
        inv_var_wt = invsqrt_channel_variance * weight
        neg_channel_mean = -channel_mean
        if bias is not None:
1189
            outvar = inp * inv_var_wt + (neg_channel_mean * inv_var_wt + bias)
1190
        else:
1191
            outvar = inp * inv_var_wt + neg_channel_mean * inv_var_wt
1192
    else:
1193
        outvar = inp * invsqrt_channel_variance + (
1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210
            -channel_mean * invsqrt_channel_variance
        )
        if bias is not None:
            outvar = outvar + bias

    if training and running_var is not None and running_mean is not None:
        running_mean *= momentum
        running_mean += (1 - momentum) * channel_mean
        channel_variance_unbiased = channel_x1s ** 2 / (
            -reduce_size * (reduce_size - 1)
        ) + channel_x2s / (reduce_size - 1)
        running_var *= momentum
        running_var += (1 - momentum) * channel_variance_unbiased

    return outvar


1211 1212 1213 1214
def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:
    """
    Returns a new tensor where each of the elements are randomly set to zero
    with probability P = ``drop_prob``. Optionally rescale the output tensor if ``training`` is True.
1215

1216
    :param inp: input tensor.
1217 1218 1219 1220
    :param drop_prob: probability to drop (set to zero) a single element.
    :param training: the default behavior of ``dropout`` during training is to rescale the output,
        then it can be replaced by an :class:`~.Identity` during inference. Default: True
    :return: the output tensor
1221 1222 1223 1224 1225 1226 1227 1228 1229

    Examples:

    .. testcode::

        import numpy as np
        from megengine import tensor
        import megengine.functional as F

1230 1231
        x = tensor(np.ones(10, dtype=np.float32))
        out = F.dropout(x, 1./3.)
1232 1233 1234 1235 1236
        print(out.numpy())

    Outputs:

    .. testoutput::
1237
        :options: +SKIP
1238

1239
        [1.5 1.5 0.  1.5 1.5 1.5 1.5 1.5 1.5 1.5]
1240

1241
    """
1242 1243 1244 1245 1246 1247 1248
    assert 0 <= drop_prob < 1
    rv = uniform(size=inp.shape)
    mask = rv > drop_prob
    inp *= mask.astype(inp.dtype)
    if training:
        inp *= 1 / (1 - drop_prob)
    return inp
1249 1250


1251 1252 1253
def one_hot(inp: Tensor, num_classes: int) -> Tensor:
    r"""
    Performs one-hot encoding for the input tensor.
1254

1255 1256
    :param inp: input tensor.
    :param num_classes: number of classes denotes the last dimension of the output tensor.
1257
    :return: output tensor.
1258 1259 1260 1261 1262 1263 1264 1265

    Examples:

    .. testcode::

        import numpy as np
        from megengine import tensor
        import megengine.functional as F
1266

1267 1268
        x = tensor(np.arange(1, 4, dtype=np.int32))
        out = F.one_hot(x, num_classes=4)
1269 1270 1271 1272 1273 1274
        print(out.numpy())

    Outputs:

    .. testoutput::

1275 1276 1277
        [[0 1 0 0]
         [0 0 1 0]
         [0 0 0 1]]
1278

1279
    """
1280 1281
    zeros_tensor = zeros(list(inp.shape) + [num_classes], inp.dtype, inp.device)
    ones_tensor = ones(list(inp.shape) + [1], inp.dtype, inp.device)
1282

1283 1284
    op = builtin.IndexingSetOneHot(axis=inp.ndim)
    (result,) = apply(op, zeros_tensor, inp, ones_tensor)
1285 1286 1287
    return result


1288
def embedding(
1289
    inp: Tensor,
1290 1291 1292 1293 1294
    weight: Tensor,
    padding_idx: Optional[int] = None,
    max_norm: Optional[float] = None,
    norm_type: Optional[float] = None,
):
1295 1296
    """
    Applies lookup table for embedding.
1297

1298
    :param inp: tensor with indices.
M
Megvii Engine Team 已提交
1299 1300 1301 1302
    :param weight: learnable weights which embeds from.
    :param padding_idx: should be set to None, not supported now.
    :param max_norm: should be set to None, not supported now.
    :param norm_type: should be set to None, not supported now.
1303
    :return: output tensor.
1304 1305 1306 1307 1308 1309 1310 1311

    Refer to :class:`~.Embedding` for more information.
    """
    if padding_idx is not None:
        raise ValueError("Not support padding_idx Now!")
    if max_norm is not None or norm_type is not None:
        raise ValueError("Not support weight normlization Now!")

1312 1313
    dest_shp = list(inp.shape) + [weight.shape[-1]]
    return weight[inp.reshape(-1)].reshape(dest_shp)
1314 1315 1316 1317 1318


def indexing_one_hot(
    src: Tensor, index: Tensor, axis: int = 1, keepdims=False
) -> Tensor:
1319 1320
    r"""
    One-hot indexing for some axes.
1321

1322
    :param src: input tensor.
1323
    :param index: index tensor.
1324 1325 1326
    :param axis: axis on src for which values in index index. Default: 1
    :param keepdims: whether not to remove the axis in result. Default: False
    :return: output tensor.
1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339

    Examples:

    .. testcode::

        import megengine.functional as F
        from megengine import tensor

        src = tensor([[1.0, 2.0]])
        index = tensor([0])
        val = F.indexing_one_hot(src, index)
        print(val.numpy())

M
Megvii Engine Team 已提交
1340
    Outputs:
1341

1342 1343 1344 1345 1346
    .. testoutput::

        [1.]

    """
1347
    assert isinstance(src, Tensor), "src must be of Tensor type"
1348
    op = builtin.IndexingOneHot(axis=axis)
1349
    index = utils.convert_single_value(index, dtype="int32", device=src.device)
1350 1351
    (result,) = apply(op, src, index)
    if not keepdims:
1352
        result = squeeze(result, axis)
1353 1354 1355
    return result


1356 1357 1358 1359 1360 1361
interpolate = deprecated_func("1.3", "megengine.functional.vision", "interpolate", True)
roi_pooling = deprecated_func("1.3", "megengine.functional.vision", "roi_pooling", True)
roi_align = deprecated_func("1.3", "megengine.functional.vision", "roi_align", True)
nms = deprecated_func("1.3", "megengine.functional.vision", "nms", True)
resize = deprecated_func("1.3", "megengine.functional.vision", "resize", True)
remap = deprecated_func("1.3", "megengine.functional.vision", "remap", True)
1362
nvof = deprecated_func("1.3", "megengine.functional.vision", "nvof", True)
1363 1364 1365 1366 1367
warp_affine = deprecated_func("1.3", "megengine.functional.vision", "warp_affine", True)
warp_perspective = deprecated_func(
    "1.3", "megengine.functional.vision", "warp_perspective", True
)

1368 1369
from .loss import *  # isort:skip
from .quantized import conv_bias_activation  # isort:skip