From 7d95e598c185fe90ab4cb566cda2367cb792b5a2 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Thu, 4 Mar 2021 18:54:55 +0800 Subject: [PATCH] support float16 for temporal_shift op (#31432) --- paddle/fluid/operators/temporal_shift_op.cu | 21 +++++++++-------- .../tests/unittests/test_temporal_shift_op.py | 23 ++++++++++++++++++- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/temporal_shift_op.cu b/paddle/fluid/operators/temporal_shift_op.cu index b61d9aeff7..4f2d7ce3cf 100644 --- a/paddle/fluid/operators/temporal_shift_op.cu +++ b/paddle/fluid/operators/temporal_shift_op.cu @@ -33,8 +33,8 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw, int ih = (tid % hw) / w; int iw = tid % w; - const int c1 = static_cast(c * shift_ratio); - const int c2 = static_cast(c * 2 * shift_ratio); + const int c1 = static_cast(c * shift_ratio); + const int c2 = static_cast(c * 2 * shift_ratio); if (ic < c1) { src_it = it - 1; @@ -69,8 +69,8 @@ __global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, int ih = (tid % hw) / w; int iw = tid % w; - const int c1 = static_cast(c * shift_ratio); - const int c2 = static_cast(c * 2 * shift_ratio); + const int c1 = static_cast(c * shift_ratio); + const int c2 = static_cast(c * 2 * shift_ratio); if (ic < c1) { src_it = it - 1; @@ -163,8 +163,11 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(temporal_shift, ops::TemporalShiftOpCUDAKernel, - ops::TemporalShiftOpCUDAKernel); -REGISTER_OP_CUDA_KERNEL(temporal_shift_grad, - ops::TemporalShiftGradOpCUDAKernel, - ops::TemporalShiftGradOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + temporal_shift, ops::TemporalShiftOpCUDAKernel, + ops::TemporalShiftOpCUDAKernel, + ops::TemporalShiftOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + temporal_shift_grad, ops::TemporalShiftGradOpCUDAKernel, + ops::TemporalShiftGradOpCUDAKernel, + ops::TemporalShiftGradOpCUDAKernel); diff --git a/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py b/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py index 12eec2073b..050c38e549 100644 --- a/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py +++ b/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py @@ -40,7 +40,7 @@ class TestTemporalShift(OpTest): def setUp(self): self.initTestCase() self.op_type = 'temporal_shift' - x = np.random.random(self.x_shape).astype('float64') + x = np.random.random(self.x_shape).astype(self.dtype) self.attrs = { "seg_num": self.seg_num, @@ -62,6 +62,7 @@ class TestTemporalShift(OpTest): self.x_shape = (6, 4, 4, 4) self.seg_num = 3 self.shift_ratio = 0.25 + self.dtype = 'float64' class TestTemporalShift2(TestTemporalShift): @@ -78,6 +79,26 @@ class TestTemporalShift3(TestTemporalShift): self.shift_ratio = 0.3 +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestTemporalShiftFP16(TestTemporalShift): + def initTestCase(self): + self.x_shape = (3, 10, 5, 5) + self.seg_num = 1 + self.shift_ratio = 0.3 + self.dtype = 'float16' + + def test_check_output(self): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place) + + def test_check_grad_ignore_uv(self): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_grad_with_place(place, ['X'], 'Out') + + class TestTemporalShiftAPI(unittest.TestCase): def test_api(self): input = paddle.randn([6, 4, 2, 2]) -- GitLab