未验证 提交 266fcbe0 编写于 作者: W wangguanzhong 提交者: GitHub

support double in deformable conv (#35330)

* support double in deformable conv

* add double for dcn v2
上级 49797d85
......@@ -126,7 +126,8 @@ __global__ void ModulatedDeformableCol2imGpuKernel(
DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, cur_h + dy,
cur_w + dx, height, width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
platform::CudaAtomicAdd(grad_im + cur_bottom_grad_pos,
weight * cur_top_grad);
}
}
}
......@@ -748,6 +749,8 @@ namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(deformable_conv,
ops::DeformableConvCUDAKernel<CUDA, float>);
ops::DeformableConvCUDAKernel<CUDA, float>,
ops::DeformableConvCUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(deformable_conv_grad,
ops::DeformableConvGradCUDAKernel<CUDA, float>);
ops::DeformableConvGradCUDAKernel<CUDA, float>,
ops::DeformableConvGradCUDAKernel<CUDA, double>);
......@@ -307,6 +307,8 @@ REGISTER_OPERATOR(deformable_conv_v1, ops::DeformableConvV1Op,
REGISTER_OPERATOR(deformable_conv_v1_grad, ops::DeformableConvV1GradOp);
REGISTER_OP_CPU_KERNEL(deformable_conv_v1,
ops::DeformableConvV1CPUKernel<float>);
ops::DeformableConvV1CPUKernel<float>,
ops::DeformableConvV1CPUKernel<double>);
REGISTER_OP_CPU_KERNEL(deformable_conv_v1_grad,
ops::DeformableConvV1GradCPUKernel<float>);
ops::DeformableConvV1GradCPUKernel<float>,
ops::DeformableConvV1GradCPUKernel<double>);
......@@ -99,7 +99,8 @@ __global__ void DeformableCol2imCUDAKernel(
DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, cur_h + dy,
cur_w + dx, height, width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
platform::CudaAtomicAdd(grad_im + cur_bottom_grad_pos,
weight * cur_top_grad);
}
}
}
......@@ -604,6 +605,8 @@ class DeformableConvV1GradCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(deformable_conv_v1,
ops::DeformableConvV1CUDAKernel<float>);
ops::DeformableConvV1CUDAKernel<float>,
ops::DeformableConvV1CUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(deformable_conv_v1_grad,
ops::DeformableConvV1GradCUDAKernel<float>);
ops::DeformableConvV1GradCUDAKernel<float>,
ops::DeformableConvV1GradCUDAKernel<double>);
......@@ -111,7 +111,7 @@ def dconv_im2col_gemm(input, offset, mask, filter, group, conv_param):
class TestModulatedDeformableConvOp(OpTest):
def setUp(self):
self.op_type = "deformable_conv"
self.dtype = np.float32
self.init_type()
self.init_group()
self.init_dilation()
self.init_test_case()
......@@ -183,6 +183,9 @@ class TestModulatedDeformableConvOp(OpTest):
def init_group(self):
self.groups = 1
def init_type(self):
self.dtype = np.float32
class TestWithStride(TestModulatedDeformableConvOp):
def init_test_case(self):
......@@ -258,6 +261,32 @@ class TestWithGroup(TestModulatedDeformableConvOp):
self.groups = 2
class TestWithDouble(TestModulatedDeformableConvOp):
def init_type(self):
self.dtype = np.float64
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 6, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [4, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
mask_c = self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
self.mask_size = [
self.input_size[0], mask_c, self.input_size[2], self.input_size[3]
]
class TestModulatedDeformableConvInvalidInput(unittest.TestCase):
def test_error(self):
def test_invalid_input():
......
......@@ -108,7 +108,7 @@ def dconv_im2col_gemm(input, offset, filter, group, conv_param):
class TestModulatedDeformableConvOp(OpTest):
def setUp(self):
self.op_type = "deformable_conv_v1"
self.dtype = np.float32
self.init_type()
self.init_group()
self.init_dilation()
self.init_test_case()
......@@ -177,6 +177,9 @@ class TestModulatedDeformableConvOp(OpTest):
def init_group(self):
self.groups = 1
def init_type(self):
self.dtype = np.float32
class TestWithStride(TestModulatedDeformableConvOp):
def init_test_case(self):
......@@ -253,6 +256,11 @@ class TestWithGroup(TestModulatedDeformableConvOp):
self.groups = 2
class TestWithDouble(TestModulatedDeformableConvOp):
def init_type(self):
self.dtype = np.float64
class TestModulatedDeformableConvV1InvalidInput(unittest.TestCase):
def test_error(self):
def test_invalid_input():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册