From 12075f2a0b5197d8142cea98c80bf96e170ad722 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Mon, 27 Feb 2023 15:55:12 +0800 Subject: [PATCH] support fp16 on temporal_shift (#50919) --- .../tests/unittests/test_temporal_shift_op.py | 25 +++++++++++++++++++ python/paddle/nn/functional/extension.py | 2 +- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py b/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py index ead0b50c1ad..6b99e0ead08 100644 --- a/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py +++ b/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py @@ -139,6 +139,31 @@ class TestTemporalShiftAPI(unittest.TestCase): x=input, seg_num=2, shift_ratio=0.2 ) + def test_static_fp16_gpu(self): + if paddle.fluid.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + input = np.random.random([4, 4, 112, 112]).astype("float16") + + x = paddle.static.data( + name="x", shape=[4, 4, 112, 112], dtype="float16" + ) + + y = paddle.nn.functional.temporal_shift( + x=x, seg_num=2, shift_ratio=0.2 + ) + + exe = paddle.static.Executor(place) + res = exe.run( + paddle.static.default_main_program(), + feed={ + "x": input, + }, + fetch_list=[y], + ) + def test_error(self): def attr_data_format(): input = paddle.randn([6, 4, 2, 2]) diff --git a/python/paddle/nn/functional/extension.py b/python/paddle/nn/functional/extension.py index 533bf138a1a..67bc16ccddc 100644 --- a/python/paddle/nn/functional/extension.py +++ b/python/paddle/nn/functional/extension.py @@ -379,7 +379,7 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"): else: helper = LayerHelper("temporal_shift", **locals()) check_variable_and_dtype( - x, 'x', ['float32', 'float64'], 'temporal_shift' + x, 'x', ['float16', 'float32', 'float64'], 'temporal_shift' ) check_type(seg_num, 'seg_num', int, 'temporal_shift') check_type(shift_ratio, 'shift_ratio', float, 'temporal_shift') -- GitLab