未验证 提交 acfdb8b3 编写于 作者: X xiongkun 提交者: GitHub

fix test-einsum-v2 unittest in cuda 11.7 (#44772)

上级 65a3530c
......@@ -17,6 +17,7 @@ import contextlib
import unittest
import paddle
from paddle.fluid import core
from paddle.fluid.dygraph.amp.auto_cast import _is_gpu_bfloat16_supported
import os
......@@ -529,16 +530,15 @@ class TestBF16(unittest.TestCase):
"""
def test_shape(self):
cuda_major = paddle.version.cuda().split('.')[0].strip()
if paddle.is_compiled_with_cuda() and int(cuda_major) >= 11:
""" MatmulKernel support bfloat16 only if cuda_major > 11.0.
if paddle.is_compiled_with_cuda() and _is_gpu_bfloat16_supported():
""" MatmulKernel support bfloat16 only if cuda_major >= 11.0 and Compute Capability >= 8.0
"""
A = paddle.to_tensor(np.array([1.0, 2.0])).astype(paddle.bfloat16)
A = A.cuda()
B = paddle.to_tensor(np.array([2.0, 3.0])).astype(paddle.bfloat16)
B = B.cuda()
C = paddle.einsum('i,i->', A, B)
self.assertEqual(C.item(), 8.0)
self.assertEqual(C.astype(paddle.float32).item(), 8.0)
class TestComplex(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册