From cb5e0d20c2fd3167bc3e5678a5f09afad0b9bac6 Mon Sep 17 00:00:00 2001 From: zhupengyang <1165938320@qq.com> Date: Thu, 16 Jan 2020 15:00:14 +0800 Subject: [PATCH] skip cases with small shape (#22318) --- .../paddle/fluid/tests/unittests/op_test.py | 12 +++---- .../fluid/tests/unittests/test_prelu_op.py | 35 ++++++++++++------- .../white_list/check_shape_white_list.py | 2 -- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index e5154e0c595..4e0fe464cd7 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -231,12 +231,12 @@ class OpTest(unittest.TestCase): "This test of %s op needs check_grad with fp64 precision." % cls.op_type) - if not get_numeric_gradient.is_large_shape \ - and cls.op_type not in check_shape_white_list.NOT_CHECK_OP_LIST \ - and cls.op_type not in check_shape_white_list.NEED_TO_FIX_OP_LIST: - raise AssertionError( - "Input's shape should be large than or equal to 100 for " + - cls.op_type + " Op.") + if not get_numeric_gradient.is_large_shape \ + and cls.op_type not in check_shape_white_list.NOT_CHECK_OP_LIST \ + and cls.op_type not in check_shape_white_list.NEED_TO_FIX_OP_LIST: + raise AssertionError( + "Input's shape should be large than or equal to 100 for " + + cls.op_type + " Op.") def try_call_once(self, data_type): if not self.call_once: diff --git a/python/paddle/fluid/tests/unittests/test_prelu_op.py b/python/paddle/fluid/tests/unittests/test_prelu_op.py index 190fa0f42ae..a30db2c4243 100644 --- a/python/paddle/fluid/tests/unittests/test_prelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_prelu_op.py @@ -17,15 +17,16 @@ from __future__ import print_function import unittest import numpy as np import six -from op_test import OpTest +from op_test import OpTest, skip_check_grad_ci class PReluTest(OpTest): def setUp(self): + self.init_input_shape() + self.init_attr() self.op_type = "prelu" - self.initTestCase() - x_np = np.random.normal(size=(3, 5, 5, 10)).astype("float32") + x_np = np.random.uniform(-1, 1, self.x_shape).astype("float32") # Since zero point in prelu is not differentiable, avoid randomize # zero. x_np[np.abs(x_np) < 0.005] = 0.02 @@ -37,8 +38,8 @@ class PReluTest(OpTest): alpha_np = np.random.rand(1, x_np.shape[1], 1, 1).astype("float32") self.inputs = {'X': x_np, 'Alpha': alpha_np} else: - alpha_np = np.random.rand(1, x_np.shape[1], x_np.shape[2], \ - x_np.shape[3]).astype("float32") + alpha_np = np.random.rand(1, x_np.shape[1], x_np.shape[2], + x_np.shape[3]).astype("float32") self.inputs = {'X': x_np, 'Alpha': alpha_np} out_np = np.maximum(self.inputs['X'], 0.) @@ -47,7 +48,10 @@ class PReluTest(OpTest): assert out_np is not self.inputs['X'] self.outputs = {'Out': out_np} - def initTestCase(self): + def init_input_shape(self): + self.x_shape = (2, 100, 3, 4) + + def init_attr(self): self.attrs = {'mode': "channel"} def test_check_output(self): @@ -66,16 +70,21 @@ class PReluTest(OpTest): # TODO(minqiyang): Resume these test cases after fixing Python3 CI job issues if six.PY2: - class TestCase1(PReluTest): - def initTestCase(self): + @skip_check_grad_ci( + reason="[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode" + ) + class TestModeAll(PReluTest): + def init_input_shape(self): + self.x_shape = (2, 3, 4, 5) + + def init_attr(self): self.attrs = {'mode': "all"} - class TestCase2(PReluTest): - def initTestCase(self): - self.attrs = {'mode': "channel"} + class TestModeElt(PReluTest): + def init_input_shape(self): + self.x_shape = (3, 2, 5, 10) - class TestCase3(PReluTest): - def initTestCase(self): + def init_attr(self): self.attrs = {'mode': "element"} diff --git a/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py b/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py index f00c9e811d7..143c12047d0 100644 --- a/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py @@ -24,8 +24,6 @@ NOT_CHECK_OP_LIST = [ 'elementwise_min', 'elementwise_pow', 'fused_elemwise_activation', - # prelu op's input alpha must be 1-d and only has one data in 'all' mode - 'prelu' ] NEED_TO_FIX_OP_LIST = [ -- GitLab