未验证 提交 c8b90d8f 编写于 作者: B Bai Yifan 提交者: GitHub

fix deformable_conv small cases, test=develop (#22441)

上级 943cb8c6
......@@ -154,32 +154,11 @@ class TestModulatedDeformableConvOp(OpTest):
'Output',
max_relative_error=0.05)
def test_check_grad_no_filter(self):
self.check_grad(
['Input', 'Offset', 'Mask'],
'Output',
max_relative_error=0.1,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
self.check_grad(
['Filter', 'Offset', 'Mask'],
'Output',
max_relative_error=0.1,
no_grad_set=set(['Input']))
def test_check_grad_no_offset_no_mask(self):
self.check_grad(
['Input', 'Filter'],
'Output',
max_relative_error=0.1,
no_grad_set=set(['Offset', 'Mask']))
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 4, 4, 4] # NCHW
self.input_size = [2, 8, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [4, f_c, 3, 3]
......@@ -229,7 +208,7 @@ class TestWithDilation(TestModulatedDeformableConvOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [1, 1]
self.input_size = [2, 3, 4, 4] # NCHW
self.input_size = [4, 3, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
......@@ -250,14 +229,14 @@ class TestWithDilation(TestModulatedDeformableConvOp):
self.dilations = [2, 2]
class TestWith1x1(TestModulatedDeformableConvOp):
class TestWith3x3(TestModulatedDeformableConvOp):
def init_test_case(self):
self.pad = [0, 0]
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1]
self.filter_size = [6, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
......
......@@ -21,7 +21,6 @@ NEED_TO_FIX_OP_LIST = [
'fused_elemwise_activation',
'bilinear_tensor_product',
'conv2d_transpose',
'deformable_conv',
'depthwise_conv2d_transpose',
'grid_sampler',
'hierarchical_sigmoid',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册