未验证 提交 6959eae5 编写于 作者: Y Yiqun Liu 提交者: GitHub

Unify the static amp codes of fp16 and bf16. Reimplement #52694 in release/2.4. (#52697)

上级 d1e8b1e2
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from __future__ import print_function from __future__ import print_function
from . import decorator from . import decorator
from .decorator import * from .decorator import decorate, amp_decorate
from . import fp16_lists from . import fp16_lists
from .fp16_lists import * from .fp16_lists import *
from . import fp16_utils from . import fp16_utils
......
...@@ -38,27 +38,35 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None): ...@@ -38,27 +38,35 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None):
""" """
check_type(x, 'x', (tuple, list), 'check_finite_and_unscale') check_type(x, 'x', (tuple, list), 'check_finite_and_unscale')
for e in x: for e in x:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], check_variable_and_dtype(
'check_finite_and_unscale') e,
"x",
['float16', 'float32', 'float64', 'uint16'],
'check_finite_and_unscale',
)
helper = LayerHelper("check_finite_and_unscale", **locals()) helper = LayerHelper("check_finite_and_unscale", **locals())
found_inf = helper.create_variable_for_type_inference(dtype='bool') found_inf = helper.create_variable_for_type_inference(dtype='bool')
inputs = {'X': x, 'Scale': scale} inputs = {'X': x, 'Scale': scale}
if core.is_compiled_with_npu(): if core.is_compiled_with_npu():
check_variable_and_dtype(float_status, "float_status", check_variable_and_dtype(
float_status,
"float_status",
['float16', 'float32'], ['float16', 'float32'],
'check_finite_and_unscale') 'check_finite_and_unscale',
)
inputs['FloatStatus'] = float_status inputs['FloatStatus'] = float_status
outputs = {'Out': x, 'FoundInfinite': found_inf} outputs = {'Out': x, 'FoundInfinite': found_inf}
helper.append_op(type='check_finite_and_unscale', helper.append_op(
inputs=inputs, type='check_finite_and_unscale', inputs=inputs, outputs=outputs
outputs=outputs) )
return x, found_inf return x, found_inf
def update_loss_scaling(x, def update_loss_scaling(
x,
found_inf, found_inf,
prev_loss_scaling, prev_loss_scaling,
num_good_steps, num_good_steps,
...@@ -68,7 +76,8 @@ def update_loss_scaling(x, ...@@ -68,7 +76,8 @@ def update_loss_scaling(x,
incr_ratio, incr_ratio,
decr_ratio, decr_ratio,
stop_update=False, stop_update=False,
name=None): name=None,
):
""" """
Update loss scaling according to overall gradients. If all gradients is Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio. finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
...@@ -96,17 +105,31 @@ def update_loss_scaling(x, ...@@ -96,17 +105,31 @@ def update_loss_scaling(x,
loss scaling. loss scaling.
""" """
check_variable_and_dtype(prev_loss_scaling, "prev_loss_scaling", check_variable_and_dtype(
['float32', 'float64'], "update_loss_scaling") prev_loss_scaling,
"prev_loss_scaling",
['float32', 'float64'],
"update_loss_scaling",
)
check_type(x, 'x', (tuple, list), 'update_loss_scaling') check_type(x, 'x', (tuple, list), 'update_loss_scaling')
for e in x: for e in x:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], check_variable_and_dtype(
'update_loss_scaling') e,
if e.dtype == core.VarDesc.VarType.FP16: "x",
assert prev_loss_scaling.dtype == core.VarDesc.VarType.FP32, \ ['float16', 'float32', 'float64', 'uint16'],
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16." 'update_loss_scaling',
)
if (
e.dtype == core.VarDesc.VarType.FP16
or e.dtype == core.VarDesc.VarType.BF16
):
assert (
prev_loss_scaling.dtype == core.VarDesc.VarType.FP32
), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
else: else:
assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x." assert (
prev_loss_scaling.dtype == e.dtype
), "The dtype of prev_loss_scaling should be equal to the dtype of x."
helper = LayerHelper("update_loss_scaling", **locals()) helper = LayerHelper("update_loss_scaling", **locals())
...@@ -115,14 +138,14 @@ def update_loss_scaling(x, ...@@ -115,14 +138,14 @@ def update_loss_scaling(x,
'FoundInfinite': found_inf, 'FoundInfinite': found_inf,
'PrevLossScaling': prev_loss_scaling, 'PrevLossScaling': prev_loss_scaling,
'InGoodSteps': num_good_steps, 'InGoodSteps': num_good_steps,
'InBadSteps': num_bad_steps 'InBadSteps': num_bad_steps,
} }
outputs = { outputs = {
'Out': x, 'Out': x,
'LossScaling': prev_loss_scaling, 'LossScaling': prev_loss_scaling,
'OutGoodSteps': num_good_steps, 'OutGoodSteps': num_good_steps,
'OutBadSteps': num_bad_steps 'OutBadSteps': num_bad_steps,
} }
attrs = { attrs = {
...@@ -137,9 +160,8 @@ def update_loss_scaling(x, ...@@ -137,9 +160,8 @@ def update_loss_scaling(x,
else: else:
attrs['stop_update'] = stop_update attrs['stop_update'] = stop_update
helper.append_op(type='update_loss_scaling', helper.append_op(
inputs=inputs, type='update_loss_scaling', inputs=inputs, outputs=outputs, attrs=attrs
outputs=outputs, )
attrs=attrs)
return x return x
...@@ -13,16 +13,47 @@ ...@@ -13,16 +13,47 @@
# limitations under the License. # limitations under the License.
import copy import copy
from ... import core from ... import core
__all__ = ["CustomOpLists", "AutoMixedPrecisionLists"] __all__ = ["CustomOpLists", "AutoMixedPrecisionLists"]
# lookup_table fp16 is slower than fp32, though fp16 is supported. # lookup_table fp16 is slower than fp32, though fp16 is supported.
_extra_unsupported_fp16_list = { _extra_unsupported_list = {
'lookup_table', 'lookup_table_v2', 'scatter', 'scatter_grad' 'lookup_table',
'lookup_table_v2',
'scatter',
'scatter_grad',
} }
def _get_unsupported_list(dtype):
if dtype == "float16":
amp_dtype = core.VarDesc.VarType.FP16
elif dtype == "bfloat16":
amp_dtype = core.VarDesc.VarType.BF16
else:
raise ValueError(
"If enable AMP, dtype should be 'float16' or 'bfloat16'."
)
# The set of ops that don't support fp16 calculation
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_sys_unsupported_list = []
# _sys_unsupported_bf16_list = []
if core.is_compiled_with_xpu():
_, _, _sys_unsupported_list = core.op_supported_infos('XPU', amp_dtype)
elif core.is_compiled_with_npu():
_, _, _sys_unsupported_list = core.op_supported_infos('NPU', amp_dtype)
elif core.is_compiled_with_mlu():
_, _, _sys_unsupported_list = core.op_supported_infos('MLU', amp_dtype)
else:
_, _, _sys_unsupported_list = core.op_supported_infos('GPU', amp_dtype)
unsupported_list = _extra_unsupported_list | _sys_unsupported_list
return unsupported_list
class AutoMixedPrecisionLists(object): class AutoMixedPrecisionLists(object):
""" """
AutoMixedPrecisionLists is a class for black/white list. It can update AutoMixedPrecisionLists is a class for black/white list. It can update
...@@ -36,16 +67,20 @@ class AutoMixedPrecisionLists(object): ...@@ -36,16 +67,20 @@ class AutoMixedPrecisionLists(object):
custom_black_varnames (set): Users' custom black varibles' names. custom_black_varnames (set): Users' custom black varibles' names.
""" """
def __init__(self, def __init__(
self,
custom_white_list=None, custom_white_list=None,
custom_black_list=None, custom_black_list=None,
custom_black_varnames=None): custom_black_varnames=None,
dtype="float16",
):
self._custom_white_list = custom_white_list self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list self._custom_black_list = custom_black_list
self.amp_dtype = dtype
self.white_list = copy.copy(white_list) self.white_list = copy.copy(white_list)
self.black_list = copy.copy(black_list) self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list) self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(unsupported_fp16_list) self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
self.black_varnames = copy.copy(custom_black_varnames) self.black_varnames = copy.copy(custom_black_varnames)
self._update_list() self._update_list()
...@@ -56,8 +91,9 @@ class AutoMixedPrecisionLists(object): ...@@ -56,8 +91,9 @@ class AutoMixedPrecisionLists(object):
if self._custom_white_list and self._custom_black_list: if self._custom_white_list and self._custom_black_list:
for op_name in self._custom_white_list: for op_name in self._custom_white_list:
if op_name in self._custom_black_list: if op_name in self._custom_black_list:
raise ValueError("Custom white list overlap " raise ValueError(
"custom black list") "Custom white list overlap " "custom black list"
)
if self._custom_white_list: if self._custom_white_list:
for op_name in self._custom_white_list: for op_name in self._custom_white_list:
if op_name in self.black_list: if op_name in self.black_list:
...@@ -65,7 +101,7 @@ class AutoMixedPrecisionLists(object): ...@@ -65,7 +101,7 @@ class AutoMixedPrecisionLists(object):
elif op_name in self.gray_list: elif op_name in self.gray_list:
self.gray_list.remove(op_name) self.gray_list.remove(op_name)
self.white_list.add(op_name) self.white_list.add(op_name)
if op_name in _extra_unsupported_fp16_list: if op_name in _extra_unsupported_list:
self.unsupported_list.remove(op_name) self.unsupported_list.remove(op_name)
if self._custom_black_list: if self._custom_black_list:
for op_name in self._custom_black_list: for op_name in self._custom_black_list:
...@@ -170,22 +206,4 @@ gray_list = { ...@@ -170,22 +206,4 @@ gray_list = {
'fused_multi_transformer', 'fused_multi_transformer',
} }
# The set of ops that don't support fp16 calculation
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_sys_unsupported_fp16_list = []
if core.is_compiled_with_xpu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'XPU', core.VarDesc.VarType.FP16)
elif core.is_compiled_with_npu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'NPU', core.VarDesc.VarType.FP16)
elif core.is_compiled_with_mlu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'MLU', core.VarDesc.VarType.FP16)
else:
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'GPU', core.VarDesc.VarType.FP16)
unsupported_fp16_list = _extra_unsupported_fp16_list | _sys_unsupported_fp16_list
CustomOpLists = AutoMixedPrecisionLists CustomOpLists = AutoMixedPrecisionLists
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册