未验证 提交 b9d739a7 编写于 作者: D Double_V 提交者: GitHub

fix pool bug, test=develop (#27537)

* fix pool bug, test=develop

* fix coverage,test=develop

* fix bug, test=develop
上级 86fa0432
...@@ -195,6 +195,23 @@ class TestPool1d_API(unittest.TestCase): ...@@ -195,6 +195,23 @@ class TestPool1d_API(unittest.TestCase):
result = max_pool1d_dg(input) result = max_pool1d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np)) self.assertTrue(np.allclose(result.numpy(), result_np))
def check_max_dygraph_return_index_results(self, place):
with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32]).astype("float32")
input = fluid.dygraph.to_variable(input_np)
result, index = F.max_pool1d(
input, kernel_size=2, stride=2, padding=0, return_indices=True)
result_np = max_pool1D_forward_naive(
input_np, ksize=[2], strides=[2], paddings=[0])
self.assertTrue(np.allclose(result.numpy(), result_np))
max_pool1d_dg = paddle.nn.layer.MaxPool1d(
kernel_size=2, stride=None, padding=0)
result = max_pool1d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np))
def check_max_dygraph_padding_same(self, place): def check_max_dygraph_padding_same(self, place):
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32]).astype("float32") input_np = np.random.random([2, 3, 32]).astype("float32")
...@@ -228,6 +245,7 @@ class TestPool1d_API(unittest.TestCase): ...@@ -228,6 +245,7 @@ class TestPool1d_API(unittest.TestCase):
self.check_avg_static_results(place) self.check_avg_static_results(place)
self.check_max_dygraph_padding_same(place) self.check_max_dygraph_padding_same(place)
self.check_avg_dygraph_padding_same(place) self.check_avg_dygraph_padding_same(place)
self.check_max_dygraph_return_index_results(place)
class TestPool2dError_API(unittest.TestCase): class TestPool2dError_API(unittest.TestCase):
......
...@@ -571,15 +571,26 @@ def max_pool1d(x, ...@@ -571,15 +571,26 @@ def max_pool1d(x,
padding = _expand_low_nd_padding(padding) padding = _expand_low_nd_padding(padding)
if in_dygraph_mode(): if in_dygraph_mode():
pool_out = core.ops.max_pool2d_with_index( if return_indices:
x, 'ksize', kernel_size, 'global_pooling', False, 'strides', stride, pool_out = core.ops.max_pool2d_with_index(
'paddings', padding, 'padding_algorithm', padding_algorithm, x, 'ksize', kernel_size, 'global_pooling', False, 'strides',
'use_cudnn', True, 'ceil_mode', ceil_mode, 'use_mkldnn', False, stride, 'paddings', padding, 'padding_algorithm',
'exclusive', True, 'data_format', data_format) padding_algorithm, 'use_cudnn', True, 'ceil_mode', ceil_mode,
return (squeeze(pool_out[0], [2]), squeeze( 'use_mkldnn', False, 'exclusive', True, 'data_format',
pool_out[1], [2])) if return_indices else squeeze(pool_out[0], [2]) data_format)
return (squeeze(pool_out[0], [2]), squeeze(
pool_out[1],
[2])) if return_indices else squeeze(pool_out[0], [2])
else:
pool_out = core.ops.pool2d(
x, 'pooling_type', 'max', 'ksize', kernel_size,
'global_pooling', False, 'padding_algorithm', padding_algorithm,
'strides', stride, 'paddings', padding, 'use_cudnn', True,
'ceil_mode', ceil_mode, 'use_mkldnn', False, 'exclusive', True,
'data_format', data_format)
return squeeze(pool_out, [2])
op_type = 'max_pool2d_with_index' op_type = 'max_pool2d_with_index' if return_indices else "pool2d"
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype) pool_out = helper.create_variable_for_type_inference(dtype)
...@@ -696,7 +707,7 @@ def max_pool2d(x, ...@@ -696,7 +707,7 @@ def max_pool2d(x,
) )
if in_dygraph_mode(): if in_dygraph_mode():
if data_format == "NCHW": if return_indices:
output = core.ops.max_pool2d_with_index( output = core.ops.max_pool2d_with_index(
x, 'ksize', kernel_size, 'global_pooling', False, 'strides', x, 'ksize', kernel_size, 'global_pooling', False, 'strides',
stride, 'paddings', padding, 'padding_algorithm', stride, 'paddings', padding, 'padding_algorithm',
...@@ -704,7 +715,7 @@ def max_pool2d(x, ...@@ -704,7 +715,7 @@ def max_pool2d(x,
'use_mkldnn', False, 'exclusive', True, 'data_format', 'use_mkldnn', False, 'exclusive', True, 'data_format',
data_format) data_format)
return output if return_indices else output[0] return output if return_indices else output[0]
elif data_format == "NHWC" and not return_indices: else:
output = core.ops.pool2d( output = core.ops.pool2d(
x, 'pooling_type', 'max', 'ksize', kernel_size, x, 'pooling_type', 'max', 'ksize', kernel_size,
'global_pooling', False, 'padding_algorithm', padding_algorithm, 'global_pooling', False, 'padding_algorithm', padding_algorithm,
...@@ -713,7 +724,7 @@ def max_pool2d(x, ...@@ -713,7 +724,7 @@ def max_pool2d(x,
'data_format', data_format) 'data_format', data_format)
return output return output
op_type = 'max_pool2d_with_index' if data_format == "NCHW" else "pool2d" op_type = 'max_pool2d_with_index' if return_indices else "pool2d"
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype) pool_out = helper.create_variable_for_type_inference(dtype)
...@@ -822,7 +833,7 @@ def max_pool3d(x, ...@@ -822,7 +833,7 @@ def max_pool3d(x,
) )
if in_dygraph_mode(): if in_dygraph_mode():
if data_format == "NCDHW": if return_indices:
output = core.ops.max_pool3d_with_index( output = core.ops.max_pool3d_with_index(
x, 'pooling_type', 'max', 'ksize', kernel_size, 'strides', x, 'pooling_type', 'max', 'ksize', kernel_size, 'strides',
stride, 'paddings', padding, 'global_pooling', False, stride, 'paddings', padding, 'global_pooling', False,
...@@ -830,7 +841,7 @@ def max_pool3d(x, ...@@ -830,7 +841,7 @@ def max_pool3d(x,
'ceil_mode', ceil_mode, 'use_mkldnn', False, 'exclusive', True, 'ceil_mode', ceil_mode, 'use_mkldnn', False, 'exclusive', True,
'data_format', data_format) 'data_format', data_format)
return output if return_indices else output[0] return output if return_indices else output[0]
elif data_format == "NDHWC" and not return_indices: else:
output = core.ops.pool3d( output = core.ops.pool3d(
x, 'pooling_type', 'max', 'ksize', kernel_size, x, 'pooling_type', 'max', 'ksize', kernel_size,
'global_pooling', False, 'padding_algorithm', padding_algorithm, 'global_pooling', False, 'padding_algorithm', padding_algorithm,
...@@ -839,7 +850,7 @@ def max_pool3d(x, ...@@ -839,7 +850,7 @@ def max_pool3d(x,
'data_format', data_format) 'data_format', data_format)
return output return output
op_type = "max_pool3d_with_index" if data_format == "NCDHW" else "pool3d" op_type = "max_pool3d_with_index" if return_indices else "pool3d"
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype) pool_out = helper.create_variable_for_type_inference(dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册