diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index a9c66ec2396d748853e507cd1e95b945a0182ffa..0e4c0522df9a7069320f1396309266e64d6ab10a 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -70,8 +70,6 @@ def get_numeric_gradient(place, tensor_to_check = scope.find_var(input_to_check).get_tensor() tensor_size = product(tensor_to_check.shape()) - if tensor_size < 100: - get_numeric_gradient.is_large_shape = False tensor_to_check_dtype = tensor_to_check._dtype() if tensor_to_check_dtype == core.VarDesc.VarType.FP32: tensor_to_check_dtype = np.float32 @@ -178,14 +176,13 @@ class OpTest(unittest.TestCase): cls.call_once = False cls.dtype = None cls.outputs = {} + cls.input_shape_is_large = True np.random.seed(123) random.seed(124) cls._use_system_allocator = _set_use_system_allocator(True) - get_numeric_gradient.is_large_shape = True - @classmethod def tearDownClass(cls): """Restore random seeds""" @@ -238,7 +235,7 @@ class OpTest(unittest.TestCase): "This test of %s op needs check_grad with fp64 precision." % cls.op_type) - if not get_numeric_gradient.is_large_shape \ + if not cls.input_shape_is_large \ and cls.op_type not in check_shape_white_list.NEED_TO_FIX_OP_LIST: raise AssertionError( "Input's shape should be large than or equal to 100 for " + @@ -1319,6 +1316,14 @@ class OpTest(unittest.TestCase): raise AssertionError("no_grad_set must be None, op_type is " + self.op_type + " Op.") + for input_to_check in inputs_to_check: + set_input(self.scope, self.op, self.inputs, place) + tensor_to_check = self.scope.find_var(input_to_check).get_tensor() + tensor_size = six.moves.reduce(lambda a, b: a * b, + tensor_to_check.shape(), 1) + if tensor_size < 100: + self.__class__.input_shape_is_large = False + if not type(output_names) is list: output_names = [output_names] diff --git a/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py b/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py index 9061daa0fe7958dc3057f40daf8b637baf371595..317abe124ea33c01491422941cb7c41bf59d9ace 100644 --- a/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py @@ -30,4 +30,5 @@ NEED_TO_FIX_OP_LIST = [ 'soft_relu', 'squared_l2_distance', 'tree_conv', + 'cvm', ]