diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index a8873b137f3d08ea19cfdda1f65f715b71bfcff7..19007bccc48cab3cc341f51d8bb64dd43f3294f4 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -370,6 +370,7 @@ def amp_guard( "For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d." % (paddle.device.cuda.get_device_name(), prop[0], prop[1]) ) + enable = False elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported(): prop = paddle.device.cuda.get_device_capability() cuda_version = paddle.version.cuda() @@ -382,6 +383,7 @@ def amp_guard( cuda_version, ) ) + enable = False amp_dtype = dtype amp_global_state().amp_dtype = amp_dtype @@ -572,6 +574,46 @@ def amp_decorate( else: return models, optimizers + # check tracer + tracer = _dygraph_tracer() + if not tracer: + raise ValueError( + "current_tracer is None, maybe it is not in imperative mode." + ) + + # check device_type: + if not ( + tracer._expected_place.is_gpu_place() + or tracer._expected_place.is_xpu_place() + or tracer._expected_place.is_custom_place() + ): + if optimizers is None: + return models + else: + return models, optimizers + # For xpu: + if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'): + if optimizers is None: + return models + else: + return models, optimizers + # For custom device: + if tracer._expected_place.is_custom_place() and (dtype == 'bfloat16'): + if optimizers is None: + return models + else: + return models, optimizers + # For gpu float16: Compute Capability should >= 7. + # For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11. + if tracer._expected_place.is_gpu_place(): + if (dtype == 'float16' and not _is_gpu_float16_supported()) or ( + dtype == 'bfloat16' and not _is_gpu_bfloat16_supported() + ): + if optimizers is None: + return models + else: + return models, optimizers + models_is_list = False if isinstance(models, paddle.nn.Layer): models_is_list = False