未验证 提交 7c542c29 编写于 作者: Z zhangbo9674 提交者: GitHub

Improve some control conditions of bfloat16 OP unittest (#45639)

上级 b12c27eb
......@@ -108,8 +108,10 @@ class TestFP16ElementwiseAddOp(TestElementwiseAddOp):
@unittest.skipIf(
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100,
"core is not compiled with CUDA and cudnn version need larger than 8.1.0")
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100
or paddle.device.cuda.get_device_capability()[0] < 8,
"only support compiled with CUDA and cudnn version need larger than 8.1.0 and device's compute capability is at least 8.0"
)
class TestBF16ElementwiseAddOp(OpTest):
def setUp(self):
......
......@@ -62,9 +62,11 @@ class TestElementwiseOp(OpTest):
no_grad_set=set('Y'))
@unittest.skipIf(
core.is_compiled_with_cuda() and core.cudnn_version() < 8100,
"run test when gpu is availble and the minimum cudnn version is 8.1.0.")
@unittest.skipIf(core.is_compiled_with_cuda() and (
core.cudnn_version() < 8100
or paddle.device.cuda.get_device_capability()[0] < 8
), "run test when gpu is availble and the minimum cudnn version is 8.1.0 and gpu's compute capability is at least 8.0."
)
class TestElementwiseBF16Op(OpTest):
def setUp(self):
......
......@@ -417,7 +417,9 @@ class TestBF16ScaleBiasLayerNorm(unittest.TestCase):
return y_np, x_g_np, w_g_np, b_g_np
def test_main(self):
if (not core.is_compiled_with_cuda()) or (core.cudnn_version() < 8100):
if (not core.is_compiled_with_cuda()) or (
core.cudnn_version() <
8100) or (paddle.device.cuda.get_device_capability()[0] < 8):
return
x_np = np.random.random([10, 20]).astype('float32')
weight_np = np.random.random([20]).astype('float32')
......
......@@ -360,8 +360,10 @@ class TestSoftmaxBF16Op(OpTest):
@unittest.skipIf(
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100,
"core is not compiled with CUDA and cudnn version need larger than 8.1.0")
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100
or paddle.device.cuda.get_device_capability()[0] < 8,
"only support compiled with CUDA and cudnn version need larger than 8.1.0 and device's compute capability is at least 8.0"
)
class TestSoftmaxBF16CUDNNOp(TestSoftmaxBF16Op):
def init_cudnn(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册