未验证 提交 2143bd57 编写于 作者: Z zhupengyang 提交者: GitHub

enhance shape check if use user_defined_grads in check_grad (#22722)

上级 fa449c88
......@@ -70,8 +70,6 @@ def get_numeric_gradient(place,
tensor_to_check = scope.find_var(input_to_check).get_tensor()
tensor_size = product(tensor_to_check.shape())
if tensor_size < 100:
get_numeric_gradient.is_large_shape = False
tensor_to_check_dtype = tensor_to_check._dtype()
if tensor_to_check_dtype == core.VarDesc.VarType.FP32:
tensor_to_check_dtype = np.float32
......@@ -178,14 +176,13 @@ class OpTest(unittest.TestCase):
cls.call_once = False
cls.dtype = None
cls.outputs = {}
cls.input_shape_is_large = True
np.random.seed(123)
random.seed(124)
cls._use_system_allocator = _set_use_system_allocator(True)
get_numeric_gradient.is_large_shape = True
@classmethod
def tearDownClass(cls):
"""Restore random seeds"""
......@@ -238,7 +235,7 @@ class OpTest(unittest.TestCase):
"This test of %s op needs check_grad with fp64 precision." %
cls.op_type)
if not get_numeric_gradient.is_large_shape \
if not cls.input_shape_is_large \
and cls.op_type not in check_shape_white_list.NEED_TO_FIX_OP_LIST:
raise AssertionError(
"Input's shape should be large than or equal to 100 for " +
......@@ -1319,6 +1316,14 @@ class OpTest(unittest.TestCase):
raise AssertionError("no_grad_set must be None, op_type is " +
self.op_type + " Op.")
for input_to_check in inputs_to_check:
set_input(self.scope, self.op, self.inputs, place)
tensor_to_check = self.scope.find_var(input_to_check).get_tensor()
tensor_size = six.moves.reduce(lambda a, b: a * b,
tensor_to_check.shape(), 1)
if tensor_size < 100:
self.__class__.input_shape_is_large = False
if not type(output_names) is list:
output_names = [output_names]
......
......@@ -30,4 +30,5 @@ NEED_TO_FIX_OP_LIST = [
'soft_relu',
'squared_l2_distance',
'tree_conv',
'cvm',
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册