From 7b0692a610ba4c16eee05fab28a489b05688e218 Mon Sep 17 00:00:00 2001 From: juncaipeng <52520497+juncaipeng@users.noreply.github.com> Date: Tue, 21 Jan 2020 11:27:58 +0800 Subject: [PATCH] remove skip_check in test_activation_mkldnn_op, test=develop (#22376) * remove skip_check in test_activation_mkldnn_op, test=develop --- .../mkldnn/test_activation_mkldnn_op.py | 3 +-- .../paddle/fluid/tests/unittests/op_test.py | 19 +++++++++---------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py index dcdbb4619b..c988e6275f 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py @@ -17,7 +17,7 @@ from __future__ import print_function import unittest import numpy as np import paddle.fluid.core as core -from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci +from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.test_activation_op import TestRelu, TestTanh, TestSqrt, TestAbs, TestLeakyRelu from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd @@ -111,7 +111,6 @@ class TestMKLDNNAbsDim2(TestAbs): ['X'], 'Out', max_relative_error=0.007, check_dygraph=False) -@skip_check_grad_ci(reason="Use float32 in mkldnn relu op.") class TestMKLDNNReluDim4(TestRelu): def setUp(self): super(TestMKLDNNReluDim4, self).setUp() diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 469ce7cb09..c9cdcf576d 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -208,12 +208,7 @@ class OpTest(unittest.TestCase): return True def is_mkldnn_op_test(): - if (hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True) or \ - (hasattr(cls, "attrs") and "use_mkldnn" in cls.attrs and \ - cls.attrs["use_mkldnn"] == True): - return True - else: - return False + return hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True if not hasattr(cls, "op_type"): raise AssertionError( @@ -321,8 +316,10 @@ class OpTest(unittest.TestCase): def _append_ops(self, block): self.__class__.op_type = self.op_type # for ci check, please not delete it for now - if hasattr(self, "use_mkldnn"): - self.__class__.use_mkldnn = self.use_mkldnn + if (hasattr(self, "use_mkldnn") and self.use_mkldnn == True) or \ + (hasattr(self, "attrs") and "use_mkldnn" in self.attrs and \ + self.attrs["use_mkldnn"] == True): + self.__class__.use_mkldnn = True op_proto = OpProtoHolder.instance().get_op_proto(self.op_type) "infer datatype from inputs and outputs for this test case" self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) @@ -1189,8 +1186,10 @@ class OpTest(unittest.TestCase): check_dygraph=True, inplace_atol=None): self.__class__.op_type = self.op_type - if hasattr(self, "use_mkldnn"): - self.__class__.use_mkldnn = self.use_mkldnn + if (hasattr(self, "use_mkldnn") and self.use_mkldnn == True) or \ + (hasattr(self, "attrs") and "use_mkldnn" in self.attrs and \ + self.attrs["use_mkldnn"] == True): + self.__class__.use_mkldnn = True places = self._get_places() for place in places: res = self.check_output_with_place(place, atol, no_check_set, -- GitLab