未验证 提交 68f51239 编写于 作者: J jakpiase 提交者: GitHub

enabled bf16 tests in prelu (#34196)

上级 d8839292
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16 from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
def ref_prelu(x, weight, mode): def ref_prelu(x, weight, mode):
...@@ -109,6 +109,7 @@ class TestPReluModeAllAlpha1DOneDNNOp(TestPReluModeAllOneDNNOp): ...@@ -109,6 +109,7 @@ class TestPReluModeAllAlpha1DOneDNNOp(TestPReluModeAllOneDNNOp):
# BF16 TESTS # BF16 TESTS
def create_bf16_test_class(parent): def create_bf16_test_class(parent):
@OpTestTool.skip_if_not_cpu_bf16()
class TestPReluBF16OneDNNOp(parent): class TestPReluBF16OneDNNOp(parent):
def set_inputs(self, ): def set_inputs(self, ):
self.inputs = { self.inputs = {
...@@ -143,42 +144,25 @@ def create_bf16_test_class(parent): ...@@ -143,42 +144,25 @@ def create_bf16_test_class(parent):
self.dout = dout self.dout = dout
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): self.check_output_with_place(core.CPUPlace())
self.skipTest(
"OneDNN doesn't support bf16 with CUDA, skipping UT" +
self.__class__.__name__)
elif not core.supports_bfloat16():
self.skipTest("Core doesn't support bf16, skipping UT" +
self.__class__.__name__)
else:
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda() or not core.supports_bfloat16(): self.calculate_grads()
self.skipTest( self.check_grad_with_place(
"Core is compiled with cuda or doesn't support bf16, kipping UT" core.CPUPlace(), ["X", "Alpha"],
+ self.__class__.__name__) "Out",
else: user_defined_grads=[self.dx, self.dalpha],
self.calculate_grads() user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])
self.check_grad_with_place(
core.CPUPlace(), ["X", "Alpha"],
"Out",
user_defined_grads=[self.dx, self.dalpha],
user_defined_grad_outputs=[
convert_float_to_uint16(self.dout)
])
cls_name = "{0}_{1}".format(parent.__name__, "BF16") cls_name = "{0}_{1}".format(parent.__name__, "BF16")
TestPReluBF16OneDNNOp.__name__ = cls_name TestPReluBF16OneDNNOp.__name__ = cls_name
globals()[cls_name] = TestPReluBF16OneDNNOp globals()[cls_name] = TestPReluBF16OneDNNOp
#TODO jakpiase create_bf16_test_class(TestPReluModeChannelOneDNNOp)
#enable bf16 tests back when oneDNN bf16 class will be ready create_bf16_test_class(TestPReluModeElementOneDNNOp)
#create_bf16_test_class(TestPReluModeChannelOneDNNOp) create_bf16_test_class(TestPReluModeChannel3DOneDNNOp)
#create_bf16_test_class(TestPReluModeElementOneDNNOp) create_bf16_test_class(TestPReluModeChannelAlpha1DOneDNNOp)
#create_bf16_test_class(TestPReluModeChannel3DOneDNNOp)
#create_bf16_test_class(TestPReluModeChannelAlpha1DOneDNNOp)
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册