未验证 提交 12075f2a 编写于 作者: 张春乔 提交者: GitHub

support fp16 on temporal_shift (#50919)

上级 8a097399
......@@ -139,6 +139,31 @@ class TestTemporalShiftAPI(unittest.TestCase):
x=input, seg_num=2, shift_ratio=0.2
)
def test_static_fp16_gpu(self):
if paddle.fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
input = np.random.random([4, 4, 112, 112]).astype("float16")
x = paddle.static.data(
name="x", shape=[4, 4, 112, 112], dtype="float16"
)
y = paddle.nn.functional.temporal_shift(
x=x, seg_num=2, shift_ratio=0.2
)
exe = paddle.static.Executor(place)
res = exe.run(
paddle.static.default_main_program(),
feed={
"x": input,
},
fetch_list=[y],
)
def test_error(self):
def attr_data_format():
input = paddle.randn([6, 4, 2, 2])
......
......@@ -379,7 +379,7 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"):
else:
helper = LayerHelper("temporal_shift", **locals())
check_variable_and_dtype(
x, 'x', ['float32', 'float64'], 'temporal_shift'
x, 'x', ['float16', 'float32', 'float64'], 'temporal_shift'
)
check_type(seg_num, 'seg_num', int, 'temporal_shift')
check_type(shift_ratio, 'shift_ratio', float, 'temporal_shift')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册