未验证 提交 b9d739a7 编写于 作者: D Double_V 提交者: GitHub

fix pool bug, test=develop (#27537)

* fix pool bug, test=develop

* fix coverage,test=develop

* fix bug, test=develop
上级 86fa0432
......@@ -195,6 +195,23 @@ class TestPool1d_API(unittest.TestCase):
result = max_pool1d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np))
def check_max_dygraph_return_index_results(self, place):
with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32]).astype("float32")
input = fluid.dygraph.to_variable(input_np)
result, index = F.max_pool1d(
input, kernel_size=2, stride=2, padding=0, return_indices=True)
result_np = max_pool1D_forward_naive(
input_np, ksize=[2], strides=[2], paddings=[0])
self.assertTrue(np.allclose(result.numpy(), result_np))
max_pool1d_dg = paddle.nn.layer.MaxPool1d(
kernel_size=2, stride=None, padding=0)
result = max_pool1d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np))
def check_max_dygraph_padding_same(self, place):
with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32]).astype("float32")
......@@ -228,6 +245,7 @@ class TestPool1d_API(unittest.TestCase):
self.check_avg_static_results(place)
self.check_max_dygraph_padding_same(place)
self.check_avg_dygraph_padding_same(place)
self.check_max_dygraph_return_index_results(place)
class TestPool2dError_API(unittest.TestCase):
......
......@@ -571,15 +571,26 @@ def max_pool1d(x,
padding = _expand_low_nd_padding(padding)
if in_dygraph_mode():
pool_out = core.ops.max_pool2d_with_index(
x, 'ksize', kernel_size, 'global_pooling', False, 'strides', stride,
'paddings', padding, 'padding_algorithm', padding_algorithm,
'use_cudnn', True, 'ceil_mode', ceil_mode, 'use_mkldnn', False,
'exclusive', True, 'data_format', data_format)
return (squeeze(pool_out[0], [2]), squeeze(
pool_out[1], [2])) if return_indices else squeeze(pool_out[0], [2])
if return_indices:
pool_out = core.ops.max_pool2d_with_index(
x, 'ksize', kernel_size, 'global_pooling', False, 'strides',
stride, 'paddings', padding, 'padding_algorithm',
padding_algorithm, 'use_cudnn', True, 'ceil_mode', ceil_mode,
'use_mkldnn', False, 'exclusive', True, 'data_format',
data_format)
return (squeeze(pool_out[0], [2]), squeeze(
pool_out[1],
[2])) if return_indices else squeeze(pool_out[0], [2])
else:
pool_out = core.ops.pool2d(
x, 'pooling_type', 'max', 'ksize', kernel_size,
'global_pooling', False, 'padding_algorithm', padding_algorithm,
'strides', stride, 'paddings', padding, 'use_cudnn', True,
'ceil_mode', ceil_mode, 'use_mkldnn', False, 'exclusive', True,
'data_format', data_format)
return squeeze(pool_out, [2])
op_type = 'max_pool2d_with_index'
op_type = 'max_pool2d_with_index' if return_indices else "pool2d"
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype)
......@@ -696,7 +707,7 @@ def max_pool2d(x,
)
if in_dygraph_mode():
if data_format == "NCHW":
if return_indices:
output = core.ops.max_pool2d_with_index(
x, 'ksize', kernel_size, 'global_pooling', False, 'strides',
stride, 'paddings', padding, 'padding_algorithm',
......@@ -704,7 +715,7 @@ def max_pool2d(x,
'use_mkldnn', False, 'exclusive', True, 'data_format',
data_format)
return output if return_indices else output[0]
elif data_format == "NHWC" and not return_indices:
else:
output = core.ops.pool2d(
x, 'pooling_type', 'max', 'ksize', kernel_size,
'global_pooling', False, 'padding_algorithm', padding_algorithm,
......@@ -713,7 +724,7 @@ def max_pool2d(x,
'data_format', data_format)
return output
op_type = 'max_pool2d_with_index' if data_format == "NCHW" else "pool2d"
op_type = 'max_pool2d_with_index' if return_indices else "pool2d"
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype)
......@@ -822,7 +833,7 @@ def max_pool3d(x,
)
if in_dygraph_mode():
if data_format == "NCDHW":
if return_indices:
output = core.ops.max_pool3d_with_index(
x, 'pooling_type', 'max', 'ksize', kernel_size, 'strides',
stride, 'paddings', padding, 'global_pooling', False,
......@@ -830,7 +841,7 @@ def max_pool3d(x,
'ceil_mode', ceil_mode, 'use_mkldnn', False, 'exclusive', True,
'data_format', data_format)
return output if return_indices else output[0]
elif data_format == "NDHWC" and not return_indices:
else:
output = core.ops.pool3d(
x, 'pooling_type', 'max', 'ksize', kernel_size,
'global_pooling', False, 'padding_algorithm', padding_algorithm,
......@@ -839,7 +850,7 @@ def max_pool3d(x,
'data_format', data_format)
return output
op_type = "max_pool3d_with_index" if data_format == "NCDHW" else "pool3d"
op_type = "max_pool3d_with_index" if return_indices else "pool3d"
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册