From 77e289ae49387c5629cac7df5b6bc3d32ef10fbe Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Thu, 6 Jul 2023 10:48:23 +0800 Subject: [PATCH] [AMP] modify default value for GradScaler (#54653) --- python/paddle/amp/amp_lists.py | 14 +++++++++----- python/paddle/amp/grad_scaler.py | 12 ++++++------ python/paddle/static/amp/decorator.py | 16 ++++++++-------- python/paddle/static/amp/fp16_lists.py | 6 +++--- test/amp/amp_base_models.py | 2 +- 5 files changed, 27 insertions(+), 23 deletions(-) diff --git a/python/paddle/amp/amp_lists.py b/python/paddle/amp/amp_lists.py index 43ff2a81be5..645370a28fa 100644 --- a/python/paddle/amp/amp_lists.py +++ b/python/paddle/amp/amp_lists.py @@ -75,8 +75,8 @@ FP16_BLACK_LIST = { 'margin_cross_entropy', } -# FP16 performance of grad op is worse than that of FP32. Use FP32 by default. -FP16_EXTRA_BLACK_LIST = { +# FP16/BF16 performance of grad op is worse than that of FP32. Use FP32 by default. +EXTRA_BLACK_LIST = { 'linear_interp_v2', 'nearest_interp_v2', 'bilinear_interp_v2', @@ -112,9 +112,13 @@ def black_list(): black_list = { "float16": { "OD": set(), - "O1": FP16_BLACK_LIST | FP16_EXTRA_BLACK_LIST, - "O2": FP16_EXTRA_BLACK_LIST, + "O1": FP16_BLACK_LIST | EXTRA_BLACK_LIST, + "O2": EXTRA_BLACK_LIST, + }, + "bfloat16": { + "OD": set(), + "O1": BF16_BLACK_LIST | EXTRA_BLACK_LIST, + "O2": EXTRA_BLACK_LIST, }, - "bfloat16": {"OD": set(), "O1": BF16_BLACK_LIST, "O2": set()}, } return black_list diff --git a/python/paddle/amp/grad_scaler.py b/python/paddle/amp/grad_scaler.py index c22a324201c..b25c1ff7cf5 100644 --- a/python/paddle/amp/grad_scaler.py +++ b/python/paddle/amp/grad_scaler.py @@ -591,15 +591,15 @@ class GradScaler(AmpScaler): Args: enable(bool, optional): Enable loss scaling or not. Default is True. - init_loss_scaling (float, optional): The initial loss scaling factor. Default is 2**15. + init_loss_scaling (float, optional): The initial loss scaling factor. Default is 65536.0. incr_ratio(float, optional): The multiplier to use when increasing the loss scaling. Default is 2.0. decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing the loss scaling. Default is 0.5. incr_every_n_steps(int, optional): Increases loss scaling every n consecutive - steps with finite gradients. Default is 1000. + steps with finite gradients. Default is 2000. decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n - accumulated steps with nan or inf gradients. Default is 2. + accumulated steps with nan or inf gradients. Default is 1. use_dynamic_loss_scaling(bool, optional): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True. Returns: An GradScaler object. @@ -628,11 +628,11 @@ class GradScaler(AmpScaler): def __init__( self, enable=True, - init_loss_scaling=2.0**15, + init_loss_scaling=2.0**16, incr_ratio=2.0, decr_ratio=0.5, - incr_every_n_steps=1000, - decr_every_n_nan_or_inf=2, + incr_every_n_steps=2000, + decr_every_n_nan_or_inf=1, use_dynamic_loss_scaling=True, ): super().__init__( diff --git a/python/paddle/static/amp/decorator.py b/python/paddle/static/amp/decorator.py index a9d105b4db7..75e7f28955e 100644 --- a/python/paddle/static/amp/decorator.py +++ b/python/paddle/static/amp/decorator.py @@ -811,11 +811,11 @@ def decorate( # noqa: F811 dtype='float16', master_weight=None, master_grad=False, - init_loss_scaling=2**15, - incr_every_n_steps=1000, - decr_every_n_nan_or_inf=2, + init_loss_scaling=2**16, + incr_every_n_steps=2000, + decr_every_n_nan_or_inf=1, incr_ratio=2.0, - decr_ratio=0.8, + decr_ratio=0.5, use_dynamic_loss_scaling=None, use_amp_guard=False, use_promote=False, @@ -841,15 +841,15 @@ def decorate( # noqa: F811 during weight updating. If master_grad is False, in O2 level optimizer will not use master grad. Default is False. init_loss_scaling(float, optional): The initial loss scaling factor. - Default is 32768. + Default is 65536. incr_every_n_steps(int, optional): Increases loss scaling every n - consecutive steps with finite gradients. Default is 1000. + consecutive steps with finite gradients. Default is 2000. decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n - accumulated steps with nan or inf gradients. Default is 2. + accumulated steps with nan or inf gradients. Default is 1. incr_ratio(float, optional): The multiplier to use when increasing the loss scaling. Default is 2. decr_ratio(float, optional): The less-than-one-multiplier to use when - decreasing the loss scaling. Default is 0.8. + decreasing the loss scaling. Default is 0.5. use_dynamic_loss_scaling(bool, None): Whether to use dynamic loss scaling. Default is None, which means True for float16, and False for bfloat16. diff --git a/python/paddle/static/amp/fp16_lists.py b/python/paddle/static/amp/fp16_lists.py index b057f1adf21..c3d8f20b04d 100644 --- a/python/paddle/static/amp/fp16_lists.py +++ b/python/paddle/static/amp/fp16_lists.py @@ -16,8 +16,8 @@ import copy import logging from paddle.amp.amp_lists import ( + EXTRA_BLACK_LIST, FP16_BLACK_LIST, - FP16_EXTRA_BLACK_LIST, FP16_WHITE_LIST, ) from paddle.fluid import core @@ -28,7 +28,7 @@ _logger = get_logger( ) black_list = FP16_BLACK_LIST -_extra_black_list = FP16_EXTRA_BLACK_LIST +_extra_black_list = EXTRA_BLACK_LIST white_list = FP16_WHITE_LIST @@ -138,7 +138,7 @@ def _get_white_list(dtype): def _get_black_list(): _black_list = copy.copy(FP16_BLACK_LIST) - _black_list = _black_list | FP16_EXTRA_BLACK_LIST + _black_list = _black_list | EXTRA_BLACK_LIST return _black_list diff --git a/test/amp/amp_base_models.py b/test/amp/amp_base_models.py index 6d08b0d1483..0c8c11ea5e5 100644 --- a/test/amp/amp_base_models.py +++ b/test/amp/amp_base_models.py @@ -182,7 +182,7 @@ def build_conv_model( model = SimpleConvNet() optimizer = _build_optimizer(use_amp=False, model=model) if use_amp and amp_dtype == "float16": - scaler = paddle.amp.GradScaler() + scaler = paddle.amp.GradScaler(init_loss_scaling=32768.0) else: scaler = None if use_amp and amp_level == "O2": -- GitLab