diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_prelu_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_prelu_mkldnn_op.py index 5489bf109dd54aea3440e66811f75960ed117fc7..901aa200a377511bcd7e69cc4309e813a1ff8a47 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_prelu_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_prelu_mkldnn_op.py @@ -18,7 +18,7 @@ import unittest import numpy as np import paddle import paddle.fluid.core as core -from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16 +from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16 def ref_prelu(x, weight, mode): @@ -109,6 +109,7 @@ class TestPReluModeAllAlpha1DOneDNNOp(TestPReluModeAllOneDNNOp): # BF16 TESTS def create_bf16_test_class(parent): + @OpTestTool.skip_if_not_cpu_bf16() class TestPReluBF16OneDNNOp(parent): def set_inputs(self, ): self.inputs = { @@ -143,42 +144,25 @@ def create_bf16_test_class(parent): self.dout = dout def test_check_output(self): - if core.is_compiled_with_cuda(): - self.skipTest( - "OneDNN doesn't support bf16 with CUDA, skipping UT" + - self.__class__.__name__) - elif not core.supports_bfloat16(): - self.skipTest("Core doesn't support bf16, skipping UT" + - self.__class__.__name__) - else: - self.check_output_with_place(core.CPUPlace()) + self.check_output_with_place(core.CPUPlace()) def test_check_grad(self): - if core.is_compiled_with_cuda() or not core.supports_bfloat16(): - self.skipTest( - "Core is compiled with cuda or doesn't support bf16, kipping UT" - + self.__class__.__name__) - else: - self.calculate_grads() - self.check_grad_with_place( - core.CPUPlace(), ["X", "Alpha"], - "Out", - user_defined_grads=[self.dx, self.dalpha], - user_defined_grad_outputs=[ - convert_float_to_uint16(self.dout) - ]) + self.calculate_grads() + self.check_grad_with_place( + core.CPUPlace(), ["X", "Alpha"], + "Out", + user_defined_grads=[self.dx, self.dalpha], + user_defined_grad_outputs=[convert_float_to_uint16(self.dout)]) cls_name = "{0}_{1}".format(parent.__name__, "BF16") TestPReluBF16OneDNNOp.__name__ = cls_name globals()[cls_name] = TestPReluBF16OneDNNOp -#TODO jakpiase -#enable bf16 tests back when oneDNN bf16 class will be ready -#create_bf16_test_class(TestPReluModeChannelOneDNNOp) -#create_bf16_test_class(TestPReluModeElementOneDNNOp) -#create_bf16_test_class(TestPReluModeChannel3DOneDNNOp) -#create_bf16_test_class(TestPReluModeChannelAlpha1DOneDNNOp) +create_bf16_test_class(TestPReluModeChannelOneDNNOp) +create_bf16_test_class(TestPReluModeElementOneDNNOp) +create_bf16_test_class(TestPReluModeChannel3DOneDNNOp) +create_bf16_test_class(TestPReluModeChannelAlpha1DOneDNNOp) if __name__ == "__main__": paddle.enable_static()