未验证 提交 fb15aa1c 编写于 作者: C cc 提交者: GitHub

Ngraph op tests skip check grad ci (#22688)

* ngraph op test skip check grad ci, test=develop
上级 1b561da1
......@@ -17,11 +17,10 @@ from __future__ import print_function
import unittest, sys
sys.path.append("../")
import numpy as np
from op_test import OpTest, skip_check_grad_ci
from op_test import OpTest
from test_activation_op import TestAbs, TestGelu, TestSigmoid, TestSquare, TestRelu, TestTanh
@skip_check_grad_ci(reason="Use float32 in ngraph relu op.")
class TestNGRAPHReluDim4(TestRelu):
def setUp(self):
super(TestNGRAPHReluDim4, self).setUp()
......
......@@ -210,6 +210,9 @@ class OpTest(unittest.TestCase):
def is_mkldnn_op_test():
return hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True
def is_ngraph_op_test():
return hasattr(cls, "use_ngraph") and cls.use_ngraph == True
if not hasattr(cls, "op_type"):
raise AssertionError(
"This test do not have op_type in class attrs, "
......@@ -229,6 +232,7 @@ class OpTest(unittest.TestCase):
if cls.dtype in [np.float32, np.float64] \
and cls.op_type not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST \
and not hasattr(cls, 'exist_fp64_check_grad') \
and not is_ngraph_op_test() \
and not is_mkldnn_op_test():
raise AssertionError(
"This test of %s op needs check_grad with fp64 precision." %
......@@ -320,6 +324,10 @@ class OpTest(unittest.TestCase):
(hasattr(self, "attrs") and "use_mkldnn" in self.attrs and \
self.attrs["use_mkldnn"] == True):
self.__class__.use_mkldnn = True
if fluid.core.is_compiled_with_ngraph() and \
fluid.core.globals()['FLAGS_use_ngraph']:
self.__class__.use_ngraph = True
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
"infer datatype from inputs and outputs for this test case"
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
......@@ -936,14 +944,16 @@ class OpTest(unittest.TestCase):
attrs_use_mkldnn = hasattr(
self,
'attrs') and bool(self.attrs.get('use_mkldnn', False))
flags_use_ngraph = fluid.core.globals()["FLAGS_use_ngraph"]
attrs_use_ngraph = hasattr(
self,
'attrs') and bool(self.attrs.get('use_ngraph', False))
if flags_use_mkldnn or attrs_use_mkldnn:
warnings.warn(
"check inplace_grad for ops using mkldnn is not supported"
)
continue
use_ngraph = fluid.core.is_compiled_with_ngraph(
) and fluid.core.globals()["FLAGS_use_ngraph"]
if use_ngraph:
if flags_use_ngraph or attrs_use_ngraph:
warnings.warn(
"check inplace_grad for ops using ngraph is not supported"
)
......@@ -1190,6 +1200,10 @@ class OpTest(unittest.TestCase):
(hasattr(self, "attrs") and "use_mkldnn" in self.attrs and \
self.attrs["use_mkldnn"] == True):
self.__class__.use_mkldnn = True
if fluid.core.is_compiled_with_ngraph() and \
fluid.core.globals()['FLAGS_use_ngraph']:
self.__class__.use_ngraph = True
places = self._get_places()
for place in places:
res = self.check_output_with_place(place, atol, no_check_set,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册