diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 7b70daa1507d87e51d9911687e8d7814480280c8..642ac5e26a25e7f8f5baee8055f926a0dd369ccb 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -218,6 +218,13 @@ def amp_guard(enable=True, % tracer._expected_place) enable = False + if tracer._expected_place.is_gpu_place(): + prop = paddle.device.cuda.get_device_capability() + if prop[0] < 7: + warnings.warn( + "AMP only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d." + % (paddle.device.cuda.get_device_name(), prop[0], prop[1])) + if level == 'O1': amp_level = AMP_LEVEL.O1 _white_list = WHITE_LIST