提交 2bd84d67 编写于 作者: M Megvii Engine Team

feat(mge): add adaptive pooling python wrapper

GitOrigin-RevId: 789f1511ec76e41bfb7cd8e6430da527af288570
上级 edb32495
...@@ -13,7 +13,7 @@ from ..core._imperative_rt import CompNode ...@@ -13,7 +13,7 @@ from ..core._imperative_rt import CompNode
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops._internal import param_defs as P from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import utils from ..core.tensor import megbrain_graph, utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.utils import astensor1d from ..core.tensor.utils import astensor1d
from ..distributed import WORLD, is_distributed from ..distributed import WORLD, is_distributed
...@@ -27,6 +27,8 @@ from .tensor import add_axis, broadcast, concat, full, ones, remove_axis, reshap ...@@ -27,6 +27,8 @@ from .tensor import add_axis, broadcast, concat, full, ones, remove_axis, reshap
from .types import _pair, _pair_nonzero from .types import _pair, _pair_nonzero
__all__ = [ __all__ = [
"adaptive_avg_pool2d",
"adaptive_max_pool2d",
"avg_pool2d", "avg_pool2d",
"batched_nms", "batched_nms",
"batch_norm2d", "batch_norm2d",
...@@ -324,6 +326,48 @@ def avg_pool2d( ...@@ -324,6 +326,48 @@ def avg_pool2d(
return output return output
def adaptive_max_pool2d(
inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
) -> Tensor:
"""Applies a 2D max adaptive pooling over an input.
Refer to :class:`~.MaxAdaptivePool2d` for more information.
:param inp: The input tensor.
:param oshp: (OH, OW) size of the output shape.
:return: output tensor.
"""
assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type"
if isinstance(oshp, int):
oshp = (oshp, oshp)
op = builtin.AdaptivePooling(mode="MAX", format="NCHW",)
oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device)
(output,) = apply(op, inp, oshp)
return output
def adaptive_avg_pool2d(
inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
) -> Tensor:
"""Applies a 2D average adaptive pooling over an input.
Refer to :class:`~.AvgAdaptivePool2d` for more information.
:param inp: The input tensor.
:param oshp: (OH, OW) size of the output shape.
:return: output tensor.
"""
assert isinstance(inp, (Tensor, megbrain_graph.VarNode)), "inp must be Tensor type"
if isinstance(oshp, int):
oshp = (oshp, oshp)
op = builtin.AdaptivePooling(mode="AVERAGE", format="NCHW",)
oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device)
(output,) = apply(op, inp, oshp)
return output
def prelu(inp: Tensor, weight: Tensor) -> Tensor: def prelu(inp: Tensor, weight: Tensor) -> Tensor:
r""" r"""
Applies the element-wise PReLU function. Applies the element-wise PReLU function.
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax
from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm
from .concat import Concat from .concat import Concat
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import abstractmethod
from typing import Tuple, Union
from ..functional import adaptive_avg_pool2d, adaptive_max_pool2d
from ..tensor import Parameter, Tensor
from .module import Module
class _AdaptivePoolNd(Module):
def __init__(
self, oshp: Union[Tuple[int, int], int, Tensor],
):
super(_AdaptivePoolNd, self).__init__()
self.oshp = oshp
@abstractmethod
def forward(self, inp):
pass
class AdaptiveMaxPool2d(_AdaptivePoolNd):
r"""Applies a 2D max adaptive pooling over an input.
For instance, given an input of the size :math:`(N, C, H, W)` and
an output shape :math:`(OH, OW)`, this layer generates the output of
the size :math:`(N, C, OH, OW)` through a process described as:
.. math::
\begin{aligned}
out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1}
\text{input}(N_i, C_j, \text{stride[0]} \times h + m,
\text{stride[1]} \times w + n)
\end{aligned}
Kernel_size and stride can be inferred from input shape and out shape:
padding: (0, 0)
stride: (floor(IH / OH), floor(IW / OW))
kernel_size: (IH - (OH - 1) * stride_h, IW - (OW - 1) * stride_w)
Examples:
.. testcode::
import numpy as np
import megengine as mge
import megengine.module as M
m = M.AdaptiveMaxPool2d((2, 2))
inp = mge.tensor(np.arange(0, 16).astype("float32").reshape(1, 1, 4, 4))
oup = m(inp)
print(oup.numpy())
Outputs:
.. testoutput::
[[[[5. 7.]
[13. 15.]]]]
"""
def forward(self, inp):
return adaptive_max_pool2d(inp, self.oshp)
class AdaptiveAvgPool2d(_AdaptivePoolNd):
r"""Applies a 2D average pooling over an input.
For instance, given an input of the size :math:`(N, C, H, W)` and
an output shape :math:`(OH, OW)`, this layer generates the output of
the size :math:`(N, C, OH, OW)` through a process described as:
.. math::
out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n)
Kernel_size and stride can be inferred from input shape and out shape:
padding: (0, 0)
stride: (floor(IH / OH), floor(IW / OW))
kernel_size: (IH - (OH - 1) * stride_h, IW - (OW - 1) * stride_w)
Examples:
.. testcode::
import numpy as np
import megengine as mge
import megengine.module as M
m = M.AdaptiveAvgPool2d((2, 2))
inp = mge.tensor(np.arange(0, 16).astype("float32").reshape(1, 1, 4, 4))
oup = m(inp)
print(oup.numpy())
Outputs:
.. testoutput::
[[[[2.5 4.5]
[10.5 12.5]]]]
"""
def forward(self, inp):
return adaptive_avg_pool2d(inp, self.oshp)
...@@ -206,6 +206,66 @@ def test_roi_pooling(): ...@@ -206,6 +206,66 @@ def test_roi_pooling():
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
def test_adaptive_avg_pool2d():
inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
oshp = (2, 2)
grad = Grad().wrt(inp, callback=_save_to(inp))
outp = F.adaptive_avg_pool2d(inp, oshp,)
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
np.testing.assert_equal(
outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32)
)
grad(outp, tensor(F.ones_like(outp)))
assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
np.testing.assert_equal(
inp.grad.numpy(),
np.array(
[
[
[
[0.25, 0.25, 0.25, 0.25],
[0.25, 0.25, 0.25, 0.25],
[0.25, 0.25, 0.25, 0.25],
[0.25, 0.25, 0.25, 0.25],
]
]
],
dtype=np.float32,
),
)
def test_adaptive_max_pool2d():
inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
oshp = (2, 2)
grad = Grad().wrt(inp, callback=_save_to(inp))
outp = F.adaptive_max_pool2d(inp, oshp,)
assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
np.testing.assert_equal(
outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32)
)
grad(outp, tensor(F.ones_like(outp)))
assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
np.testing.assert_equal(
inp.grad.numpy(),
np.array(
[
[
[
[0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 1.0],
]
]
],
dtype=np.float32,
),
)
def test_one_hot(): def test_one_hot():
def onehot_low_dimension(): def onehot_low_dimension():
inp = tensor(np.arange(1, 4, dtype=np.int32)) inp = tensor(np.arange(1, 4, dtype=np.int32))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册