未验证 提交 266fcbe0 编写于 作者: W wangguanzhong 提交者: GitHub

support double in deformable conv (#35330)

* support double in deformable conv

* add double for dcn v2
上级 49797d85
...@@ -126,7 +126,8 @@ __global__ void ModulatedDeformableCol2imGpuKernel( ...@@ -126,7 +126,8 @@ __global__ void ModulatedDeformableCol2imGpuKernel(
DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, cur_h + dy,
cur_w + dx, height, width); cur_w + dx, height, width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); platform::CudaAtomicAdd(grad_im + cur_bottom_grad_pos,
weight * cur_top_grad);
} }
} }
} }
...@@ -748,6 +749,8 @@ namespace ops = paddle::operators; ...@@ -748,6 +749,8 @@ namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext; using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(deformable_conv, REGISTER_OP_CUDA_KERNEL(deformable_conv,
ops::DeformableConvCUDAKernel<CUDA, float>); ops::DeformableConvCUDAKernel<CUDA, float>,
ops::DeformableConvCUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(deformable_conv_grad, REGISTER_OP_CUDA_KERNEL(deformable_conv_grad,
ops::DeformableConvGradCUDAKernel<CUDA, float>); ops::DeformableConvGradCUDAKernel<CUDA, float>,
ops::DeformableConvGradCUDAKernel<CUDA, double>);
...@@ -307,6 +307,8 @@ REGISTER_OPERATOR(deformable_conv_v1, ops::DeformableConvV1Op, ...@@ -307,6 +307,8 @@ REGISTER_OPERATOR(deformable_conv_v1, ops::DeformableConvV1Op,
REGISTER_OPERATOR(deformable_conv_v1_grad, ops::DeformableConvV1GradOp); REGISTER_OPERATOR(deformable_conv_v1_grad, ops::DeformableConvV1GradOp);
REGISTER_OP_CPU_KERNEL(deformable_conv_v1, REGISTER_OP_CPU_KERNEL(deformable_conv_v1,
ops::DeformableConvV1CPUKernel<float>); ops::DeformableConvV1CPUKernel<float>,
ops::DeformableConvV1CPUKernel<double>);
REGISTER_OP_CPU_KERNEL(deformable_conv_v1_grad, REGISTER_OP_CPU_KERNEL(deformable_conv_v1_grad,
ops::DeformableConvV1GradCPUKernel<float>); ops::DeformableConvV1GradCPUKernel<float>,
ops::DeformableConvV1GradCPUKernel<double>);
...@@ -99,7 +99,8 @@ __global__ void DeformableCol2imCUDAKernel( ...@@ -99,7 +99,8 @@ __global__ void DeformableCol2imCUDAKernel(
DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, cur_h + dy,
cur_w + dx, height, width); cur_w + dx, height, width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); platform::CudaAtomicAdd(grad_im + cur_bottom_grad_pos,
weight * cur_top_grad);
} }
} }
} }
...@@ -604,6 +605,8 @@ class DeformableConvV1GradCUDAKernel : public framework::OpKernel<T> { ...@@ -604,6 +605,8 @@ class DeformableConvV1GradCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(deformable_conv_v1, REGISTER_OP_CUDA_KERNEL(deformable_conv_v1,
ops::DeformableConvV1CUDAKernel<float>); ops::DeformableConvV1CUDAKernel<float>,
ops::DeformableConvV1CUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(deformable_conv_v1_grad, REGISTER_OP_CUDA_KERNEL(deformable_conv_v1_grad,
ops::DeformableConvV1GradCUDAKernel<float>); ops::DeformableConvV1GradCUDAKernel<float>,
ops::DeformableConvV1GradCUDAKernel<double>);
...@@ -111,7 +111,7 @@ def dconv_im2col_gemm(input, offset, mask, filter, group, conv_param): ...@@ -111,7 +111,7 @@ def dconv_im2col_gemm(input, offset, mask, filter, group, conv_param):
class TestModulatedDeformableConvOp(OpTest): class TestModulatedDeformableConvOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "deformable_conv" self.op_type = "deformable_conv"
self.dtype = np.float32 self.init_type()
self.init_group() self.init_group()
self.init_dilation() self.init_dilation()
self.init_test_case() self.init_test_case()
...@@ -183,6 +183,9 @@ class TestModulatedDeformableConvOp(OpTest): ...@@ -183,6 +183,9 @@ class TestModulatedDeformableConvOp(OpTest):
def init_group(self): def init_group(self):
self.groups = 1 self.groups = 1
def init_type(self):
self.dtype = np.float32
class TestWithStride(TestModulatedDeformableConvOp): class TestWithStride(TestModulatedDeformableConvOp):
def init_test_case(self): def init_test_case(self):
...@@ -258,6 +261,32 @@ class TestWithGroup(TestModulatedDeformableConvOp): ...@@ -258,6 +261,32 @@ class TestWithGroup(TestModulatedDeformableConvOp):
self.groups = 2 self.groups = 2
class TestWithDouble(TestModulatedDeformableConvOp):
def init_type(self):
self.dtype = np.float64
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 6, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [4, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
mask_c = self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
self.mask_size = [
self.input_size[0], mask_c, self.input_size[2], self.input_size[3]
]
class TestModulatedDeformableConvInvalidInput(unittest.TestCase): class TestModulatedDeformableConvInvalidInput(unittest.TestCase):
def test_error(self): def test_error(self):
def test_invalid_input(): def test_invalid_input():
......
...@@ -108,7 +108,7 @@ def dconv_im2col_gemm(input, offset, filter, group, conv_param): ...@@ -108,7 +108,7 @@ def dconv_im2col_gemm(input, offset, filter, group, conv_param):
class TestModulatedDeformableConvOp(OpTest): class TestModulatedDeformableConvOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "deformable_conv_v1" self.op_type = "deformable_conv_v1"
self.dtype = np.float32 self.init_type()
self.init_group() self.init_group()
self.init_dilation() self.init_dilation()
self.init_test_case() self.init_test_case()
...@@ -177,6 +177,9 @@ class TestModulatedDeformableConvOp(OpTest): ...@@ -177,6 +177,9 @@ class TestModulatedDeformableConvOp(OpTest):
def init_group(self): def init_group(self):
self.groups = 1 self.groups = 1
def init_type(self):
self.dtype = np.float32
class TestWithStride(TestModulatedDeformableConvOp): class TestWithStride(TestModulatedDeformableConvOp):
def init_test_case(self): def init_test_case(self):
...@@ -253,6 +256,11 @@ class TestWithGroup(TestModulatedDeformableConvOp): ...@@ -253,6 +256,11 @@ class TestWithGroup(TestModulatedDeformableConvOp):
self.groups = 2 self.groups = 2
class TestWithDouble(TestModulatedDeformableConvOp):
def init_type(self):
self.dtype = np.float64
class TestModulatedDeformableConvV1InvalidInput(unittest.TestCase): class TestModulatedDeformableConvV1InvalidInput(unittest.TestCase):
def test_error(self): def test_error(self):
def test_invalid_input(): def test_invalid_input():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册