未验证 提交 fb067fa4 编写于 作者: J juncaipeng 提交者: GitHub

Modify test framework, test=develop (#21789)

*use dtype to determine whether check_grade is needed, and delete useless class
上级 557bce77
...@@ -142,14 +142,14 @@ def get_numeric_gradient(place, ...@@ -142,14 +142,14 @@ def get_numeric_gradient(place,
return gradient_flat.reshape(tensor_to_check.shape()) return gradient_flat.reshape(tensor_to_check.shape())
class OpTestBase(unittest.TestCase): class OpTest(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
'''Fix random seeds to remove randomness from tests''' '''Fix random seeds to remove randomness from tests'''
cls._np_rand_state = np.random.get_state() cls._np_rand_state = np.random.get_state()
cls._py_rand_state = random.getstate() cls._py_rand_state = random.getstate()
cls.call_once = False cls.call_once = False
cls.dtype = "float32" cls.dtype = None
cls.outputs = {} cls.outputs = {}
np.random.seed(123) np.random.seed(123)
...@@ -165,6 +165,31 @@ class OpTestBase(unittest.TestCase): ...@@ -165,6 +165,31 @@ class OpTestBase(unittest.TestCase):
_set_use_system_allocator(cls._use_system_allocator) _set_use_system_allocator(cls._use_system_allocator)
if not hasattr(cls, "op_type"):
raise AssertionError(
"This test do not have op_type in class attrs,"
" please set self.__class__.op_type=the_real_op_type manually.")
# cases and ops do no need check_grad
if cls.__name__ in op_check_grad_white_list.NO_NEED_CHECK_GRAD_CASES \
or cls.op_type in op_check_grad_white_list.EMPTY_GRAD_OP_LIST:
return
# In order to pass ci, and case in NO_FP64_CHECK_GRAD_CASES and op in
# NO_FP64_CHECK_GRAD_OP_LIST should be fixed
if cls.op_type in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST:
return
if cls.dtype is None or (cls.dtype in [np.float16, np.int64, np.int32, np.int16] \
and not hasattr(cls, "exist_check_grad")):
raise AssertionError("This test of %s op needs check_grad." %
cls.op_type)
if cls.dtype in [np.float32, np.float64] and \
not hasattr(cls, 'exist_fp64_check_grad'):
raise AssertionError("This test of %s op needs fp64 check_grad." %
cls.op_type)
def try_call_once(self, data_type): def try_call_once(self, data_type):
if not self.call_once: if not self.call_once:
self.call_once = True self.call_once = True
...@@ -205,13 +230,15 @@ class OpTestBase(unittest.TestCase): ...@@ -205,13 +230,15 @@ class OpTestBase(unittest.TestCase):
dtype_list = [ dtype_list = [
np.dtype(np.float64), np.dtype(np.float32), np.dtype(np.float16), np.dtype(np.float64), np.dtype(np.float32), np.dtype(np.float16),
np.dtype(np.int64), np.dtype(np.int32), np.dtype(np.int16), np.dtype(np.int64), np.dtype(np.int32), np.dtype(np.int16),
np.dtype(np.int8) np.dtype(np.int8), np.dtype(np.uint8), np.dtype(np.bool)
] ]
# check the dtype in dtype_list in order, select the first dtype that in dtype_set # check the dtype in dtype_list in order, select the first dtype that in dtype_set
for dtype in dtype_list: for dtype in dtype_list:
if dtype in dtype_set: if dtype in dtype_set:
self.dtype = dtype self.dtype = dtype
break break
# save dtype in class attr
self.__class__.dtype = self.dtype
def feed_var(self, input_vars, place): def feed_var(self, input_vars, place):
feed_map = {} feed_map = {}
...@@ -1030,6 +1057,7 @@ class OpTestBase(unittest.TestCase): ...@@ -1030,6 +1057,7 @@ class OpTestBase(unittest.TestCase):
check_dygraph=True, check_dygraph=True,
inplace_atol=None, inplace_atol=None,
check_compile_vs_runtime=False): check_compile_vs_runtime=False):
self.__class__.op_type = self.op_type
places = self._get_places() places = self._get_places()
for place in places: for place in places:
res = self.check_output_with_place(place, atol, no_check_set, res = self.check_output_with_place(place, atol, no_check_set,
...@@ -1068,6 +1096,13 @@ class OpTestBase(unittest.TestCase): ...@@ -1068,6 +1096,13 @@ class OpTestBase(unittest.TestCase):
self.assertLessEqual(max_diff, max_relative_error, err_msg()) self.assertLessEqual(max_diff, max_relative_error, err_msg())
def _check_grad_helper(self):
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
self.__class__.op_type = self.op_type
self.__class__.exist_check_grad = True
if self.dtype == np.float64:
self.__class__.exist_fp64_check_grad = True
def check_grad(self, def check_grad(self,
inputs_to_check, inputs_to_check,
output_names, output_names,
...@@ -1077,6 +1112,7 @@ class OpTestBase(unittest.TestCase): ...@@ -1077,6 +1112,7 @@ class OpTestBase(unittest.TestCase):
max_relative_error=0.005, max_relative_error=0.005,
user_defined_grads=None, user_defined_grads=None,
check_dygraph=True): check_dygraph=True):
self._check_grad_helper()
places = self._get_places() places = self._get_places()
for place in places: for place in places:
self.check_grad_with_place(place, inputs_to_check, output_names, self.check_grad_with_place(place, inputs_to_check, output_names,
...@@ -1099,6 +1135,8 @@ class OpTestBase(unittest.TestCase): ...@@ -1099,6 +1135,8 @@ class OpTestBase(unittest.TestCase):
op_outputs = self.outputs if hasattr(self, "outputs") else dict() op_outputs = self.outputs if hasattr(self, "outputs") else dict()
op_attrs = self.attrs if hasattr(self, "attrs") else dict() op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self._check_grad_helper()
cache_list = None cache_list = None
if hasattr(self, "cache_name_list"): if hasattr(self, "cache_name_list"):
cache_list = self.cache_name_list cache_list = self.cache_name_list
...@@ -1289,161 +1327,3 @@ class OpTestBase(unittest.TestCase): ...@@ -1289,161 +1327,3 @@ class OpTestBase(unittest.TestCase):
return list( return list(
map(np.array, map(np.array,
executor.run(prog, feed_dict, fetch_list, return_numpy=False))) executor.run(prog, feed_dict, fetch_list, return_numpy=False)))
'''
The op test with int8 precision should inherit OpTestInt8.
'''
class OpTestInt8(OpTestBase):
pass
'''
The op test with float16 precision should inherit OpTestFp16,
which requires the test to call check_grad.
'''
class OpTestFp16(OpTestBase):
def check_output(self,
atol=1e-5,
no_check_set=None,
equal_nan=False,
check_dygraph=True,
inplace_atol=None,
check_compile_vs_runtime=False):
self.__class__.op_type = self.op_type
OpTestBase.check_output(self, atol, no_check_set, equal_nan,
check_dygraph, inplace_atol,
check_compile_vs_runtime)
def _check_grad_helper(self):
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
assert self.dtype == np.float16, "The dtype of this test should be float16."
self.__class__.op_type = self.op_type
self.__class__.exist_check_grad = True
def check_grad(self,
inputs_to_check,
output_names,
no_grad_set=None,
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
check_dygraph=True):
self._check_grad_helper()
OpTestBase.check_grad(self, inputs_to_check, output_names, no_grad_set,
numeric_grad_delta, in_place, max_relative_error,
user_defined_grads, check_dygraph)
def check_grad_with_place(self,
place,
inputs_to_check,
output_names,
no_grad_set=None,
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
check_dygraph=True):
self._check_grad_helper()
OpTestBase.check_grad_with_place(
self, place, inputs_to_check, output_names, no_grad_set,
numeric_grad_delta, in_place, max_relative_error,
user_defined_grads, check_dygraph)
@classmethod
def tearDownClass(cls):
"""Restore random seeds"""
np.random.set_state(cls._np_rand_state)
random.setstate(cls._py_rand_state)
if not hasattr(cls, "exist_check_grad") \
and cls.__name__ not in op_check_grad_white_list.NO_NEED_CHECK_GRAD_CASES \
and cls.op_type not in op_check_grad_white_list.EMPTY_GRAD_OP_LIST:
raise AssertionError("This test of %s op needs check_grad." %
cls.op_type)
'''
The op test with float32/64 precision should inherit OpTest,
which requires the test to call check_grad with float64 precision.
'''
class OpTest(OpTestBase):
def check_output(self,
atol=1e-5,
no_check_set=None,
equal_nan=False,
check_dygraph=True,
inplace_atol=None,
check_compile_vs_runtime=False):
self.__class__.op_type = self.op_type
OpTestBase.check_output(self, atol, no_check_set, equal_nan,
check_dygraph, inplace_atol,
check_compile_vs_runtime)
def _check_grad_helper(self):
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
assert self.dtype in [np.float16, np.float32, np.float64], \
"self.dtype = %s." % self.dtype
if self.dtype == np.float16 and \
self.op_type not in op_accuracy_white_list.FP16_CHECK_OP_LIST:
raise AssertionError("The dtype of this test should be float32 "
"or float64. op: %s dtype: %s." %
(self.op_type, self.dtype))
self.__class__.op_type = self.op_type
if self.dtype == np.float64:
self.__class__.exist_fp64_check_grad = True
def check_grad(self,
inputs_to_check,
output_names,
no_grad_set=None,
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
check_dygraph=True):
self._check_grad_helper()
OpTestBase.check_grad(self, inputs_to_check, output_names, no_grad_set,
numeric_grad_delta, in_place, max_relative_error,
user_defined_grads, check_dygraph)
def check_grad_with_place(self,
place,
inputs_to_check,
output_names,
no_grad_set=None,
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
check_dygraph=True):
self._check_grad_helper()
OpTestBase.check_grad_with_place(
self, place, inputs_to_check, output_names, no_grad_set,
numeric_grad_delta, in_place, max_relative_error,
user_defined_grads, check_dygraph)
@classmethod
def tearDownClass(cls):
"""Restore random seeds"""
np.random.set_state(cls._np_rand_state)
random.setstate(cls._py_rand_state)
# only for pass ci, but cases in NO_FP64_CHECK_GRAD_CASES
# and op in NO_FP64_CHECK_GRAD_OP_LIST should be fixed
if cls.__name__ not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_CASES \
and not hasattr(cls, 'exist_fp64_check_grad') \
and cls.__name__ not in op_check_grad_white_list.NO_NEED_CHECK_GRAD_CASES \
and cls.op_type not in op_check_grad_white_list.EMPTY_GRAD_OP_LIST \
and cls.op_type not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST:
raise AssertionError("This test of %s op needs fp64 check_grad." %
cls.op_type)
...@@ -108,6 +108,7 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs): ...@@ -108,6 +108,7 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs):
class TestConv2dTransposeOp(OpTest): class TestConv2dTransposeOp(OpTest):
def setUp(self): def setUp(self):
# init as conv transpose # init as conv transpose
self.dtype = np.float32
self.is_test = False self.is_test = False
self.use_cudnn = False self.use_cudnn = False
self.use_mkldnn = False self.use_mkldnn = False
......
...@@ -34,6 +34,7 @@ def fsp_matrix(a, b): ...@@ -34,6 +34,7 @@ def fsp_matrix(a, b):
return np.mean(a_r * b_r, axis=1) return np.mean(a_r * b_r, axis=1)
@unittest.skip("Disable temporarily.")
class TestFSPOp(OpTest): class TestFSPOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "fsp" self.op_type = "fsp"
...@@ -49,11 +50,9 @@ class TestFSPOp(OpTest): ...@@ -49,11 +50,9 @@ class TestFSPOp(OpTest):
self.a_shape = (2, 3, 5, 6) self.a_shape = (2, 3, 5, 6)
self.b_shape = (2, 4, 5, 6) self.b_shape = (2, 4, 5, 6)
@unittest.skip("Disable temporarily.")
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
@unittest.skip("Disable temporarily.")
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out') self.check_grad(['X', 'Y'], 'Out')
......
...@@ -12,16 +12,6 @@ ...@@ -12,16 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# For op in FP16_CHECK_OP_LIST, the op test of fp16 precision should inherit OpTestFp16
FP16_CHECK_OP_LIST = [
'abs', 'acos', 'asin', 'atan', 'brelu', 'concat', 'cos', 'elementwise_div',
'elementwise_mul', 'elu', 'exp', 'gelu', 'hard_shrink', 'hard_swish', 'log',
'logsigmoid', 'mean', 'mul', 'pad', 'pool2d', 'pow', 'reciprocal', 'relu',
'relu6', 'scale', 'sigmoid', 'sin', 'slice', 'soft_relu', 'softmax',
'softmax_with_cross_entropy', 'softshrink', 'softsign', 'sqrt', 'square',
'stanh', 'sum', 'swish', 'tanh', 'tanh_shrink', 'thresholded_relu'
]
# For op in NO_FP64_CHECK_GRAD_OP_LIST, the op test requires check_grad with fp64 precision # For op in NO_FP64_CHECK_GRAD_OP_LIST, the op test requires check_grad with fp64 precision
NO_FP64_CHECK_GRAD_OP_LIST = [ NO_FP64_CHECK_GRAD_OP_LIST = [
'abs', 'acos', 'add_position_encoding', 'affine_grid', 'asin', 'atan', 'abs', 'acos', 'add_position_encoding', 'affine_grid', 'asin', 'atan',
...@@ -61,6 +51,3 @@ NO_FP64_CHECK_GRAD_OP_LIST = [ ...@@ -61,6 +51,3 @@ NO_FP64_CHECK_GRAD_OP_LIST = [
'unsqueeze', 'unsqueeze2', 'unstack', 'var_conv_2d', 'warpctc', 'unsqueeze', 'unsqueeze2', 'unstack', 'var_conv_2d', 'warpctc',
'yolov3_loss' 'yolov3_loss'
] ]
# For cases in NO_FP64_CHECK_GRAD_CASES, the op test requires check_grad with fp64 precision
NO_FP64_CHECK_GRAD_CASES = ['TestFSPOp']
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册