diff --git a/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py b/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py index ead0b50c1ad0e500306c25c14b14d56efcb2b4f0..6b99e0ead08867f6242561ba24a449192f95090b 100644 --- a/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py +++ b/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py @@ -139,6 +139,31 @@ class TestTemporalShiftAPI(unittest.TestCase): x=input, seg_num=2, shift_ratio=0.2 ) + def test_static_fp16_gpu(self): + if paddle.fluid.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + input = np.random.random([4, 4, 112, 112]).astype("float16") + + x = paddle.static.data( + name="x", shape=[4, 4, 112, 112], dtype="float16" + ) + + y = paddle.nn.functional.temporal_shift( + x=x, seg_num=2, shift_ratio=0.2 + ) + + exe = paddle.static.Executor(place) + res = exe.run( + paddle.static.default_main_program(), + feed={ + "x": input, + }, + fetch_list=[y], + ) + def test_error(self): def attr_data_format(): input = paddle.randn([6, 4, 2, 2]) diff --git a/python/paddle/nn/functional/extension.py b/python/paddle/nn/functional/extension.py index 533bf138a1a49da2624ae27f160df9dca097f172..67bc16ccddc6ace341232fd92c291f607f437cfa 100644 --- a/python/paddle/nn/functional/extension.py +++ b/python/paddle/nn/functional/extension.py @@ -379,7 +379,7 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"): else: helper = LayerHelper("temporal_shift", **locals()) check_variable_and_dtype( - x, 'x', ['float32', 'float64'], 'temporal_shift' + x, 'x', ['float16', 'float32', 'float64'], 'temporal_shift' ) check_type(seg_num, 'seg_num', int, 'temporal_shift') check_type(shift_ratio, 'shift_ratio', float, 'temporal_shift')