diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 47f13d214539ca3d7dc68ba61c324309397b6a92..06a78e73aed48c9570d97248c33f9b70a0e0797b 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -13,7 +13,7 @@ from ..core._imperative_rt import CompNode from ..core.ops import builtin from ..core.ops._internal import param_defs as P from ..core.ops.special import Const -from ..core.tensor import utils +from ..core.tensor import megbrain_graph, utils from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..core.tensor.utils import astensor1d from ..distributed import WORLD, is_distributed @@ -27,6 +27,8 @@ from .tensor import add_axis, broadcast, concat, full, ones, remove_axis, reshap from .types import _pair, _pair_nonzero __all__ = [ + "adaptive_avg_pool2d", + "adaptive_max_pool2d", "avg_pool2d", "batched_nms", "batch_norm2d", @@ -324,6 +326,48 @@ def avg_pool2d( return output +def adaptive_max_pool2d( + inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor], +) -> Tensor: + """Applies a 2D max adaptive pooling over an input. + + Refer to :class:`~.MaxAdaptivePool2d` for more information. + + :param inp: The input tensor. + :param oshp: (OH, OW) size of the output shape. + :return: output tensor. + """ + assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type" + if isinstance(oshp, int): + oshp = (oshp, oshp) + + op = builtin.AdaptivePooling(mode="MAX", format="NCHW",) + oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device) + (output,) = apply(op, inp, oshp) + return output + + +def adaptive_avg_pool2d( + inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor], +) -> Tensor: + """Applies a 2D average adaptive pooling over an input. + + Refer to :class:`~.AvgAdaptivePool2d` for more information. + + :param inp: The input tensor. + :param oshp: (OH, OW) size of the output shape. + :return: output tensor. + """ + assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type" + if isinstance(oshp, int): + oshp = (oshp, oshp) + + op = builtin.AdaptivePooling(mode="AVERAGE", format="NCHW",) + oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device) + (output,) = apply(op, inp, oshp) + return output + + def prelu(inp: Tensor, weight: Tensor) -> Tensor: r""" Applies the element-wise PReLU function. diff --git a/imperative/python/megengine/module/__init__.py b/imperative/python/megengine/module/__init__.py index 916000d08cbcacba0176cf28ac9b4a48072f757f..6c5b48fd8f463f5ba036c247f843f809f9c4a546 100644 --- a/imperative/python/megengine/module/__init__.py +++ b/imperative/python/megengine/module/__init__.py @@ -8,6 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax +from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm from .concat import Concat from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d diff --git a/imperative/python/megengine/module/adaptive_pooling.py b/imperative/python/megengine/module/adaptive_pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..99e7c57d272fdfb231dca5fc3a5f45100b57d83a --- /dev/null +++ b/imperative/python/megengine/module/adaptive_pooling.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from abc import abstractmethod +from typing import Tuple, Union + +from ..functional import adaptive_avg_pool2d, adaptive_max_pool2d +from ..tensor import Parameter, Tensor +from .module import Module + + +class _AdaptivePoolNd(Module): + def __init__( + self, oshp: Union[Tuple[int, int], int, Tensor], + ): + super(_AdaptivePoolNd, self).__init__() + self.oshp = oshp + + @abstractmethod + def forward(self, inp): + pass + + +class AdaptiveMaxPool2d(_AdaptivePoolNd): + r"""Applies a 2D max adaptive pooling over an input. + + For instance, given an input of the size :math:`(N, C, H, W)` and + an output shape :math:`(OH, OW)`, this layer generates the output of + the size :math:`(N, C, OH, OW)` through a process described as: + + .. math:: + \begin{aligned} + out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} + \text{input}(N_i, C_j, \text{stride[0]} \times h + m, + \text{stride[1]} \times w + n) + \end{aligned} + + Kernel_size and stride can be inferred from input shape and out shape: + padding: (0, 0) + stride: (floor(IH / OH), floor(IW / OW)) + kernel_size: (IH - (OH - 1) * stride_h, IW - (OW - 1) * stride_w) + + Examples: + + .. testcode:: + + import numpy as np + import megengine as mge + import megengine.module as M + + m = M.AdaptiveMaxPool2d((2, 2)) + inp = mge.tensor(np.arange(0, 16).astype("float32").reshape(1, 1, 4, 4)) + oup = m(inp) + print(oup.numpy()) + + Outputs: + + .. testoutput:: + + [[[[5. 7.] + [13. 15.]]]] + + """ + + def forward(self, inp): + return adaptive_max_pool2d(inp, self.oshp) + + +class AdaptiveAvgPool2d(_AdaptivePoolNd): + r"""Applies a 2D average pooling over an input. + + For instance, given an input of the size :math:`(N, C, H, W)` and + an output shape :math:`(OH, OW)`, this layer generates the output of + the size :math:`(N, C, OH, OW)` through a process described as: + + .. math:: + + out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} + input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n) + + Kernel_size and stride can be inferred from input shape and out shape: + padding: (0, 0) + stride: (floor(IH / OH), floor(IW / OW)) + kernel_size: (IH - (OH - 1) * stride_h, IW - (OW - 1) * stride_w) + + Examples: + + .. testcode:: + + import numpy as np + import megengine as mge + import megengine.module as M + + m = M.AdaptiveAvgPool2d((2, 2)) + inp = mge.tensor(np.arange(0, 16).astype("float32").reshape(1, 1, 4, 4)) + oup = m(inp) + print(oup.numpy()) + + Outputs: + + .. testoutput:: + + [[[[2.5 4.5] + [10.5 12.5]]]] + + """ + + def forward(self, inp): + return adaptive_avg_pool2d(inp, self.oshp) diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 22a5d7142963e4455f18082b2ae9b9a1aae7cbd2..f3187ec4d0bf9c3bc36d7a24fa48727400ee550d 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -206,6 +206,66 @@ def test_roi_pooling(): assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) +def test_adaptive_avg_pool2d(): + inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) + oshp = (2, 2) + grad = Grad().wrt(inp, callback=_save_to(inp)) + outp = F.adaptive_avg_pool2d(inp, oshp,) + assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) + np.testing.assert_equal( + outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32) + ) + + grad(outp, tensor(F.ones_like(outp))) + assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) + np.testing.assert_equal( + inp.grad.numpy(), + np.array( + [ + [ + [ + [0.25, 0.25, 0.25, 0.25], + [0.25, 0.25, 0.25, 0.25], + [0.25, 0.25, 0.25, 0.25], + [0.25, 0.25, 0.25, 0.25], + ] + ] + ], + dtype=np.float32, + ), + ) + + +def test_adaptive_max_pool2d(): + inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4)) + oshp = (2, 2) + grad = Grad().wrt(inp, callback=_save_to(inp)) + outp = F.adaptive_max_pool2d(inp, oshp,) + assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,) + np.testing.assert_equal( + outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32) + ) + + grad(outp, tensor(F.ones_like(outp))) + assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape) + np.testing.assert_equal( + inp.grad.numpy(), + np.array( + [ + [ + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 1.0], + ] + ] + ], + dtype=np.float32, + ), + ) + + def test_one_hot(): def onehot_low_dimension(): inp = tensor(np.arange(1, 4, dtype=np.int32))