未验证 提交 7b0692a6 编写于 作者: J juncaipeng 提交者: GitHub

remove skip_check in test_activation_mkldnn_op, test=develop (#22376)

* remove skip_check in test_activation_mkldnn_op, test=develop
上级 5f655d2c
...@@ -17,7 +17,7 @@ from __future__ import print_function ...@@ -17,7 +17,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_activation_op import TestRelu, TestTanh, TestSqrt, TestAbs, TestLeakyRelu from paddle.fluid.tests.unittests.test_activation_op import TestRelu, TestTanh, TestSqrt, TestAbs, TestLeakyRelu
from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd
...@@ -111,7 +111,6 @@ class TestMKLDNNAbsDim2(TestAbs): ...@@ -111,7 +111,6 @@ class TestMKLDNNAbsDim2(TestAbs):
['X'], 'Out', max_relative_error=0.007, check_dygraph=False) ['X'], 'Out', max_relative_error=0.007, check_dygraph=False)
@skip_check_grad_ci(reason="Use float32 in mkldnn relu op.")
class TestMKLDNNReluDim4(TestRelu): class TestMKLDNNReluDim4(TestRelu):
def setUp(self): def setUp(self):
super(TestMKLDNNReluDim4, self).setUp() super(TestMKLDNNReluDim4, self).setUp()
......
...@@ -208,12 +208,7 @@ class OpTest(unittest.TestCase): ...@@ -208,12 +208,7 @@ class OpTest(unittest.TestCase):
return True return True
def is_mkldnn_op_test(): def is_mkldnn_op_test():
if (hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True) or \ return hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True
(hasattr(cls, "attrs") and "use_mkldnn" in cls.attrs and \
cls.attrs["use_mkldnn"] == True):
return True
else:
return False
if not hasattr(cls, "op_type"): if not hasattr(cls, "op_type"):
raise AssertionError( raise AssertionError(
...@@ -321,8 +316,10 @@ class OpTest(unittest.TestCase): ...@@ -321,8 +316,10 @@ class OpTest(unittest.TestCase):
def _append_ops(self, block): def _append_ops(self, block):
self.__class__.op_type = self.op_type # for ci check, please not delete it for now self.__class__.op_type = self.op_type # for ci check, please not delete it for now
if hasattr(self, "use_mkldnn"): if (hasattr(self, "use_mkldnn") and self.use_mkldnn == True) or \
self.__class__.use_mkldnn = self.use_mkldnn (hasattr(self, "attrs") and "use_mkldnn" in self.attrs and \
self.attrs["use_mkldnn"] == True):
self.__class__.use_mkldnn = True
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type) op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
"infer datatype from inputs and outputs for this test case" "infer datatype from inputs and outputs for this test case"
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
...@@ -1189,8 +1186,10 @@ class OpTest(unittest.TestCase): ...@@ -1189,8 +1186,10 @@ class OpTest(unittest.TestCase):
check_dygraph=True, check_dygraph=True,
inplace_atol=None): inplace_atol=None):
self.__class__.op_type = self.op_type self.__class__.op_type = self.op_type
if hasattr(self, "use_mkldnn"): if (hasattr(self, "use_mkldnn") and self.use_mkldnn == True) or \
self.__class__.use_mkldnn = self.use_mkldnn (hasattr(self, "attrs") and "use_mkldnn" in self.attrs and \
self.attrs["use_mkldnn"] == True):
self.__class__.use_mkldnn = True
places = self._get_places() places = self._get_places()
for place in places: for place in places:
res = self.check_output_with_place(place, atol, no_check_set, res = self.check_output_with_place(place, atol, no_check_set,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册