提交 985e4bae 编写于 作者: D dyning 提交者: Zhang Ting

fix unittest for spp op, test=develop (#22030)

fix unittest for spp op
上级 c53b62eb
...@@ -62,11 +62,10 @@ class TestSppOp(OpTest): ...@@ -62,11 +62,10 @@ class TestSppOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
if self.pool_type != "avg": self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', max_relative_error=0.05)
def init_test_case(self): def init_test_case(self):
self.shape = [4, 2, 4, 4] self.shape = [3, 2, 4, 4]
self.pyramid_height = 3 self.pyramid_height = 3
self.pool2D_forward_naive = max_pool2D_forward_naive self.pool2D_forward_naive = max_pool2D_forward_naive
self.pool_type = "max" self.pool_type = "max"
...@@ -74,7 +73,7 @@ class TestSppOp(OpTest): ...@@ -74,7 +73,7 @@ class TestSppOp(OpTest):
class TestCase2(TestSppOp): class TestCase2(TestSppOp):
def init_test_case(self): def init_test_case(self):
self.shape = [3, 2, 4, 4] self.shape = [3, 2, 16, 16]
self.pyramid_height = 3 self.pyramid_height = 3
self.pool2D_forward_naive = avg_pool2D_forward_naive self.pool2D_forward_naive = avg_pool2D_forward_naive
self.pool_type = "avg" self.pool_type = "avg"
......
...@@ -28,9 +28,9 @@ NO_FP64_CHECK_GRAD_OP_LIST = [ ...@@ -28,9 +28,9 @@ NO_FP64_CHECK_GRAD_OP_LIST = [
'reduce_min', 'relu', 'reshape2', 'roi_perspective_transform', 'row_conv', 'reduce_min', 'relu', 'reshape2', 'roi_perspective_transform', 'row_conv',
'scale', 'scatter', 'sequence_conv', 'sequence_pool', 'sequence_reverse', 'scale', 'scatter', 'sequence_conv', 'sequence_pool', 'sequence_reverse',
'sequence_slice', 'sequence_topk_avg_pooling', 'shuffle_channel', 'sigmoid', 'sequence_slice', 'sequence_topk_avg_pooling', 'shuffle_channel', 'sigmoid',
'smooth_l1_loss', 'softmax', 'spectral_norm', 'spp', 'sqrt', 'smooth_l1_loss', 'softmax', 'spectral_norm', 'sqrt', 'squared_l2_distance',
'squared_l2_distance', 'squared_l2_norm', 'tanh', 'transpose2', 'squared_l2_norm', 'tanh', 'transpose2', 'trilinear_interp', 'var_conv_2d',
'trilinear_interp', 'var_conv_2d', 'warpctc' 'warpctc'
] ]
NO_FP16_CHECK_GRAD_OP_LIST = [ NO_FP16_CHECK_GRAD_OP_LIST = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册