From 659cede029038be2fa094cdc6b319b65eb264b94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Mon, 27 Feb 2023 15:51:13 +0800 Subject: [PATCH] [fp16] support fp16 on AvgPool3D (#50920) * support fp16 on AvgPool3D * Apply suggestions from code review --- .../fluid/tests/unittests/test_pool3d_api.py | 26 +++++++++++++++++++ python/paddle/nn/functional/pooling.py | 4 ++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_pool3d_api.py b/python/paddle/fluid/tests/unittests/test_pool3d_api.py index e1d7543e7bc..2c069b9e844 100644 --- a/python/paddle/fluid/tests/unittests/test_pool3d_api.py +++ b/python/paddle/fluid/tests/unittests/test_pool3d_api.py @@ -366,6 +366,32 @@ class TestPool3D_API(unittest.TestCase): self.check_max_dygraph_ndhwc_results(place) self.check_max_dygraph_ceilmode_results(place) + def test_static_pf16_gpu(self): + if paddle.fluid.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + input = np.random.random([1, 2, 3, 32, 32]).astype("float16") + + x = paddle.static.data( + name="x", shape=[1, 2, 3, 32, 32], dtype="float16" + ) + + m = paddle.nn.AvgPool3D(kernel_size=2, stride=2, padding=0) + y = m(x) + + exe = paddle.static.Executor(place) + res = exe.run( + paddle.static.default_main_program(), + feed={ + "x": input, + }, + fetch_list=[y], + ) + + assert np.array_equal(res[0].shape, [1, 2, 1, 16, 16]) + class TestPool3DError_API(unittest.TestCase): def test_error_api(self): diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 03e7f202c53..d4a1e3c1456 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -511,7 +511,9 @@ def avg_pool3d( else: op_type = "pool3d" helper = LayerHelper(op_type, **locals()) - check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'max_pool3d') + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64'], 'avg_pool3d' + ) dtype = helper.input_dtype(input_param_name='x') pool_out = helper.create_variable_for_type_inference(dtype) outputs = {"Out": pool_out} -- GitLab