未验证 提交 769c032f 编写于 作者: D dyning 提交者: GitHub

fix spp test (#22675)

上级 1a595d8e
...@@ -25,8 +25,11 @@ class TestSppOp(OpTest): ...@@ -25,8 +25,11 @@ class TestSppOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "spp" self.op_type = "spp"
self.init_test_case() self.init_test_case()
input = np.random.random(self.shape).astype("float64") nsize, csize, hsize, wsize = self.shape
nsize, csize, hsize, wsize = input.shape data = np.array(list(range(nsize * csize * hsize * wsize)))
input = data.reshape(self.shape)
input_random = np.random.random(self.shape).astype("float64")
input = input + input_random
out_level_flatten = [] out_level_flatten = []
for i in range(self.pyramid_height): for i in range(self.pyramid_height):
bins = np.power(2, i) bins = np.power(2, i)
...@@ -55,7 +58,6 @@ class TestSppOp(OpTest): ...@@ -55,7 +58,6 @@ class TestSppOp(OpTest):
'pyramid_height': self.pyramid_height, 'pyramid_height': self.pyramid_height,
'pooling_type': self.pool_type 'pooling_type': self.pool_type
} }
self.outputs = {'Out': output.astype('float64')} self.outputs = {'Out': output.astype('float64')}
def test_check_output(self): def test_check_output(self):
...@@ -65,7 +67,7 @@ class TestSppOp(OpTest): ...@@ -65,7 +67,7 @@ class TestSppOp(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
def init_test_case(self): def init_test_case(self):
self.shape = [3, 2, 4, 4] self.shape = [3, 2, 16, 16]
self.pyramid_height = 3 self.pyramid_height = 3
self.pool2D_forward_naive = max_pool2D_forward_naive self.pool2D_forward_naive = max_pool2D_forward_naive
self.pool_type = "max" self.pool_type = "max"
......
...@@ -29,7 +29,6 @@ NEED_TO_FIX_OP_LIST = [ ...@@ -29,7 +29,6 @@ NEED_TO_FIX_OP_LIST = [
'scatter', 'scatter',
'smooth_l1_loss', 'smooth_l1_loss',
'soft_relu', 'soft_relu',
'spp',
'squared_l2_distance', 'squared_l2_distance',
'tree_conv', 'tree_conv',
] ]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册