diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 406fdb672d498abae4a45b5f6276b8d5dbd95998..469ce7cb09e20e0b1526bfc64a883b374d8fe61a 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -198,7 +198,7 @@ class OpTest(unittest.TestCase): all_op_kernels = core._get_all_register_op_kernels() grad_op = op_type + '_grad' if grad_op in all_op_kernels.keys(): - if hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True: + if is_mkldnn_op_test(): grad_op_kernels = all_op_kernels[grad_op] for grad_op_kernel in grad_op_kernels: if 'MKLDNN' in grad_op_kernel: @@ -207,6 +207,14 @@ class OpTest(unittest.TestCase): return False return True + def is_mkldnn_op_test(): + if (hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True) or \ + (hasattr(cls, "attrs") and "use_mkldnn" in cls.attrs and \ + cls.attrs["use_mkldnn"] == True): + return True + else: + return False + if not hasattr(cls, "op_type"): raise AssertionError( "This test do not have op_type in class attrs, " @@ -226,7 +234,7 @@ class OpTest(unittest.TestCase): if cls.dtype in [np.float32, np.float64] \ and cls.op_type not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST \ and not hasattr(cls, 'exist_fp64_check_grad') \ - and (not hasattr(cls, "use_mkldnn") or cls.use_mkldnn == False): + and not is_mkldnn_op_test(): raise AssertionError( "This test of %s op needs check_grad with fp64 precision." % cls.op_type)