未验证 提交 e052fde7 编写于 作者: W wawltor 提交者: GitHub

fix the v100 cuda11.2 matmul_v2 and elementwise_div bug (#42479)

上级 a3917625
...@@ -2972,6 +2972,10 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2972,6 +2972,10 @@ All parameter, weight, gradient are variables in Paddle.
// Only GPUs with Compute Capability >= 53 support float16 // Only GPUs with Compute Capability >= 53 support float16
return platform::GetGPUComputeCapability(place.device) >= 53; return platform::GetGPUComputeCapability(place.device) >= 53;
}); });
m.def("is_bfloat16_supported", [](const platform::CUDAPlace &place) -> bool {
// Only GPUs with Compute Capability >= 80 support bfloat16
return platform::GetGPUComputeCapability(place.device) >= 80;
});
#endif #endif
m.def("set_feed_variable", m.def("set_feed_variable",
......
...@@ -60,9 +60,9 @@ class ElementwiseDivOp(OpTest): ...@@ -60,9 +60,9 @@ class ElementwiseDivOp(OpTest):
pass pass
@unittest.skipIf( @unittest.skipIf(not core.is_compiled_with_cuda() or
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100, not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and cudnn version need larger than 8.1.0") "core is not compiled with CUDA and not support the bfloat16")
class TestElementwiseDivOpBF16(OpTest): class TestElementwiseDivOpBF16(OpTest):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
......
...@@ -385,9 +385,9 @@ create_test_fp16_class(TestMatMulOp17) ...@@ -385,9 +385,9 @@ create_test_fp16_class(TestMatMulOp17)
def create_test_bf16_class(parent, atol=0.01): def create_test_bf16_class(parent, atol=0.01):
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda() or core.cudnn_version() < 8100, not core.is_compiled_with_cuda() or
"core is not compiled with CUDA and cudnn version need larger than 8.1.0" not core.is_bfloat16_supported(core.CUDAPlace(0)),
) "core is not compiled with CUDA and not support the bfloat16")
class TestMatMulOpBf16Case(parent): class TestMatMulOpBf16Case(parent):
def get_numeric_grad(self, place, check_name): def get_numeric_grad(self, place, check_name):
scope = core.Scope() scope = core.Scope()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册