未验证 提交 2bf61284 编写于 作者: Z Zhang Ting 提交者: GitHub

[AMP] fix static promote (#53439)

上级 3fd2e765
......@@ -23,11 +23,15 @@ _logger = get_logger(
)
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_extra_unsupported_list = {
_extra_black_list = {
'lookup_table',
'lookup_table_v2',
'scatter',
'scatter_grad',
'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
}
......@@ -118,8 +122,7 @@ def _get_sys_unsupported_list(dtype):
def _get_unsupported_list(dtype):
# The set of ops that don't support fp16 calculation
_, _sys_unsupported_list = _get_sys_unsupported_list(dtype)
unsupported_list = _extra_unsupported_list | _sys_unsupported_list
return unsupported_list
return _sys_unsupported_list
# The three sets listed below are changed dynamiclly. They don't contain all
......@@ -145,6 +148,32 @@ def _get_white_list(dtype):
return white_list_for_dtype
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list = {
'exp',
'square',
'log',
'mean',
'sum',
'cos_sim',
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy',
'cross_entropy2',
# default fp32 can avoid return inf when the sum value large than 65504
'reduce_sum',
}
def _get_black_list():
_black_list = copy.copy(black_list)
_black_list = _black_list | _extra_black_list
return _black_list
class AutoMixedPrecisionLists:
"""
AutoMixedPrecisionLists is a class for black/white list. It can update
......@@ -170,7 +199,7 @@ class AutoMixedPrecisionLists:
self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list
self.white_list = copy.copy(_get_white_list(self.amp_dtype))
self.black_list = copy.copy(black_list)
self.black_list = copy.copy(_get_black_list())
self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
self.black_varnames = copy.copy(custom_black_varnames)
......@@ -196,8 +225,6 @@ class AutoMixedPrecisionLists:
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.white_list.add(op_name)
if op_name in _extra_unsupported_list:
self.unsupported_list.remove(op_name)
if self._custom_black_list:
for op_name in self._custom_black_list:
if op_name in self.white_list:
......@@ -217,33 +244,6 @@ class AutoMixedPrecisionLists:
)
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list = {
'exp',
'square',
'log',
'mean',
'sum',
'cos_sim',
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy',
'cross_entropy2',
# fp16 is slower than fp32, though fp16 is supported.
'lookup_table',
'lookup_table_v2',
'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
# default fp32 can avoid return inf when the sum value large than 65504
'reduce_sum',
}
# This set contains two types of ops. All ops supported fp16 calculation. One
# of two types is considered numerically-safe, but may be made unsafe by an
# upstream blacklist op. Another type do not have numerically-significant
......
......@@ -425,15 +425,22 @@ def set_var_dst_dtype(
def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level):
if level == "O1":
return
keep_fp32_var_names = set()
if level == "O1":
return keep_fp32_var_names
all_parameters = []
for block in program.blocks:
all_parameters.extend(block.all_parameters())
ops = block.ops
for op in ops:
if op_need_keep_fp32(op, amp_lists, use_fp16_guard):
# Currently, lookup_table is in black_list and unsupport_list, it's weight will be
# set to fp32 in setp 1 of cast_model_tp_fp16. But the weight may be used as matmul's
# input in transformer, so the weight is also in to_fp16_var_names.
# TODO(zhangting2020): consider fix auto_parallel_fp16 and remove lookup_table
# from black_list and unsupport_list.
if op in ['lookup_table', 'lookup_table_v2']:
continue
if _need_keep_fp32(op, amp_lists.unsupported_list, use_fp16_guard):
for in_name in op.input_names:
keep_fp32_var_names = keep_fp32_var_names.union(
op.input(in_name)
......@@ -451,6 +458,7 @@ def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level):
if param.name not in keep_fp32_var_names:
_logger.debug(f"-- set param {param.name} to {dtype} --.")
param.desc.set_dtype(dtype)
return keep_fp32_var_names
def op_need_keep_fp32(op, amp_lists, use_fp16_guard):
......@@ -607,15 +615,17 @@ def cast_model_to_fp16(
keep_fp32_ops = set()
keep_fp16_ops = set()
to_fp16_var_names = set()
keep_fp32_var_names = set()
# step 1: set params dtype.
set_param_dtype(
fp32_var_names = set_param_dtype(
program,
dtype=dest_type,
amp_lists=amp_lists,
use_fp16_guard=use_fp16_guard,
level=level,
)
keep_fp32_var_names = keep_fp32_var_names.union(fp32_var_names)
def need_process(op):
need_process = True
......@@ -719,6 +729,8 @@ def cast_model_to_fp16(
idx += num_cast_ops + 1
_logger.debug("---- after cast model to fp16 ----")
_logger.debug(program)
to_fp16_var_names.difference_update(keep_fp32_var_names)
return to_fp16_var_names
......
......@@ -50,9 +50,10 @@ class TestAMPList(unittest.TestCase):
self.check_if_op_not_in_list(
self.custom_white_list, amp_list.black_list
)
self.check_if_op_not_in_list(
self.custom_white_list, amp_list.unsupported_list
)
if paddle.amp.is_float16_supported():
self.check_if_op_not_in_list(
self.custom_white_list, amp_list.black_list
)
def test_eager(self):
if not paddle.amp.is_float16_supported():
......
......@@ -48,7 +48,6 @@ class TestAMPPromote(AmpTestBase):
max_iters = 2
x_fp32 = np.random.random(size=[1, 1, 6, 6]).astype("float32")
print(main_program)
losses_o1 = self.run_program(
main_program,
startup_program,
......
......@@ -265,18 +265,20 @@ class TestProgramBF16(AmpTestBase):
amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
expected_fp32_calls = {"lookup_table_v2": 1}
expected_bf16_calls = {
"matmul_v2": 1,
"elementwise_add": 1,
"dropout": 1,
"lookup_table_v2": 0,
"squared_l2_norm": 2,
"adamw": 2,
"squared_l2_norm": 3,
"adamw": 3,
}
self._check_optimizer(
main_program,
expected_bf16_calls["matmul_v2"]
+ expected_bf16_calls["elementwise_add"],
+ expected_bf16_calls["elementwise_add"]
+ expected_fp32_calls["lookup_table_v2"],
)
self._check_op_calls(op_stats_list[0], expected_bf16_calls)
......
......@@ -318,7 +318,10 @@ class TestImageClassification(unittest.TestCase):
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
amp_lists = paddle.static.amp.AutoMixedPrecisionLists()
......@@ -331,7 +334,10 @@ class TestImageClassification(unittest.TestCase):
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
# 1. w={'exp}, b=None
......@@ -348,7 +354,10 @@ class TestImageClassification(unittest.TestCase):
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
# 2. w={'tanh'}, b=None
......@@ -365,7 +374,10 @@ class TestImageClassification(unittest.TestCase):
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
# 3. w={'lstm'}, b=None
......@@ -381,7 +393,10 @@ class TestImageClassification(unittest.TestCase):
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
# 4. w=None, b={'conv2d'}
......@@ -400,7 +415,10 @@ class TestImageClassification(unittest.TestCase):
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
# 5. w=None, b={'tanh'}
......@@ -419,7 +437,10 @@ class TestImageClassification(unittest.TestCase):
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
# 6. w=None, b={'lstm'}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册