From 5d289ef029a40cdbb103e5ac6cffeec0d8d99bec Mon Sep 17 00:00:00 2001 From: zhaojichen <zhaojichen1@huawei.com> Date: Tue, 21 Apr 2020 10:10:49 -0400 Subject: [PATCH] add AvgPooling layer --- mindspore/nn/layer/pooling.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 359e75a4c..a19ef06b7 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -19,6 +19,7 @@ from mindspore._checkparam import Validator as validator from ... import context from ..cell import Cell from ..._checkparam import Rel +from ..._checkparam import ParamValidator class _PoolNd(Cell): @@ -264,15 +265,15 @@ class AvgPool1d(_PoolNd): stride=1, pad_mode="valid"): super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode) - validator.check_type('kernel_size', kernel_size, [int,]) - validator.check_type('stride', stride, [int,]) - self.padding = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME']) + ParamValidator.check_type('kernel_size', kernel_size, [int,]) + ParamValidator.check_type('stride', stride, [int,]) + self.pad_mode = ParamValidator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME']) if not isinstance(kernel_size, int): - validator.check_integer("kernel_size", kernel_size, 1, Rel.GE) + ParamValidator.check_integer("kernel_size", kernel_size, 1, Rel.GE) raise ValueError("kernel_size should be 1 int number but got {}". format(kernel_size)) if not isinstance(stride, int): - validator.check_integer("stride", stride, 1, Rel.GE) + ParamValidator.check_integer("stride", stride, 1, Rel.GE) raise ValueError("stride should be 1 int number but got {}".format(stride)) self.kernel_size = (1, kernel_size) self.stride = (1, stride) -- GitLab