未验证 提交 6959eae5 编写于 作者: Y Yiqun Liu 提交者: GitHub

Unify the static amp codes of fp16 and bf16. Reimplement #52694 in release/2.4. (#52697)

上级 d1e8b1e2
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from __future__ import print_function from __future__ import print_function
from . import decorator from . import decorator
from .decorator import * from .decorator import decorate, amp_decorate
from . import fp16_lists from . import fp16_lists
from .fp16_lists import * from .fp16_lists import *
from . import fp16_utils from . import fp16_utils
......
...@@ -27,8 +27,8 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None): ...@@ -27,8 +27,8 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None):
$$Out = X / scale$$ $$Out = X / scale$$
If any tensor in X contains Inf or Nan, the Out will generate a indicator. If any tensor in X contains Inf or Nan, the Out will generate a indicator.
FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of
Out should not be used, and its data may not be deterministic. Out should not be used, and its data may not be deterministic.
Otherwise, FoundInfinite will be 0 (False). Otherwise, FoundInfinite will be 0 (False).
Args: Args:
...@@ -38,75 +38,98 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None): ...@@ -38,75 +38,98 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None):
""" """
check_type(x, 'x', (tuple, list), 'check_finite_and_unscale') check_type(x, 'x', (tuple, list), 'check_finite_and_unscale')
for e in x: for e in x:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], check_variable_and_dtype(
'check_finite_and_unscale') e,
"x",
['float16', 'float32', 'float64', 'uint16'],
'check_finite_and_unscale',
)
helper = LayerHelper("check_finite_and_unscale", **locals()) helper = LayerHelper("check_finite_and_unscale", **locals())
found_inf = helper.create_variable_for_type_inference(dtype='bool') found_inf = helper.create_variable_for_type_inference(dtype='bool')
inputs = {'X': x, 'Scale': scale} inputs = {'X': x, 'Scale': scale}
if core.is_compiled_with_npu(): if core.is_compiled_with_npu():
check_variable_and_dtype(float_status, "float_status", check_variable_and_dtype(
['float16', 'float32'], float_status,
'check_finite_and_unscale') "float_status",
['float16', 'float32'],
'check_finite_and_unscale',
)
inputs['FloatStatus'] = float_status inputs['FloatStatus'] = float_status
outputs = {'Out': x, 'FoundInfinite': found_inf} outputs = {'Out': x, 'FoundInfinite': found_inf}
helper.append_op(type='check_finite_and_unscale', helper.append_op(
inputs=inputs, type='check_finite_and_unscale', inputs=inputs, outputs=outputs
outputs=outputs) )
return x, found_inf return x, found_inf
def update_loss_scaling(x, def update_loss_scaling(
found_inf, x,
prev_loss_scaling, found_inf,
num_good_steps, prev_loss_scaling,
num_bad_steps, num_good_steps,
incr_every_n_steps, num_bad_steps,
decr_every_n_nan_or_inf, incr_every_n_steps,
incr_ratio, decr_every_n_nan_or_inf,
decr_ratio, incr_ratio,
stop_update=False, decr_ratio,
name=None): stop_update=False,
name=None,
):
""" """
Update loss scaling according to overall gradients. If all gradients is Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio. finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Otherwise, loss scaling will decrease by decr_ratio after Otherwise, loss scaling will decrease by decr_ratio after
decr_every_n_nan_or_inf steps and each step some gradients are infinite. decr_every_n_nan_or_inf steps and each step some gradients are infinite.
Args: Args:
x(list|tuple): The input tensors of update_loss_scaling operator. x(list|tuple): The input tensors of update_loss_scaling operator.
found_inf (Variable): A boolean variable indicates whether found_inf (Variable): A boolean variable indicates whether
there is any infinite gradient. there is any infinite gradient.
prev_loss_scaling (Variable): Previous loss scaling. prev_loss_scaling (Variable): Previous loss scaling.
num_good_steps (Variable): A variable accumulates good steps in which num_good_steps (Variable): A variable accumulates good steps in which
all gradients are finite. all gradients are finite.
num_bad_steps (Variable): A variable accumulates bad steps in which num_bad_steps (Variable): A variable accumulates bad steps in which
some gradients are infinite. some gradients are infinite.
incr_every_n_steps (int): A variable represents increasing loss incr_every_n_steps (int): A variable represents increasing loss
scaling every n consecutive steps with scaling every n consecutive steps with
finite gradients. finite gradients.
decr_every_n_nan_or_inf (int): A variable represents decreasing decr_every_n_nan_or_inf (int): A variable represents decreasing
loss scaling every n accumulated loss scaling every n accumulated
steps with nan or inf gradients. steps with nan or inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss incr_ratio(float): The multiplier to use when increasing the loss
scaling. scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing decr_ratio(float): The less-than-one-multiplier to use when decreasing
loss scaling. loss scaling.
""" """
check_variable_and_dtype(prev_loss_scaling, "prev_loss_scaling", check_variable_and_dtype(
['float32', 'float64'], "update_loss_scaling") prev_loss_scaling,
"prev_loss_scaling",
['float32', 'float64'],
"update_loss_scaling",
)
check_type(x, 'x', (tuple, list), 'update_loss_scaling') check_type(x, 'x', (tuple, list), 'update_loss_scaling')
for e in x: for e in x:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], check_variable_and_dtype(
'update_loss_scaling') e,
if e.dtype == core.VarDesc.VarType.FP16: "x",
assert prev_loss_scaling.dtype == core.VarDesc.VarType.FP32, \ ['float16', 'float32', 'float64', 'uint16'],
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16." 'update_loss_scaling',
)
if (
e.dtype == core.VarDesc.VarType.FP16
or e.dtype == core.VarDesc.VarType.BF16
):
assert (
prev_loss_scaling.dtype == core.VarDesc.VarType.FP32
), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
else: else:
assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x." assert (
prev_loss_scaling.dtype == e.dtype
), "The dtype of prev_loss_scaling should be equal to the dtype of x."
helper = LayerHelper("update_loss_scaling", **locals()) helper = LayerHelper("update_loss_scaling", **locals())
...@@ -115,14 +138,14 @@ def update_loss_scaling(x, ...@@ -115,14 +138,14 @@ def update_loss_scaling(x,
'FoundInfinite': found_inf, 'FoundInfinite': found_inf,
'PrevLossScaling': prev_loss_scaling, 'PrevLossScaling': prev_loss_scaling,
'InGoodSteps': num_good_steps, 'InGoodSteps': num_good_steps,
'InBadSteps': num_bad_steps 'InBadSteps': num_bad_steps,
} }
outputs = { outputs = {
'Out': x, 'Out': x,
'LossScaling': prev_loss_scaling, 'LossScaling': prev_loss_scaling,
'OutGoodSteps': num_good_steps, 'OutGoodSteps': num_good_steps,
'OutBadSteps': num_bad_steps 'OutBadSteps': num_bad_steps,
} }
attrs = { attrs = {
...@@ -137,9 +160,8 @@ def update_loss_scaling(x, ...@@ -137,9 +160,8 @@ def update_loss_scaling(x,
else: else:
attrs['stop_update'] = stop_update attrs['stop_update'] = stop_update
helper.append_op(type='update_loss_scaling', helper.append_op(
inputs=inputs, type='update_loss_scaling', inputs=inputs, outputs=outputs, attrs=attrs
outputs=outputs, )
attrs=attrs)
return x return x
...@@ -13,16 +13,47 @@ ...@@ -13,16 +13,47 @@
# limitations under the License. # limitations under the License.
import copy import copy
from ... import core from ... import core
__all__ = ["CustomOpLists", "AutoMixedPrecisionLists"] __all__ = ["CustomOpLists", "AutoMixedPrecisionLists"]
# lookup_table fp16 is slower than fp32, though fp16 is supported. # lookup_table fp16 is slower than fp32, though fp16 is supported.
_extra_unsupported_fp16_list = { _extra_unsupported_list = {
'lookup_table', 'lookup_table_v2', 'scatter', 'scatter_grad' 'lookup_table',
'lookup_table_v2',
'scatter',
'scatter_grad',
} }
def _get_unsupported_list(dtype):
if dtype == "float16":
amp_dtype = core.VarDesc.VarType.FP16
elif dtype == "bfloat16":
amp_dtype = core.VarDesc.VarType.BF16
else:
raise ValueError(
"If enable AMP, dtype should be 'float16' or 'bfloat16'."
)
# The set of ops that don't support fp16 calculation
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_sys_unsupported_list = []
# _sys_unsupported_bf16_list = []
if core.is_compiled_with_xpu():
_, _, _sys_unsupported_list = core.op_supported_infos('XPU', amp_dtype)
elif core.is_compiled_with_npu():
_, _, _sys_unsupported_list = core.op_supported_infos('NPU', amp_dtype)
elif core.is_compiled_with_mlu():
_, _, _sys_unsupported_list = core.op_supported_infos('MLU', amp_dtype)
else:
_, _, _sys_unsupported_list = core.op_supported_infos('GPU', amp_dtype)
unsupported_list = _extra_unsupported_list | _sys_unsupported_list
return unsupported_list
class AutoMixedPrecisionLists(object): class AutoMixedPrecisionLists(object):
""" """
AutoMixedPrecisionLists is a class for black/white list. It can update AutoMixedPrecisionLists is a class for black/white list. It can update
...@@ -36,16 +67,20 @@ class AutoMixedPrecisionLists(object): ...@@ -36,16 +67,20 @@ class AutoMixedPrecisionLists(object):
custom_black_varnames (set): Users' custom black varibles' names. custom_black_varnames (set): Users' custom black varibles' names.
""" """
def __init__(self, def __init__(
custom_white_list=None, self,
custom_black_list=None, custom_white_list=None,
custom_black_varnames=None): custom_black_list=None,
custom_black_varnames=None,
dtype="float16",
):
self._custom_white_list = custom_white_list self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list self._custom_black_list = custom_black_list
self.amp_dtype = dtype
self.white_list = copy.copy(white_list) self.white_list = copy.copy(white_list)
self.black_list = copy.copy(black_list) self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list) self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(unsupported_fp16_list) self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
self.black_varnames = copy.copy(custom_black_varnames) self.black_varnames = copy.copy(custom_black_varnames)
self._update_list() self._update_list()
...@@ -56,8 +91,9 @@ class AutoMixedPrecisionLists(object): ...@@ -56,8 +91,9 @@ class AutoMixedPrecisionLists(object):
if self._custom_white_list and self._custom_black_list: if self._custom_white_list and self._custom_black_list:
for op_name in self._custom_white_list: for op_name in self._custom_white_list:
if op_name in self._custom_black_list: if op_name in self._custom_black_list:
raise ValueError("Custom white list overlap " raise ValueError(
"custom black list") "Custom white list overlap " "custom black list"
)
if self._custom_white_list: if self._custom_white_list:
for op_name in self._custom_white_list: for op_name in self._custom_white_list:
if op_name in self.black_list: if op_name in self.black_list:
...@@ -65,7 +101,7 @@ class AutoMixedPrecisionLists(object): ...@@ -65,7 +101,7 @@ class AutoMixedPrecisionLists(object):
elif op_name in self.gray_list: elif op_name in self.gray_list:
self.gray_list.remove(op_name) self.gray_list.remove(op_name)
self.white_list.add(op_name) self.white_list.add(op_name)
if op_name in _extra_unsupported_fp16_list: if op_name in _extra_unsupported_list:
self.unsupported_list.remove(op_name) self.unsupported_list.remove(op_name)
if self._custom_black_list: if self._custom_black_list:
for op_name in self._custom_black_list: for op_name in self._custom_black_list:
...@@ -170,22 +206,4 @@ gray_list = { ...@@ -170,22 +206,4 @@ gray_list = {
'fused_multi_transformer', 'fused_multi_transformer',
} }
# The set of ops that don't support fp16 calculation
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_sys_unsupported_fp16_list = []
if core.is_compiled_with_xpu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'XPU', core.VarDesc.VarType.FP16)
elif core.is_compiled_with_npu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'NPU', core.VarDesc.VarType.FP16)
elif core.is_compiled_with_mlu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'MLU', core.VarDesc.VarType.FP16)
else:
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'GPU', core.VarDesc.VarType.FP16)
unsupported_fp16_list = _extra_unsupported_fp16_list | _sys_unsupported_fp16_list
CustomOpLists = AutoMixedPrecisionLists CustomOpLists = AutoMixedPrecisionLists
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册