diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index 6d0bc89296a28c8776764aa375c00d9144b7dc17..b01522b3ef8e069e1ba17b015402b48b1abdee5c 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -700,7 +700,7 @@ def auto_cast( with paddle.amp.auto_cast(): conv = conv2d(data) - print(conv.dtype) # paddle.float32 + print(conv.dtype) # paddle.float16 with paddle.amp.auto_cast(enable=False): conv = conv2d(data) @@ -714,11 +714,11 @@ def auto_cast( b = paddle.rand([2,3]) with paddle.amp.auto_cast(custom_white_list={'elementwise_add'}): c = a + b - print(c.dtype) # paddle.float32 + print(c.dtype) # paddle.float16 with paddle.amp.auto_cast(custom_white_list={'elementwise_add'}, level='O2'): d = a + b - print(d.dtype) # paddle.float32 + print(d.dtype) # paddle.float16 """ return amp_guard(enable, custom_white_list, custom_black_list, level, dtype)