提交 2bd84d67 编写于 作者: M Megvii Engine Team

feat(mge): add adaptive pooling python wrapper

GitOrigin-RevId: 789f1511ec76e41bfb7cd8e6430da527af288570
上级 edb32495
......@@ -13,7 +13,7 @@ from ..core._imperative_rt import CompNode
from ..core.ops import builtin
from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const
from ..core.tensor import utils
from ..core.tensor import megbrain_graph, utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.utils import astensor1d
from ..distributed import WORLD, is_distributed
......@@ -27,6 +27,8 @@ from .tensor import add_axis, broadcast, concat, full, ones, remove_axis, reshap
from .types import _pair, _pair_nonzero
__all__ = [
"adaptive_avg_pool2d",
"adaptive_max_pool2d",
"avg_pool2d",
"batched_nms",
"batch_norm2d",
......@@ -324,6 +326,48 @@ def avg_pool2d(
return output
def adaptive_max_pool2d(
inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
) -> Tensor:
"""Applies a 2D max adaptive pooling over an input.
Refer to :class:`~.MaxAdaptivePool2d` for more information.
:param inp: The input tensor.
:param oshp: (OH, OW) size of the output shape.
:return: output tensor.
"""
assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type"
if isinstance(oshp, int):
oshp = (oshp, oshp)
op = builtin.AdaptivePooling(mode="MAX", format="NCHW",)
oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device)
(output,) = apply(op, inp, oshp)
return output
def adaptive_avg_pool2d(
inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
) -> Tensor:
"""Applies a 2D average adaptive pooling over an input.
Refer to :class:`~.AvgAdaptivePool2d` for more information.
:param inp: The input tensor.
:param oshp: (OH, OW) size of the output shape.
:return: output tensor.
"""
assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type"
if isinstance(oshp, int):
oshp = (oshp, oshp)
op = builtin.AdaptivePooling(mode="AVERAGE", format="NCHW",)
oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device)
(output,) = apply(op, inp, oshp)
return output
def prelu(inp: Tensor, weight: Tensor) -> Tensor:
r"""
Applies the element-wise PReLU function.
......
......@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax
from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm
from .concat import Concat
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import abstractmethod
from typing import Tuple, Union
from ..functional import adaptive_avg_pool2d, adaptive_max_pool2d
from ..tensor import Parameter, Tensor
from .module import Module
class _AdaptivePoolNd(Module):
def __init__(
self, oshp: Union[Tuple[int, int], int, Tensor],
):
super(_AdaptivePoolNd, self).__init__()
self.oshp = oshp
@abstractmethod
def forward(self, inp):
pass
class AdaptiveMaxPool2d(_AdaptivePoolNd):
r"""Applies a 2D max adaptive pooling over an input.
For instance, given an input of the size :math:`(N, C, H, W)` and
an output shape :math:`(OH, OW)`, this layer generates the output of
the size :math:`(N, C, OH, OW)` through a process described as:
.. math::
\begin{aligned}
out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1}
\text{input}(N_i, C_j, \text{stride[0]} \times h + m,
\text{stride[1]} \times w + n)
\end{aligned}
Kernel_size and stride can be inferred from input shape and out shape:
padding: (0, 0)
stride: (floor(IH / OH), floor(IW / OW))
kernel_size: (IH - (OH - 1) * stride_h, IW - (OW - 1) * stride_w)
Examples:
.. testcode::
import numpy as np
import megengine as mge
import megengine.module as M
m = M.AdaptiveMaxPool2d((2, 2))
inp = mge.tensor(np.arange(0, 16).astype("float32").reshape(1, 1, 4, 4))
oup = m(inp)
print(oup.numpy())
Outputs:
.. testoutput::
[[[[5. 7.]
[13. 15.]]]]
"""
def forward(self, inp):
return adaptive_max_pool2d(inp, self.oshp)
class AdaptiveAvgPool2d(_AdaptivePoolNd):
r"""Applies a 2D average pooling over an input.
For instance, given an input of the size :math:`(N, C, H, W)` and
an output shape :math:`(OH, OW)`, this layer generates the output of
the size :math:`(N, C, OH, OW)` through a process described as:
.. math::
out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n)
Kernel_size and stride can be inferred from input shape and out shape:
padding: (0, 0)
stride: (floor(IH / OH), floor(IW / OW))
kernel_size: (IH - (OH - 1) * stride_h, IW - (OW - 1) * stride_w)
Examples:
.. testcode::
import numpy as np
import megengine as mge
import megengine.module as M
m = M.AdaptiveAvgPool2d((2, 2))
inp = mge.tensor(np.arange(0, 16).astype("float32").reshape(1, 1, 4, 4))
oup = m(inp)
print(oup.numpy())
Outputs:
.. testoutput::
[[[[2.5 4.5]
[10.5 12.5]]]]
"""
def forward(self, inp):
return adaptive_avg_pool2d(inp, self.oshp)
......@@ -206,6 +206,66 @@ def test_roi_pooling():
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
def test_adaptive_avg_pool2d():
inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
oshp = (2, 2)
grad = Grad().wrt(inp, callback=_save_to(inp))
outp = F.adaptive_avg_pool2d(inp, oshp,)
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
np.testing.assert_equal(
outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32)
)
grad(outp, tensor(F.ones_like(outp)))
assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
np.testing.assert_equal(
inp.grad.numpy(),
np.array(
[
[
[
[0.25, 0.25, 0.25, 0.25],
[0.25, 0.25, 0.25, 0.25],
[0.25, 0.25, 0.25, 0.25],
[0.25, 0.25, 0.25, 0.25],
]
]
],
dtype=np.float32,
),
)
def test_adaptive_max_pool2d():
inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
oshp = (2, 2)
grad = Grad().wrt(inp, callback=_save_to(inp))
outp = F.adaptive_max_pool2d(inp, oshp,)
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
np.testing.assert_equal(
outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32)
)
grad(outp, tensor(F.ones_like(outp)))
assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
np.testing.assert_equal(
inp.grad.numpy(),
np.array(
[
[
[
[0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 1.0],
]
]
],
dtype=np.float32,
),
)
def test_one_hot():
def onehot_low_dimension():
inp = tensor(np.arange(1, 4, dtype=np.int32))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册