未验证 提交 3c612b92 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] fix max_pool3d_with_index interface under different mode (#45647)

* [Eager] fix max_pool3d_with_index interface under different mode

* fix mistake

* Add tests under legacy and supplement comments
上级 9b5e0154
...@@ -164,7 +164,7 @@ class TestAdaptiveMaxPool3DAPI(unittest.TestCase): ...@@ -164,7 +164,7 @@ class TestAdaptiveMaxPool3DAPI(unittest.TestCase):
assert np.allclose(res_5, self.res_5_np) assert np.allclose(res_5, self.res_5_np)
def test_dynamic_graph(self): def func_dynamic_graph(self):
for use_cuda in ([False, True] for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]): if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
...@@ -195,6 +195,11 @@ class TestAdaptiveMaxPool3DAPI(unittest.TestCase): ...@@ -195,6 +195,11 @@ class TestAdaptiveMaxPool3DAPI(unittest.TestCase):
assert np.allclose(out_5.numpy(), self.res_5_np) assert np.allclose(out_5.numpy(), self.res_5_np)
def test_dynamic_graph(self):
with paddle.fluid.framework._test_eager_guard():
self.func_dynamic_graph()
self.func_dynamic_graph()
class TestAdaptiveMaxPool3DClassAPI(unittest.TestCase): class TestAdaptiveMaxPool3DClassAPI(unittest.TestCase):
......
...@@ -1860,9 +1860,14 @@ def adaptive_max_pool3d(x, output_size, return_mask=False, name=None): ...@@ -1860,9 +1860,14 @@ def adaptive_max_pool3d(x, output_size, return_mask=False, name=None):
output_size[2] = in_w output_size[2] = in_w
if in_dynamic_mode(): if in_dynamic_mode():
pool_out = _legacy_C_ops.max_pool3d_with_index(x, 'pooling_type', 'max', if in_dygraph_mode():
'ksize', output_size, # By default, strides is [1,1,1] and paddings is [0, 0, 0]
'adaptive', True) pool_out = _C_ops.max_pool3d_with_index(x, output_size, [1, 1, 1],
[0, 0, 0], False, True)
elif _in_legacy_dygraph():
pool_out = _legacy_C_ops.max_pool3d_with_index(
x, 'pooling_type', 'max', 'ksize', output_size, 'adaptive',
True)
return pool_out if return_mask else pool_out[0] return pool_out if return_mask else pool_out[0]
l_type = 'max_pool3d_with_index' l_type = 'max_pool3d_with_index'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册