未验证 提交 68f51239 编写于 作者: J jakpiase 提交者: GitHub

enabled bf16 tests in prelu (#34196)

上级 d8839292
......@@ -18,7 +18,7 @@ import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
def ref_prelu(x, weight, mode):
......@@ -109,6 +109,7 @@ class TestPReluModeAllAlpha1DOneDNNOp(TestPReluModeAllOneDNNOp):
# BF16 TESTS
def create_bf16_test_class(parent):
@OpTestTool.skip_if_not_cpu_bf16()
class TestPReluBF16OneDNNOp(parent):
def set_inputs(self, ):
self.inputs = {
......@@ -143,42 +144,25 @@ def create_bf16_test_class(parent):
self.dout = dout
def test_check_output(self):
if core.is_compiled_with_cuda():
self.skipTest(
"OneDNN doesn't support bf16 with CUDA, skipping UT" +
self.__class__.__name__)
elif not core.supports_bfloat16():
self.skipTest("Core doesn't support bf16, skipping UT" +
self.__class__.__name__)
else:
self.check_output_with_place(core.CPUPlace())
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
if core.is_compiled_with_cuda() or not core.supports_bfloat16():
self.skipTest(
"Core is compiled with cuda or doesn't support bf16, kipping UT"
+ self.__class__.__name__)
else:
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X", "Alpha"],
"Out",
user_defined_grads=[self.dx, self.dalpha],
user_defined_grad_outputs=[
convert_float_to_uint16(self.dout)
])
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X", "Alpha"],
"Out",
user_defined_grads=[self.dx, self.dalpha],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])
cls_name = "{0}_{1}".format(parent.__name__, "BF16")
TestPReluBF16OneDNNOp.__name__ = cls_name
globals()[cls_name] = TestPReluBF16OneDNNOp
#TODO jakpiase
#enable bf16 tests back when oneDNN bf16 class will be ready
#create_bf16_test_class(TestPReluModeChannelOneDNNOp)
#create_bf16_test_class(TestPReluModeElementOneDNNOp)
#create_bf16_test_class(TestPReluModeChannel3DOneDNNOp)
#create_bf16_test_class(TestPReluModeChannelAlpha1DOneDNNOp)
create_bf16_test_class(TestPReluModeChannelOneDNNOp)
create_bf16_test_class(TestPReluModeElementOneDNNOp)
create_bf16_test_class(TestPReluModeChannel3DOneDNNOp)
create_bf16_test_class(TestPReluModeChannelAlpha1DOneDNNOp)
if __name__ == "__main__":
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册