未验证 提交 71d3b06c 编写于 作者: W wangxinxin08 提交者: GitHub

fix unittest of conv2d due to V100 do not support bfloat16 (#42496)

上级 e052fde7
...@@ -172,9 +172,9 @@ def create_test_cudnn_fp16_class(parent, grad_check=True): ...@@ -172,9 +172,9 @@ def create_test_cudnn_fp16_class(parent, grad_check=True):
def create_test_cudnn_bf16_class(parent): def create_test_cudnn_bf16_class(parent):
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100, not core.is_compiled_with_cuda() or
"core is not compiled with CUDA and cudnn version need larger than 8.1.0" not core.is_bfloat16_supported(core.CUDAPlace(0)),
) "core is not compiled with CUDA and do not support bfloat16")
class TestConv2DCUDNNBF16(parent): class TestConv2DCUDNNBF16(parent):
def get_numeric_grad(self, place, check_name): def get_numeric_grad(self, place, check_name):
scope = core.Scope() scope = core.Scope()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册