From da3e9d66ea08b07b2d3dab11f97af66ab337c478 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 5 Jan 2023 09:45:20 +0800 Subject: [PATCH] move fuild.dygraph.amp to paddle.amp (#49193) --- python/paddle/amp/__init__.py | 13 +- python/paddle/amp/auto_cast.py | 633 ++++++++++++++++- python/paddle/amp/grad_scaler.py | 559 ++++++++++++++- python/paddle/fluid/dygraph/__init__.py | 4 - python/paddle/fluid/dygraph/amp/__init__.py | 23 - python/paddle/fluid/dygraph/amp/auto_cast.py | 666 ------------------ .../paddle/fluid/dygraph/amp/loss_scaler.py | 589 ---------------- ...perative_auto_mixed_precision_for_eager.py | 40 +- .../unittests/test_low_precision_list.py | 2 +- .../paddle/jit/dy2static/partial_program.py | 5 +- python/setup.py.in | 1 - setup.py | 1 - 12 files changed, 1221 insertions(+), 1315 deletions(-) delete mode 100644 python/paddle/fluid/dygraph/amp/__init__.py delete mode 100644 python/paddle/fluid/dygraph/amp/auto_cast.py delete mode 100644 python/paddle/fluid/dygraph/amp/loss_scaler.py diff --git a/python/paddle/amp/__init__.py b/python/paddle/amp/__init__.py index 381aad8850b..349c6c4daa1 100644 --- a/python/paddle/amp/__init__.py +++ b/python/paddle/amp/__init__.py @@ -13,7 +13,18 @@ # limitations under the License. from .auto_cast import auto_cast # noqa: F401 -from .grad_scaler import GradScaler # noqa: F401 from .auto_cast import decorate # noqa: F401 +from .auto_cast import amp_guard # noqa: F401 +from .auto_cast import amp_decorate # noqa: F401 +from .auto_cast import low_precision_op_list # noqa: F401 +from .auto_cast import WHITE_LIST # noqa: F401 +from .auto_cast import BLACK_LIST # noqa: F401 +from .auto_cast import PURE_FP16_WHITE_LIST # noqa: F401 +from .auto_cast import PURE_FP16_BLACK_LIST # noqa: F401 + +from . import grad_scaler # noqa: F401 +from .grad_scaler import GradScaler # noqa: F401 +from .grad_scaler import AmpScaler # noqa: F401 +from .grad_scaler import OptimizerState # noqa: F401 __all__ = ['auto_cast', 'GradScaler', 'decorate'] diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index b26a585d5b4..6eb63040e08 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -12,9 +12,638 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle.fluid.dygraph.amp import amp_decorate, amp_guard +import copy +import warnings + +import paddle +from paddle.fluid import core +from paddle.fluid.framework import _dygraph_tracer, dygraph_only +from paddle.fluid.wrapped_decorator import signature_safe_contextmanager + +AMP_LEVEL = core.AmpLevel + +# The set of ops that support fp16 calculation and are considered numerically- +# safe and performance-critical. These ops are always converted to fp16. +WHITE_LIST = { + 'conv2d', + 'matmul', + 'matmul_v2', + 'mul', + 'fake_quantize_dequantize_abs_max', + 'fake_quantize_dequantize_moving_average_abs_max', +} + +# The set of ops that support fp16 calculation and are considered numerically- +# dangerous and whose effects may also be observed in downstream ops. +BLACK_LIST = { + 'exp', + 'square', + 'log', + 'mean', + 'sum', + 'cos_sim', + 'softmax', + 'softmax_with_cross_entropy', + 'sigmoid_cross_entropy_with_logits', + 'c_softmax_with_cross_entropy', + 'cross_entropy', + 'cross_entropy2', + # default fp32 can avoid return inf when the sum value large than 65504 + 'reduce_sum', + # FP16 performance of grad op is worse than that of FP32. Use FP32 by default. + 'linear_interp_v2', + 'nearest_interp_v2', + 'bilinear_interp_v2', + 'bicubic_interp_v2', + 'trilinear_interp_v2', +} + +AMP_RELATED_FLAGS = [ + 'FLAGS_cudnn_exhaustive_search', + 'FLAGS_conv_workspace_size_limit', + 'FLAGS_cudnn_batchnorm_spatial_persistent', +] + +AMP_RELATED_FLAGS_SETTING = { + 'FLAGS_cudnn_exhaustive_search': 1, + 'FLAGS_conv_workspace_size_limit': 1000, + 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, +} + +PURE_FP16_WHITE_LIST = set() +PURE_FP16_BLACK_LIST = { + 'lookup_table', + 'lookup_table_v2', + 'scatter', + 'scatter_grad', + # FP16 performance of grad op is worse than that of FP32. Use FP32 by default. + 'linear_interp_v2', + 'nearest_interp_v2', + 'bilinear_interp_v2', + 'bicubic_interp_v2', + 'trilinear_interp_v2', +} + +BF16_WHITE_LIST = {'conv2d', 'matmul_v2'} +BF16_BLACK_LIST = set() + +PURE_BF16_WHITE_LIST = set() +PURE_BF16_BLACK_LIST = set() + +_g_amp_state_ = None + + +def low_precision_op_list(): + op_list = paddle.fluid.core.get_low_precision_op_list() + op_count = 0 + print('<---------------- low precision op list ------------------->') + print('<---- op name ------|------- op count---------------------->') + for x in op_list: + print(' %-18s| %4d' % (x, op_list[x])) + op_count += 1 + print( + '<------------- low precision op num:{:5d} ----------------->'.format( + op_count + ) + ) + + +def amp_state(): + global _g_amp_state_ + return _g_amp_state_ + + +# NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list +# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode. +def _update_list( + custom_white_list, custom_black_list, level='O1', dtype='float16' +): + """ + Update black and white list according to users' custom list. + """ + if dtype == 'float16': + if level == 'O1': + _white_list = copy.copy(WHITE_LIST) + _black_list = copy.copy(BLACK_LIST) + else: + _white_list = copy.copy(PURE_FP16_WHITE_LIST) + _black_list = copy.copy(PURE_FP16_BLACK_LIST) + else: + if level == 'O1': + _white_list = copy.copy(BF16_WHITE_LIST) + _black_list = copy.copy(BF16_BLACK_LIST) + else: + _white_list = copy.copy(PURE_BF16_WHITE_LIST) + _black_list = copy.copy(PURE_BF16_BLACK_LIST) + if custom_white_list and custom_black_list: + for op_name in custom_white_list: + if op_name in custom_black_list: + raise ValueError( + "Custom white list overlap " "custom black list" + ) + if custom_white_list: + for op_name in custom_white_list: + if op_name in _black_list: + _black_list.remove(op_name) + _white_list.add(op_name) + if custom_black_list: + for op_name in custom_black_list: + if op_name in _white_list: + _white_list.remove(op_name) + _black_list.add(op_name) + return _white_list, _black_list + + +def _in_amp_guard(): + """ + Judge whether current code block is in `amp_guard` context. + """ + tracer = _dygraph_tracer() + if tracer: + if tracer._amp_level == core.AmpLevel.O1: + return True + else: + return False + else: + return False + + +def _in_pure_fp16_guard(): + tracer = _dygraph_tracer() + return tracer and tracer._amp_level == core.AmpLevel.O2 + + +def _is_gpu_float16_supported(): + """ + Judge whether current gpu support float16 amp. + """ + prop = paddle.device.cuda.get_device_capability() + return prop[0] >= 7 + + +def _is_gpu_bfloat16_supported(): + """ + Judge whether current gpu support bfloat16 amp. + """ + prop = paddle.device.cuda.get_device_capability() + cuda_version = paddle.version.cuda() + if cuda_version is not None and cuda_version != 'False': + cuda_version_check = int(cuda_version.split('.')[0]) >= 11 + else: + cuda_version_check = False + return prop[0] >= 8 and cuda_version_check + + +@dygraph_only +def pure_fp16_initialize(models): + for idx in range(len(models)): + for layer in models[idx].sublayers(include_self=True): + layer._casted_by_pure_fp16 = True + if (layer._dtype == 'float16') or isinstance( + layer, + ( + paddle.nn.BatchNorm, + paddle.nn.BatchNorm1D, + paddle.nn.BatchNorm2D, + paddle.nn.BatchNorm3D, + paddle.nn.LayerNorm, + paddle.nn.SyncBatchNorm, + ), + ): + continue + if isinstance( + layer, + ( + paddle.incubate.nn.FusedFeedForward, + paddle.incubate.nn.FusedMultiHeadAttention, + ), + ): + layer._amp_decorate(dtype='float16') + continue + layer._to_impl( + dtype='float16', include_sublayers=False, floating_only=True + ) + return models + + +@dygraph_only +def pure_bf16_initialize(models): + for idx in range(len(models)): + for layer in models[idx].sublayers(include_self=True): + layer._to_impl( + dtype='bfloat16', include_sublayers=False, floating_only=True + ) + return models + + +def check_models(models): + for model in models: + if not isinstance(model, paddle.nn.Layer): + raise RuntimeError( + "Current train mode is pure fp16, models should be paddle.nn.Layer, but receive {}.".format( + type(model) + ) + ) + if isinstance(model, paddle.DataParallel): + raise RuntimeError( + "For distributed AMP training, you should first use paddle.amp.decorate() to decotate origin model, and then call paddle.DataParallel get distributed model." + ) + + +def _is_valid_optimizer(optimizer): + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, + ) + + return isinstance( + optimizer, + ( + paddle.optimizer.Optimizer, + paddle.fluid.optimizer.Optimizer, + DygraphShardingOptimizer, + ), + ) + + +def check_optimizers(optimizers): + for optimizer in optimizers: + if not _is_valid_optimizer(optimizer): + raise RuntimeError( + "Current train mode is pure fp16, optimizers should be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer or DygraphShardingOptimizer, but receive {}.".format( + type(optimizer) + ) + ) + + +@signature_safe_contextmanager +@dygraph_only +def amp_guard( + enable=True, + custom_white_list=None, + custom_black_list=None, + level='O1', + dtype='float16', +): + """ + Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode. + If enabled, the input data type (float32 or float16) of each operator is decided + by autocast algorithm for better performance. + + Commonly, it is used together with `GradScaler` to achieve Auto-Mixed-Precision in + imperative mode. It is used together with `decorator` to achieve Pure fp16 in imperative mode. + + Args: + enable(bool, optional): Enable auto-mixed-precision or not. Default is True. + custom_white_list(set|list|tuple, optional): The custom white_list. It's the set of ops that support + fp16 calculation and are considered numerically-safe and performance-critical. These ops + will be converted to fp16. + custom_black_list(set|list|tuple, optional): The custom black_list. The set of ops that support fp16 + calculation and are considered numerically-dangerous and whose effects may also be + observed in downstream ops. These ops will not be converted to fp16. + level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; + O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp) + dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. + + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') + conv2d = paddle.nn.Conv2D(3, 2, 3) + data = paddle.to_tensor(data) + with paddle.amp.amp_guard(): + conv = conv2d(data) + print(conv.dtype) # FP16 + with paddle.amp.amp_guard(enable=False): + conv = conv2d(data) + print(conv.dtype) # FP32 + + """ + amp_state = locals() + global _g_amp_state_ + original_state = _g_amp_state_ + _g_amp_state_ = amp_state + + # check amp_level: O0-O2 + level = level.upper() + if not (level in ['O0', 'O1', 'O2']): + raise ValueError( + "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode." + ) + + # check amp_dtype: float16 or bfloat16 + dtype = dtype.lower() + if not (dtype in ['float16', 'bfloat16']): + raise ValueError("dtype should be 'float16' or 'bfloat16'.") + + # check tracer + tracer = _dygraph_tracer() + if not tracer: + raise ValueError( + "current_tracer is None, maybe it is not in imperative mode." + ) + + # check device_type: + # NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16, mlu for float16, npu for float16. + # Maybe we will support cpu for bfloat16. + if enable and not ( + tracer._expected_place.is_gpu_place() + or tracer._expected_place.is_xpu_place() + or tracer._expected_place.is_mlu_place() + or tracer._expected_place.is_npu_place() + or tracer._expected_place.is_custom_place() + ): + warnings.warn( + 'amp_guard can only be enabled on CUDAPlace, XPUPlace, MLUPlace, NPUPlace, and CustomPlace, current place is %s, so it makes no effect.' + % tracer._expected_place + ) + enable = False + # For npu: + if tracer._expected_place.is_npu_place() and (dtype == 'bfloat16'): + warnings.warn('NPUPlace only support float16 amp.') + enable = False + # For xpu: + if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'): + warnings.warn('XPUPlace only support float16 amp.') + enable = False + # For mlu: + if tracer._expected_place.is_mlu_place() and (dtype == 'bfloat16'): + warnings.warn('MLUPlace only support float16 amp.') + enable = False + # For custom device: + if tracer._expected_place.is_custom_place() and (dtype == 'bfloat16'): + warnings.warn('CustomPlace only support float16 amp.') + enable = False + # For gpu float16: Compute Capability should >= 7. + # For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11. + if tracer._expected_place.is_gpu_place(): + if (dtype == 'float16') and not _is_gpu_float16_supported(): + prop = paddle.device.cuda.get_device_capability() + warnings.warn( + "For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d." + % (paddle.device.cuda.get_device_name(), prop[0], prop[1]) + ) + elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported(): + prop = paddle.device.cuda.get_device_capability() + cuda_version = paddle.version.cuda() + warnings.warn( + "For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: %s, with Compute Capability: %d.%d, current CUDA Version is: %s." + % ( + paddle.device.cuda.get_device_name(), + prop[0], + prop[1], + cuda_version, + ) + ) + + amp_dtype = dtype + + if level == 'O1': + amp_level = AMP_LEVEL.O1 + if dtype == 'float16': + _white_list = WHITE_LIST + _black_list = BLACK_LIST + elif dtype == 'bfloat16': + _white_list = BF16_WHITE_LIST + _black_list = BF16_BLACK_LIST + + elif level == 'O2': + amp_level = AMP_LEVEL.O2 + if dtype == 'float16': + _white_list = PURE_FP16_WHITE_LIST + _black_list = PURE_FP16_BLACK_LIST + elif dtype == 'bfloat16': + _white_list = BF16_WHITE_LIST + _black_list = BF16_BLACK_LIST + elif level == 'O0': + amp_level = AMP_LEVEL.O0 + if dtype == 'float16': + _white_list = WHITE_LIST + _black_list = BLACK_LIST + elif dtype == 'bfloat16': + _white_list = BF16_WHITE_LIST + _black_list = BF16_BLACK_LIST + + if custom_white_list or custom_black_list: + _white_list, _black_list = _update_list( + custom_white_list, custom_black_list, level, dtype + ) + + if not enable: + amp_level = AMP_LEVEL.O0 + amp_dtype = "float32" + + if tracer: + # enable auto_cast + original_amp_level = tracer._amp_level + tracer._amp_level = amp_level + + # set amp op list + original_white_list, original_black_list = tracer._get_amp_op_list() + tracer._set_amp_op_list(_white_list, _black_list) + + # TODO(zhiqiu) set amp related flags automatically in this guard + # Currently, if FLAGS_cudnn_batchnorm_spatial_persistent is set True in amp_guard, + # batch_norm can run in fast mode, but batch_norm_grad can not if backward if not executed insise amp_guard. + # So, users need to set related flags manually. + + # original_flags = get_flags(AMP_RELATED_FLAGS) + # set_flags(AMP_RELATED_FLAGS_SETTING) + + # set amp dtype + original_amp_dtype = tracer._amp_dtype + tracer._amp_dtype = amp_dtype + + # restore status + try: + yield + finally: + if tracer: + _g_amp_state_ = original_state + tracer._amp_level = original_amp_level + tracer._set_amp_op_list(original_white_list, original_black_list) + # set_flags(original_flags) + tracer._amp_dtype = original_amp_dtype + + +class StateDictHook: + def __init__(self, save_dtype): + self._save_dtype = save_dtype + + def __call__(self, state_dict): + for key in state_dict: + param = state_dict[key] + if paddle.is_floating_point(param): + param_applied = paddle.cast(param, self._save_dtype) + param_applied.name = param.name + state_dict[key] = param_applied + + +def _set_multi_precision(optimizer, multi_precision): + from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, + ) + + optimizer = ( + optimizer._inner_optimizer + if isinstance(optimizer, DygraphShardingOptimizer) + else optimizer + ) + if hasattr(optimizer, "_multi_precision"): + optimizer._multi_precision = multi_precision + + +@dygraph_only +def amp_decorate( + models, + optimizers=None, + level='O1', + dtype='float16', + master_weight=None, + save_dtype=None, +): + """ + Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. + When level is O2(pure fp16), the decorate will cast all parameters of models to FP16, except BatchNorm and LayerNorm. + + Commonly, it is used together with `amp_guard` to achieve Pure fp16 in imperative mode. -__all__ = [] + Args: + models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None. + optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None. + level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing; + O2 represent Pure fp16/bf16, the decorator will cast all parameters of models to FP16/BF16, except BatchNorm and LayerNorm. Default is O1(amp) + dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. + master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None. + save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None. + The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None. + + Examples: + + .. code-block:: python + + # required: gpu + # Demo1: single model and optimizer: + import paddle + + model = paddle.nn.Conv2D(3, 2, 3, bias_attr=False) + optimizer = paddle.optimizer.SGD(parameters=model.parameters()) + + model, optimizer = paddle.amp.amp_decorate(models=model, optimizers=optimizer, level='O2') + + data = paddle.rand([10, 3, 32, 32]) + + with paddle.amp.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'): + output = model(data) + print(output.dtype) # FP16 + + # required: gpu + # Demo2: multi models and optimizers: + model2 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False) + optimizer2 = paddle.optimizer.Adam(parameters=model2.parameters()) + + models, optimizers = paddle.amp.amp_decorate(models=[model, model2], optimizers=[optimizer, optimizer2], level='O2') + + data = paddle.rand([10, 3, 32, 32]) + + with paddle.amp.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'): + output = models[0](data) + output2 = models[1](data) + print(output.dtype) # FP16 + print(output2.dtype) # FP16 + + # required: gpu + # Demo3: optimizers is None: + model3 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False) + optimizer3 = paddle.optimizer.Adam(parameters=model2.parameters()) + + model = paddle.amp.amp_decorate(models=model3, level='O2') + + data = paddle.rand([10, 3, 32, 32]) + + with paddle.amp.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'): + output = model(data) + print(output.dtype) # FP16 + """ + if not (level in ['O1', 'O2']): + raise ValueError( + "level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode." + ) + + if level == 'O1': + if optimizers is None: + return models + else: + return models, optimizers + + models_is_list = False + if isinstance(models, paddle.nn.Layer): + models_is_list = False + models = [models] + check_models(models) + elif isinstance(models, list): + check_models(models) + models_is_list = True + else: + raise TypeError( + "models must be either a single model or a list of models." + ) + if dtype == 'float16': + models = pure_fp16_initialize(models=models) + elif dtype == 'bfloat16': + models = pure_bf16_initialize(models=models) + else: + raise TypeError("dtype only support float16 or bfloat16.") + + if optimizers is not None: + # check optimizers + optimizers_is_list = False + if _is_valid_optimizer(optimizers): + optimizers_is_list = False + optimizers = [optimizers] + check_optimizers(optimizers) + elif isinstance(optimizers, list): + check_optimizers(optimizers) + optimizers_is_list = True + else: + raise TypeError( + "optimizers must be either a single optimizer or a list of optimizers." + ) + # support master_weight + use_multi_precision = not (master_weight is False) + for opt in optimizers: + _set_multi_precision(opt, use_multi_precision) + + if save_dtype is not None: + if not (save_dtype in ['float16', 'bfloat16', 'float32', 'float64']): + raise ValueError( + "save_dtype can only be float16 float32 or float64, but your input save_dtype is %s." + % save_dtype + ) + for idx in range(len(models)): + for layer in models[idx].sublayers(include_self=True): + layer.register_state_dict_hook(StateDictHook(save_dtype)) + + if models_is_list: + if optimizers is not None: + if optimizers_is_list: + return models, optimizers + else: + return models, optimizers[0] + else: + return models + else: + if optimizers is not None: + if optimizers_is_list: + return models[0], optimizers + else: + return models[0], optimizers[0] + else: + return models[0] def auto_cast( diff --git a/python/paddle/amp/grad_scaler.py b/python/paddle/amp/grad_scaler.py index d847da5455d..85e6f6efc6b 100644 --- a/python/paddle/amp/grad_scaler.py +++ b/python/paddle/amp/grad_scaler.py @@ -12,17 +12,572 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from collections import defaultdict +from enum import Enum -from paddle.fluid.dygraph.amp import AmpScaler, OptimizerState +import numpy as np -__all__ = [] +from paddle import _legacy_C_ops +from paddle.fluid import core, in_dygraph_mode +from paddle.fluid.data_feeder import check_type +from paddle.fluid.dygraph import to_variable +from paddle.fluid.framework import _dygraph_tracer, dygraph_only + + +class OptimizerState(Enum): + INIT = 0 + UNSCALED = 1 + STEPPED = 2 def _refresh_optimizer_state(): return {"state": OptimizerState.INIT} +class AmpScaler: + """ + AmpScaler is used for Auto-Mixed-Precision training/inferring in imperative + mode. It controls the scaling of loss, helps avoiding numerical overflow. + The object of this class has seventeen methods `scale()`, `unscale_()`, `minimize()` and `get`/`set` api of parameters. + + `scale()` is used to multiply the loss by a scale ratio. + `unscale_()` is used to unscale the gradients of parameters, multiplies the gradients of parameters by 1/(scale ratio) + `minimize()` is similar as `optimizer.minimize()`, performs parameters updating, and it will update the loss_scaling. + + Commonly, it is used together with `amp_guard` to achieve Auto-Mixed-Precision in + imperative mode. + + Args: + enable(bool, optional): Enable loss scaling or not. Default is True. + init_loss_scaling (float, optional): The initial loss scaling factor. Default is 2**15. + incr_ratio(float, optional): The multiplier to use when increasing the loss + scaling. Default is 2.0. + decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing + the loss scaling. Default is 0.5. + incr_every_n_steps(int, optional): Increases loss scaling every n consecutive + steps with finite gradients. Default is 1000. + decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n + accumulated steps with nan or inf gradients. Default is 2. + use_dynamic_loss_scaling(bool, optional): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True. + Returns: + An AmpScaler object. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') + model = paddle.nn.Conv2D(3, 2, 3) + optimizer = paddle.optimizer.SGDOptimizer( + learning_rate=0.01, parameter_list=model.parameters()) + scaler = paddle.amp.AmpScaler(init_loss_scaling=1024) + data = paddle.to_tensor(data) + with paddle.amp.amp_guard(): + conv = model(data) + loss = paddle.mean(conv) + scaled = scaler.scale(loss) + scaled.backward() + scaler.minimize(optimizer, scaled) + """ + + @dygraph_only + def __init__( + self, + enable=True, + init_loss_scaling=2.0**15, + incr_ratio=2.0, + decr_ratio=0.5, + incr_every_n_steps=1000, + decr_every_n_nan_or_inf=1, + use_dynamic_loss_scaling=True, + ): + + tracer = _dygraph_tracer() + if not tracer: + raise ValueError( + "current_tracer is None, maybe it is not in imperative mode." + ) + + if enable and not ( + tracer._expected_place.is_gpu_place() + or tracer._expected_place.is_xpu_place() + or tracer._expected_place.is_mlu_place() + or tracer._expected_place.is_npu_place() + or tracer._expected_place.is_custom_place() + ): + warnings.warn( + 'AmpScaler can only be enabled on CUDAPlace, XPUPlace, MLUPlace, NPUPlace and CustomPlace, current place is %s, so it makes no effect.' + % tracer._expected_place + ) + enable = False + + self._enable = enable + + if self._enable: + assert incr_ratio > 1.0, "The incr_ratio must be > 1.0." + assert decr_ratio < 1.0, "The decr_ratio must be < 1.0." + + self._init_loss_scaling = init_loss_scaling + self._incr_ratio = incr_ratio + self._decr_ratio = decr_ratio + self._incr_every_n_steps = incr_every_n_steps + self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf + self._incr_count = 0 + self._decr_count = 0 + self._use_dynamic_loss_scaling = use_dynamic_loss_scaling + + self._found_inf = to_variable(np.array([0]).astype(np.bool_)) + self._temp_found_inf_fp16 = to_variable( + np.array([0]).astype(np.bool_) + ) + self._temp_found_inf_bf16 = to_variable( + np.array([0]).astype(np.bool_) + ) + self._temp_found_inf_fp32 = to_variable( + np.array([0]).astype(np.bool_) + ) + self._scale = to_variable( + np.array([self._init_loss_scaling]).astype(np.float32) + ) + self._cache_founf_inf = None + self._optimizer_states = defaultdict(_refresh_optimizer_state) + + def scale(self, var): + """ + Multiplies a Tensor by the scale factor and returns scaled outputs. + If this instance of :class:`AmpScaler` is not enabled, output are returned unmodified. + + Args: + var (Tensor): The Tensor to scale. + Returns: + The scaled Tensor or original Tensor. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') + model = paddle.nn.Conv2D(3, 2, 3) + optimizer = paddle.optimizer.SGDOptimizer( + learning_rate=0.01, parameter_list=model.parameters()) + scaler = paddle.amp.AmpScaler(init_loss_scaling=1024) + data = paddle.to_tensor(data) + with paddle.amp.amp_guard(): + conv = model(data) + loss = paddle.mean(conv) + scaled = scaler.scale(loss) + scaled.backward() + scaler.minimize(optimizer, scaled) + """ + check_type(var, "var", core.VarBase, 'AmpScaler.scale()') + + if not self._enable: + return var + + return var * self._scale + + def minimize(self, optimizer, *args, **kwargs): + """ + This function is similar as `Optimizer.minimize()`, which performs parameters updating. + + If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped. + Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters. + + Finally, the loss scaling ratio is updated. + + Args: + optimizer(Optimizer): The optimizer used to update parameters. + args: Arguments, which will be forward to `optimizer.minimize()`. + kwargs: Keyword arguments, which will be forward to `Optimizer.minimize()`. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') + model = paddle.nn.Conv2D(3, 2, 3) + optimizer = paddle.optimizer.SGDOptimizer( + learning_rate=0.01, parameter_list=model.parameters()) + scaler = paddle.amp.AmpScaler(init_loss_scaling=1024) + data = paddle.to_tensor(data) + with paddle.amp.amp_guard(): + conv = model(data) + loss = paddle.mean(conv) + scaled = scaler.scale(loss) + scaled.backward() + scaler.minimize(optimizer, scaled) + """ + if not self._enable: + return optimizer.minimize(*args, **kwargs) + + optimizer_state = self._optimizer_states[id(optimizer)] + + # unscale the grad + if optimizer_state["state"] is OptimizerState.INIT: + self._unscale(optimizer) + + optimize_ops, params_grads = (None, None) + + if self._found_inf: + self._cache_founf_inf = True + else: + optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) + self._cache_founf_inf = False + + if self._use_dynamic_loss_scaling: + # uopdate the scale + self._update() + + self._optimizer_states = defaultdict(_refresh_optimizer_state) + + return optimize_ops, params_grads + + def _unscale(self, optimizer): + """ + Unscale the gradients of parameters, multiplies the gradients of parameters by 1/(loss scaling ratio). + If this instance of :class:`GradScaler` is not enabled, output are returned unmodified. + Args: + optimizer(Optimizer): The optimizer used to update parameters. + Returns: + The unscaled parameters or original parameters. + """ + if not self._enable: + return + + optimizer_state = self._optimizer_states[id(optimizer)] + + if optimizer_state["state"] is OptimizerState.UNSCALED: + raise RuntimeError( + "unscale_() has already been called on this optimizer since the last update()." + ) + elif optimizer_state["state"] is OptimizerState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + if getattr(optimizer, '_param_groups', None) and isinstance( + optimizer._param_groups[0], dict + ): + param_grads = [] + param_grads_fp16 = [] + param_grads_bf16 = [] + param_grads_fp32 = [] + for group in optimizer._param_groups: + for param in group['params']: + if param._grad_ivar() is not None: + param_grads.append(param._grad_ivar()) + if ( + param._grad_ivar().dtype + == core.VarDesc.VarType.FP16 + ): + param_grads_fp16.append(param._grad_ivar()) + elif ( + param._grad_ivar().dtype + == core.VarDesc.VarType.BF16 + ): + param_grads_bf16.append(param._grad_ivar()) + else: + param_grads_fp32.append(param._grad_ivar()) + else: + if in_dygraph_mode(): + # It is very time-consuming to call c++ functions in a loop on the python side. + # We put this part of the code on the c++ side to improve the speed in eager mode. + ( + param_grads_fp16, + param_grads_bf16, + param_grads_fp32, + ) = core.eager.get_grads_lists(optimizer._parameter_list) + else: + # Keep the original code to support legacy mode. + # Delete the else branch when the legacy mode exits. + param_grads = [ + param._grad_ivar() + for param in optimizer._parameter_list + if param._grad_ivar() is not None + ] + param_grads_fp16 = [ + param + for param in param_grads + if param.dtype == core.VarDesc.VarType.FP16 + ] + param_grads_bf16 = [ + param + for param in param_grads + if param.dtype == core.VarDesc.VarType.BF16 + ] + param_grads_fp32 = [ + param + for param in param_grads + if param.dtype == core.VarDesc.VarType.FP32 + ] + if core.is_compiled_with_npu(): + float_status = _legacy_C_ops.alloc_float_status() + _legacy_C_ops.clear_float_status(float_status, float_status) + + if len(param_grads_fp16): + _legacy_C_ops.check_finite_and_unscale( + param_grads_fp16, + self._scale, + float_status, + param_grads_fp16, + self._temp_found_inf_fp16, + ) + if len(param_grads_bf16): + _legacy_C_ops.check_finite_and_unscale( + param_grads_bf16, + self._scale, + float_status, + param_grads_bf16, + self._temp_found_inf_bf16, + ) + if len(param_grads_fp32): + _legacy_C_ops.check_finite_and_unscale( + param_grads_fp32, + self._scale, + float_status, + param_grads_fp32, + self._temp_found_inf_fp32, + ) + else: + if len(param_grads_fp16): + _legacy_C_ops.check_finite_and_unscale( + param_grads_fp16, + self._scale, + param_grads_fp16, + self._temp_found_inf_fp16, + ) + if len(param_grads_bf16): + _legacy_C_ops.check_finite_and_unscale( + param_grads_bf16, + self._scale, + param_grads_bf16, + self._temp_found_inf_bf16, + ) + if len(param_grads_fp32): + _legacy_C_ops.check_finite_and_unscale( + param_grads_fp32, + self._scale, + param_grads_fp32, + self._temp_found_inf_fp32, + ) + + self._found_inf = ( + self._temp_found_inf_fp16 + or self._temp_found_inf_bf16 + or self._temp_found_inf_fp32 + ) + + optimizer_state["state"] = OptimizerState.UNSCALED + + def _update(self): + """ + Updates the loss_scaling. + """ + if not self._enable: + return + + if self._cache_founf_inf: + self._incr_count = 0 + self._decr_count = self._decr_count + 1 + if self._decr_count == self._decr_every_n_nan_or_inf: + print( + 'Found inf or nan, current scale is: {}, decrease to: {}*{}'.format( + float(self._scale), + float(self._scale), + float(self._decr_ratio), + ) + ) + self._scale = self._scale * self._decr_ratio + self._decr_count = 0 + else: + self._decr_count = 0 + self._incr_count = self._incr_count + 1 + if self._incr_count == self._incr_every_n_steps: + self._scale = self._scale * self._incr_ratio + self._incr_count = 0 + + return + + def is_enable(self): + """ + Enable loss scaling or not. + + Returns: + bool: enable loss scaling return True else return False. + """ + return self._enable + + def is_use_dynamic_loss_scaling(self): + """ + Whether to use dynamic loss scaling. + + Returns: + bool: if fixed loss_scaling is used return False, if the loss scaling is updated dynamicly return true. + """ + return self._use_dynamic_loss_scaling + + def get_init_loss_scaling(self): + """ + Return the initial loss scaling factor. + + Reurns: + float: the initial loss scaling factor. + """ + return self._init_loss_scaling + + def set_init_loss_scaling(self, new_init_loss_scaling): + """ + Set the initial loss scaling factor by `new_init_loss_scaling`. + + Args: + new_init_loss_scaling(int): The new_init_loss_scaling used to update initial loss scaling factor.s + """ + self._init_loss_scaling = new_init_loss_scaling + self._scale = to_variable( + np.array([self._init_loss_scaling]).astype(np.float32) + ) + + def get_incr_ratio(self): + """ + Return the multiplier to use when increasing the loss scaling. + + Reurns: + float: the multiplier to use when increasing the loss scaling. + """ + return self._incr_ratio + + def set_incr_ratio(self, new_incr_ratio): + """ + Set the multiplier to use when increasing the loss scaling by `new_incr_ratio`, `new_incr_ratio` should > 1.0. + + Args: + new_incr_ratio(float): The new_incr_ratio used to update the multiplier to use when increasing the loss scaling. + """ + assert new_incr_ratio > 1.0, "The new_incr_ratio must be > 1.0." + self._incr_ratio = new_incr_ratio + + def get_decr_ratio(self): + """ + Get the less-than-one-multiplier to use when decreasing the loss scaling. + + Reurns: + float: the less-than-one-multiplier to use when decreasing the loss scaling. + """ + return self._decr_ratio + + def set_decr_ratio(self, new_decr_ratio): + """ + Set the less-than-one-multiplier to use when decreasing the loss scaling by `new_incr_ratio`, `new_decr_ratio` should < 1.0. + + Args: + new_decr_ratio(float): The new_decr_ratio used to update the less-than-one-multiplier to use when decreasing the loss scaling. + """ + assert new_decr_ratio < 1.0, "The new_decr_ratio must be < 1.0." + self._decr_ratio = new_decr_ratio + + def get_incr_every_n_steps(self): + """ + Return the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients. + + Reurns: + int: the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients. + """ + return self._incr_every_n_steps + + def set_incr_every_n_steps(self, new_incr_every_n_steps): + """ + Set the num `n` by `new_incr_every_n_steps`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients. + + Args: + new_incr_every_n_steps(int): The new_incr_every_n_steps used to update the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients. + """ + self._incr_every_n_steps = new_incr_every_n_steps + + def get_decr_every_n_nan_or_inf(self): + """ + Return the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients. + + Reurns: + int: the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients. + """ + return self._decr_every_n_nan_or_inf + + def set_decr_every_n_nan_or_inf(self, new_decr_every_n_nan_or_inf): + """ + Set the num `n` by `new_decr_every_n_nan_or_inf`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients. + + Args: + new_decr_every_n_nan_or_inf(int): The new_decr_every_n_nan_or_inf used to update the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients. + """ + self._decr_every_n_nan_or_inf = new_decr_every_n_nan_or_inf + + def state_dict(self): + """ + Returns the state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict. + + Reurns: + A dict of scaler includes: + scale (tensor): The loss scaling factor. + incr_ratio(float): The multiplier to use when increasing the loss scaling. + decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling. + incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients. + decr_every_n_nan_or_inf(int): Decreases loss scaling every n accumulated steps with nan or inf gradients. + incr_count(int): The number of recent consecutive unskipped steps. + decr_count(int): The number of recent consecutive skipped steps. + use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True. + """ + return ( + { + "scale": self._scale.numpy(), + "incr_ratio": self._incr_ratio, + "decr_ratio": self._decr_ratio, + "incr_every_n_steps": self._incr_every_n_steps, + "decr_every_n_nan_or_inf": self._decr_every_n_nan_or_inf, + "incr_count": self._incr_count, + "decr_count": self._decr_count, + "use_dynamic_loss_scaling": self._use_dynamic_loss_scaling, + } + if self._enable + else {} + ) + + def load_state_dict(self, state_dict): + """ + Loads the scaler state. + + Args: + state_dict(dict): scaler state. Should be an object returned from a call to `AmpScaler.state_dict()`. + """ + if not self._enable: + return + + if len(state_dict) == 0: + raise RuntimeError( + "The input state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler." + ) + + self._init_loss_scaling = state_dict["scale"][0] + self._scale = to_variable( + np.array([self._init_loss_scaling]).astype(np.float32) + ) + self._incr_ratio = state_dict["incr_ratio"] + self._decr_ratio = state_dict["decr_ratio"] + self._incr_every_n_steps = state_dict["incr_every_n_steps"] + self._decr_every_n_nan_or_inf = state_dict["decr_every_n_nan_or_inf"] + self._incr_count = state_dict["incr_count"] + self._decr_count = state_dict["decr_count"] + self._use_dynamic_loss_scaling = state_dict["use_dynamic_loss_scaling"] + + class GradScaler(AmpScaler): """ GradScaler is used for Auto-Mixed-Precision training in dynamic graph mode. diff --git a/python/paddle/fluid/dygraph/__init__.py b/python/paddle/fluid/dygraph/__init__.py index d9f6034b732..18c8d34cf62 100644 --- a/python/paddle/fluid/dygraph/__init__.py +++ b/python/paddle/fluid/dygraph/__init__.py @@ -28,9 +28,6 @@ from .parallel import * from . import learning_rate_scheduler from .learning_rate_scheduler import * -from . import amp -from .amp import * - from .math_op_patch import monkey_patch_math_varbase __all__ = [] @@ -38,4 +35,3 @@ __all__ += layers.__all__ __all__ += base.__all__ __all__ += parallel.__all__ __all__ += learning_rate_scheduler.__all__ -__all__ += amp.__all__ diff --git a/python/paddle/fluid/dygraph/amp/__init__.py b/python/paddle/fluid/dygraph/amp/__init__.py deleted file mode 100644 index e86c5a20c5a..00000000000 --- a/python/paddle/fluid/dygraph/amp/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from . import auto_cast -from .auto_cast import * - -from . import loss_scaler -from .loss_scaler import * - -__all__ = [] -__all__ += auto_cast.__all__ -__all__ += loss_scaler.__all__ diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py deleted file mode 100644 index fd5131ce070..00000000000 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ /dev/null @@ -1,666 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from paddle.fluid.wrapped_decorator import ( - signature_safe_contextmanager, - wrap_decorator, -) -from paddle.fluid import core -import contextlib -from paddle.fluid.framework import ( - Variable, - OpProtoHolder, - Parameter, - _dygraph_tracer, - dygraph_only, - set_flags, - get_flags, -) -import warnings -import copy -import functools -import paddle -import operator -import types - -AMP_LEVEL = core.AmpLevel - -__all__ = ['amp_guard', 'amp_decorate'] - -# The set of ops that support fp16 calculation and are considered numerically- -# safe and performance-critical. These ops are always converted to fp16. -WHITE_LIST = { - 'conv2d', - 'matmul', - 'matmul_v2', - 'mul', - 'fake_quantize_dequantize_abs_max', - 'fake_quantize_dequantize_moving_average_abs_max', -} - -# The set of ops that support fp16 calculation and are considered numerically- -# dangerous and whose effects may also be observed in downstream ops. -BLACK_LIST = { - 'exp', - 'square', - 'log', - 'mean', - 'sum', - 'cos_sim', - 'softmax', - 'softmax_with_cross_entropy', - 'sigmoid_cross_entropy_with_logits', - 'c_softmax_with_cross_entropy', - 'cross_entropy', - 'cross_entropy2', - # default fp32 can avoid return inf when the sum value large than 65504 - 'reduce_sum', - # FP16 performance of grad op is worse than that of FP32. Use FP32 by default. - 'linear_interp_v2', - 'nearest_interp_v2', - 'bilinear_interp_v2', - 'bicubic_interp_v2', - 'trilinear_interp_v2', -} - -AMP_RELATED_FLAGS = [ - 'FLAGS_cudnn_exhaustive_search', - 'FLAGS_conv_workspace_size_limit', - 'FLAGS_cudnn_batchnorm_spatial_persistent', -] - -AMP_RELATED_FLAGS_SETTING = { - 'FLAGS_cudnn_exhaustive_search': 1, - 'FLAGS_conv_workspace_size_limit': 1000, - 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, -} - -PURE_FP16_WHITE_LIST = set() -PURE_FP16_BLACK_LIST = { - 'lookup_table', - 'lookup_table_v2', - 'scatter', - 'scatter_grad', - # FP16 performance of grad op is worse than that of FP32. Use FP32 by default. - 'linear_interp_v2', - 'nearest_interp_v2', - 'bilinear_interp_v2', - 'bicubic_interp_v2', - 'trilinear_interp_v2', -} - -BF16_WHITE_LIST = {'conv2d', 'matmul_v2'} -BF16_BLACK_LIST = set() - -PURE_BF16_WHITE_LIST = set() -PURE_BF16_BLACK_LIST = set() - -_g_amp_state_ = None - - -def low_precision_op_list(): - op_list = paddle.fluid.core.get_low_precision_op_list() - op_count = 0 - print('<---------------- low precision op list ------------------->') - print('<---- op name ------|------- op count---------------------->') - for x in op_list: - print(' %-18s| %4d' % (x, op_list[x])) - op_count += 1 - print( - '<------------- low precision op num:{:5d} ----------------->'.format( - op_count - ) - ) - - -def amp_state(): - global _g_amp_state_ - return _g_amp_state_ - - -# NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list -# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode. -def _update_list( - custom_white_list, custom_black_list, level='O1', dtype='float16' -): - """ - Update black and white list according to users' custom list. - """ - if dtype == 'float16': - if level == 'O1': - _white_list = copy.copy(WHITE_LIST) - _black_list = copy.copy(BLACK_LIST) - else: - _white_list = copy.copy(PURE_FP16_WHITE_LIST) - _black_list = copy.copy(PURE_FP16_BLACK_LIST) - else: - if level == 'O1': - _white_list = copy.copy(BF16_WHITE_LIST) - _black_list = copy.copy(BF16_BLACK_LIST) - else: - _white_list = copy.copy(PURE_BF16_WHITE_LIST) - _black_list = copy.copy(PURE_BF16_BLACK_LIST) - if custom_white_list and custom_black_list: - for op_name in custom_white_list: - if op_name in custom_black_list: - raise ValueError( - "Custom white list overlap " "custom black list" - ) - if custom_white_list: - for op_name in custom_white_list: - if op_name in _black_list: - _black_list.remove(op_name) - _white_list.add(op_name) - if custom_black_list: - for op_name in custom_black_list: - if op_name in _white_list: - _white_list.remove(op_name) - _black_list.add(op_name) - return _white_list, _black_list - - -def _in_amp_guard(): - """ - Judge whether current code block is in `amp_guard` context. - """ - tracer = _dygraph_tracer() - if tracer: - if tracer._amp_level == core.AmpLevel.O1: - return True - else: - return False - else: - return False - - -def _in_pure_fp16_guard(): - tracer = _dygraph_tracer() - return tracer and tracer._amp_level == core.AmpLevel.O2 - - -def _is_gpu_float16_supported(): - """ - Judge whether current gpu support float16 amp. - """ - prop = paddle.device.cuda.get_device_capability() - return prop[0] >= 7 - - -def _is_gpu_bfloat16_supported(): - """ - Judge whether current gpu support bfloat16 amp. - """ - prop = paddle.device.cuda.get_device_capability() - cuda_version = paddle.version.cuda() - if cuda_version is not None and cuda_version != 'False': - cuda_version_check = int(cuda_version.split('.')[0]) >= 11 - else: - cuda_version_check = False - return prop[0] >= 8 and cuda_version_check - - -@dygraph_only -def pure_fp16_initialize(models): - for idx in range(len(models)): - for layer in models[idx].sublayers(include_self=True): - layer._casted_by_pure_fp16 = True - if (layer._dtype == 'float16') or isinstance( - layer, - ( - paddle.nn.BatchNorm, - paddle.nn.BatchNorm1D, - paddle.nn.BatchNorm2D, - paddle.nn.BatchNorm3D, - paddle.nn.LayerNorm, - paddle.nn.SyncBatchNorm, - ), - ): - continue - if isinstance( - layer, - ( - paddle.incubate.nn.FusedFeedForward, - paddle.incubate.nn.FusedMultiHeadAttention, - ), - ): - layer._amp_decorate(dtype='float16') - continue - layer._to_impl( - dtype='float16', include_sublayers=False, floating_only=True - ) - return models - - -@dygraph_only -def pure_bf16_initialize(models): - for idx in range(len(models)): - for layer in models[idx].sublayers(include_self=True): - layer._to_impl( - dtype='bfloat16', include_sublayers=False, floating_only=True - ) - return models - - -def check_models(models): - for model in models: - if not isinstance(model, paddle.nn.Layer): - raise RuntimeError( - "Current train mode is pure fp16, models should be paddle.nn.Layer, but receive {}.".format( - type(model) - ) - ) - if isinstance(model, paddle.DataParallel): - raise RuntimeError( - "For distributed AMP training, you should first use paddle.amp.decorate() to decotate origin model, and then call paddle.DataParallel get distributed model." - ) - - -def _is_valid_optimizer(optimizer): - from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( - DygraphShardingOptimizer, - ) - - return isinstance( - optimizer, - ( - paddle.optimizer.Optimizer, - paddle.fluid.optimizer.Optimizer, - DygraphShardingOptimizer, - ), - ) - - -def check_optimizers(optimizers): - for optimizer in optimizers: - if not _is_valid_optimizer(optimizer): - raise RuntimeError( - "Current train mode is pure fp16, optimizers should be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer or DygraphShardingOptimizer, but receive {}.".format( - type(optimizer) - ) - ) - - -@signature_safe_contextmanager -@dygraph_only -def amp_guard( - enable=True, - custom_white_list=None, - custom_black_list=None, - level='O1', - dtype='float16', -): - """ - :api_attr: imperative - - Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode. - If enabled, the input data type (float32 or float16) of each operator is decided - by autocast algorithm for better performance. - - Commonly, it is used together with `GradScaler` to achieve Auto-Mixed-Precision in - imperative mode. It is used together with `decorator` to achieve Pure fp16 in imperative mode. - - Args: - enable(bool, optional): Enable auto-mixed-precision or not. Default is True. - custom_white_list(set|list|tuple, optional): The custom white_list. It's the set of ops that support - fp16 calculation and are considered numerically-safe and performance-critical. These ops - will be converted to fp16. - custom_black_list(set|list|tuple, optional): The custom black_list. The set of ops that support fp16 - calculation and are considered numerically-dangerous and whose effects may also be - observed in downstream ops. These ops will not be converted to fp16. - level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; - O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp) - dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. - - - Examples: - - .. code-block:: python - - import numpy as np - import paddle - - data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') - with paddle.fluid.dygraph.guard(): - conv2d = paddle.fluid.dygraph.Conv2D(3, 2, 3) - data = paddle.fluid.dygraph.to_variable(data) - with paddle.fluid.dygraph.amp_guard(): - conv = conv2d(data) - print(conv.dtype) # FP16 - with paddle.fluid.dygraph.amp_guard(enable=False): - conv = conv2d(data) - print(conv.dtype) # FP32 - - """ - amp_state = locals() - global _g_amp_state_ - original_state = _g_amp_state_ - _g_amp_state_ = amp_state - - # check amp_level: O0-O2 - level = level.upper() - if not (level in ['O0', 'O1', 'O2']): - raise ValueError( - "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode." - ) - - # check amp_dtype: float16 or bfloat16 - dtype = dtype.lower() - if not (dtype in ['float16', 'bfloat16']): - raise ValueError("dtype should be 'float16' or 'bfloat16'.") - - # check tracer - tracer = _dygraph_tracer() - if not tracer: - raise ValueError( - "current_tracer is None, maybe it is not in imperative mode." - ) - - # check device_type: - # NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16, mlu for float16, npu for float16. - # Maybe we will support cpu for bfloat16. - if enable and not ( - tracer._expected_place.is_gpu_place() - or tracer._expected_place.is_xpu_place() - or tracer._expected_place.is_mlu_place() - or tracer._expected_place.is_npu_place() - or tracer._expected_place.is_custom_place() - ): - warnings.warn( - 'amp_guard can only be enabled on CUDAPlace, XPUPlace, MLUPlace, NPUPlace, and CustomPlace, current place is %s, so it makes no effect.' - % tracer._expected_place - ) - enable = False - # For npu: - if tracer._expected_place.is_npu_place() and (dtype == 'bfloat16'): - warnings.warn('NPUPlace only support float16 amp.') - enable = False - # For xpu: - if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'): - warnings.warn('XPUPlace only support float16 amp.') - enable = False - # For mlu: - if tracer._expected_place.is_mlu_place() and (dtype == 'bfloat16'): - warnings.warn('MLUPlace only support float16 amp.') - enable = False - # For custom device: - if tracer._expected_place.is_custom_place() and (dtype == 'bfloat16'): - warnings.warn('CustomPlace only support float16 amp.') - enable = False - # For gpu float16: Compute Capability should >= 7. - # For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11. - if tracer._expected_place.is_gpu_place(): - if (dtype == 'float16') and not _is_gpu_float16_supported(): - prop = paddle.device.cuda.get_device_capability() - warnings.warn( - "For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d." - % (paddle.device.cuda.get_device_name(), prop[0], prop[1]) - ) - elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported(): - prop = paddle.device.cuda.get_device_capability() - cuda_version = paddle.version.cuda() - warnings.warn( - "For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: %s, with Compute Capability: %d.%d, current CUDA Version is: %s." - % ( - paddle.device.cuda.get_device_name(), - prop[0], - prop[1], - cuda_version, - ) - ) - - amp_dtype = dtype - - if level == 'O1': - amp_level = AMP_LEVEL.O1 - if dtype == 'float16': - _white_list = WHITE_LIST - _black_list = BLACK_LIST - elif dtype == 'bfloat16': - _white_list = BF16_WHITE_LIST - _black_list = BF16_BLACK_LIST - - elif level == 'O2': - amp_level = AMP_LEVEL.O2 - if dtype == 'float16': - _white_list = PURE_FP16_WHITE_LIST - _black_list = PURE_FP16_BLACK_LIST - elif dtype == 'bfloat16': - _white_list = BF16_WHITE_LIST - _black_list = BF16_BLACK_LIST - elif level == 'O0': - amp_level = AMP_LEVEL.O0 - if dtype == 'float16': - _white_list = WHITE_LIST - _black_list = BLACK_LIST - elif dtype == 'bfloat16': - _white_list = BF16_WHITE_LIST - _black_list = BF16_BLACK_LIST - - if custom_white_list or custom_black_list: - _white_list, _black_list = _update_list( - custom_white_list, custom_black_list, level, dtype - ) - - if not enable: - amp_level = AMP_LEVEL.O0 - amp_dtype = "float32" - - if tracer: - # enable auto_cast - original_amp_level = tracer._amp_level - tracer._amp_level = amp_level - - # set amp op list - original_white_list, original_black_list = tracer._get_amp_op_list() - tracer._set_amp_op_list(_white_list, _black_list) - - # TODO(zhiqiu) set amp related flags automatically in this guard - # Currently, if FLAGS_cudnn_batchnorm_spatial_persistent is set True in amp_guard, - # batch_norm can run in fast mode, but batch_norm_grad can not if backward if not executed insise amp_guard. - # So, users need to set related flags manually. - - # original_flags = get_flags(AMP_RELATED_FLAGS) - # set_flags(AMP_RELATED_FLAGS_SETTING) - - # set amp dtype - original_amp_dtype = tracer._amp_dtype - tracer._amp_dtype = amp_dtype - - # restore status - try: - yield - finally: - if tracer: - _g_amp_state_ = original_state - tracer._amp_level = original_amp_level - tracer._set_amp_op_list(original_white_list, original_black_list) - # set_flags(original_flags) - tracer._amp_dtype = original_amp_dtype - - -class StateDictHook: - def __init__(self, save_dtype): - self._save_dtype = save_dtype - - def __call__(self, state_dict): - for key in state_dict: - param = state_dict[key] - with paddle.fluid.dygraph.guard(): - if paddle.is_floating_point(param): - param_applied = paddle.cast(param, self._save_dtype) - param_applied.name = param.name - state_dict[key] = param_applied - - -def _set_multi_precision(optimizer, multi_precision): - from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( - DygraphShardingOptimizer, - ) - - optimizer = ( - optimizer._inner_optimizer - if isinstance(optimizer, DygraphShardingOptimizer) - else optimizer - ) - if hasattr(optimizer, "_multi_precision"): - optimizer._multi_precision = multi_precision - - -@dygraph_only -def amp_decorate( - models, - optimizers=None, - level='O1', - dtype='float16', - master_weight=None, - save_dtype=None, -): - """ - Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing. - When level is O2(pure fp16), the decorate will cast all parameters of models to FP16, except BatchNorm and LayerNorm. - - Commonly, it is used together with `amp_guard` to achieve Pure fp16 in imperative mode. - - Args: - models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None. - optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None. - level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing; - O2 represent Pure fp16/bf16, the decorator will cast all parameters of models to FP16/BF16, except BatchNorm and LayerNorm. Default is O1(amp) - dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. - master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None. - save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None. - The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None. - - Examples: - - .. code-block:: python - - # required: gpu - # Demo1: single model and optimizer: - import paddle - - model = paddle.nn.Conv2D(3, 2, 3, bias_attr=False) - optimizer = paddle.optimizer.SGD(parameters=model.parameters()) - - model, optimizer = paddle.fluid.dygraph.amp_decorate(models=model, optimizers=optimizer, level='O2') - - data = paddle.rand([10, 3, 32, 32]) - - with paddle.fluid.dygraph.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'): - output = model(data) - print(output.dtype) # FP16 - - # required: gpu - # Demo2: multi models and optimizers: - model2 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False) - optimizer2 = paddle.optimizer.Adam(parameters=model2.parameters()) - - models, optimizers = paddle.fluid.dygraph.amp_decorate(models=[model, model2], optimizers=[optimizer, optimizer2], level='O2') - - data = paddle.rand([10, 3, 32, 32]) - - with paddle.fluid.dygraph.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'): - output = models[0](data) - output2 = models[1](data) - print(output.dtype) # FP16 - print(output2.dtype) # FP16 - - # required: gpu - # Demo3: optimizers is None: - model3 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False) - optimizer3 = paddle.optimizer.Adam(parameters=model2.parameters()) - - model = paddle.fluid.dygraph.amp_decorate(models=model3, level='O2') - - data = paddle.rand([10, 3, 32, 32]) - - with paddle.fluid.dygraph.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'): - output = model(data) - print(output.dtype) # FP16 - """ - if not (level in ['O1', 'O2']): - raise ValueError( - "level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode." - ) - - if level == 'O1': - if optimizers is None: - return models - else: - return models, optimizers - - models_is_list = False - if isinstance(models, paddle.nn.Layer): - models_is_list = False - models = [models] - check_models(models) - elif isinstance(models, list): - check_models(models) - models_is_list = True - else: - raise TypeError( - "models must be either a single model or a list of models." - ) - if dtype == 'float16': - models = pure_fp16_initialize(models=models) - elif dtype == 'bfloat16': - models = pure_bf16_initialize(models=models) - else: - raise TypeError("dtype only support float16 or bfloat16.") - - if optimizers is not None: - # check optimizers - optimizers_is_list = False - if _is_valid_optimizer(optimizers): - optimizers_is_list = False - optimizers = [optimizers] - check_optimizers(optimizers) - elif isinstance(optimizers, list): - check_optimizers(optimizers) - optimizers_is_list = True - else: - raise TypeError( - "optimizers must be either a single optimizer or a list of optimizers." - ) - # support master_weight - use_multi_precision = not (master_weight is False) - for opt in optimizers: - _set_multi_precision(opt, use_multi_precision) - - if save_dtype is not None: - if not (save_dtype in ['float16', 'bfloat16', 'float32', 'float64']): - raise ValueError( - "save_dtype can only be float16 float32 or float64, but your input save_dtype is %s." - % save_dtype - ) - for idx in range(len(models)): - for layer in models[idx].sublayers(include_self=True): - layer.register_state_dict_hook(StateDictHook(save_dtype)) - - if models_is_list: - if optimizers is not None: - if optimizers_is_list: - return models, optimizers - else: - return models, optimizers[0] - else: - return models - else: - if optimizers is not None: - if optimizers_is_list: - return models[0], optimizers - else: - return models[0], optimizers[0] - else: - return models[0] diff --git a/python/paddle/fluid/dygraph/amp/loss_scaler.py b/python/paddle/fluid/dygraph/amp/loss_scaler.py deleted file mode 100644 index 3ecdd7019b1..00000000000 --- a/python/paddle/fluid/dygraph/amp/loss_scaler.py +++ /dev/null @@ -1,589 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from paddle.fluid import core -from paddle.fluid.dygraph import to_variable -from paddle.fluid.framework import ( - _varbase_creator, - _dygraph_tracer, - dygraph_only, -) -from paddle.fluid.data_feeder import check_type -from ...wrapped_decorator import signature_safe_contextmanager, wrap_decorator -import warnings -import numpy as np -from paddle import _C_ops, _legacy_C_ops -from collections import defaultdict -from enum import Enum -from paddle.fluid import in_dygraph_mode - -__all__ = ['AmpScaler', 'OptimizerState'] - - -class OptimizerState(Enum): - INIT = 0 - UNSCALED = 1 - STEPPED = 2 - - -def _refresh_optimizer_state(): - return {"state": OptimizerState.INIT} - - -class AmpScaler: - """ - :api_attr: imperative - - AmpScaler is used for Auto-Mixed-Precision training/inferring in imperative - mode. It controls the scaling of loss, helps avoiding numerical overflow. - The object of this class has seventeen methods `scale()`, `unscale_()`, `minimize()` and `get`/`set` api of parameters. - - `scale()` is used to multiply the loss by a scale ratio. - `unscale_()` is used to unscale the gradients of parameters, multiplies the gradients of parameters by 1/(scale ratio) - `minimize()` is similar as `optimizer.minimize()`, performs parameters updating, and it will update the loss_scaling. - - Commonly, it is used together with `amp_guard` to achieve Auto-Mixed-Precision in - imperative mode. - - Args: - enable(bool, optional): Enable loss scaling or not. Default is True. - init_loss_scaling (float, optional): The initial loss scaling factor. Default is 2**15. - incr_ratio(float, optional): The multiplier to use when increasing the loss - scaling. Default is 2.0. - decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing - the loss scaling. Default is 0.5. - incr_every_n_steps(int, optional): Increases loss scaling every n consecutive - steps with finite gradients. Default is 1000. - decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n - accumulated steps with nan or inf gradients. Default is 2. - use_dynamic_loss_scaling(bool, optional): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True. - Returns: - An AmpScaler object. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle.fluid as fluid - - data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') - with fluid.dygraph.guard(): - model = fluid.dygraph.Conv2D(3, 2, 3) - optimizer = fluid.optimizer.SGDOptimizer( - learning_rate=0.01, parameter_list=model.parameters()) - scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024) - data = fluid.dygraph.to_variable(data) - with fluid.dygraph.amp_guard(): - conv = model(data) - loss = fluid.layers.reduce_mean(conv) - scaled = scaler.scale(loss) - scaled.backward() - scaler.minimize(optimizer, scaled) - """ - - @dygraph_only - def __init__( - self, - enable=True, - init_loss_scaling=2.0**15, - incr_ratio=2.0, - decr_ratio=0.5, - incr_every_n_steps=1000, - decr_every_n_nan_or_inf=1, - use_dynamic_loss_scaling=True, - ): - - tracer = _dygraph_tracer() - if not tracer: - raise ValueError( - "current_tracer is None, maybe it is not in imperative mode." - ) - - if enable and not ( - tracer._expected_place.is_gpu_place() - or tracer._expected_place.is_xpu_place() - or tracer._expected_place.is_mlu_place() - or tracer._expected_place.is_npu_place() - or tracer._expected_place.is_custom_place() - ): - warnings.warn( - 'AmpScaler can only be enabled on CUDAPlace, XPUPlace, MLUPlace, NPUPlace and CustomPlace, current place is %s, so it makes no effect.' - % tracer._expected_place - ) - enable = False - - self._enable = enable - - if self._enable: - assert incr_ratio > 1.0, "The incr_ratio must be > 1.0." - assert decr_ratio < 1.0, "The decr_ratio must be < 1.0." - - self._init_loss_scaling = init_loss_scaling - self._incr_ratio = incr_ratio - self._decr_ratio = decr_ratio - self._incr_every_n_steps = incr_every_n_steps - self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf - self._incr_count = 0 - self._decr_count = 0 - self._use_dynamic_loss_scaling = use_dynamic_loss_scaling - - self._found_inf = to_variable(np.array([0]).astype(np.bool_)) - self._temp_found_inf_fp16 = to_variable( - np.array([0]).astype(np.bool_) - ) - self._temp_found_inf_bf16 = to_variable( - np.array([0]).astype(np.bool_) - ) - self._temp_found_inf_fp32 = to_variable( - np.array([0]).astype(np.bool_) - ) - self._scale = to_variable( - np.array([self._init_loss_scaling]).astype(np.float32) - ) - self._cache_founf_inf = None - self._optimizer_states = defaultdict(_refresh_optimizer_state) - - def scale(self, var): - """ - Multiplies a variable(Tensor) by the scale factor and returns scaled outputs. - If this instance of :class:`AmpScaler` is not enabled, output are returned unmodified. - - Args: - var (Variable): The variable to scale. - Returns: - The scaled variable or original variable. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle.fluid as fluid - - data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') - with fluid.dygraph.guard(): - model = fluid.dygraph.Conv2D(3, 2, 3) - optimizer = fluid.optimizer.SGDOptimizer( - learning_rate=0.01, parameter_list=model.parameters()) - scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024) - data = fluid.dygraph.to_variable(data) - with fluid.dygraph.amp_guard(): - conv = model(data) - loss = fluid.layers.reduce_mean(conv) - scaled = scaler.scale(loss) - scaled.backward() - scaler.minimize(optimizer, scaled) - """ - check_type(var, "var", core.VarBase, 'AmpScaler.scale()') - - if not self._enable: - return var - - return var * self._scale - - def minimize(self, optimizer, *args, **kwargs): - """ - This function is similar as `Optimizer.minimize()`, which performs parameters updating. - - If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped. - Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters. - - Finally, the loss scaling ratio is updated. - - Args: - optimizer(Optimizer): The optimizer used to update parameters. - args: Arguments, which will be forward to `optimizer.minimize()`. - kwargs: Keyword arguments, which will be forward to `Optimizer.minimize()`. - - Examples: - - .. code-block:: python - - import numpy as np - import paddle.fluid as fluid - - data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') - with fluid.dygraph.guard(): - model = fluid.dygraph.Conv2D(3, 2, 3) - optimizer = fluid.optimizer.SGDOptimizer( - learning_rate=0.01, parameter_list=model.parameters()) - scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024) - data = fluid.dygraph.to_variable(data) - with fluid.dygraph.amp_guard(): - conv = model(data) - loss = fluid.layers.reduce_mean(conv) - scaled = scaler.scale(loss) - scaled.backward() - scaler.minimize(optimizer, scaled) - """ - if not self._enable: - return optimizer.minimize(*args, **kwargs) - - optimizer_state = self._optimizer_states[id(optimizer)] - - # unscale the grad - if optimizer_state["state"] is OptimizerState.INIT: - self._unscale(optimizer) - - optimize_ops, params_grads = (None, None) - - if self._found_inf: - self._cache_founf_inf = True - else: - optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) - self._cache_founf_inf = False - - if self._use_dynamic_loss_scaling: - # uopdate the scale - self._update() - - self._optimizer_states = defaultdict(_refresh_optimizer_state) - - return optimize_ops, params_grads - - def _unscale(self, optimizer): - """ - Unscale the gradients of parameters, multiplies the gradients of parameters by 1/(loss scaling ratio). - If this instance of :class:`GradScaler` is not enabled, output are returned unmodified. - Args: - optimizer(Optimizer): The optimizer used to update parameters. - Returns: - The unscaled parameters or original parameters. - """ - if not self._enable: - return - - optimizer_state = self._optimizer_states[id(optimizer)] - - if optimizer_state["state"] is OptimizerState.UNSCALED: - raise RuntimeError( - "unscale_() has already been called on this optimizer since the last update()." - ) - elif optimizer_state["state"] is OptimizerState.STEPPED: - raise RuntimeError("unscale_() is being called after step().") - - if getattr(optimizer, '_param_groups', None) and isinstance( - optimizer._param_groups[0], dict - ): - param_grads = [] - param_grads_fp16 = [] - param_grads_bf16 = [] - param_grads_fp32 = [] - for group in optimizer._param_groups: - for param in group['params']: - if param._grad_ivar() is not None: - param_grads.append(param._grad_ivar()) - if ( - param._grad_ivar().dtype - == core.VarDesc.VarType.FP16 - ): - param_grads_fp16.append(param._grad_ivar()) - elif ( - param._grad_ivar().dtype - == core.VarDesc.VarType.BF16 - ): - param_grads_bf16.append(param._grad_ivar()) - else: - param_grads_fp32.append(param._grad_ivar()) - else: - if in_dygraph_mode(): - # It is very time-consuming to call c++ functions in a loop on the python side. - # We put this part of the code on the c++ side to improve the speed in eager mode. - ( - param_grads_fp16, - param_grads_bf16, - param_grads_fp32, - ) = core.eager.get_grads_lists(optimizer._parameter_list) - else: - # Keep the original code to support legacy mode. - # Delete the else branch when the legacy mode exits. - param_grads = [ - param._grad_ivar() - for param in optimizer._parameter_list - if param._grad_ivar() is not None - ] - param_grads_fp16 = [ - param - for param in param_grads - if param.dtype == core.VarDesc.VarType.FP16 - ] - param_grads_bf16 = [ - param - for param in param_grads - if param.dtype == core.VarDesc.VarType.BF16 - ] - param_grads_fp32 = [ - param - for param in param_grads - if param.dtype == core.VarDesc.VarType.FP32 - ] - if core.is_compiled_with_npu(): - float_status = _legacy_C_ops.alloc_float_status() - _legacy_C_ops.clear_float_status(float_status, float_status) - - if len(param_grads_fp16): - _legacy_C_ops.check_finite_and_unscale( - param_grads_fp16, - self._scale, - float_status, - param_grads_fp16, - self._temp_found_inf_fp16, - ) - if len(param_grads_bf16): - _legacy_C_ops.check_finite_and_unscale( - param_grads_bf16, - self._scale, - float_status, - param_grads_bf16, - self._temp_found_inf_bf16, - ) - if len(param_grads_fp32): - _legacy_C_ops.check_finite_and_unscale( - param_grads_fp32, - self._scale, - float_status, - param_grads_fp32, - self._temp_found_inf_fp32, - ) - else: - if len(param_grads_fp16): - _legacy_C_ops.check_finite_and_unscale( - param_grads_fp16, - self._scale, - param_grads_fp16, - self._temp_found_inf_fp16, - ) - if len(param_grads_bf16): - _legacy_C_ops.check_finite_and_unscale( - param_grads_bf16, - self._scale, - param_grads_bf16, - self._temp_found_inf_bf16, - ) - if len(param_grads_fp32): - _legacy_C_ops.check_finite_and_unscale( - param_grads_fp32, - self._scale, - param_grads_fp32, - self._temp_found_inf_fp32, - ) - - self._found_inf = ( - self._temp_found_inf_fp16 - or self._temp_found_inf_bf16 - or self._temp_found_inf_fp32 - ) - - optimizer_state["state"] = OptimizerState.UNSCALED - - def _update(self): - """ - Updates the loss_scaling. - """ - if not self._enable: - return - - if self._cache_founf_inf: - self._incr_count = 0 - self._decr_count = self._decr_count + 1 - if self._decr_count == self._decr_every_n_nan_or_inf: - print( - 'Found inf or nan, current scale is: {}, decrease to: {}*{}'.format( - float(self._scale), - float(self._scale), - float(self._decr_ratio), - ) - ) - self._scale = self._scale * self._decr_ratio - self._decr_count = 0 - else: - self._decr_count = 0 - self._incr_count = self._incr_count + 1 - if self._incr_count == self._incr_every_n_steps: - self._scale = self._scale * self._incr_ratio - self._incr_count = 0 - - return - - def is_enable(self): - """ - Enable loss scaling or not. - - Returns: - bool: enable loss scaling return True else return False. - """ - return self._enable - - def is_use_dynamic_loss_scaling(self): - """ - Whether to use dynamic loss scaling. - - Returns: - bool: if fixed loss_scaling is used return False, if the loss scaling is updated dynamicly return true. - """ - return self._use_dynamic_loss_scaling - - def get_init_loss_scaling(self): - """ - Return the initial loss scaling factor. - - Reurns: - float: the initial loss scaling factor. - """ - return self._init_loss_scaling - - def set_init_loss_scaling(self, new_init_loss_scaling): - """ - Set the initial loss scaling factor by `new_init_loss_scaling`. - - Args: - new_init_loss_scaling(int): The new_init_loss_scaling used to update initial loss scaling factor.s - """ - self._init_loss_scaling = new_init_loss_scaling - self._scale = to_variable( - np.array([self._init_loss_scaling]).astype(np.float32) - ) - - def get_incr_ratio(self): - """ - Return the multiplier to use when increasing the loss scaling. - - Reurns: - float: the multiplier to use when increasing the loss scaling. - """ - return self._incr_ratio - - def set_incr_ratio(self, new_incr_ratio): - """ - Set the multiplier to use when increasing the loss scaling by `new_incr_ratio`, `new_incr_ratio` should > 1.0. - - Args: - new_incr_ratio(float): The new_incr_ratio used to update the multiplier to use when increasing the loss scaling. - """ - assert new_incr_ratio > 1.0, "The new_incr_ratio must be > 1.0." - self._incr_ratio = new_incr_ratio - - def get_decr_ratio(self): - """ - Get the less-than-one-multiplier to use when decreasing the loss scaling. - - Reurns: - float: the less-than-one-multiplier to use when decreasing the loss scaling. - """ - return self._decr_ratio - - def set_decr_ratio(self, new_decr_ratio): - """ - Set the less-than-one-multiplier to use when decreasing the loss scaling by `new_incr_ratio`, `new_decr_ratio` should < 1.0. - - Args: - new_decr_ratio(float): The new_decr_ratio used to update the less-than-one-multiplier to use when decreasing the loss scaling. - """ - assert new_decr_ratio < 1.0, "The new_decr_ratio must be < 1.0." - self._decr_ratio = new_decr_ratio - - def get_incr_every_n_steps(self): - """ - Return the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients. - - Reurns: - int: the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients. - """ - return self._incr_every_n_steps - - def set_incr_every_n_steps(self, new_incr_every_n_steps): - """ - Set the num `n` by `new_incr_every_n_steps`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients. - - Args: - new_incr_every_n_steps(int): The new_incr_every_n_steps used to update the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients. - """ - self._incr_every_n_steps = new_incr_every_n_steps - - def get_decr_every_n_nan_or_inf(self): - """ - Return the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients. - - Reurns: - int: the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients. - """ - return self._decr_every_n_nan_or_inf - - def set_decr_every_n_nan_or_inf(self, new_decr_every_n_nan_or_inf): - """ - Set the num `n` by `new_decr_every_n_nan_or_inf`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients. - - Args: - new_decr_every_n_nan_or_inf(int): The new_decr_every_n_nan_or_inf used to update the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients. - """ - self._decr_every_n_nan_or_inf = new_decr_every_n_nan_or_inf - - def state_dict(self): - """ - Returns the state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict. - - Reurns: - A dict of scaler includes: - scale (tensor): The loss scaling factor. - incr_ratio(float): The multiplier to use when increasing the loss scaling. - decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling. - incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients. - decr_every_n_nan_or_inf(int): Decreases loss scaling every n accumulated steps with nan or inf gradients. - incr_count(int): The number of recent consecutive unskipped steps. - decr_count(int): The number of recent consecutive skipped steps. - use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True. - """ - return ( - { - "scale": self._scale.numpy(), - "incr_ratio": self._incr_ratio, - "decr_ratio": self._decr_ratio, - "incr_every_n_steps": self._incr_every_n_steps, - "decr_every_n_nan_or_inf": self._decr_every_n_nan_or_inf, - "incr_count": self._incr_count, - "decr_count": self._decr_count, - "use_dynamic_loss_scaling": self._use_dynamic_loss_scaling, - } - if self._enable - else {} - ) - - def load_state_dict(self, state_dict): - """ - Loads the scaler state. - - Args: - state_dict(dict): scaler state. Should be an object returned from a call to `AmpScaler.state_dict()`. - """ - if not self._enable: - return - - if len(state_dict) == 0: - raise RuntimeError( - "The input state dict is empty, possibly because it was saved " - "from a disabled instance of GradScaler." - ) - - self._init_loss_scaling = state_dict["scale"][0] - self._scale = to_variable( - np.array([self._init_loss_scaling]).astype(np.float32) - ) - self._incr_ratio = state_dict["incr_ratio"] - self._decr_ratio = state_dict["decr_ratio"] - self._incr_every_n_steps = state_dict["incr_every_n_steps"] - self._decr_every_n_nan_or_inf = state_dict["decr_every_n_nan_or_inf"] - self._incr_count = state_dict["incr_count"] - self._decr_count = state_dict["decr_count"] - self._use_dynamic_loss_scaling = state_dict["use_dynamic_loss_scaling"] diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py index 9be07f9279d..3c97b3ca152 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py @@ -60,10 +60,10 @@ class TestAutoCast(unittest.TestCase): with fluid.dygraph.guard(): conv2d = paddle.nn.Conv2D(3, 2, 3, bias_attr=False) data = fluid.dygraph.to_variable(data) - with fluid.dygraph.amp_guard(True): + with paddle.amp.amp_guard(True): out_fp16 = conv2d(data) - with fluid.dygraph.amp_guard(False): + with paddle.amp.amp_guard(False): out_fp32 = conv2d(data) self.assertTrue(data.dtype == fluid.core.VarDesc.VarType.FP32) @@ -77,7 +77,7 @@ class TestAutoCast(unittest.TestCase): data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') with fluid.dygraph.guard(): data = fluid.dygraph.to_variable(data) - with fluid.dygraph.amp_guard(True): + with paddle.amp.amp_guard(True): out_fp32 = paddle.mean(data) self.assertTrue(data.dtype == fluid.core.VarDesc.VarType.FP32) @@ -89,9 +89,9 @@ class TestAutoCast(unittest.TestCase): def custom_op_list(self): with fluid.dygraph.guard(): tracer = fluid.framework._dygraph_tracer() - base_white_list = fluid.dygraph.amp.auto_cast.WHITE_LIST - base_black_list = fluid.dygraph.amp.auto_cast.BLACK_LIST - with fluid.dygraph.amp_guard( + base_white_list = paddle.amp.WHITE_LIST + base_black_list = paddle.amp.BLACK_LIST + with paddle.amp.amp_guard( custom_white_list=["log"], custom_black_list=["conv2d"] ): white_list, black_list = tracer._get_amp_op_list() @@ -105,9 +105,9 @@ class TestAutoCast(unittest.TestCase): == (set(base_black_list) - {"log"}) | {"conv2d"} ) - base_white_list = fluid.dygraph.amp.auto_cast.PURE_FP16_WHITE_LIST - base_black_list = fluid.dygraph.amp.auto_cast.PURE_FP16_BLACK_LIST - with fluid.dygraph.amp_guard( + base_white_list = paddle.amp.PURE_FP16_WHITE_LIST + base_black_list = paddle.amp.PURE_FP16_BLACK_LIST + with paddle.amp.amp_guard( custom_white_list=["log"], custom_black_list=["conv2d"], level='O2', @@ -138,7 +138,7 @@ class TestAutoCast(unittest.TestCase): stride=2, act='relu', ) - with fluid.dygraph.amp_guard( + with paddle.amp.amp_guard( custom_white_list=["conv2d"], custom_black_list=["conv2d"] ): inp = fluid.dygraph.to_variable(inp_np) @@ -154,13 +154,13 @@ class TestAutoCast(unittest.TestCase): with fluid.dygraph.guard(): conv2d = paddle.nn.Conv2D(3, 2, 3, bias_attr=False) data = fluid.dygraph.to_variable(data) - with fluid.dygraph.amp_guard(True): + with paddle.amp.amp_guard(True): out_amp_fp16 = conv2d(data) out_amp_fp32 = paddle.expand_as( out_amp_fp16, out_amp_fp16 ) # expand_as_v2 has no fp16 kernel - with fluid.dygraph.amp_guard(True, level='O2'): + with paddle.amp.amp_guard(True, level='O2'): out_purefp16_fp16 = conv2d(data) out_purefp16_fp32 = paddle.expand_as( out_purefp16_fp16, out_purefp16_fp16 @@ -184,7 +184,7 @@ class TestAutoCast(unittest.TestCase): with fluid.dygraph.guard(): conv2d = paddle.nn.Conv2D(3, 2, 3, bias_attr=False) data = fluid.dygraph.to_variable(data) - with fluid.dygraph.amp_guard(level='O'): + with paddle.amp.amp_guard(level='O'): out = conv2d(data) self.assertRaises(ValueError, func) @@ -197,7 +197,7 @@ class TestAmpScaler(unittest.TestCase): def scale(self): with fluid.dygraph.guard(): data = paddle.rand([10, 1024]) - scaler = paddle.fluid.dygraph.AmpScaler(init_loss_scaling=1024) + scaler = paddle.amp.AmpScaler(init_loss_scaling=1024) scaled_data = scaler.scale(data) self.assertEqual( np.array_equal(scaled_data.numpy(), data.numpy() * 1024), True @@ -223,7 +223,7 @@ class TestAmpScaler(unittest.TestCase): optimizer = fluid.optimizer.SGDOptimizer( learning_rate=0.01, parameter_list=model.parameters() ) - scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024) + scaler = paddle.amp.AmpScaler(init_loss_scaling=1024) data = fluid.dygraph.to_variable(inp_np) out = model(data) @@ -332,7 +332,7 @@ class TestAmpScaler(unittest.TestCase): optimizer = fluid.optimizer.SGDOptimizer( learning_rate=0.01, parameter_list=model.parameters() ) - scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024) + scaler = paddle.amp.AmpScaler(init_loss_scaling=1024) data = fluid.dygraph.to_variable(inp_np) out = model(data) @@ -1262,12 +1262,12 @@ class TestResnet(unittest.TestCase): dy_param_init_value[param.name] = param.numpy() program = None - scaler = paddle.fluid.dygraph.AmpScaler( + scaler = paddle.amp.AmpScaler( enable=enable_amp, init_loss_scaling=2.0**10 ) if enable_amp and (level == 'O2'): - resnet, optimizer = paddle.fluid.dygraph.amp_decorate( + resnet, optimizer = paddle.amp.amp_decorate( models=resnet, optimizers=optimizer, level='O2' ) @@ -1290,9 +1290,7 @@ class TestResnet(unittest.TestCase): img = fluid.dygraph.to_variable(dy_x_data) label = fluid.dygraph.to_variable(y_data) label.stop_gradient = True - with paddle.fluid.dygraph.amp_guard( - enable=enable_amp, level=level - ): + with paddle.amp.amp_guard(enable=enable_amp, level=level): out = resnet(img) loss = paddle.nn.functional.cross_entropy( diff --git a/python/paddle/fluid/tests/unittests/test_low_precision_list.py b/python/paddle/fluid/tests/unittests/test_low_precision_list.py index 7099fbe168c..afa737bfed4 100644 --- a/python/paddle/fluid/tests/unittests/test_low_precision_list.py +++ b/python/paddle/fluid/tests/unittests/test_low_precision_list.py @@ -28,7 +28,7 @@ class TestAMPList(unittest.TestCase): with paddle.amp.auto_cast(): conv = conv2d(data) c = a + b - paddle.fluid.dygraph.amp.auto_cast.low_precision_op_list() + paddle.amp.low_precision_op_list() op_list = paddle.fluid.core.get_low_precision_op_list() print(conv.dtype) if conv.dtype == paddle.float16: diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index 805c7b743b0..d9cc9c390dd 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -18,6 +18,7 @@ import numpy as np import paddle from paddle import _legacy_C_ops +from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard from paddle.fluid import backward, core, framework, program_guard from paddle.fluid.compiler import BuildStrategy from paddle.fluid.contrib.mixed_precision.decorator import ( @@ -28,10 +29,6 @@ from paddle.fluid.contrib.mixed_precision.fp16_utils import ( rewrite_program, ) from paddle.fluid.dygraph import layers -from paddle.fluid.dygraph.amp.auto_cast import ( - _in_amp_guard, - _in_pure_fp16_guard, -) from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.executor import ( _is_dy2st_enable_standalone_executor, diff --git a/python/setup.py.in b/python/setup.py.in index f3e7b9d8c6e..cb906b54458 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -331,7 +331,6 @@ packages=['paddle', 'paddle.inference.contrib.utils', 'paddle.fluid', 'paddle.fluid.dygraph', - 'paddle.fluid.dygraph.amp', 'paddle.fluid.proto', 'paddle.fluid.proto.profiler', 'paddle.fluid.distributed', diff --git a/setup.py b/setup.py index 5dd71fa7eb9..879adb4935d 100644 --- a/setup.py +++ b/setup.py @@ -1202,7 +1202,6 @@ def get_setup_parameters(): 'paddle.inference.contrib.utils', 'paddle.fluid', 'paddle.fluid.dygraph', - 'paddle.fluid.dygraph.amp', 'paddle.fluid.proto', 'paddle.fluid.proto.profiler', 'paddle.fluid.distributed', -- GitLab