From 769c032fc4e421df2ebc8c4f748cb8c8e136843c Mon Sep 17 00:00:00 2001 From: dyning Date: Fri, 21 Feb 2020 12:24:14 +0800 Subject: [PATCH] fix spp test (#22675) --- python/paddle/fluid/tests/unittests/test_spp_op.py | 10 ++++++---- .../unittests/white_list/check_shape_white_list.py | 1 - 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_spp_op.py b/python/paddle/fluid/tests/unittests/test_spp_op.py index 1a134b58de..4a7ea97cfb 100644 --- a/python/paddle/fluid/tests/unittests/test_spp_op.py +++ b/python/paddle/fluid/tests/unittests/test_spp_op.py @@ -25,8 +25,11 @@ class TestSppOp(OpTest): def setUp(self): self.op_type = "spp" self.init_test_case() - input = np.random.random(self.shape).astype("float64") - nsize, csize, hsize, wsize = input.shape + nsize, csize, hsize, wsize = self.shape + data = np.array(list(range(nsize * csize * hsize * wsize))) + input = data.reshape(self.shape) + input_random = np.random.random(self.shape).astype("float64") + input = input + input_random out_level_flatten = [] for i in range(self.pyramid_height): bins = np.power(2, i) @@ -55,7 +58,6 @@ class TestSppOp(OpTest): 'pyramid_height': self.pyramid_height, 'pooling_type': self.pool_type } - self.outputs = {'Out': output.astype('float64')} def test_check_output(self): @@ -65,7 +67,7 @@ class TestSppOp(OpTest): self.check_grad(['X'], 'Out') def init_test_case(self): - self.shape = [3, 2, 4, 4] + self.shape = [3, 2, 16, 16] self.pyramid_height = 3 self.pool2D_forward_naive = max_pool2D_forward_naive self.pool_type = "max" diff --git a/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py b/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py index 34c58fda7e..7cc5445e02 100644 --- a/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py @@ -29,7 +29,6 @@ NEED_TO_FIX_OP_LIST = [ 'scatter', 'smooth_l1_loss', 'soft_relu', - 'spp', 'squared_l2_distance', 'tree_conv', ] -- GitLab