未验证 提交 171ed2cf 编写于 作者: A arlesniak 提交者: GitHub

Added OpTestTool for BF16 (#33977)

* Added OpTestTool for BF16 convenience

* fixes after review, names changed to snake case.

* fixes after review, naming reflects cpu.
上级 4ce66826
...@@ -17,12 +17,14 @@ from __future__ import print_function ...@@ -17,12 +17,14 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16 from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
paddle.enable_static()
def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
"""Reference forward implementation using np.matmul.""" """Reference forward implementation using np.matmul."""
...@@ -236,6 +238,7 @@ class TestMatMulV2MatrixXMatrix5DTranposeYOneDNNOp( ...@@ -236,6 +238,7 @@ class TestMatMulV2MatrixXMatrix5DTranposeYOneDNNOp(
# BF16 TESTS # BF16 TESTS
def create_bf16_test_class(parent): def create_bf16_test_class(parent):
@OpTestTool.skip_if_not_cpu_bf16()
class TestMatMulV2Bf16OneDNNOp(parent): class TestMatMulV2Bf16OneDNNOp(parent):
def set_inputs(self, x, y): def set_inputs(self, x, y):
self.inputs = { self.inputs = {
...@@ -247,15 +250,7 @@ def create_bf16_test_class(parent): ...@@ -247,15 +250,7 @@ def create_bf16_test_class(parent):
self.attrs['mkldnn_data_type'] = "bfloat16" self.attrs['mkldnn_data_type'] = "bfloat16"
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): self.check_output_with_place(core.CPUPlace())
self.skipTest(
"OneDNN doesn't support bf16 with CUDA, skipping UT" +
self.__class__.__name__)
elif not core.supports_bfloat16():
self.skipTest("Core doesn't support bf16, skipping UT" +
self.__class__.__name__)
else:
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self): def test_check_grad(self):
pass pass
......
...@@ -16,16 +16,14 @@ from __future__ import print_function ...@@ -16,16 +16,14 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 from paddle.fluid.tests.unittests.op_test import OpTestTool, OpTest, skip_check_grad_ci, convert_float_to_uint16
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
paddle.enable_static()
@unittest.skipIf(not core.supports_bfloat16(), @OpTestTool.skip_if_not_cpu_bf16()
"place does not support BF16 evaluation")
@unittest.skipIf(core.is_compiled_with_cuda(),
"core is compiled with CUDA which has no BF implementation")
class TestReduceSumDefaultBF16OneDNNOp(OpTest): class TestReduceSumDefaultBF16OneDNNOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "reduce_sum" self.op_type = "reduce_sum"
......
...@@ -32,7 +32,7 @@ import paddle.fluid.core as core ...@@ -32,7 +32,7 @@ import paddle.fluid.core as core
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.framework import Program, OpProtoHolder, Variable from paddle.fluid.framework import Program, OpProtoHolder, Variable, _current_expected_place
from paddle.fluid.tests.unittests.testsuite import ( from paddle.fluid.tests.unittests.testsuite import (
create_op, create_op,
set_input, set_input,
...@@ -1783,3 +1783,16 @@ class OpTest(unittest.TestCase): ...@@ -1783,3 +1783,16 @@ class OpTest(unittest.TestCase):
fetch_list, fetch_list,
scope=scope, scope=scope,
return_numpy=False))) return_numpy=False)))
class OpTestTool:
@classmethod
def skip_if(cls, condition: object, reason: str):
return unittest.skipIf(condition, reason)
@classmethod
def skip_if_not_cpu_bf16(cls):
return OpTestTool.skip_if(
not (isinstance(_current_expected_place(), core.CPUPlace) and
core.supports_bfloat16()),
"Place does not support BF16 evaluation")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册