未验证 提交 da3e9d66 编写于 作者: Z zhangkaihuo 提交者: GitHub

move fuild.dygraph.amp to paddle.amp (#49193)

上级 343bff7b
......@@ -13,7 +13,18 @@
# limitations under the License.
from .auto_cast import auto_cast # noqa: F401
from .grad_scaler import GradScaler # noqa: F401
from .auto_cast import decorate # noqa: F401
from .auto_cast import amp_guard # noqa: F401
from .auto_cast import amp_decorate # noqa: F401
from .auto_cast import low_precision_op_list # noqa: F401
from .auto_cast import WHITE_LIST # noqa: F401
from .auto_cast import BLACK_LIST # noqa: F401
from .auto_cast import PURE_FP16_WHITE_LIST # noqa: F401
from .auto_cast import PURE_FP16_BLACK_LIST # noqa: F401
from . import grad_scaler # noqa: F401
from .grad_scaler import GradScaler # noqa: F401
from .grad_scaler import AmpScaler # noqa: F401
from .grad_scaler import OptimizerState # noqa: F401
__all__ = ['auto_cast', 'GradScaler', 'decorate']
......@@ -12,9 +12,638 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.dygraph.amp import amp_decorate, amp_guard
import copy
import warnings
import paddle
from paddle.fluid import core
from paddle.fluid.framework import _dygraph_tracer, dygraph_only
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
AMP_LEVEL = core.AmpLevel
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
WHITE_LIST = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
'fake_quantize_dequantize_abs_max',
'fake_quantize_dequantize_moving_average_abs_max',
}
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
BLACK_LIST = {
'exp',
'square',
'log',
'mean',
'sum',
'cos_sim',
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy',
'cross_entropy2',
# default fp32 can avoid return inf when the sum value large than 65504
'reduce_sum',
# FP16 performance of grad op is worse than that of FP32. Use FP32 by default.
'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
}
AMP_RELATED_FLAGS = [
'FLAGS_cudnn_exhaustive_search',
'FLAGS_conv_workspace_size_limit',
'FLAGS_cudnn_batchnorm_spatial_persistent',
]
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_exhaustive_search': 1,
'FLAGS_conv_workspace_size_limit': 1000,
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
}
PURE_FP16_WHITE_LIST = set()
PURE_FP16_BLACK_LIST = {
'lookup_table',
'lookup_table_v2',
'scatter',
'scatter_grad',
# FP16 performance of grad op is worse than that of FP32. Use FP32 by default.
'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
}
BF16_WHITE_LIST = {'conv2d', 'matmul_v2'}
BF16_BLACK_LIST = set()
PURE_BF16_WHITE_LIST = set()
PURE_BF16_BLACK_LIST = set()
_g_amp_state_ = None
def low_precision_op_list():
op_list = paddle.fluid.core.get_low_precision_op_list()
op_count = 0
print('<---------------- low precision op list ------------------->')
print('<---- op name ------|------- op count---------------------->')
for x in op_list:
print(' %-18s| %4d' % (x, op_list[x]))
op_count += 1
print(
'<------------- low precision op num:{:5d} ----------------->'.format(
op_count
)
)
def amp_state():
global _g_amp_state_
return _g_amp_state_
# NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list
# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode.
def _update_list(
custom_white_list, custom_black_list, level='O1', dtype='float16'
):
"""
Update black and white list according to users' custom list.
"""
if dtype == 'float16':
if level == 'O1':
_white_list = copy.copy(WHITE_LIST)
_black_list = copy.copy(BLACK_LIST)
else:
_white_list = copy.copy(PURE_FP16_WHITE_LIST)
_black_list = copy.copy(PURE_FP16_BLACK_LIST)
else:
if level == 'O1':
_white_list = copy.copy(BF16_WHITE_LIST)
_black_list = copy.copy(BF16_BLACK_LIST)
else:
_white_list = copy.copy(PURE_BF16_WHITE_LIST)
_black_list = copy.copy(PURE_BF16_BLACK_LIST)
if custom_white_list and custom_black_list:
for op_name in custom_white_list:
if op_name in custom_black_list:
raise ValueError(
"Custom white list overlap " "custom black list"
)
if custom_white_list:
for op_name in custom_white_list:
if op_name in _black_list:
_black_list.remove(op_name)
_white_list.add(op_name)
if custom_black_list:
for op_name in custom_black_list:
if op_name in _white_list:
_white_list.remove(op_name)
_black_list.add(op_name)
return _white_list, _black_list
def _in_amp_guard():
"""
Judge whether current code block is in `amp_guard` context.
"""
tracer = _dygraph_tracer()
if tracer:
if tracer._amp_level == core.AmpLevel.O1:
return True
else:
return False
else:
return False
def _in_pure_fp16_guard():
tracer = _dygraph_tracer()
return tracer and tracer._amp_level == core.AmpLevel.O2
def _is_gpu_float16_supported():
"""
Judge whether current gpu support float16 amp.
"""
prop = paddle.device.cuda.get_device_capability()
return prop[0] >= 7
def _is_gpu_bfloat16_supported():
"""
Judge whether current gpu support bfloat16 amp.
"""
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
if cuda_version is not None and cuda_version != 'False':
cuda_version_check = int(cuda_version.split('.')[0]) >= 11
else:
cuda_version_check = False
return prop[0] >= 8 and cuda_version_check
@dygraph_only
def pure_fp16_initialize(models):
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
layer._casted_by_pure_fp16 = True
if (layer._dtype == 'float16') or isinstance(
layer,
(
paddle.nn.BatchNorm,
paddle.nn.BatchNorm1D,
paddle.nn.BatchNorm2D,
paddle.nn.BatchNorm3D,
paddle.nn.LayerNorm,
paddle.nn.SyncBatchNorm,
),
):
continue
if isinstance(
layer,
(
paddle.incubate.nn.FusedFeedForward,
paddle.incubate.nn.FusedMultiHeadAttention,
),
):
layer._amp_decorate(dtype='float16')
continue
layer._to_impl(
dtype='float16', include_sublayers=False, floating_only=True
)
return models
@dygraph_only
def pure_bf16_initialize(models):
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
layer._to_impl(
dtype='bfloat16', include_sublayers=False, floating_only=True
)
return models
def check_models(models):
for model in models:
if not isinstance(model, paddle.nn.Layer):
raise RuntimeError(
"Current train mode is pure fp16, models should be paddle.nn.Layer, but receive {}.".format(
type(model)
)
)
if isinstance(model, paddle.DataParallel):
raise RuntimeError(
"For distributed AMP training, you should first use paddle.amp.decorate() to decotate origin model, and then call paddle.DataParallel get distributed model."
)
def _is_valid_optimizer(optimizer):
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)
return isinstance(
optimizer,
(
paddle.optimizer.Optimizer,
paddle.fluid.optimizer.Optimizer,
DygraphShardingOptimizer,
),
)
def check_optimizers(optimizers):
for optimizer in optimizers:
if not _is_valid_optimizer(optimizer):
raise RuntimeError(
"Current train mode is pure fp16, optimizers should be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer or DygraphShardingOptimizer, but receive {}.".format(
type(optimizer)
)
)
@signature_safe_contextmanager
@dygraph_only
def amp_guard(
enable=True,
custom_white_list=None,
custom_black_list=None,
level='O1',
dtype='float16',
):
"""
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
If enabled, the input data type (float32 or float16) of each operator is decided
by autocast algorithm for better performance.
Commonly, it is used together with `GradScaler` to achieve Auto-Mixed-Precision in
imperative mode. It is used together with `decorator` to achieve Pure fp16 in imperative mode.
Args:
enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
custom_white_list(set|list|tuple, optional): The custom white_list. It's the set of ops that support
fp16 calculation and are considered numerically-safe and performance-critical. These ops
will be converted to fp16.
custom_black_list(set|list|tuple, optional): The custom black_list. The set of ops that support fp16
calculation and are considered numerically-dangerous and whose effects may also be
observed in downstream ops. These ops will not be converted to fp16.
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list;
O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp)
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
Examples:
.. code-block:: python
import numpy as np
import paddle
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
conv2d = paddle.nn.Conv2D(3, 2, 3)
data = paddle.to_tensor(data)
with paddle.amp.amp_guard():
conv = conv2d(data)
print(conv.dtype) # FP16
with paddle.amp.amp_guard(enable=False):
conv = conv2d(data)
print(conv.dtype) # FP32
"""
amp_state = locals()
global _g_amp_state_
original_state = _g_amp_state_
_g_amp_state_ = amp_state
# check amp_level: O0-O2
level = level.upper()
if not (level in ['O0', 'O1', 'O2']):
raise ValueError(
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)
# check amp_dtype: float16 or bfloat16
dtype = dtype.lower()
if not (dtype in ['float16', 'bfloat16']):
raise ValueError("dtype should be 'float16' or 'bfloat16'.")
# check tracer
tracer = _dygraph_tracer()
if not tracer:
raise ValueError(
"current_tracer is None, maybe it is not in imperative mode."
)
# check device_type:
# NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16, mlu for float16, npu for float16.
# Maybe we will support cpu for bfloat16.
if enable and not (
tracer._expected_place.is_gpu_place()
or tracer._expected_place.is_xpu_place()
or tracer._expected_place.is_mlu_place()
or tracer._expected_place.is_npu_place()
or tracer._expected_place.is_custom_place()
):
warnings.warn(
'amp_guard can only be enabled on CUDAPlace, XPUPlace, MLUPlace, NPUPlace, and CustomPlace, current place is %s, so it makes no effect.'
% tracer._expected_place
)
enable = False
# For npu:
if tracer._expected_place.is_npu_place() and (dtype == 'bfloat16'):
warnings.warn('NPUPlace only support float16 amp.')
enable = False
# For xpu:
if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'):
warnings.warn('XPUPlace only support float16 amp.')
enable = False
# For mlu:
if tracer._expected_place.is_mlu_place() and (dtype == 'bfloat16'):
warnings.warn('MLUPlace only support float16 amp.')
enable = False
# For custom device:
if tracer._expected_place.is_custom_place() and (dtype == 'bfloat16'):
warnings.warn('CustomPlace only support float16 amp.')
enable = False
# For gpu float16: Compute Capability should >= 7.
# For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
if tracer._expected_place.is_gpu_place():
if (dtype == 'float16') and not _is_gpu_float16_supported():
prop = paddle.device.cuda.get_device_capability()
warnings.warn(
"For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d."
% (paddle.device.cuda.get_device_name(), prop[0], prop[1])
)
elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported():
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
warnings.warn(
"For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: %s, with Compute Capability: %d.%d, current CUDA Version is: %s."
% (
paddle.device.cuda.get_device_name(),
prop[0],
prop[1],
cuda_version,
)
)
amp_dtype = dtype
if level == 'O1':
amp_level = AMP_LEVEL.O1
if dtype == 'float16':
_white_list = WHITE_LIST
_black_list = BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
elif level == 'O2':
amp_level = AMP_LEVEL.O2
if dtype == 'float16':
_white_list = PURE_FP16_WHITE_LIST
_black_list = PURE_FP16_BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
elif level == 'O0':
amp_level = AMP_LEVEL.O0
if dtype == 'float16':
_white_list = WHITE_LIST
_black_list = BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
if custom_white_list or custom_black_list:
_white_list, _black_list = _update_list(
custom_white_list, custom_black_list, level, dtype
)
if not enable:
amp_level = AMP_LEVEL.O0
amp_dtype = "float32"
if tracer:
# enable auto_cast
original_amp_level = tracer._amp_level
tracer._amp_level = amp_level
# set amp op list
original_white_list, original_black_list = tracer._get_amp_op_list()
tracer._set_amp_op_list(_white_list, _black_list)
# TODO(zhiqiu) set amp related flags automatically in this guard
# Currently, if FLAGS_cudnn_batchnorm_spatial_persistent is set True in amp_guard,
# batch_norm can run in fast mode, but batch_norm_grad can not if backward if not executed insise amp_guard.
# So, users need to set related flags manually.
# original_flags = get_flags(AMP_RELATED_FLAGS)
# set_flags(AMP_RELATED_FLAGS_SETTING)
# set amp dtype
original_amp_dtype = tracer._amp_dtype
tracer._amp_dtype = amp_dtype
# restore status
try:
yield
finally:
if tracer:
_g_amp_state_ = original_state
tracer._amp_level = original_amp_level
tracer._set_amp_op_list(original_white_list, original_black_list)
# set_flags(original_flags)
tracer._amp_dtype = original_amp_dtype
class StateDictHook:
def __init__(self, save_dtype):
self._save_dtype = save_dtype
def __call__(self, state_dict):
for key in state_dict:
param = state_dict[key]
if paddle.is_floating_point(param):
param_applied = paddle.cast(param, self._save_dtype)
param_applied.name = param.name
state_dict[key] = param_applied
def _set_multi_precision(optimizer, multi_precision):
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)
optimizer = (
optimizer._inner_optimizer
if isinstance(optimizer, DygraphShardingOptimizer)
else optimizer
)
if hasattr(optimizer, "_multi_precision"):
optimizer._multi_precision = multi_precision
@dygraph_only
def amp_decorate(
models,
optimizers=None,
level='O1',
dtype='float16',
master_weight=None,
save_dtype=None,
):
"""
Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
When level is O2(pure fp16), the decorate will cast all parameters of models to FP16, except BatchNorm and LayerNorm.
Commonly, it is used together with `amp_guard` to achieve Pure fp16 in imperative mode.
__all__ = []
Args:
models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None.
optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None.
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing;
O2 represent Pure fp16/bf16, the decorator will cast all parameters of models to FP16/BF16, except BatchNorm and LayerNorm. Default is O1(amp)
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.
save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None.
The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None.
Examples:
.. code-block:: python
# required: gpu
# Demo1: single model and optimizer:
import paddle
model = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
optimizer = paddle.optimizer.SGD(parameters=model.parameters())
model, optimizer = paddle.amp.amp_decorate(models=model, optimizers=optimizer, level='O2')
data = paddle.rand([10, 3, 32, 32])
with paddle.amp.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
output = model(data)
print(output.dtype) # FP16
# required: gpu
# Demo2: multi models and optimizers:
model2 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
optimizer2 = paddle.optimizer.Adam(parameters=model2.parameters())
models, optimizers = paddle.amp.amp_decorate(models=[model, model2], optimizers=[optimizer, optimizer2], level='O2')
data = paddle.rand([10, 3, 32, 32])
with paddle.amp.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
output = models[0](data)
output2 = models[1](data)
print(output.dtype) # FP16
print(output2.dtype) # FP16
# required: gpu
# Demo3: optimizers is None:
model3 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
optimizer3 = paddle.optimizer.Adam(parameters=model2.parameters())
model = paddle.amp.amp_decorate(models=model3, level='O2')
data = paddle.rand([10, 3, 32, 32])
with paddle.amp.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
output = model(data)
print(output.dtype) # FP16
"""
if not (level in ['O1', 'O2']):
raise ValueError(
"level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode."
)
if level == 'O1':
if optimizers is None:
return models
else:
return models, optimizers
models_is_list = False
if isinstance(models, paddle.nn.Layer):
models_is_list = False
models = [models]
check_models(models)
elif isinstance(models, list):
check_models(models)
models_is_list = True
else:
raise TypeError(
"models must be either a single model or a list of models."
)
if dtype == 'float16':
models = pure_fp16_initialize(models=models)
elif dtype == 'bfloat16':
models = pure_bf16_initialize(models=models)
else:
raise TypeError("dtype only support float16 or bfloat16.")
if optimizers is not None:
# check optimizers
optimizers_is_list = False
if _is_valid_optimizer(optimizers):
optimizers_is_list = False
optimizers = [optimizers]
check_optimizers(optimizers)
elif isinstance(optimizers, list):
check_optimizers(optimizers)
optimizers_is_list = True
else:
raise TypeError(
"optimizers must be either a single optimizer or a list of optimizers."
)
# support master_weight
use_multi_precision = not (master_weight is False)
for opt in optimizers:
_set_multi_precision(opt, use_multi_precision)
if save_dtype is not None:
if not (save_dtype in ['float16', 'bfloat16', 'float32', 'float64']):
raise ValueError(
"save_dtype can only be float16 float32 or float64, but your input save_dtype is %s."
% save_dtype
)
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
layer.register_state_dict_hook(StateDictHook(save_dtype))
if models_is_list:
if optimizers is not None:
if optimizers_is_list:
return models, optimizers
else:
return models, optimizers[0]
else:
return models
else:
if optimizers is not None:
if optimizers_is_list:
return models[0], optimizers
else:
return models[0], optimizers[0]
else:
return models[0]
def auto_cast(
......
......@@ -12,17 +12,572 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from collections import defaultdict
from enum import Enum
from paddle.fluid.dygraph.amp import AmpScaler, OptimizerState
import numpy as np
__all__ = []
from paddle import _legacy_C_ops
from paddle.fluid import core, in_dygraph_mode
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import _dygraph_tracer, dygraph_only
class OptimizerState(Enum):
INIT = 0
UNSCALED = 1
STEPPED = 2
def _refresh_optimizer_state():
return {"state": OptimizerState.INIT}
class AmpScaler:
"""
AmpScaler is used for Auto-Mixed-Precision training/inferring in imperative
mode. It controls the scaling of loss, helps avoiding numerical overflow.
The object of this class has seventeen methods `scale()`, `unscale_()`, `minimize()` and `get`/`set` api of parameters.
`scale()` is used to multiply the loss by a scale ratio.
`unscale_()` is used to unscale the gradients of parameters, multiplies the gradients of parameters by 1/(scale ratio)
`minimize()` is similar as `optimizer.minimize()`, performs parameters updating, and it will update the loss_scaling.
Commonly, it is used together with `amp_guard` to achieve Auto-Mixed-Precision in
imperative mode.
Args:
enable(bool, optional): Enable loss scaling or not. Default is True.
init_loss_scaling (float, optional): The initial loss scaling factor. Default is 2**15.
incr_ratio(float, optional): The multiplier to use when increasing the loss
scaling. Default is 2.0.
decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing
the loss scaling. Default is 0.5.
incr_every_n_steps(int, optional): Increases loss scaling every n consecutive
steps with finite gradients. Default is 1000.
decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n
accumulated steps with nan or inf gradients. Default is 2.
use_dynamic_loss_scaling(bool, optional): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True.
Returns:
An AmpScaler object.
Examples:
.. code-block:: python
import numpy as np
import paddle
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
model = paddle.nn.Conv2D(3, 2, 3)
optimizer = paddle.optimizer.SGDOptimizer(
learning_rate=0.01, parameter_list=model.parameters())
scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
data = paddle.to_tensor(data)
with paddle.amp.amp_guard():
conv = model(data)
loss = paddle.mean(conv)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
"""
@dygraph_only
def __init__(
self,
enable=True,
init_loss_scaling=2.0**15,
incr_ratio=2.0,
decr_ratio=0.5,
incr_every_n_steps=1000,
decr_every_n_nan_or_inf=1,
use_dynamic_loss_scaling=True,
):
tracer = _dygraph_tracer()
if not tracer:
raise ValueError(
"current_tracer is None, maybe it is not in imperative mode."
)
if enable and not (
tracer._expected_place.is_gpu_place()
or tracer._expected_place.is_xpu_place()
or tracer._expected_place.is_mlu_place()
or tracer._expected_place.is_npu_place()
or tracer._expected_place.is_custom_place()
):
warnings.warn(
'AmpScaler can only be enabled on CUDAPlace, XPUPlace, MLUPlace, NPUPlace and CustomPlace, current place is %s, so it makes no effect.'
% tracer._expected_place
)
enable = False
self._enable = enable
if self._enable:
assert incr_ratio > 1.0, "The incr_ratio must be > 1.0."
assert decr_ratio < 1.0, "The decr_ratio must be < 1.0."
self._init_loss_scaling = init_loss_scaling
self._incr_ratio = incr_ratio
self._decr_ratio = decr_ratio
self._incr_every_n_steps = incr_every_n_steps
self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
self._incr_count = 0
self._decr_count = 0
self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
self._found_inf = to_variable(np.array([0]).astype(np.bool_))
self._temp_found_inf_fp16 = to_variable(
np.array([0]).astype(np.bool_)
)
self._temp_found_inf_bf16 = to_variable(
np.array([0]).astype(np.bool_)
)
self._temp_found_inf_fp32 = to_variable(
np.array([0]).astype(np.bool_)
)
self._scale = to_variable(
np.array([self._init_loss_scaling]).astype(np.float32)
)
self._cache_founf_inf = None
self._optimizer_states = defaultdict(_refresh_optimizer_state)
def scale(self, var):
"""
Multiplies a Tensor by the scale factor and returns scaled outputs.
If this instance of :class:`AmpScaler` is not enabled, output are returned unmodified.
Args:
var (Tensor): The Tensor to scale.
Returns:
The scaled Tensor or original Tensor.
Examples:
.. code-block:: python
import numpy as np
import paddle
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
model = paddle.nn.Conv2D(3, 2, 3)
optimizer = paddle.optimizer.SGDOptimizer(
learning_rate=0.01, parameter_list=model.parameters())
scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
data = paddle.to_tensor(data)
with paddle.amp.amp_guard():
conv = model(data)
loss = paddle.mean(conv)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
"""
check_type(var, "var", core.VarBase, 'AmpScaler.scale()')
if not self._enable:
return var
return var * self._scale
def minimize(self, optimizer, *args, **kwargs):
"""
This function is similar as `Optimizer.minimize()`, which performs parameters updating.
If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters.
Finally, the loss scaling ratio is updated.
Args:
optimizer(Optimizer): The optimizer used to update parameters.
args: Arguments, which will be forward to `optimizer.minimize()`.
kwargs: Keyword arguments, which will be forward to `Optimizer.minimize()`.
Examples:
.. code-block:: python
import numpy as np
import paddle
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
model = paddle.nn.Conv2D(3, 2, 3)
optimizer = paddle.optimizer.SGDOptimizer(
learning_rate=0.01, parameter_list=model.parameters())
scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
data = paddle.to_tensor(data)
with paddle.amp.amp_guard():
conv = model(data)
loss = paddle.mean(conv)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
"""
if not self._enable:
return optimizer.minimize(*args, **kwargs)
optimizer_state = self._optimizer_states[id(optimizer)]
# unscale the grad
if optimizer_state["state"] is OptimizerState.INIT:
self._unscale(optimizer)
optimize_ops, params_grads = (None, None)
if self._found_inf:
self._cache_founf_inf = True
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False
if self._use_dynamic_loss_scaling:
# uopdate the scale
self._update()
self._optimizer_states = defaultdict(_refresh_optimizer_state)
return optimize_ops, params_grads
def _unscale(self, optimizer):
"""
Unscale the gradients of parameters, multiplies the gradients of parameters by 1/(loss scaling ratio).
If this instance of :class:`GradScaler` is not enabled, output are returned unmodified.
Args:
optimizer(Optimizer): The optimizer used to update parameters.
Returns:
The unscaled parameters or original parameters.
"""
if not self._enable:
return
optimizer_state = self._optimizer_states[id(optimizer)]
if optimizer_state["state"] is OptimizerState.UNSCALED:
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
elif optimizer_state["state"] is OptimizerState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
if getattr(optimizer, '_param_groups', None) and isinstance(
optimizer._param_groups[0], dict
):
param_grads = []
param_grads_fp16 = []
param_grads_bf16 = []
param_grads_fp32 = []
for group in optimizer._param_groups:
for param in group['params']:
if param._grad_ivar() is not None:
param_grads.append(param._grad_ivar())
if (
param._grad_ivar().dtype
== core.VarDesc.VarType.FP16
):
param_grads_fp16.append(param._grad_ivar())
elif (
param._grad_ivar().dtype
== core.VarDesc.VarType.BF16
):
param_grads_bf16.append(param._grad_ivar())
else:
param_grads_fp32.append(param._grad_ivar())
else:
if in_dygraph_mode():
# It is very time-consuming to call c++ functions in a loop on the python side.
# We put this part of the code on the c++ side to improve the speed in eager mode.
(
param_grads_fp16,
param_grads_bf16,
param_grads_fp32,
) = core.eager.get_grads_lists(optimizer._parameter_list)
else:
# Keep the original code to support legacy mode.
# Delete the else branch when the legacy mode exits.
param_grads = [
param._grad_ivar()
for param in optimizer._parameter_list
if param._grad_ivar() is not None
]
param_grads_fp16 = [
param
for param in param_grads
if param.dtype == core.VarDesc.VarType.FP16
]
param_grads_bf16 = [
param
for param in param_grads
if param.dtype == core.VarDesc.VarType.BF16
]
param_grads_fp32 = [
param
for param in param_grads
if param.dtype == core.VarDesc.VarType.FP32
]
if core.is_compiled_with_npu():
float_status = _legacy_C_ops.alloc_float_status()
_legacy_C_ops.clear_float_status(float_status, float_status)
if len(param_grads_fp16):
_legacy_C_ops.check_finite_and_unscale(
param_grads_fp16,
self._scale,
float_status,
param_grads_fp16,
self._temp_found_inf_fp16,
)
if len(param_grads_bf16):
_legacy_C_ops.check_finite_and_unscale(
param_grads_bf16,
self._scale,
float_status,
param_grads_bf16,
self._temp_found_inf_bf16,
)
if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale(
param_grads_fp32,
self._scale,
float_status,
param_grads_fp32,
self._temp_found_inf_fp32,
)
else:
if len(param_grads_fp16):
_legacy_C_ops.check_finite_and_unscale(
param_grads_fp16,
self._scale,
param_grads_fp16,
self._temp_found_inf_fp16,
)
if len(param_grads_bf16):
_legacy_C_ops.check_finite_and_unscale(
param_grads_bf16,
self._scale,
param_grads_bf16,
self._temp_found_inf_bf16,
)
if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale(
param_grads_fp32,
self._scale,
param_grads_fp32,
self._temp_found_inf_fp32,
)
self._found_inf = (
self._temp_found_inf_fp16
or self._temp_found_inf_bf16
or self._temp_found_inf_fp32
)
optimizer_state["state"] = OptimizerState.UNSCALED
def _update(self):
"""
Updates the loss_scaling.
"""
if not self._enable:
return
if self._cache_founf_inf:
self._incr_count = 0
self._decr_count = self._decr_count + 1
if self._decr_count == self._decr_every_n_nan_or_inf:
print(
'Found inf or nan, current scale is: {}, decrease to: {}*{}'.format(
float(self._scale),
float(self._scale),
float(self._decr_ratio),
)
)
self._scale = self._scale * self._decr_ratio
self._decr_count = 0
else:
self._decr_count = 0
self._incr_count = self._incr_count + 1
if self._incr_count == self._incr_every_n_steps:
self._scale = self._scale * self._incr_ratio
self._incr_count = 0
return
def is_enable(self):
"""
Enable loss scaling or not.
Returns:
bool: enable loss scaling return True else return False.
"""
return self._enable
def is_use_dynamic_loss_scaling(self):
"""
Whether to use dynamic loss scaling.
Returns:
bool: if fixed loss_scaling is used return False, if the loss scaling is updated dynamicly return true.
"""
return self._use_dynamic_loss_scaling
def get_init_loss_scaling(self):
"""
Return the initial loss scaling factor.
Reurns:
float: the initial loss scaling factor.
"""
return self._init_loss_scaling
def set_init_loss_scaling(self, new_init_loss_scaling):
"""
Set the initial loss scaling factor by `new_init_loss_scaling`.
Args:
new_init_loss_scaling(int): The new_init_loss_scaling used to update initial loss scaling factor.s
"""
self._init_loss_scaling = new_init_loss_scaling
self._scale = to_variable(
np.array([self._init_loss_scaling]).astype(np.float32)
)
def get_incr_ratio(self):
"""
Return the multiplier to use when increasing the loss scaling.
Reurns:
float: the multiplier to use when increasing the loss scaling.
"""
return self._incr_ratio
def set_incr_ratio(self, new_incr_ratio):
"""
Set the multiplier to use when increasing the loss scaling by `new_incr_ratio`, `new_incr_ratio` should > 1.0.
Args:
new_incr_ratio(float): The new_incr_ratio used to update the multiplier to use when increasing the loss scaling.
"""
assert new_incr_ratio > 1.0, "The new_incr_ratio must be > 1.0."
self._incr_ratio = new_incr_ratio
def get_decr_ratio(self):
"""
Get the less-than-one-multiplier to use when decreasing the loss scaling.
Reurns:
float: the less-than-one-multiplier to use when decreasing the loss scaling.
"""
return self._decr_ratio
def set_decr_ratio(self, new_decr_ratio):
"""
Set the less-than-one-multiplier to use when decreasing the loss scaling by `new_incr_ratio`, `new_decr_ratio` should < 1.0.
Args:
new_decr_ratio(float): The new_decr_ratio used to update the less-than-one-multiplier to use when decreasing the loss scaling.
"""
assert new_decr_ratio < 1.0, "The new_decr_ratio must be < 1.0."
self._decr_ratio = new_decr_ratio
def get_incr_every_n_steps(self):
"""
Return the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
Reurns:
int: the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
"""
return self._incr_every_n_steps
def set_incr_every_n_steps(self, new_incr_every_n_steps):
"""
Set the num `n` by `new_incr_every_n_steps`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
Args:
new_incr_every_n_steps(int): The new_incr_every_n_steps used to update the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
"""
self._incr_every_n_steps = new_incr_every_n_steps
def get_decr_every_n_nan_or_inf(self):
"""
Return the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
Reurns:
int: the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
"""
return self._decr_every_n_nan_or_inf
def set_decr_every_n_nan_or_inf(self, new_decr_every_n_nan_or_inf):
"""
Set the num `n` by `new_decr_every_n_nan_or_inf`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
Args:
new_decr_every_n_nan_or_inf(int): The new_decr_every_n_nan_or_inf used to update the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
"""
self._decr_every_n_nan_or_inf = new_decr_every_n_nan_or_inf
def state_dict(self):
"""
Returns the state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict.
Reurns:
A dict of scaler includes:
scale (tensor): The loss scaling factor.
incr_ratio(float): The multiplier to use when increasing the loss scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling.
incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients.
decr_every_n_nan_or_inf(int): Decreases loss scaling every n accumulated steps with nan or inf gradients.
incr_count(int): The number of recent consecutive unskipped steps.
decr_count(int): The number of recent consecutive skipped steps.
use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True.
"""
return (
{
"scale": self._scale.numpy(),
"incr_ratio": self._incr_ratio,
"decr_ratio": self._decr_ratio,
"incr_every_n_steps": self._incr_every_n_steps,
"decr_every_n_nan_or_inf": self._decr_every_n_nan_or_inf,
"incr_count": self._incr_count,
"decr_count": self._decr_count,
"use_dynamic_loss_scaling": self._use_dynamic_loss_scaling,
}
if self._enable
else {}
)
def load_state_dict(self, state_dict):
"""
Loads the scaler state.
Args:
state_dict(dict): scaler state. Should be an object returned from a call to `AmpScaler.state_dict()`.
"""
if not self._enable:
return
if len(state_dict) == 0:
raise RuntimeError(
"The input state dict is empty, possibly because it was saved "
"from a disabled instance of GradScaler."
)
self._init_loss_scaling = state_dict["scale"][0]
self._scale = to_variable(
np.array([self._init_loss_scaling]).astype(np.float32)
)
self._incr_ratio = state_dict["incr_ratio"]
self._decr_ratio = state_dict["decr_ratio"]
self._incr_every_n_steps = state_dict["incr_every_n_steps"]
self._decr_every_n_nan_or_inf = state_dict["decr_every_n_nan_or_inf"]
self._incr_count = state_dict["incr_count"]
self._decr_count = state_dict["decr_count"]
self._use_dynamic_loss_scaling = state_dict["use_dynamic_loss_scaling"]
class GradScaler(AmpScaler):
"""
GradScaler is used for Auto-Mixed-Precision training in dynamic graph mode.
......
......@@ -28,9 +28,6 @@ from .parallel import *
from . import learning_rate_scheduler
from .learning_rate_scheduler import *
from . import amp
from .amp import *
from .math_op_patch import monkey_patch_math_varbase
__all__ = []
......@@ -38,4 +35,3 @@ __all__ += layers.__all__
__all__ += base.__all__
__all__ += parallel.__all__
__all__ += learning_rate_scheduler.__all__
__all__ += amp.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import auto_cast
from .auto_cast import *
from . import loss_scaler
from .loss_scaler import *
__all__ = []
__all__ += auto_cast.__all__
__all__ += loss_scaler.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.wrapped_decorator import (
signature_safe_contextmanager,
wrap_decorator,
)
from paddle.fluid import core
import contextlib
from paddle.fluid.framework import (
Variable,
OpProtoHolder,
Parameter,
_dygraph_tracer,
dygraph_only,
set_flags,
get_flags,
)
import warnings
import copy
import functools
import paddle
import operator
import types
AMP_LEVEL = core.AmpLevel
__all__ = ['amp_guard', 'amp_decorate']
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
WHITE_LIST = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
'fake_quantize_dequantize_abs_max',
'fake_quantize_dequantize_moving_average_abs_max',
}
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
BLACK_LIST = {
'exp',
'square',
'log',
'mean',
'sum',
'cos_sim',
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy',
'cross_entropy2',
# default fp32 can avoid return inf when the sum value large than 65504
'reduce_sum',
# FP16 performance of grad op is worse than that of FP32. Use FP32 by default.
'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
}
AMP_RELATED_FLAGS = [
'FLAGS_cudnn_exhaustive_search',
'FLAGS_conv_workspace_size_limit',
'FLAGS_cudnn_batchnorm_spatial_persistent',
]
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_exhaustive_search': 1,
'FLAGS_conv_workspace_size_limit': 1000,
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
}
PURE_FP16_WHITE_LIST = set()
PURE_FP16_BLACK_LIST = {
'lookup_table',
'lookup_table_v2',
'scatter',
'scatter_grad',
# FP16 performance of grad op is worse than that of FP32. Use FP32 by default.
'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
}
BF16_WHITE_LIST = {'conv2d', 'matmul_v2'}
BF16_BLACK_LIST = set()
PURE_BF16_WHITE_LIST = set()
PURE_BF16_BLACK_LIST = set()
_g_amp_state_ = None
def low_precision_op_list():
op_list = paddle.fluid.core.get_low_precision_op_list()
op_count = 0
print('<---------------- low precision op list ------------------->')
print('<---- op name ------|------- op count---------------------->')
for x in op_list:
print(' %-18s| %4d' % (x, op_list[x]))
op_count += 1
print(
'<------------- low precision op num:{:5d} ----------------->'.format(
op_count
)
)
def amp_state():
global _g_amp_state_
return _g_amp_state_
# NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list
# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode.
def _update_list(
custom_white_list, custom_black_list, level='O1', dtype='float16'
):
"""
Update black and white list according to users' custom list.
"""
if dtype == 'float16':
if level == 'O1':
_white_list = copy.copy(WHITE_LIST)
_black_list = copy.copy(BLACK_LIST)
else:
_white_list = copy.copy(PURE_FP16_WHITE_LIST)
_black_list = copy.copy(PURE_FP16_BLACK_LIST)
else:
if level == 'O1':
_white_list = copy.copy(BF16_WHITE_LIST)
_black_list = copy.copy(BF16_BLACK_LIST)
else:
_white_list = copy.copy(PURE_BF16_WHITE_LIST)
_black_list = copy.copy(PURE_BF16_BLACK_LIST)
if custom_white_list and custom_black_list:
for op_name in custom_white_list:
if op_name in custom_black_list:
raise ValueError(
"Custom white list overlap " "custom black list"
)
if custom_white_list:
for op_name in custom_white_list:
if op_name in _black_list:
_black_list.remove(op_name)
_white_list.add(op_name)
if custom_black_list:
for op_name in custom_black_list:
if op_name in _white_list:
_white_list.remove(op_name)
_black_list.add(op_name)
return _white_list, _black_list
def _in_amp_guard():
"""
Judge whether current code block is in `amp_guard` context.
"""
tracer = _dygraph_tracer()
if tracer:
if tracer._amp_level == core.AmpLevel.O1:
return True
else:
return False
else:
return False
def _in_pure_fp16_guard():
tracer = _dygraph_tracer()
return tracer and tracer._amp_level == core.AmpLevel.O2
def _is_gpu_float16_supported():
"""
Judge whether current gpu support float16 amp.
"""
prop = paddle.device.cuda.get_device_capability()
return prop[0] >= 7
def _is_gpu_bfloat16_supported():
"""
Judge whether current gpu support bfloat16 amp.
"""
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
if cuda_version is not None and cuda_version != 'False':
cuda_version_check = int(cuda_version.split('.')[0]) >= 11
else:
cuda_version_check = False
return prop[0] >= 8 and cuda_version_check
@dygraph_only
def pure_fp16_initialize(models):
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
layer._casted_by_pure_fp16 = True
if (layer._dtype == 'float16') or isinstance(
layer,
(
paddle.nn.BatchNorm,
paddle.nn.BatchNorm1D,
paddle.nn.BatchNorm2D,
paddle.nn.BatchNorm3D,
paddle.nn.LayerNorm,
paddle.nn.SyncBatchNorm,
),
):
continue
if isinstance(
layer,
(
paddle.incubate.nn.FusedFeedForward,
paddle.incubate.nn.FusedMultiHeadAttention,
),
):
layer._amp_decorate(dtype='float16')
continue
layer._to_impl(
dtype='float16', include_sublayers=False, floating_only=True
)
return models
@dygraph_only
def pure_bf16_initialize(models):
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
layer._to_impl(
dtype='bfloat16', include_sublayers=False, floating_only=True
)
return models
def check_models(models):
for model in models:
if not isinstance(model, paddle.nn.Layer):
raise RuntimeError(
"Current train mode is pure fp16, models should be paddle.nn.Layer, but receive {}.".format(
type(model)
)
)
if isinstance(model, paddle.DataParallel):
raise RuntimeError(
"For distributed AMP training, you should first use paddle.amp.decorate() to decotate origin model, and then call paddle.DataParallel get distributed model."
)
def _is_valid_optimizer(optimizer):
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)
return isinstance(
optimizer,
(
paddle.optimizer.Optimizer,
paddle.fluid.optimizer.Optimizer,
DygraphShardingOptimizer,
),
)
def check_optimizers(optimizers):
for optimizer in optimizers:
if not _is_valid_optimizer(optimizer):
raise RuntimeError(
"Current train mode is pure fp16, optimizers should be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer or DygraphShardingOptimizer, but receive {}.".format(
type(optimizer)
)
)
@signature_safe_contextmanager
@dygraph_only
def amp_guard(
enable=True,
custom_white_list=None,
custom_black_list=None,
level='O1',
dtype='float16',
):
"""
:api_attr: imperative
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
If enabled, the input data type (float32 or float16) of each operator is decided
by autocast algorithm for better performance.
Commonly, it is used together with `GradScaler` to achieve Auto-Mixed-Precision in
imperative mode. It is used together with `decorator` to achieve Pure fp16 in imperative mode.
Args:
enable(bool, optional): Enable auto-mixed-precision or not. Default is True.
custom_white_list(set|list|tuple, optional): The custom white_list. It's the set of ops that support
fp16 calculation and are considered numerically-safe and performance-critical. These ops
will be converted to fp16.
custom_black_list(set|list|tuple, optional): The custom black_list. The set of ops that support fp16
calculation and are considered numerically-dangerous and whose effects may also be
observed in downstream ops. These ops will not be converted to fp16.
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list;
O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp)
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
Examples:
.. code-block:: python
import numpy as np
import paddle
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
with paddle.fluid.dygraph.guard():
conv2d = paddle.fluid.dygraph.Conv2D(3, 2, 3)
data = paddle.fluid.dygraph.to_variable(data)
with paddle.fluid.dygraph.amp_guard():
conv = conv2d(data)
print(conv.dtype) # FP16
with paddle.fluid.dygraph.amp_guard(enable=False):
conv = conv2d(data)
print(conv.dtype) # FP32
"""
amp_state = locals()
global _g_amp_state_
original_state = _g_amp_state_
_g_amp_state_ = amp_state
# check amp_level: O0-O2
level = level.upper()
if not (level in ['O0', 'O1', 'O2']):
raise ValueError(
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)
# check amp_dtype: float16 or bfloat16
dtype = dtype.lower()
if not (dtype in ['float16', 'bfloat16']):
raise ValueError("dtype should be 'float16' or 'bfloat16'.")
# check tracer
tracer = _dygraph_tracer()
if not tracer:
raise ValueError(
"current_tracer is None, maybe it is not in imperative mode."
)
# check device_type:
# NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16, mlu for float16, npu for float16.
# Maybe we will support cpu for bfloat16.
if enable and not (
tracer._expected_place.is_gpu_place()
or tracer._expected_place.is_xpu_place()
or tracer._expected_place.is_mlu_place()
or tracer._expected_place.is_npu_place()
or tracer._expected_place.is_custom_place()
):
warnings.warn(
'amp_guard can only be enabled on CUDAPlace, XPUPlace, MLUPlace, NPUPlace, and CustomPlace, current place is %s, so it makes no effect.'
% tracer._expected_place
)
enable = False
# For npu:
if tracer._expected_place.is_npu_place() and (dtype == 'bfloat16'):
warnings.warn('NPUPlace only support float16 amp.')
enable = False
# For xpu:
if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'):
warnings.warn('XPUPlace only support float16 amp.')
enable = False
# For mlu:
if tracer._expected_place.is_mlu_place() and (dtype == 'bfloat16'):
warnings.warn('MLUPlace only support float16 amp.')
enable = False
# For custom device:
if tracer._expected_place.is_custom_place() and (dtype == 'bfloat16'):
warnings.warn('CustomPlace only support float16 amp.')
enable = False
# For gpu float16: Compute Capability should >= 7.
# For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
if tracer._expected_place.is_gpu_place():
if (dtype == 'float16') and not _is_gpu_float16_supported():
prop = paddle.device.cuda.get_device_capability()
warnings.warn(
"For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d."
% (paddle.device.cuda.get_device_name(), prop[0], prop[1])
)
elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported():
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
warnings.warn(
"For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: %s, with Compute Capability: %d.%d, current CUDA Version is: %s."
% (
paddle.device.cuda.get_device_name(),
prop[0],
prop[1],
cuda_version,
)
)
amp_dtype = dtype
if level == 'O1':
amp_level = AMP_LEVEL.O1
if dtype == 'float16':
_white_list = WHITE_LIST
_black_list = BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
elif level == 'O2':
amp_level = AMP_LEVEL.O2
if dtype == 'float16':
_white_list = PURE_FP16_WHITE_LIST
_black_list = PURE_FP16_BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
elif level == 'O0':
amp_level = AMP_LEVEL.O0
if dtype == 'float16':
_white_list = WHITE_LIST
_black_list = BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
if custom_white_list or custom_black_list:
_white_list, _black_list = _update_list(
custom_white_list, custom_black_list, level, dtype
)
if not enable:
amp_level = AMP_LEVEL.O0
amp_dtype = "float32"
if tracer:
# enable auto_cast
original_amp_level = tracer._amp_level
tracer._amp_level = amp_level
# set amp op list
original_white_list, original_black_list = tracer._get_amp_op_list()
tracer._set_amp_op_list(_white_list, _black_list)
# TODO(zhiqiu) set amp related flags automatically in this guard
# Currently, if FLAGS_cudnn_batchnorm_spatial_persistent is set True in amp_guard,
# batch_norm can run in fast mode, but batch_norm_grad can not if backward if not executed insise amp_guard.
# So, users need to set related flags manually.
# original_flags = get_flags(AMP_RELATED_FLAGS)
# set_flags(AMP_RELATED_FLAGS_SETTING)
# set amp dtype
original_amp_dtype = tracer._amp_dtype
tracer._amp_dtype = amp_dtype
# restore status
try:
yield
finally:
if tracer:
_g_amp_state_ = original_state
tracer._amp_level = original_amp_level
tracer._set_amp_op_list(original_white_list, original_black_list)
# set_flags(original_flags)
tracer._amp_dtype = original_amp_dtype
class StateDictHook:
def __init__(self, save_dtype):
self._save_dtype = save_dtype
def __call__(self, state_dict):
for key in state_dict:
param = state_dict[key]
with paddle.fluid.dygraph.guard():
if paddle.is_floating_point(param):
param_applied = paddle.cast(param, self._save_dtype)
param_applied.name = param.name
state_dict[key] = param_applied
def _set_multi_precision(optimizer, multi_precision):
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)
optimizer = (
optimizer._inner_optimizer
if isinstance(optimizer, DygraphShardingOptimizer)
else optimizer
)
if hasattr(optimizer, "_multi_precision"):
optimizer._multi_precision = multi_precision
@dygraph_only
def amp_decorate(
models,
optimizers=None,
level='O1',
dtype='float16',
master_weight=None,
save_dtype=None,
):
"""
Decorate models and optimizers for auto-mixed-precision. When level is O1(amp), the decorate will do nothing.
When level is O2(pure fp16), the decorate will cast all parameters of models to FP16, except BatchNorm and LayerNorm.
Commonly, it is used together with `amp_guard` to achieve Pure fp16 in imperative mode.
Args:
models(Layer|list of Layer, optional): The defined models by user, models must be either a single model or a list of models. Default is None.
optimizers(Optimizer|list of Optimizer, optional): The defined optimizers by user, optimizers must be either a single optimizer or a list of optimizers. Default is None.
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the decorator will do nothing;
O2 represent Pure fp16/bf16, the decorator will cast all parameters of models to FP16/BF16, except BatchNorm and LayerNorm. Default is O1(amp)
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
master_weight(bool, optinal): For level='O2', whether to use multi-precision during weight updating. If master_weight is None, in O2 level optimizer will use multi-precision. Default is None.
save_dtype(float, optional): The save model parameter dtype when use `paddle.save` or `paddle.jit.save`,it should be float16, bfloat16, float32, float64 or None.
The save_dtype will not change model parameters dtype, it just change the state_dict dtype. When save_dtype is None, the save dtype is same as model dtype. Default is None.
Examples:
.. code-block:: python
# required: gpu
# Demo1: single model and optimizer:
import paddle
model = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
optimizer = paddle.optimizer.SGD(parameters=model.parameters())
model, optimizer = paddle.fluid.dygraph.amp_decorate(models=model, optimizers=optimizer, level='O2')
data = paddle.rand([10, 3, 32, 32])
with paddle.fluid.dygraph.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
output = model(data)
print(output.dtype) # FP16
# required: gpu
# Demo2: multi models and optimizers:
model2 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
optimizer2 = paddle.optimizer.Adam(parameters=model2.parameters())
models, optimizers = paddle.fluid.dygraph.amp_decorate(models=[model, model2], optimizers=[optimizer, optimizer2], level='O2')
data = paddle.rand([10, 3, 32, 32])
with paddle.fluid.dygraph.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
output = models[0](data)
output2 = models[1](data)
print(output.dtype) # FP16
print(output2.dtype) # FP16
# required: gpu
# Demo3: optimizers is None:
model3 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
optimizer3 = paddle.optimizer.Adam(parameters=model2.parameters())
model = paddle.fluid.dygraph.amp_decorate(models=model3, level='O2')
data = paddle.rand([10, 3, 32, 32])
with paddle.fluid.dygraph.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'):
output = model(data)
print(output.dtype) # FP16
"""
if not (level in ['O1', 'O2']):
raise ValueError(
"level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode."
)
if level == 'O1':
if optimizers is None:
return models
else:
return models, optimizers
models_is_list = False
if isinstance(models, paddle.nn.Layer):
models_is_list = False
models = [models]
check_models(models)
elif isinstance(models, list):
check_models(models)
models_is_list = True
else:
raise TypeError(
"models must be either a single model or a list of models."
)
if dtype == 'float16':
models = pure_fp16_initialize(models=models)
elif dtype == 'bfloat16':
models = pure_bf16_initialize(models=models)
else:
raise TypeError("dtype only support float16 or bfloat16.")
if optimizers is not None:
# check optimizers
optimizers_is_list = False
if _is_valid_optimizer(optimizers):
optimizers_is_list = False
optimizers = [optimizers]
check_optimizers(optimizers)
elif isinstance(optimizers, list):
check_optimizers(optimizers)
optimizers_is_list = True
else:
raise TypeError(
"optimizers must be either a single optimizer or a list of optimizers."
)
# support master_weight
use_multi_precision = not (master_weight is False)
for opt in optimizers:
_set_multi_precision(opt, use_multi_precision)
if save_dtype is not None:
if not (save_dtype in ['float16', 'bfloat16', 'float32', 'float64']):
raise ValueError(
"save_dtype can only be float16 float32 or float64, but your input save_dtype is %s."
% save_dtype
)
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
layer.register_state_dict_hook(StateDictHook(save_dtype))
if models_is_list:
if optimizers is not None:
if optimizers_is_list:
return models, optimizers
else:
return models, optimizers[0]
else:
return models
else:
if optimizers is not None:
if optimizers_is_list:
return models[0], optimizers
else:
return models[0], optimizers[0]
else:
return models[0]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import core
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import (
_varbase_creator,
_dygraph_tracer,
dygraph_only,
)
from paddle.fluid.data_feeder import check_type
from ...wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import warnings
import numpy as np
from paddle import _C_ops, _legacy_C_ops
from collections import defaultdict
from enum import Enum
from paddle.fluid import in_dygraph_mode
__all__ = ['AmpScaler', 'OptimizerState']
class OptimizerState(Enum):
INIT = 0
UNSCALED = 1
STEPPED = 2
def _refresh_optimizer_state():
return {"state": OptimizerState.INIT}
class AmpScaler:
"""
:api_attr: imperative
AmpScaler is used for Auto-Mixed-Precision training/inferring in imperative
mode. It controls the scaling of loss, helps avoiding numerical overflow.
The object of this class has seventeen methods `scale()`, `unscale_()`, `minimize()` and `get`/`set` api of parameters.
`scale()` is used to multiply the loss by a scale ratio.
`unscale_()` is used to unscale the gradients of parameters, multiplies the gradients of parameters by 1/(scale ratio)
`minimize()` is similar as `optimizer.minimize()`, performs parameters updating, and it will update the loss_scaling.
Commonly, it is used together with `amp_guard` to achieve Auto-Mixed-Precision in
imperative mode.
Args:
enable(bool, optional): Enable loss scaling or not. Default is True.
init_loss_scaling (float, optional): The initial loss scaling factor. Default is 2**15.
incr_ratio(float, optional): The multiplier to use when increasing the loss
scaling. Default is 2.0.
decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing
the loss scaling. Default is 0.5.
incr_every_n_steps(int, optional): Increases loss scaling every n consecutive
steps with finite gradients. Default is 1000.
decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n
accumulated steps with nan or inf gradients. Default is 2.
use_dynamic_loss_scaling(bool, optional): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True.
Returns:
An AmpScaler object.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
with fluid.dygraph.guard():
model = fluid.dygraph.Conv2D(3, 2, 3)
optimizer = fluid.optimizer.SGDOptimizer(
learning_rate=0.01, parameter_list=model.parameters())
scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
data = fluid.dygraph.to_variable(data)
with fluid.dygraph.amp_guard():
conv = model(data)
loss = fluid.layers.reduce_mean(conv)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
"""
@dygraph_only
def __init__(
self,
enable=True,
init_loss_scaling=2.0**15,
incr_ratio=2.0,
decr_ratio=0.5,
incr_every_n_steps=1000,
decr_every_n_nan_or_inf=1,
use_dynamic_loss_scaling=True,
):
tracer = _dygraph_tracer()
if not tracer:
raise ValueError(
"current_tracer is None, maybe it is not in imperative mode."
)
if enable and not (
tracer._expected_place.is_gpu_place()
or tracer._expected_place.is_xpu_place()
or tracer._expected_place.is_mlu_place()
or tracer._expected_place.is_npu_place()
or tracer._expected_place.is_custom_place()
):
warnings.warn(
'AmpScaler can only be enabled on CUDAPlace, XPUPlace, MLUPlace, NPUPlace and CustomPlace, current place is %s, so it makes no effect.'
% tracer._expected_place
)
enable = False
self._enable = enable
if self._enable:
assert incr_ratio > 1.0, "The incr_ratio must be > 1.0."
assert decr_ratio < 1.0, "The decr_ratio must be < 1.0."
self._init_loss_scaling = init_loss_scaling
self._incr_ratio = incr_ratio
self._decr_ratio = decr_ratio
self._incr_every_n_steps = incr_every_n_steps
self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
self._incr_count = 0
self._decr_count = 0
self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
self._found_inf = to_variable(np.array([0]).astype(np.bool_))
self._temp_found_inf_fp16 = to_variable(
np.array([0]).astype(np.bool_)
)
self._temp_found_inf_bf16 = to_variable(
np.array([0]).astype(np.bool_)
)
self._temp_found_inf_fp32 = to_variable(
np.array([0]).astype(np.bool_)
)
self._scale = to_variable(
np.array([self._init_loss_scaling]).astype(np.float32)
)
self._cache_founf_inf = None
self._optimizer_states = defaultdict(_refresh_optimizer_state)
def scale(self, var):
"""
Multiplies a variable(Tensor) by the scale factor and returns scaled outputs.
If this instance of :class:`AmpScaler` is not enabled, output are returned unmodified.
Args:
var (Variable): The variable to scale.
Returns:
The scaled variable or original variable.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
with fluid.dygraph.guard():
model = fluid.dygraph.Conv2D(3, 2, 3)
optimizer = fluid.optimizer.SGDOptimizer(
learning_rate=0.01, parameter_list=model.parameters())
scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
data = fluid.dygraph.to_variable(data)
with fluid.dygraph.amp_guard():
conv = model(data)
loss = fluid.layers.reduce_mean(conv)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
"""
check_type(var, "var", core.VarBase, 'AmpScaler.scale()')
if not self._enable:
return var
return var * self._scale
def minimize(self, optimizer, *args, **kwargs):
"""
This function is similar as `Optimizer.minimize()`, which performs parameters updating.
If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters.
Finally, the loss scaling ratio is updated.
Args:
optimizer(Optimizer): The optimizer used to update parameters.
args: Arguments, which will be forward to `optimizer.minimize()`.
kwargs: Keyword arguments, which will be forward to `Optimizer.minimize()`.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
with fluid.dygraph.guard():
model = fluid.dygraph.Conv2D(3, 2, 3)
optimizer = fluid.optimizer.SGDOptimizer(
learning_rate=0.01, parameter_list=model.parameters())
scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
data = fluid.dygraph.to_variable(data)
with fluid.dygraph.amp_guard():
conv = model(data)
loss = fluid.layers.reduce_mean(conv)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
"""
if not self._enable:
return optimizer.minimize(*args, **kwargs)
optimizer_state = self._optimizer_states[id(optimizer)]
# unscale the grad
if optimizer_state["state"] is OptimizerState.INIT:
self._unscale(optimizer)
optimize_ops, params_grads = (None, None)
if self._found_inf:
self._cache_founf_inf = True
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False
if self._use_dynamic_loss_scaling:
# uopdate the scale
self._update()
self._optimizer_states = defaultdict(_refresh_optimizer_state)
return optimize_ops, params_grads
def _unscale(self, optimizer):
"""
Unscale the gradients of parameters, multiplies the gradients of parameters by 1/(loss scaling ratio).
If this instance of :class:`GradScaler` is not enabled, output are returned unmodified.
Args:
optimizer(Optimizer): The optimizer used to update parameters.
Returns:
The unscaled parameters or original parameters.
"""
if not self._enable:
return
optimizer_state = self._optimizer_states[id(optimizer)]
if optimizer_state["state"] is OptimizerState.UNSCALED:
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
elif optimizer_state["state"] is OptimizerState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
if getattr(optimizer, '_param_groups', None) and isinstance(
optimizer._param_groups[0], dict
):
param_grads = []
param_grads_fp16 = []
param_grads_bf16 = []
param_grads_fp32 = []
for group in optimizer._param_groups:
for param in group['params']:
if param._grad_ivar() is not None:
param_grads.append(param._grad_ivar())
if (
param._grad_ivar().dtype
== core.VarDesc.VarType.FP16
):
param_grads_fp16.append(param._grad_ivar())
elif (
param._grad_ivar().dtype
== core.VarDesc.VarType.BF16
):
param_grads_bf16.append(param._grad_ivar())
else:
param_grads_fp32.append(param._grad_ivar())
else:
if in_dygraph_mode():
# It is very time-consuming to call c++ functions in a loop on the python side.
# We put this part of the code on the c++ side to improve the speed in eager mode.
(
param_grads_fp16,
param_grads_bf16,
param_grads_fp32,
) = core.eager.get_grads_lists(optimizer._parameter_list)
else:
# Keep the original code to support legacy mode.
# Delete the else branch when the legacy mode exits.
param_grads = [
param._grad_ivar()
for param in optimizer._parameter_list
if param._grad_ivar() is not None
]
param_grads_fp16 = [
param
for param in param_grads
if param.dtype == core.VarDesc.VarType.FP16
]
param_grads_bf16 = [
param
for param in param_grads
if param.dtype == core.VarDesc.VarType.BF16
]
param_grads_fp32 = [
param
for param in param_grads
if param.dtype == core.VarDesc.VarType.FP32
]
if core.is_compiled_with_npu():
float_status = _legacy_C_ops.alloc_float_status()
_legacy_C_ops.clear_float_status(float_status, float_status)
if len(param_grads_fp16):
_legacy_C_ops.check_finite_and_unscale(
param_grads_fp16,
self._scale,
float_status,
param_grads_fp16,
self._temp_found_inf_fp16,
)
if len(param_grads_bf16):
_legacy_C_ops.check_finite_and_unscale(
param_grads_bf16,
self._scale,
float_status,
param_grads_bf16,
self._temp_found_inf_bf16,
)
if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale(
param_grads_fp32,
self._scale,
float_status,
param_grads_fp32,
self._temp_found_inf_fp32,
)
else:
if len(param_grads_fp16):
_legacy_C_ops.check_finite_and_unscale(
param_grads_fp16,
self._scale,
param_grads_fp16,
self._temp_found_inf_fp16,
)
if len(param_grads_bf16):
_legacy_C_ops.check_finite_and_unscale(
param_grads_bf16,
self._scale,
param_grads_bf16,
self._temp_found_inf_bf16,
)
if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale(
param_grads_fp32,
self._scale,
param_grads_fp32,
self._temp_found_inf_fp32,
)
self._found_inf = (
self._temp_found_inf_fp16
or self._temp_found_inf_bf16
or self._temp_found_inf_fp32
)
optimizer_state["state"] = OptimizerState.UNSCALED
def _update(self):
"""
Updates the loss_scaling.
"""
if not self._enable:
return
if self._cache_founf_inf:
self._incr_count = 0
self._decr_count = self._decr_count + 1
if self._decr_count == self._decr_every_n_nan_or_inf:
print(
'Found inf or nan, current scale is: {}, decrease to: {}*{}'.format(
float(self._scale),
float(self._scale),
float(self._decr_ratio),
)
)
self._scale = self._scale * self._decr_ratio
self._decr_count = 0
else:
self._decr_count = 0
self._incr_count = self._incr_count + 1
if self._incr_count == self._incr_every_n_steps:
self._scale = self._scale * self._incr_ratio
self._incr_count = 0
return
def is_enable(self):
"""
Enable loss scaling or not.
Returns:
bool: enable loss scaling return True else return False.
"""
return self._enable
def is_use_dynamic_loss_scaling(self):
"""
Whether to use dynamic loss scaling.
Returns:
bool: if fixed loss_scaling is used return False, if the loss scaling is updated dynamicly return true.
"""
return self._use_dynamic_loss_scaling
def get_init_loss_scaling(self):
"""
Return the initial loss scaling factor.
Reurns:
float: the initial loss scaling factor.
"""
return self._init_loss_scaling
def set_init_loss_scaling(self, new_init_loss_scaling):
"""
Set the initial loss scaling factor by `new_init_loss_scaling`.
Args:
new_init_loss_scaling(int): The new_init_loss_scaling used to update initial loss scaling factor.s
"""
self._init_loss_scaling = new_init_loss_scaling
self._scale = to_variable(
np.array([self._init_loss_scaling]).astype(np.float32)
)
def get_incr_ratio(self):
"""
Return the multiplier to use when increasing the loss scaling.
Reurns:
float: the multiplier to use when increasing the loss scaling.
"""
return self._incr_ratio
def set_incr_ratio(self, new_incr_ratio):
"""
Set the multiplier to use when increasing the loss scaling by `new_incr_ratio`, `new_incr_ratio` should > 1.0.
Args:
new_incr_ratio(float): The new_incr_ratio used to update the multiplier to use when increasing the loss scaling.
"""
assert new_incr_ratio > 1.0, "The new_incr_ratio must be > 1.0."
self._incr_ratio = new_incr_ratio
def get_decr_ratio(self):
"""
Get the less-than-one-multiplier to use when decreasing the loss scaling.
Reurns:
float: the less-than-one-multiplier to use when decreasing the loss scaling.
"""
return self._decr_ratio
def set_decr_ratio(self, new_decr_ratio):
"""
Set the less-than-one-multiplier to use when decreasing the loss scaling by `new_incr_ratio`, `new_decr_ratio` should < 1.0.
Args:
new_decr_ratio(float): The new_decr_ratio used to update the less-than-one-multiplier to use when decreasing the loss scaling.
"""
assert new_decr_ratio < 1.0, "The new_decr_ratio must be < 1.0."
self._decr_ratio = new_decr_ratio
def get_incr_every_n_steps(self):
"""
Return the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
Reurns:
int: the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
"""
return self._incr_every_n_steps
def set_incr_every_n_steps(self, new_incr_every_n_steps):
"""
Set the num `n` by `new_incr_every_n_steps`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
Args:
new_incr_every_n_steps(int): The new_incr_every_n_steps used to update the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
"""
self._incr_every_n_steps = new_incr_every_n_steps
def get_decr_every_n_nan_or_inf(self):
"""
Return the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
Reurns:
int: the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
"""
return self._decr_every_n_nan_or_inf
def set_decr_every_n_nan_or_inf(self, new_decr_every_n_nan_or_inf):
"""
Set the num `n` by `new_decr_every_n_nan_or_inf`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
Args:
new_decr_every_n_nan_or_inf(int): The new_decr_every_n_nan_or_inf used to update the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
"""
self._decr_every_n_nan_or_inf = new_decr_every_n_nan_or_inf
def state_dict(self):
"""
Returns the state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict.
Reurns:
A dict of scaler includes:
scale (tensor): The loss scaling factor.
incr_ratio(float): The multiplier to use when increasing the loss scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling.
incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients.
decr_every_n_nan_or_inf(int): Decreases loss scaling every n accumulated steps with nan or inf gradients.
incr_count(int): The number of recent consecutive unskipped steps.
decr_count(int): The number of recent consecutive skipped steps.
use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True.
"""
return (
{
"scale": self._scale.numpy(),
"incr_ratio": self._incr_ratio,
"decr_ratio": self._decr_ratio,
"incr_every_n_steps": self._incr_every_n_steps,
"decr_every_n_nan_or_inf": self._decr_every_n_nan_or_inf,
"incr_count": self._incr_count,
"decr_count": self._decr_count,
"use_dynamic_loss_scaling": self._use_dynamic_loss_scaling,
}
if self._enable
else {}
)
def load_state_dict(self, state_dict):
"""
Loads the scaler state.
Args:
state_dict(dict): scaler state. Should be an object returned from a call to `AmpScaler.state_dict()`.
"""
if not self._enable:
return
if len(state_dict) == 0:
raise RuntimeError(
"The input state dict is empty, possibly because it was saved "
"from a disabled instance of GradScaler."
)
self._init_loss_scaling = state_dict["scale"][0]
self._scale = to_variable(
np.array([self._init_loss_scaling]).astype(np.float32)
)
self._incr_ratio = state_dict["incr_ratio"]
self._decr_ratio = state_dict["decr_ratio"]
self._incr_every_n_steps = state_dict["incr_every_n_steps"]
self._decr_every_n_nan_or_inf = state_dict["decr_every_n_nan_or_inf"]
self._incr_count = state_dict["incr_count"]
self._decr_count = state_dict["decr_count"]
self._use_dynamic_loss_scaling = state_dict["use_dynamic_loss_scaling"]
......@@ -60,10 +60,10 @@ class TestAutoCast(unittest.TestCase):
with fluid.dygraph.guard():
conv2d = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
data = fluid.dygraph.to_variable(data)
with fluid.dygraph.amp_guard(True):
with paddle.amp.amp_guard(True):
out_fp16 = conv2d(data)
with fluid.dygraph.amp_guard(False):
with paddle.amp.amp_guard(False):
out_fp32 = conv2d(data)
self.assertTrue(data.dtype == fluid.core.VarDesc.VarType.FP32)
......@@ -77,7 +77,7 @@ class TestAutoCast(unittest.TestCase):
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
with fluid.dygraph.guard():
data = fluid.dygraph.to_variable(data)
with fluid.dygraph.amp_guard(True):
with paddle.amp.amp_guard(True):
out_fp32 = paddle.mean(data)
self.assertTrue(data.dtype == fluid.core.VarDesc.VarType.FP32)
......@@ -89,9 +89,9 @@ class TestAutoCast(unittest.TestCase):
def custom_op_list(self):
with fluid.dygraph.guard():
tracer = fluid.framework._dygraph_tracer()
base_white_list = fluid.dygraph.amp.auto_cast.WHITE_LIST
base_black_list = fluid.dygraph.amp.auto_cast.BLACK_LIST
with fluid.dygraph.amp_guard(
base_white_list = paddle.amp.WHITE_LIST
base_black_list = paddle.amp.BLACK_LIST
with paddle.amp.amp_guard(
custom_white_list=["log"], custom_black_list=["conv2d"]
):
white_list, black_list = tracer._get_amp_op_list()
......@@ -105,9 +105,9 @@ class TestAutoCast(unittest.TestCase):
== (set(base_black_list) - {"log"}) | {"conv2d"}
)
base_white_list = fluid.dygraph.amp.auto_cast.PURE_FP16_WHITE_LIST
base_black_list = fluid.dygraph.amp.auto_cast.PURE_FP16_BLACK_LIST
with fluid.dygraph.amp_guard(
base_white_list = paddle.amp.PURE_FP16_WHITE_LIST
base_black_list = paddle.amp.PURE_FP16_BLACK_LIST
with paddle.amp.amp_guard(
custom_white_list=["log"],
custom_black_list=["conv2d"],
level='O2',
......@@ -138,7 +138,7 @@ class TestAutoCast(unittest.TestCase):
stride=2,
act='relu',
)
with fluid.dygraph.amp_guard(
with paddle.amp.amp_guard(
custom_white_list=["conv2d"], custom_black_list=["conv2d"]
):
inp = fluid.dygraph.to_variable(inp_np)
......@@ -154,13 +154,13 @@ class TestAutoCast(unittest.TestCase):
with fluid.dygraph.guard():
conv2d = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
data = fluid.dygraph.to_variable(data)
with fluid.dygraph.amp_guard(True):
with paddle.amp.amp_guard(True):
out_amp_fp16 = conv2d(data)
out_amp_fp32 = paddle.expand_as(
out_amp_fp16, out_amp_fp16
) # expand_as_v2 has no fp16 kernel
with fluid.dygraph.amp_guard(True, level='O2'):
with paddle.amp.amp_guard(True, level='O2'):
out_purefp16_fp16 = conv2d(data)
out_purefp16_fp32 = paddle.expand_as(
out_purefp16_fp16, out_purefp16_fp16
......@@ -184,7 +184,7 @@ class TestAutoCast(unittest.TestCase):
with fluid.dygraph.guard():
conv2d = paddle.nn.Conv2D(3, 2, 3, bias_attr=False)
data = fluid.dygraph.to_variable(data)
with fluid.dygraph.amp_guard(level='O'):
with paddle.amp.amp_guard(level='O'):
out = conv2d(data)
self.assertRaises(ValueError, func)
......@@ -197,7 +197,7 @@ class TestAmpScaler(unittest.TestCase):
def scale(self):
with fluid.dygraph.guard():
data = paddle.rand([10, 1024])
scaler = paddle.fluid.dygraph.AmpScaler(init_loss_scaling=1024)
scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
scaled_data = scaler.scale(data)
self.assertEqual(
np.array_equal(scaled_data.numpy(), data.numpy() * 1024), True
......@@ -223,7 +223,7 @@ class TestAmpScaler(unittest.TestCase):
optimizer = fluid.optimizer.SGDOptimizer(
learning_rate=0.01, parameter_list=model.parameters()
)
scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
data = fluid.dygraph.to_variable(inp_np)
out = model(data)
......@@ -332,7 +332,7 @@ class TestAmpScaler(unittest.TestCase):
optimizer = fluid.optimizer.SGDOptimizer(
learning_rate=0.01, parameter_list=model.parameters()
)
scaler = fluid.dygraph.AmpScaler(init_loss_scaling=1024)
scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
data = fluid.dygraph.to_variable(inp_np)
out = model(data)
......@@ -1262,12 +1262,12 @@ class TestResnet(unittest.TestCase):
dy_param_init_value[param.name] = param.numpy()
program = None
scaler = paddle.fluid.dygraph.AmpScaler(
scaler = paddle.amp.AmpScaler(
enable=enable_amp, init_loss_scaling=2.0**10
)
if enable_amp and (level == 'O2'):
resnet, optimizer = paddle.fluid.dygraph.amp_decorate(
resnet, optimizer = paddle.amp.amp_decorate(
models=resnet, optimizers=optimizer, level='O2'
)
......@@ -1290,9 +1290,7 @@ class TestResnet(unittest.TestCase):
img = fluid.dygraph.to_variable(dy_x_data)
label = fluid.dygraph.to_variable(y_data)
label.stop_gradient = True
with paddle.fluid.dygraph.amp_guard(
enable=enable_amp, level=level
):
with paddle.amp.amp_guard(enable=enable_amp, level=level):
out = resnet(img)
loss = paddle.nn.functional.cross_entropy(
......
......@@ -28,7 +28,7 @@ class TestAMPList(unittest.TestCase):
with paddle.amp.auto_cast():
conv = conv2d(data)
c = a + b
paddle.fluid.dygraph.amp.auto_cast.low_precision_op_list()
paddle.amp.low_precision_op_list()
op_list = paddle.fluid.core.get_low_precision_op_list()
print(conv.dtype)
if conv.dtype == paddle.float16:
......
......@@ -18,6 +18,7 @@ import numpy as np
import paddle
from paddle import _legacy_C_ops
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
from paddle.fluid import backward, core, framework, program_guard
from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.contrib.mixed_precision.decorator import (
......@@ -28,10 +29,6 @@ from paddle.fluid.contrib.mixed_precision.fp16_utils import (
rewrite_program,
)
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.amp.auto_cast import (
_in_amp_guard,
_in_pure_fp16_guard,
)
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.executor import (
_is_dy2st_enable_standalone_executor,
......
......@@ -331,7 +331,6 @@ packages=['paddle',
'paddle.inference.contrib.utils',
'paddle.fluid',
'paddle.fluid.dygraph',
'paddle.fluid.dygraph.amp',
'paddle.fluid.proto',
'paddle.fluid.proto.profiler',
'paddle.fluid.distributed',
......
......@@ -1202,7 +1202,6 @@ def get_setup_parameters():
'paddle.inference.contrib.utils',
'paddle.fluid',
'paddle.fluid.dygraph',
'paddle.fluid.dygraph.amp',
'paddle.fluid.proto',
'paddle.fluid.proto.profiler',
'paddle.fluid.distributed',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册