未验证 提交 6959eae5 编写于 作者: Y Yiqun Liu 提交者: GitHub

Unify the static amp codes of fp16 and bf16. Reimplement #52694 in release/2.4. (#52697)

上级 d1e8b1e2
......@@ -15,7 +15,7 @@
from __future__ import print_function
from . import decorator
from .decorator import *
from .decorator import decorate, amp_decorate
from . import fp16_lists
from .fp16_lists import *
from . import fp16_utils
......
......@@ -27,8 +27,8 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None):
$$Out = X / scale$$
If any tensor in X contains Inf or Nan, the Out will generate a indicator.
FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of
Out should not be used, and its data may not be deterministic.
FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of
Out should not be used, and its data may not be deterministic.
Otherwise, FoundInfinite will be 0 (False).
Args:
......@@ -38,75 +38,98 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None):
"""
check_type(x, 'x', (tuple, list), 'check_finite_and_unscale')
for e in x:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
'check_finite_and_unscale')
check_variable_and_dtype(
e,
"x",
['float16', 'float32', 'float64', 'uint16'],
'check_finite_and_unscale',
)
helper = LayerHelper("check_finite_and_unscale", **locals())
found_inf = helper.create_variable_for_type_inference(dtype='bool')
inputs = {'X': x, 'Scale': scale}
if core.is_compiled_with_npu():
check_variable_and_dtype(float_status, "float_status",
['float16', 'float32'],
'check_finite_and_unscale')
check_variable_and_dtype(
float_status,
"float_status",
['float16', 'float32'],
'check_finite_and_unscale',
)
inputs['FloatStatus'] = float_status
outputs = {'Out': x, 'FoundInfinite': found_inf}
helper.append_op(type='check_finite_and_unscale',
inputs=inputs,
outputs=outputs)
helper.append_op(
type='check_finite_and_unscale', inputs=inputs, outputs=outputs
)
return x, found_inf
def update_loss_scaling(x,
found_inf,
prev_loss_scaling,
num_good_steps,
num_bad_steps,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
stop_update=False,
name=None):
def update_loss_scaling(
x,
found_inf,
prev_loss_scaling,
num_good_steps,
num_bad_steps,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
stop_update=False,
name=None,
):
"""
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Otherwise, loss scaling will decrease by decr_ratio after
decr_every_n_nan_or_inf steps and each step some gradients are infinite.
Args:
x(list|tuple): The input tensors of update_loss_scaling operator.
found_inf (Variable): A boolean variable indicates whether
found_inf (Variable): A boolean variable indicates whether
there is any infinite gradient.
prev_loss_scaling (Variable): Previous loss scaling.
num_good_steps (Variable): A variable accumulates good steps in which
num_good_steps (Variable): A variable accumulates good steps in which
all gradients are finite.
num_bad_steps (Variable): A variable accumulates bad steps in which
num_bad_steps (Variable): A variable accumulates bad steps in which
some gradients are infinite.
incr_every_n_steps (int): A variable represents increasing loss
scaling every n consecutive steps with
incr_every_n_steps (int): A variable represents increasing loss
scaling every n consecutive steps with
finite gradients.
decr_every_n_nan_or_inf (int): A variable represents decreasing
loss scaling every n accumulated
decr_every_n_nan_or_inf (int): A variable represents decreasing
loss scaling every n accumulated
steps with nan or inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
decr_ratio(float): The less-than-one-multiplier to use when decreasing
loss scaling.
"""
check_variable_and_dtype(prev_loss_scaling, "prev_loss_scaling",
['float32', 'float64'], "update_loss_scaling")
check_variable_and_dtype(
prev_loss_scaling,
"prev_loss_scaling",
['float32', 'float64'],
"update_loss_scaling",
)
check_type(x, 'x', (tuple, list), 'update_loss_scaling')
for e in x:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
'update_loss_scaling')
if e.dtype == core.VarDesc.VarType.FP16:
assert prev_loss_scaling.dtype == core.VarDesc.VarType.FP32, \
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
check_variable_and_dtype(
e,
"x",
['float16', 'float32', 'float64', 'uint16'],
'update_loss_scaling',
)
if (
e.dtype == core.VarDesc.VarType.FP16
or e.dtype == core.VarDesc.VarType.BF16
):
assert (
prev_loss_scaling.dtype == core.VarDesc.VarType.FP32
), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
else:
assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x."
assert (
prev_loss_scaling.dtype == e.dtype
), "The dtype of prev_loss_scaling should be equal to the dtype of x."
helper = LayerHelper("update_loss_scaling", **locals())
......@@ -115,14 +138,14 @@ def update_loss_scaling(x,
'FoundInfinite': found_inf,
'PrevLossScaling': prev_loss_scaling,
'InGoodSteps': num_good_steps,
'InBadSteps': num_bad_steps
'InBadSteps': num_bad_steps,
}
outputs = {
'Out': x,
'LossScaling': prev_loss_scaling,
'OutGoodSteps': num_good_steps,
'OutBadSteps': num_bad_steps
'OutBadSteps': num_bad_steps,
}
attrs = {
......@@ -137,9 +160,8 @@ def update_loss_scaling(x,
else:
attrs['stop_update'] = stop_update
helper.append_op(type='update_loss_scaling',
inputs=inputs,
outputs=outputs,
attrs=attrs)
helper.append_op(
type='update_loss_scaling', inputs=inputs, outputs=outputs, attrs=attrs
)
return x
......@@ -13,16 +13,47 @@
# limitations under the License.
import copy
from ... import core
__all__ = ["CustomOpLists", "AutoMixedPrecisionLists"]
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_extra_unsupported_fp16_list = {
'lookup_table', 'lookup_table_v2', 'scatter', 'scatter_grad'
_extra_unsupported_list = {
'lookup_table',
'lookup_table_v2',
'scatter',
'scatter_grad',
}
def _get_unsupported_list(dtype):
if dtype == "float16":
amp_dtype = core.VarDesc.VarType.FP16
elif dtype == "bfloat16":
amp_dtype = core.VarDesc.VarType.BF16
else:
raise ValueError(
"If enable AMP, dtype should be 'float16' or 'bfloat16'."
)
# The set of ops that don't support fp16 calculation
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_sys_unsupported_list = []
# _sys_unsupported_bf16_list = []
if core.is_compiled_with_xpu():
_, _, _sys_unsupported_list = core.op_supported_infos('XPU', amp_dtype)
elif core.is_compiled_with_npu():
_, _, _sys_unsupported_list = core.op_supported_infos('NPU', amp_dtype)
elif core.is_compiled_with_mlu():
_, _, _sys_unsupported_list = core.op_supported_infos('MLU', amp_dtype)
else:
_, _, _sys_unsupported_list = core.op_supported_infos('GPU', amp_dtype)
unsupported_list = _extra_unsupported_list | _sys_unsupported_list
return unsupported_list
class AutoMixedPrecisionLists(object):
"""
AutoMixedPrecisionLists is a class for black/white list. It can update
......@@ -36,16 +67,20 @@ class AutoMixedPrecisionLists(object):
custom_black_varnames (set): Users' custom black varibles' names.
"""
def __init__(self,
custom_white_list=None,
custom_black_list=None,
custom_black_varnames=None):
def __init__(
self,
custom_white_list=None,
custom_black_list=None,
custom_black_varnames=None,
dtype="float16",
):
self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list
self.amp_dtype = dtype
self.white_list = copy.copy(white_list)
self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(unsupported_fp16_list)
self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
self.black_varnames = copy.copy(custom_black_varnames)
self._update_list()
......@@ -56,8 +91,9 @@ class AutoMixedPrecisionLists(object):
if self._custom_white_list and self._custom_black_list:
for op_name in self._custom_white_list:
if op_name in self._custom_black_list:
raise ValueError("Custom white list overlap "
"custom black list")
raise ValueError(
"Custom white list overlap " "custom black list"
)
if self._custom_white_list:
for op_name in self._custom_white_list:
if op_name in self.black_list:
......@@ -65,7 +101,7 @@ class AutoMixedPrecisionLists(object):
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.white_list.add(op_name)
if op_name in _extra_unsupported_fp16_list:
if op_name in _extra_unsupported_list:
self.unsupported_list.remove(op_name)
if self._custom_black_list:
for op_name in self._custom_black_list:
......@@ -170,22 +206,4 @@ gray_list = {
'fused_multi_transformer',
}
# The set of ops that don't support fp16 calculation
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_sys_unsupported_fp16_list = []
if core.is_compiled_with_xpu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'XPU', core.VarDesc.VarType.FP16)
elif core.is_compiled_with_npu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'NPU', core.VarDesc.VarType.FP16)
elif core.is_compiled_with_mlu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'MLU', core.VarDesc.VarType.FP16)
else:
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'GPU', core.VarDesc.VarType.FP16)
unsupported_fp16_list = _extra_unsupported_fp16_list | _sys_unsupported_fp16_list
CustomOpLists = AutoMixedPrecisionLists
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册