提交 67057d13 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!541 add average pooling 1D

Merge pull request !541 from JichenZhao/avgpooling
......@@ -24,7 +24,7 @@ from .conv import Conv2d, Conv2dTranspose
from .lstm import LSTM
from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, Pad, Unfold
from .embedding import Embedding
from .pooling import AvgPool2d, MaxPool2d
from .pooling import AvgPool2d, MaxPool2d, AvgPool1d
from .image import ImageGradients, SSIM, PSNR
__all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
......@@ -35,6 +35,6 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'LSTM',
'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot',
'Embedding',
'AvgPool2d', 'MaxPool2d', 'Pad', 'Unfold',
'AvgPool2d', 'MaxPool2d', 'AvgPool1d', 'Pad', 'Unfold',
'ImageGradients', 'SSIM', 'PSNR',
]
......@@ -14,9 +14,12 @@
# ============================================================================
"""pooling"""
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore._checkparam import Validator as validator
from ... import context
from ..cell import Cell
from ..._checkparam import Rel
from ..._checkparam import ParamValidator
class _PoolNd(Cell):
......@@ -208,3 +211,81 @@ class AvgPool2d(_PoolNd):
def construct(self, x):
return self.avg_pool(x)
class AvgPool1d(_PoolNd):
r"""
Average pooling for temporal data.
Applies a 1D average pooling over an input Tensor which can be regarded as a composition of 1D input planes.
Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, AvgPool1d outputs
regional average in the :math:`(W_{in})`-dimension. Given kernel size
:math:`ks = w_{ker}` and stride :math:`s = s_0`, the operation is as follows.
.. math::
\text{output}(N_i, C_j, h_k, w) = \frac{1}{w_{ker}} \sum_{n=0}^{w_{ker}-1}
\text{input}(N_i, C_j, h_k, s_0 \times w + n)
Note:
pad_mode for training only supports "same" and "valid".
Args:
kernel_size (int): The size of kernel window used to take the average value, Default: 1.
stride (int): The distance of kernel moving, an int number that represents
the width of movement is strides, Default: 1.
pad_mode (str): The optional values for pad mode, is "same" or "valid", not case sensitive.
Default: "valid".
- same: Adopts the way of completion. Output height and width will be the same as
the input. Total number of padding will be calculated for horizontal and vertical
direction and evenly distributed to top and bottom, left and right if possible.
Otherwise, the last extra padding will be done from the bottom and the right side.
- valid: Adopts the way of discarding. The possibly largest height and width of output
will be return without padding. Extra pixels will be discarded.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> pool = nn.AvgPool1d(kernel_size=3, strides=1)
>>> x = Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32)
>>> output = pool(x)
>>> output.shape()
(1, 2, 4, 2)
"""
def __init__(self,
kernel_size=1,
stride=1,
pad_mode="valid"):
super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode)
ParamValidator.check_type('kernel_size', kernel_size, [int,])
ParamValidator.check_type('stride', stride, [int,])
self.pad_mode = ParamValidator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'])
ParamValidator.check_integer("kernel_size", kernel_size, 1, Rel.GE)
ParamValidator.check_integer("stride", stride, 1, Rel.GE)
self.kernel_size = (1, kernel_size)
self.stride = (1, stride)
self.avg_pool = P.AvgPool(ksize=self.kernel_size,
strides=self.stride,
padding=self.pad_mode)
self.shape = F.shape
self.reduce_mean = P.ReduceMean(keep_dims=True)
self.slice = P.Slice()
def construct(self, x):
batch, channel, high, width = self.shape(x)
if width == self.kernel_size[1]:
x = self.reduce_mean(x, 3)
elif width - self.kernel_size[1] < self.stride[1]:
x = self.slice(x, (0, 0, 0, 0), (batch, channel, high, self.kernel_size[1]))
x = self.reduce_mean(x, 3)
else:
x = self.avg_pool(x)
return x
......@@ -56,3 +56,19 @@ def test_compile_max():
net = MaxNet(3, stride=1, padding=0)
x = Tensor(np.random.randint(0, 255, [1, 3, 6, 6]).astype(np.float32))
_executor.compile(net, x)
class Avg1dNet(nn.Cell):
def __init__(self,
kernel_size,
stride=None):
super(Avg1dNet, self).__init__()
self.avg1d = nn.AvgPool1d(kernel_size, stride)
def construct(self, x):
return self.avg1d(x)
def test_avg1d():
net = Avg1dNet(3, 1)
input = Tensor(np.random.randint(0, 255, [1, 3, 6, 6]).astype(np.float32))
_executor.compile(net, input)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册