未验证 提交 70e0e3d5 编写于 作者: L lidanqing 提交者: GitHub

[cherry-pick] Mechanism that converts startup_program initializers to BF16 (#32720) (#32764)

* Add casting initializers for bf16 training

* Changes after review

* Correct test and add comment
Co-authored-by: Njoanna.wozna.intel <joanna.wozna@intel.com>
上级 5fdd85ba
...@@ -49,6 +49,7 @@ class AutoMixedPrecisionListsBF16(object): ...@@ -49,6 +49,7 @@ class AutoMixedPrecisionListsBF16(object):
self.bf16_list = copy.copy(bf16_list) self.bf16_list = copy.copy(bf16_list)
self.fp32_list = copy.copy(fp32_list) self.fp32_list = copy.copy(fp32_list)
self.gray_list = copy.copy(gray_list) self.gray_list = copy.copy(gray_list)
self.bf16_initializer_list = copy.copy(bf16_initializer_list)
self.unsupported_list = copy.copy(unsupported_list) self.unsupported_list = copy.copy(unsupported_list)
self.fp32_varnames = copy.copy(custom_fp32_varnames) self.fp32_varnames = copy.copy(custom_fp32_varnames)
self._update_list() self._update_list()
...@@ -79,6 +80,8 @@ class AutoMixedPrecisionListsBF16(object): ...@@ -79,6 +80,8 @@ class AutoMixedPrecisionListsBF16(object):
self.unsupported_list.add(op_name) self.unsupported_list.add(op_name)
bf16_initializer_list = {'fill_constant', 'uniform_random'}
# always bf16 # always bf16
bf16_list = {'elementwise_add', } bf16_list = {'elementwise_add', }
......
...@@ -232,7 +232,52 @@ def bf16_guard(): ...@@ -232,7 +232,52 @@ def bf16_guard():
yield yield
def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True): def are_post_ops_bf16(post_ops, keep_fp32_ops):
for post_op in post_ops:
for op in post_op:
if op.type in keep_fp32_ops:
return False
return True
def cast_initializers_to_bf16(startup_prog,
amp_lists,
block,
all_ops,
keep_fp32_ops,
to_bf16_var_names=None):
prepend_ops = startup_prog.global_block().ops
for op in prepend_ops:
if str(op.type) in amp_lists.bf16_initializer_list:
change_op = True
op_post_ops = []
op_out_vars = []
for out_name in op.output_names:
for out_var_name in op.output(out_name):
out_var = block.var(out_var_name)
post_op = find_true_post_op(all_ops, op, out_var_name, True)
if out_var is None or out_var.type not in _valid_types:
change_op = False
break
op_post_ops.append(post_op)
op_out_vars.append(out_var)
if change_op and are_post_ops_bf16(op_post_ops, keep_fp32_ops):
for out_var in op_out_vars:
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
if to_bf16_var_names is not None and out_var.name in to_bf16_var_names:
to_bf16_var_names.remove(out_var.name)
if op.has_attr('dtype') and op.attr(
'dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.BF16)
def cast_model_to_bf16(program,
startup_prog=None,
amp_lists=None,
use_bf16_guard=True):
""" """
Traverse all ops in the whole model and set their inputs and outputs Traverse all ops in the whole model and set their inputs and outputs
to the bf16 data type. This function will do some special processing for to the bf16 data type. This function will do some special processing for
...@@ -329,6 +374,10 @@ def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True): ...@@ -329,6 +374,10 @@ def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True):
if op.has_attr('mkldnn_data_type'): if op.has_attr('mkldnn_data_type'):
op._set_attr('mkldnn_data_type', 'bfloat16') op._set_attr('mkldnn_data_type', 'bfloat16')
if startup_prog is not None:
cast_initializers_to_bf16(startup_prog, amp_lists, global_block,
ops, keep_fp32_ops, to_bf16_var_names)
# process ops in keep_fp32_ops # process ops in keep_fp32_ops
op_var_rename_map = [ op_var_rename_map = [
collections.OrderedDict() for _ in range(len(program.blocks)) collections.OrderedDict() for _ in range(len(program.blocks))
......
...@@ -94,7 +94,8 @@ class OptimizerWithMixedPrecision(object): ...@@ -94,7 +94,8 @@ class OptimizerWithMixedPrecision(object):
if self._use_pure_bf16: if self._use_pure_bf16:
self._to_bf16_var_names = cast_model_to_bf16( self._to_bf16_var_names = cast_model_to_bf16(
self._train_program, self._amp_lists, self._use_bf16_guard) self._train_program, startup_program, self._amp_lists,
self._use_bf16_guard)
else: else:
rewrite_program_bf16(self._train_program, self._amp_lists) rewrite_program_bf16(self._train_program, self._amp_lists)
...@@ -168,10 +169,12 @@ class OptimizerWithMixedPrecision(object): ...@@ -168,10 +169,12 @@ class OptimizerWithMixedPrecision(object):
self._to_bf16_var_names) self._to_bf16_var_names)
if test_program is not None: if test_program is not None:
if self._use_pure_bf16: if self._use_pure_bf16:
cast_model_to_bf16(test_program, self._amp_lists, cast_model_to_bf16(
self._use_bf16_guard) test_program,
amp_lists=self._amp_lists,
use_bf16_guard=self._use_bf16_guard)
elif use_bf16_test: elif use_bf16_test:
rewrite_program_bf16(test_program, self._amp_lists) rewrite_program_bf16(test_program, amp_lists=self._amp_lists)
def apply_gradients(self, params_grads): def apply_gradients(self, params_grads):
""" """
......
...@@ -157,7 +157,8 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name, ...@@ -157,7 +157,8 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
return num_cast_ops return num_cast_ops
assert target_var.dtype == src_dtype, \ assert target_var.dtype == src_dtype, \
"The real dtype({}) is not equal to the src dtype({})".format(_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype)) "The real dtype({}) is not equal to the src dtype({})".format(
_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype))
cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype) cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype)
cast_var = block.vars.get(cast_name) cast_var = block.vars.get(cast_name)
...@@ -209,7 +210,7 @@ def find_true_prev_op(ops, cur_op, var_name): ...@@ -209,7 +210,7 @@ def find_true_prev_op(ops, cur_op, var_name):
return None return None
def find_true_post_op(ops, cur_op, var_name): def find_true_post_op(ops, cur_op, var_name, search_all=False):
""" """
if there are post ops, return them, if there is no post op, if there are post ops, return them, if there is no post op,
return None instead. return None instead.
...@@ -217,11 +218,22 @@ def find_true_post_op(ops, cur_op, var_name): ...@@ -217,11 +218,22 @@ def find_true_post_op(ops, cur_op, var_name):
ops (list): A list of ops. ops (list): A list of ops.
cur_op (Operator): Current operator which has var_name variable. cur_op (Operator): Current operator which has var_name variable.
var_name (string): Variable name. var_name (string): Variable name.
search_all (bool): The type of operator search. Use if \"cur_op\" is not in the \"ops\" set.
""" """
post_op = [] post_op = []
for idx, op in enumerate(ops): if search_all:
if op == cur_op: """
break \"cur_op\" do not have to be in list of \"ops\". E.g. \"cur_op\" can come
from startup_prog block and \"ops\" list from main_prog block.
By setting idx to -1, we'll start looking for post-ops from the top of the list.
If search_all is False, assume that \"cur_op\" is in \"ops\" list,
so to reduce the time of search we can start iterating from \"cur_op\" idx.
"""
idx = -1
else:
for idx, op in enumerate(ops):
if op == cur_op:
break
for i in range(idx + 1, len(ops)): for i in range(idx + 1, len(ops)):
op = ops[i] op = ops[i]
...@@ -270,7 +282,7 @@ def _need_keep_fp32(op, unsupported_op_list, use_fp16_guard): ...@@ -270,7 +282,7 @@ def _need_keep_fp32(op, unsupported_op_list, use_fp16_guard):
if use_fp16_guard: if use_fp16_guard:
if op.has_attr("op_namescope") and \ if op.has_attr("op_namescope") and \
(_fp16_guard_pattern in op.attr("op_namescope")): (_fp16_guard_pattern in op.attr("op_namescope")):
# op in fp16 guard # op in fp16 guard
return False return False
else: else:
...@@ -496,8 +508,8 @@ def rewrite_program(main_prog, amp_lists): ...@@ -496,8 +508,8 @@ def rewrite_program(main_prog, amp_lists):
black_op_set = set() black_op_set = set()
for op in ops: for op in ops:
# NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder, # NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
# we don't need to handle reader op and the input of 'create_py_reader' is not # we don't need to handle reader op and the input of 'create_py_reader' is not
# in block, which may result in errors. # in block, which may result in errors.
# See GeneratorLoader._init_non_iterable() for details. # See GeneratorLoader._init_non_iterable() for details.
if op.type == 'create_py_reader' or op.type == 'read': if op.type == 'create_py_reader' or op.type == 'read':
...@@ -612,7 +624,7 @@ def update_role_var_grad(main_prog, params_grads): ...@@ -612,7 +624,7 @@ def update_role_var_grad(main_prog, params_grads):
raise ValueError("The cast op {0}'s output should not be" raise ValueError("The cast op {0}'s output should not be"
"used by a non-optimize op, however, it" "used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0])) "is used by {1}".format(op, post_ops[0]))
#add new op in the python and cpp at the same time # add new op in the python and cpp at the same time
new_op_desc = block.desc.append_op() new_op_desc = block.desc.append_op()
new_op_desc.copy_from(op.desc) new_op_desc.copy_from(op.desc)
new_op = framework.Operator( new_op = framework.Operator(
......
...@@ -139,6 +139,29 @@ class AMPTest2(unittest.TestCase): ...@@ -139,6 +139,29 @@ class AMPTest2(unittest.TestCase):
res = amp.bf16.amp_utils.find_true_post_op(block.ops, op1, "Y") res = amp.bf16.amp_utils.find_true_post_op(block.ops, op1, "Y")
assert (res == [op2]) assert (res == [op2])
def test_find_true_post_op_with_search_all(self):
program = fluid.Program()
block = program.current_block()
startup_block = fluid.default_startup_program().global_block()
var1 = block.create_var(name="X", shape=[3], dtype='float32')
var2 = block.create_var(name="Y", shape=[3], dtype='float32')
inititializer_op = startup_block._prepend_op(
type="fill_constant",
outputs={"Out": var1},
attrs={"shape": var1.shape,
"dtype": var1.dtype,
"value": 1.0})
op1 = block.append_op(
type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]})
result = amp.bf16.amp_utils.find_true_post_op(
block.ops, inititializer_op, "X", search_all=False)
assert (len(result) == 0)
result = amp.bf16.amp_utils.find_true_post_op(
block.ops, inititializer_op, "X", search_all=True)
assert (result == [op1])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -53,19 +53,27 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -53,19 +53,27 @@ class TestModelCastBF16(unittest.TestCase):
with fluid.program_guard(prog, startup_prog): with fluid.program_guard(prog, startup_prog):
yield yield
def get_static_graph_result(self, feed, fetch_list, amp_fun, def get_static_graph_result(self,
with_lod=False): feed,
fetch_list,
amp_fun,
with_lod=False,
startup_prog=None):
exe = fluid.Executor(core.CPUPlace()) exe = fluid.Executor(core.CPUPlace())
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program()
if startup_prog is None else startup_prog)
prog = fluid.default_main_program() prog = fluid.default_main_program()
if amp_fun is not None: if amp_fun is not None:
amp_fun(prog) if startup_prog is not None:
amp_fun(prog, startup_prog)
else:
amp_fun(prog)
return exe.run(prog, return exe.run(prog,
feed=feed, feed=feed,
fetch_list=fetch_list, fetch_list=fetch_list,
return_numpy=(not with_lod)) return_numpy=(not with_lod))
def _graph_common(self, _amp_fun): def _graph_common(self, _amp_fun, startup_prog=None):
size = 3 size = 3
n = np.ones([size, size], dtype='float32') * 3.2 n = np.ones([size, size], dtype='float32') * 3.2
nn = np.ones([size, size], dtype='float32') * -2.7 nn = np.ones([size, size], dtype='float32') * -2.7
...@@ -122,7 +130,8 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -122,7 +130,8 @@ class TestModelCastBF16(unittest.TestCase):
self.get_static_graph_result( self.get_static_graph_result(
feed={'t': n, 'tt': nn}, feed={'t': n, 'tt': nn},
fetch_list=[ret], fetch_list=[ret],
amp_fun=_amp_fun amp_fun=_amp_fun,
startup_prog=startup_prog
) )
self.assertTrue( self.assertTrue(
static_ret_bf16, np.ones( static_ret_bf16, np.ones(
...@@ -132,16 +141,17 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -132,16 +141,17 @@ class TestModelCastBF16(unittest.TestCase):
self._graph_common(lambda prog: amp.bf16.rewrite_program_bf16( self._graph_common(lambda prog: amp.bf16.rewrite_program_bf16(
prog, prog,
amp.bf16.AutoMixedPrecisionListsBF16( amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_varnames={'elementwise_add_0.tmp_0'}), custom_fp32_varnames={'elementwise_add_0.tmp_0'})
)) ))
def test_graph_cast(self): def test_graph_cast(self):
self._graph_common(lambda prog: amp.bf16.cast_model_to_bf16( self._graph_common(lambda prog, startup_prog: amp.bf16.cast_model_to_bf16(
prog, prog,
startup_prog,
amp.bf16.AutoMixedPrecisionListsBF16( amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_list={'elementwise_mul'}), custom_fp32_list={'elementwise_mul'}),
use_bf16_guard=True use_bf16_guard=True
)) ), startup_prog=fluid.default_startup_program())
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -231,13 +231,13 @@ def cast(x, dtype): ...@@ -231,13 +231,13 @@ def cast(x, dtype):
out = core.ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype) out = core.ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
return out return out
check_variable_and_dtype( check_variable_and_dtype(x, 'x', [
x, 'x', 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], 'uint16'
'cast') ], 'cast')
check_dtype(dtype, 'dtype', [ check_dtype(dtype, 'dtype', [
'bool', 'float16', 'float32', 'float64', 'int8', 'int32', 'int64', 'bool', 'float16', 'float32', 'float64', 'int8', 'int32', 'int64',
'uint8' 'uint8', 'uint16'
], 'cast') ], 'cast')
helper = LayerHelper('cast', **locals()) helper = LayerHelper('cast', **locals())
......
...@@ -56,7 +56,8 @@ def train(use_cuda, save_dirname, is_local, use_bf16, pure_bf16): ...@@ -56,7 +56,8 @@ def train(use_cuda, save_dirname, is_local, use_bf16, pure_bf16):
amp_lists=amp.bf16.AutoMixedPrecisionListsBF16(), amp_lists=amp.bf16.AutoMixedPrecisionListsBF16(),
use_bf16_guard=False, use_bf16_guard=False,
use_pure_bf16=pure_bf16) use_pure_bf16=pure_bf16)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(
avg_cost, startup_program=fluid.default_startup_program())
BATCH_SIZE = 20 BATCH_SIZE = 20
......
...@@ -115,7 +115,7 @@ def train(target, ...@@ -115,7 +115,7 @@ def train(target,
use_bf16_guard=False, use_bf16_guard=False,
use_pure_bf16=pure_bf16) use_pure_bf16=pure_bf16)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost, fluid.default_startup_program())
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE) paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册