diff --git a/mindspore/nn/layer/__init__.py b/mindspore/nn/layer/__init__.py index 098489a91df7def5b1722d8579d730d47d88cab7..b9f79b6cf72b022752c61d69037e47634f4615d4 100644 --- a/mindspore/nn/layer/__init__.py +++ b/mindspore/nn/layer/__init__.py @@ -24,7 +24,7 @@ from .conv import Conv2d, Conv2dTranspose from .lstm import LSTM from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, Pad, Unfold from .embedding import Embedding -from .pooling import AvgPool2d, MaxPool2d +from .pooling import AvgPool2d, MaxPool2d, AvgPool1d from .image import ImageGradients, SSIM, PSNR __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', @@ -35,6 +35,6 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', 'LSTM', 'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Embedding', - 'AvgPool2d', 'MaxPool2d', 'Pad', 'Unfold', + 'AvgPool2d', 'MaxPool2d', 'AvgPool1d', 'Pad', 'Unfold', 'ImageGradients', 'SSIM', 'PSNR', ] diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 53d97807cff0b4722a397eee5c5a852ec0d9f7ce..6cf06de029a8baf98674dfffae7ac355a09ed305 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -14,9 +14,12 @@ # ============================================================================ """pooling""" from mindspore.ops import operations as P +from mindspore.ops import functional as F from mindspore._checkparam import Validator as validator from ... import context from ..cell import Cell +from ..._checkparam import Rel +from ..._checkparam import ParamValidator class _PoolNd(Cell): @@ -208,3 +211,81 @@ class AvgPool2d(_PoolNd): def construct(self, x): return self.avg_pool(x) + + +class AvgPool1d(_PoolNd): + r""" + Average pooling for temporal data. + + Applies a 1D average pooling over an input Tensor which can be regarded as a composition of 1D input planes. + + Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, AvgPool1d outputs + regional average in the :math:`(W_{in})`-dimension. Given kernel size + :math:`ks = w_{ker}` and stride :math:`s = s_0`, the operation is as follows. + + .. math:: + \text{output}(N_i, C_j, h_k, w) = \frac{1}{w_{ker}} \sum_{n=0}^{w_{ker}-1} + \text{input}(N_i, C_j, h_k, s_0 \times w + n) + + Note: + pad_mode for training only supports "same" and "valid". + + Args: + kernel_size (int): The size of kernel window used to take the average value, Default: 1. + stride (int): The distance of kernel moving, an int number that represents + the width of movement is strides, Default: 1. + pad_mode (str): The optional values for pad mode, is "same" or "valid", not case sensitive. + Default: "valid". + + - same: Adopts the way of completion. Output height and width will be the same as + the input. Total number of padding will be calculated for horizontal and vertical + direction and evenly distributed to top and bottom, left and right if possible. + Otherwise, the last extra padding will be done from the bottom and the right side. + + - valid: Adopts the way of discarding. The possibly largest height and width of output + will be return without padding. Extra pixels will be discarded. + + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + Examples: + >>> pool = nn.AvgPool1d(kernel_size=3, strides=1) + >>> x = Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32) + >>> output = pool(x) + >>> output.shape() + (1, 2, 4, 2) + """ + + def __init__(self, + kernel_size=1, + stride=1, + pad_mode="valid"): + super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode) + ParamValidator.check_type('kernel_size', kernel_size, [int,]) + ParamValidator.check_type('stride', stride, [int,]) + self.pad_mode = ParamValidator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME']) + ParamValidator.check_integer("kernel_size", kernel_size, 1, Rel.GE) + ParamValidator.check_integer("stride", stride, 1, Rel.GE) + self.kernel_size = (1, kernel_size) + self.stride = (1, stride) + self.avg_pool = P.AvgPool(ksize=self.kernel_size, + strides=self.stride, + padding=self.pad_mode) + self.shape = F.shape + self.reduce_mean = P.ReduceMean(keep_dims=True) + self.slice = P.Slice() + + def construct(self, x): + batch, channel, high, width = self.shape(x) + if width == self.kernel_size[1]: + x = self.reduce_mean(x, 3) + elif width - self.kernel_size[1] < self.stride[1]: + x = self.slice(x, (0, 0, 0, 0), (batch, channel, high, self.kernel_size[1])) + x = self.reduce_mean(x, 3) + else: + x = self.avg_pool(x) + return x diff --git a/tests/ut/python/nn/test_pooling.py b/tests/ut/python/nn/test_pooling.py index 10bb7632b277bc4c3ec1c8a39c83a4e79c007118..428e050ea2a2ebab4b4ea0280214ae88e8033a4e 100644 --- a/tests/ut/python/nn/test_pooling.py +++ b/tests/ut/python/nn/test_pooling.py @@ -56,3 +56,19 @@ def test_compile_max(): net = MaxNet(3, stride=1, padding=0) x = Tensor(np.random.randint(0, 255, [1, 3, 6, 6]).astype(np.float32)) _executor.compile(net, x) + + +class Avg1dNet(nn.Cell): + def __init__(self, + kernel_size, + stride=None): + super(Avg1dNet, self).__init__() + self.avg1d = nn.AvgPool1d(kernel_size, stride) + + def construct(self, x): + return self.avg1d(x) + +def test_avg1d(): + net = Avg1dNet(3, 1) + input = Tensor(np.random.randint(0, 255, [1, 3, 6, 6]).astype(np.float32)) + _executor.compile(net, input) \ No newline at end of file