未验证 提交 7d95e598 编写于 作者: Z Zhang Ting 提交者: GitHub

support float16 for temporal_shift op (#31432)

上级 3a8ef10e
......@@ -33,8 +33,8 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
int ih = (tid % hw) / w;
int iw = tid % w;
const int c1 = static_cast<T>(c * shift_ratio);
const int c2 = static_cast<T>(c * 2 * shift_ratio);
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
if (ic < c1) {
src_it = it - 1;
......@@ -69,8 +69,8 @@ __global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad,
int ih = (tid % hw) / w;
int iw = tid % w;
const int c1 = static_cast<T>(c * shift_ratio);
const int c2 = static_cast<T>(c * 2 * shift_ratio);
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
if (ic < c1) {
src_it = it - 1;
......@@ -163,8 +163,11 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(temporal_shift, ops::TemporalShiftOpCUDAKernel<float>,
ops::TemporalShiftOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(temporal_shift_grad,
ops::TemporalShiftGradOpCUDAKernel<float>,
ops::TemporalShiftGradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
temporal_shift, ops::TemporalShiftOpCUDAKernel<float>,
ops::TemporalShiftOpCUDAKernel<double>,
ops::TemporalShiftOpCUDAKernel<paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
temporal_shift_grad, ops::TemporalShiftGradOpCUDAKernel<float>,
ops::TemporalShiftGradOpCUDAKernel<double>,
ops::TemporalShiftGradOpCUDAKernel<paddle::platform::float16>);
......@@ -40,7 +40,7 @@ class TestTemporalShift(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'temporal_shift'
x = np.random.random(self.x_shape).astype('float64')
x = np.random.random(self.x_shape).astype(self.dtype)
self.attrs = {
"seg_num": self.seg_num,
......@@ -62,6 +62,7 @@ class TestTemporalShift(OpTest):
self.x_shape = (6, 4, 4, 4)
self.seg_num = 3
self.shift_ratio = 0.25
self.dtype = 'float64'
class TestTemporalShift2(TestTemporalShift):
......@@ -78,6 +79,26 @@ class TestTemporalShift3(TestTemporalShift):
self.shift_ratio = 0.3
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestTemporalShiftFP16(TestTemporalShift):
def initTestCase(self):
self.x_shape = (3, 10, 5, 5)
self.seg_num = 1
self.shift_ratio = 0.3
self.dtype = 'float16'
def test_check_output(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place)
def test_check_grad_ignore_uv(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(place, ['X'], 'Out')
class TestTemporalShiftAPI(unittest.TestCase):
def test_api(self):
input = paddle.randn([6, 4, 2, 2])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册