未验证 提交 2bf61284 编写于 作者: Z Zhang Ting 提交者: GitHub

[AMP] fix static promote (#53439)

上级 3fd2e765
...@@ -23,11 +23,15 @@ _logger = get_logger( ...@@ -23,11 +23,15 @@ _logger = get_logger(
) )
# lookup_table fp16 is slower than fp32, though fp16 is supported. # lookup_table fp16 is slower than fp32, though fp16 is supported.
_extra_unsupported_list = { _extra_black_list = {
'lookup_table', 'lookup_table',
'lookup_table_v2', 'lookup_table_v2',
'scatter', 'scatter',
'scatter_grad', 'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
} }
...@@ -118,8 +122,7 @@ def _get_sys_unsupported_list(dtype): ...@@ -118,8 +122,7 @@ def _get_sys_unsupported_list(dtype):
def _get_unsupported_list(dtype): def _get_unsupported_list(dtype):
# The set of ops that don't support fp16 calculation # The set of ops that don't support fp16 calculation
_, _sys_unsupported_list = _get_sys_unsupported_list(dtype) _, _sys_unsupported_list = _get_sys_unsupported_list(dtype)
unsupported_list = _extra_unsupported_list | _sys_unsupported_list return _sys_unsupported_list
return unsupported_list
# The three sets listed below are changed dynamiclly. They don't contain all # The three sets listed below are changed dynamiclly. They don't contain all
...@@ -145,6 +148,32 @@ def _get_white_list(dtype): ...@@ -145,6 +148,32 @@ def _get_white_list(dtype):
return white_list_for_dtype return white_list_for_dtype
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list = {
'exp',
'square',
'log',
'mean',
'sum',
'cos_sim',
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy',
'cross_entropy2',
# default fp32 can avoid return inf when the sum value large than 65504
'reduce_sum',
}
def _get_black_list():
_black_list = copy.copy(black_list)
_black_list = _black_list | _extra_black_list
return _black_list
class AutoMixedPrecisionLists: class AutoMixedPrecisionLists:
""" """
AutoMixedPrecisionLists is a class for black/white list. It can update AutoMixedPrecisionLists is a class for black/white list. It can update
...@@ -170,7 +199,7 @@ class AutoMixedPrecisionLists: ...@@ -170,7 +199,7 @@ class AutoMixedPrecisionLists:
self._custom_white_list = custom_white_list self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list self._custom_black_list = custom_black_list
self.white_list = copy.copy(_get_white_list(self.amp_dtype)) self.white_list = copy.copy(_get_white_list(self.amp_dtype))
self.black_list = copy.copy(black_list) self.black_list = copy.copy(_get_black_list())
self.gray_list = copy.copy(gray_list) self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype)) self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
self.black_varnames = copy.copy(custom_black_varnames) self.black_varnames = copy.copy(custom_black_varnames)
...@@ -196,8 +225,6 @@ class AutoMixedPrecisionLists: ...@@ -196,8 +225,6 @@ class AutoMixedPrecisionLists:
elif op_name in self.gray_list: elif op_name in self.gray_list:
self.gray_list.remove(op_name) self.gray_list.remove(op_name)
self.white_list.add(op_name) self.white_list.add(op_name)
if op_name in _extra_unsupported_list:
self.unsupported_list.remove(op_name)
if self._custom_black_list: if self._custom_black_list:
for op_name in self._custom_black_list: for op_name in self._custom_black_list:
if op_name in self.white_list: if op_name in self.white_list:
...@@ -217,33 +244,6 @@ class AutoMixedPrecisionLists: ...@@ -217,33 +244,6 @@ class AutoMixedPrecisionLists:
) )
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list = {
'exp',
'square',
'log',
'mean',
'sum',
'cos_sim',
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy',
'cross_entropy2',
# fp16 is slower than fp32, though fp16 is supported.
'lookup_table',
'lookup_table_v2',
'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
# default fp32 can avoid return inf when the sum value large than 65504
'reduce_sum',
}
# This set contains two types of ops. All ops supported fp16 calculation. One # This set contains two types of ops. All ops supported fp16 calculation. One
# of two types is considered numerically-safe, but may be made unsafe by an # of two types is considered numerically-safe, but may be made unsafe by an
# upstream blacklist op. Another type do not have numerically-significant # upstream blacklist op. Another type do not have numerically-significant
......
...@@ -425,15 +425,22 @@ def set_var_dst_dtype( ...@@ -425,15 +425,22 @@ def set_var_dst_dtype(
def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level): def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level):
if level == "O1":
return
keep_fp32_var_names = set() keep_fp32_var_names = set()
if level == "O1":
return keep_fp32_var_names
all_parameters = [] all_parameters = []
for block in program.blocks: for block in program.blocks:
all_parameters.extend(block.all_parameters()) all_parameters.extend(block.all_parameters())
ops = block.ops ops = block.ops
for op in ops: for op in ops:
if op_need_keep_fp32(op, amp_lists, use_fp16_guard): # Currently, lookup_table is in black_list and unsupport_list, it's weight will be
# set to fp32 in setp 1 of cast_model_tp_fp16. But the weight may be used as matmul's
# input in transformer, so the weight is also in to_fp16_var_names.
# TODO(zhangting2020): consider fix auto_parallel_fp16 and remove lookup_table
# from black_list and unsupport_list.
if op in ['lookup_table', 'lookup_table_v2']:
continue
if _need_keep_fp32(op, amp_lists.unsupported_list, use_fp16_guard):
for in_name in op.input_names: for in_name in op.input_names:
keep_fp32_var_names = keep_fp32_var_names.union( keep_fp32_var_names = keep_fp32_var_names.union(
op.input(in_name) op.input(in_name)
...@@ -451,6 +458,7 @@ def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level): ...@@ -451,6 +458,7 @@ def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level):
if param.name not in keep_fp32_var_names: if param.name not in keep_fp32_var_names:
_logger.debug(f"-- set param {param.name} to {dtype} --.") _logger.debug(f"-- set param {param.name} to {dtype} --.")
param.desc.set_dtype(dtype) param.desc.set_dtype(dtype)
return keep_fp32_var_names
def op_need_keep_fp32(op, amp_lists, use_fp16_guard): def op_need_keep_fp32(op, amp_lists, use_fp16_guard):
...@@ -607,15 +615,17 @@ def cast_model_to_fp16( ...@@ -607,15 +615,17 @@ def cast_model_to_fp16(
keep_fp32_ops = set() keep_fp32_ops = set()
keep_fp16_ops = set() keep_fp16_ops = set()
to_fp16_var_names = set() to_fp16_var_names = set()
keep_fp32_var_names = set()
# step 1: set params dtype. # step 1: set params dtype.
set_param_dtype( fp32_var_names = set_param_dtype(
program, program,
dtype=dest_type, dtype=dest_type,
amp_lists=amp_lists, amp_lists=amp_lists,
use_fp16_guard=use_fp16_guard, use_fp16_guard=use_fp16_guard,
level=level, level=level,
) )
keep_fp32_var_names = keep_fp32_var_names.union(fp32_var_names)
def need_process(op): def need_process(op):
need_process = True need_process = True
...@@ -719,6 +729,8 @@ def cast_model_to_fp16( ...@@ -719,6 +729,8 @@ def cast_model_to_fp16(
idx += num_cast_ops + 1 idx += num_cast_ops + 1
_logger.debug("---- after cast model to fp16 ----") _logger.debug("---- after cast model to fp16 ----")
_logger.debug(program) _logger.debug(program)
to_fp16_var_names.difference_update(keep_fp32_var_names)
return to_fp16_var_names return to_fp16_var_names
......
...@@ -50,9 +50,10 @@ class TestAMPList(unittest.TestCase): ...@@ -50,9 +50,10 @@ class TestAMPList(unittest.TestCase):
self.check_if_op_not_in_list( self.check_if_op_not_in_list(
self.custom_white_list, amp_list.black_list self.custom_white_list, amp_list.black_list
) )
self.check_if_op_not_in_list( if paddle.amp.is_float16_supported():
self.custom_white_list, amp_list.unsupported_list self.check_if_op_not_in_list(
) self.custom_white_list, amp_list.black_list
)
def test_eager(self): def test_eager(self):
if not paddle.amp.is_float16_supported(): if not paddle.amp.is_float16_supported():
......
...@@ -48,7 +48,6 @@ class TestAMPPromote(AmpTestBase): ...@@ -48,7 +48,6 @@ class TestAMPPromote(AmpTestBase):
max_iters = 2 max_iters = 2
x_fp32 = np.random.random(size=[1, 1, 6, 6]).astype("float32") x_fp32 = np.random.random(size=[1, 1, 6, 6]).astype("float32")
print(main_program)
losses_o1 = self.run_program( losses_o1 = self.run_program(
main_program, main_program,
startup_program, startup_program,
......
...@@ -265,18 +265,20 @@ class TestProgramBF16(AmpTestBase): ...@@ -265,18 +265,20 @@ class TestProgramBF16(AmpTestBase):
amp.debugging.collect_operator_stats(main_program) amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program) op_stats_list = amp.debugging._get_op_stats_list(main_program)
expected_fp32_calls = {"lookup_table_v2": 1}
expected_bf16_calls = { expected_bf16_calls = {
"matmul_v2": 1, "matmul_v2": 1,
"elementwise_add": 1, "elementwise_add": 1,
"dropout": 1, "dropout": 1,
"lookup_table_v2": 0, "lookup_table_v2": 0,
"squared_l2_norm": 2, "squared_l2_norm": 3,
"adamw": 2, "adamw": 3,
} }
self._check_optimizer( self._check_optimizer(
main_program, main_program,
expected_bf16_calls["matmul_v2"] expected_bf16_calls["matmul_v2"]
+ expected_bf16_calls["elementwise_add"], + expected_bf16_calls["elementwise_add"]
+ expected_fp32_calls["lookup_table_v2"],
) )
self._check_op_calls(op_stats_list[0], expected_bf16_calls) self._check_op_calls(op_stats_list[0], expected_bf16_calls)
......
...@@ -318,7 +318,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -318,7 +318,10 @@ class TestImageClassification(unittest.TestCase):
copy.copy(paddle.static.amp.fp16_lists.white_list) copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list | paddle.static.amp.fp16_lists._only_supported_fp16_list
) )
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
amp_lists = paddle.static.amp.AutoMixedPrecisionLists() amp_lists = paddle.static.amp.AutoMixedPrecisionLists()
...@@ -331,7 +334,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -331,7 +334,10 @@ class TestImageClassification(unittest.TestCase):
copy.copy(paddle.static.amp.fp16_lists.white_list) copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list | paddle.static.amp.fp16_lists._only_supported_fp16_list
) )
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
# 1. w={'exp}, b=None # 1. w={'exp}, b=None
...@@ -348,7 +354,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -348,7 +354,10 @@ class TestImageClassification(unittest.TestCase):
copy.copy(paddle.static.amp.fp16_lists.white_list) copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list | paddle.static.amp.fp16_lists._only_supported_fp16_list
) )
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
# 2. w={'tanh'}, b=None # 2. w={'tanh'}, b=None
...@@ -365,7 +374,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -365,7 +374,10 @@ class TestImageClassification(unittest.TestCase):
copy.copy(paddle.static.amp.fp16_lists.white_list) copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list | paddle.static.amp.fp16_lists._only_supported_fp16_list
) )
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
# 3. w={'lstm'}, b=None # 3. w={'lstm'}, b=None
...@@ -381,7 +393,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -381,7 +393,10 @@ class TestImageClassification(unittest.TestCase):
copy.copy(paddle.static.amp.fp16_lists.white_list) copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list | paddle.static.amp.fp16_lists._only_supported_fp16_list
) )
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
# 4. w=None, b={'conv2d'} # 4. w=None, b={'conv2d'}
...@@ -400,7 +415,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -400,7 +415,10 @@ class TestImageClassification(unittest.TestCase):
copy.copy(paddle.static.amp.fp16_lists.white_list) copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list | paddle.static.amp.fp16_lists._only_supported_fp16_list
) )
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
# 5. w=None, b={'tanh'} # 5. w=None, b={'tanh'}
...@@ -419,7 +437,10 @@ class TestImageClassification(unittest.TestCase): ...@@ -419,7 +437,10 @@ class TestImageClassification(unittest.TestCase):
copy.copy(paddle.static.amp.fp16_lists.white_list) copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list | paddle.static.amp.fp16_lists._only_supported_fp16_list
) )
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list) black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list) gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
# 6. w=None, b={'lstm'} # 6. w=None, b={'lstm'}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册