diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 01d64550321d5e96d3dddeebb2509e3d96f3237b..37134764e9d1c8a7ffbd0a2f53668d16f90a1683 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -149,7 +149,7 @@ def _is_gpu_bfloat16_supported(): """ prop = paddle.device.cuda.get_device_capability() cuda_version = paddle.version.cuda() - if cuda_version is not None: + if cuda_version is not None and cuda_version != 'False': cuda_version_check = int(cuda_version.split('.')[0]) >= 11 else: cuda_version_check = False