未验证 提交 13bbb6b6 编写于 作者: D Double_V 提交者: GitHub

add kernel, stride check (#35106)

* add kernel, stride check

* add unitest for param out of range

* delete max limit check
上级 8c73c1b5
...@@ -340,6 +340,36 @@ class TestPool2DError_API(unittest.TestCase): ...@@ -340,6 +340,36 @@ class TestPool2DError_API(unittest.TestCase):
self.assertRaises(ValueError, run7) self.assertRaises(ValueError, run7)
def run_kernel_out_of_range():
with fluid.dygraph.guard():
input_np = np.random.uniform(-1, 1,
[2, 3, 32]).astype(np.float32)
input_pd = fluid.dygraph.to_variable(input_np)
padding = 0
res_pd = F.avg_pool1d(
input_pd,
kernel_size=-1,
stride=2,
padding=padding,
ceil_mode=True)
self.assertRaises(ValueError, run_kernel_out_of_range)
def run_stride_out_of_range():
with fluid.dygraph.guard():
input_np = np.random.uniform(-1, 1,
[2, 3, 32]).astype(np.float32)
input_pd = fluid.dygraph.to_variable(input_np)
padding = 0
res_pd = F.avg_pool1d(
input_pd,
kernel_size=2,
stride=0,
padding=padding,
ceil_mode=True)
self.assertRaises(ValueError, run_stride_out_of_range)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -494,6 +494,36 @@ class TestPool2DError_API(unittest.TestCase): ...@@ -494,6 +494,36 @@ class TestPool2DError_API(unittest.TestCase):
self.assertRaises(ValueError, run9) self.assertRaises(ValueError, run9)
def run_kernel_out_of_range():
with fluid.dygraph.guard():
input_np = np.random.uniform(-1, 1,
[2, 3, 32, 32]).astype(np.float32)
input_pd = fluid.dygraph.to_variable(input_np)
res_pd = avg_pool2d(
input_pd,
kernel_size=[-1, 2],
stride=2,
padding=0,
ceil_mode=False,
data_format='NHWC')
self.assertRaises(ValueError, run_kernel_out_of_range)
def run_stride_out_of_range():
with fluid.dygraph.guard():
input_np = np.random.uniform(-1, 1,
[2, 3, 32, 32]).astype(np.float32)
input_pd = fluid.dygraph.to_variable(input_np)
res_pd = avg_pool2d(
input_pd,
kernel_size=3,
stride=[0, 2],
padding=0,
ceil_mode=False,
data_format='NHWC')
self.assertRaises(ValueError, run_stride_out_of_range)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -471,6 +471,34 @@ class TestPool3DError_API(unittest.TestCase): ...@@ -471,6 +471,34 @@ class TestPool3DError_API(unittest.TestCase):
self.assertRaises(ValueError, run10) self.assertRaises(ValueError, run10)
def run_kernel_out_of_range():
with fluid.dygraph.guard():
input_np = np.random.uniform(
-1, 1, [2, 3, 32, 32, 32]).astype(np.float32)
input_pd = fluid.dygraph.to_variable(input_np)
res_pd = avg_pool3d(
input_pd,
kernel_size=-1,
stride=2,
padding="VALID",
ceil_mode=True)
self.assertRaises(ValueError, run_kernel_out_of_range)
def run_size_out_of_range():
with fluid.dygraph.guard():
input_np = np.random.uniform(
-1, 1, [2, 3, 32, 32, 32]).astype(np.float32)
input_pd = fluid.dygraph.to_variable(input_np)
res_pd = avg_pool3d(
input_pd,
kernel_size=2,
stride=0,
padding="VALID",
ceil_mode=True)
self.assertRaises(ValueError, run_size_out_of_range)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -42,6 +42,17 @@ def _check_instance(x, x_name, types=(int, float)): ...@@ -42,6 +42,17 @@ def _check_instance(x, x_name, types=(int, float)):
format(types, x_name, type(x))) format(types, x_name, type(x)))
def _check_value_limitation(x, x_name, min_limit=1e-3):
def _check_value(x, x_name, min_limit=1e-3):
if isinstance(x, int) and min_limit is not None and x < min_limit:
raise ValueError(
"Excepted the input {} to be greater than {} but received x: {}. ".
format(x_name, min_limit, x))
for ele in x:
_check_value(ele, x_name)
def _zero_padding_in_batch_and_channel(padding, channel_last): def _zero_padding_in_batch_and_channel(padding, channel_last):
if channel_last: if channel_last:
return list(padding[0]) == [0, 0] and list(padding[-1]) == [0, 0] return list(padding[0]) == [0, 0] and list(padding[-1]) == [0, 0]
...@@ -211,6 +222,9 @@ def avg_pool1d(x, ...@@ -211,6 +222,9 @@ def avg_pool1d(x,
stride = utils.convert_to_list(stride, 1, 'pool_stride') stride = utils.convert_to_list(stride, 1, 'pool_stride')
stride = [1] + stride stride = [1] + stride
_check_value_limitation(kernel_size, "kernel_size", min_limit=1e-3)
_check_value_limitation(stride, "stride", min_limit=1e-3)
channel_last = _channel_last("NCL", 1) channel_last = _channel_last("NCL", 1)
padding, padding_algorithm = _update_padding_nd( padding, padding_algorithm = _update_padding_nd(
padding, 1, channel_last=channel_last, ceil_mode=ceil_mode) padding, 1, channel_last=channel_last, ceil_mode=ceil_mode)
...@@ -325,6 +339,9 @@ def avg_pool2d(x, ...@@ -325,6 +339,9 @@ def avg_pool2d(x,
else: else:
stride = utils.convert_to_list(stride, 2, 'pool_stride') stride = utils.convert_to_list(stride, 2, 'pool_stride')
_check_value_limitation(kernel_size, "kernel_size", min_limit=1e-3)
_check_value_limitation(stride, "stride", min_limit=1e-3)
channel_last = _channel_last(data_format, 2) channel_last = _channel_last(data_format, 2)
padding, padding_algorithm = _update_padding_nd( padding, padding_algorithm = _update_padding_nd(
padding, 2, channel_last, ceil_mode=ceil_mode) padding, 2, channel_last, ceil_mode=ceil_mode)
...@@ -448,6 +465,9 @@ def avg_pool3d(x, ...@@ -448,6 +465,9 @@ def avg_pool3d(x,
padding, padding_algorithm = _update_padding_nd( padding, padding_algorithm = _update_padding_nd(
padding, 3, channel_last=channel_last, ceil_mode=ceil_mode) padding, 3, channel_last=channel_last, ceil_mode=ceil_mode)
_check_value_limitation(kernel_size, "kernel_size", min_limit=1e-3)
_check_value_limitation(stride, "stride", min_limit=1e-3)
if in_dygraph_mode(): if in_dygraph_mode():
output = _C_ops.pool3d( output = _C_ops.pool3d(
x, 'pooling_type', 'avg', 'ksize', kernel_size, 'strides', stride, x, 'pooling_type', 'avg', 'ksize', kernel_size, 'strides', stride,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册