From 70e0e3d53f7375bd17fb8b9dd6ba0802990800ae Mon Sep 17 00:00:00 2001 From: lidanqing Date: Fri, 7 May 2021 08:18:28 +0200 Subject: [PATCH] [cherry-pick] Mechanism that converts startup_program initializers to BF16 (#32720) (#32764) * Add casting initializers for bf16 training * Changes after review * Correct test and add comment Co-authored-by: joanna.wozna.intel --- .../contrib/mixed_precision/bf16/amp_lists.py | 3 ++ .../contrib/mixed_precision/bf16/amp_utils.py | 51 ++++++++++++++++++- .../contrib/mixed_precision/bf16/decorator.py | 11 ++-- .../contrib/mixed_precision/fp16_utils.py | 30 +++++++---- .../fluid/contrib/tests/test_bf16_utils.py | 23 +++++++++ .../contrib/tests/test_model_cast_to_bf16.py | 28 ++++++---- python/paddle/fluid/layers/tensor.py | 10 ++-- .../fluid/tests/book/test_fit_a_line.py | 3 +- .../fluid/tests/book/test_word2vec_book.py | 2 +- 9 files changed, 131 insertions(+), 30 deletions(-) diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py index 1cf54aa0838..3a4dc8ed9af 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py @@ -49,6 +49,7 @@ class AutoMixedPrecisionListsBF16(object): self.bf16_list = copy.copy(bf16_list) self.fp32_list = copy.copy(fp32_list) self.gray_list = copy.copy(gray_list) + self.bf16_initializer_list = copy.copy(bf16_initializer_list) self.unsupported_list = copy.copy(unsupported_list) self.fp32_varnames = copy.copy(custom_fp32_varnames) self._update_list() @@ -79,6 +80,8 @@ class AutoMixedPrecisionListsBF16(object): self.unsupported_list.add(op_name) +bf16_initializer_list = {'fill_constant', 'uniform_random'} + # always bf16 bf16_list = {'elementwise_add', } diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py index 038479098a6..4551947e0fa 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py @@ -232,7 +232,52 @@ def bf16_guard(): yield -def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True): +def are_post_ops_bf16(post_ops, keep_fp32_ops): + for post_op in post_ops: + for op in post_op: + if op.type in keep_fp32_ops: + return False + return True + + +def cast_initializers_to_bf16(startup_prog, + amp_lists, + block, + all_ops, + keep_fp32_ops, + to_bf16_var_names=None): + prepend_ops = startup_prog.global_block().ops + for op in prepend_ops: + if str(op.type) in amp_lists.bf16_initializer_list: + change_op = True + op_post_ops = [] + op_out_vars = [] + for out_name in op.output_names: + for out_var_name in op.output(out_name): + out_var = block.var(out_var_name) + post_op = find_true_post_op(all_ops, op, out_var_name, True) + + if out_var is None or out_var.type not in _valid_types: + change_op = False + break + op_post_ops.append(post_op) + op_out_vars.append(out_var) + + if change_op and are_post_ops_bf16(op_post_ops, keep_fp32_ops): + for out_var in op_out_vars: + if out_var.dtype == core.VarDesc.VarType.FP32: + out_var.desc.set_dtype(core.VarDesc.VarType.BF16) + if to_bf16_var_names is not None and out_var.name in to_bf16_var_names: + to_bf16_var_names.remove(out_var.name) + if op.has_attr('dtype') and op.attr( + 'dtype') == core.VarDesc.VarType.FP32: + op._set_attr('dtype', core.VarDesc.VarType.BF16) + + +def cast_model_to_bf16(program, + startup_prog=None, + amp_lists=None, + use_bf16_guard=True): """ Traverse all ops in the whole model and set their inputs and outputs to the bf16 data type. This function will do some special processing for @@ -329,6 +374,10 @@ def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True): if op.has_attr('mkldnn_data_type'): op._set_attr('mkldnn_data_type', 'bfloat16') + if startup_prog is not None: + cast_initializers_to_bf16(startup_prog, amp_lists, global_block, + ops, keep_fp32_ops, to_bf16_var_names) + # process ops in keep_fp32_ops op_var_rename_map = [ collections.OrderedDict() for _ in range(len(program.blocks)) diff --git a/python/paddle/fluid/contrib/mixed_precision/bf16/decorator.py b/python/paddle/fluid/contrib/mixed_precision/bf16/decorator.py index 86b5a5df75d..32c8a1c3544 100644 --- a/python/paddle/fluid/contrib/mixed_precision/bf16/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/bf16/decorator.py @@ -94,7 +94,8 @@ class OptimizerWithMixedPrecision(object): if self._use_pure_bf16: self._to_bf16_var_names = cast_model_to_bf16( - self._train_program, self._amp_lists, self._use_bf16_guard) + self._train_program, startup_program, self._amp_lists, + self._use_bf16_guard) else: rewrite_program_bf16(self._train_program, self._amp_lists) @@ -168,10 +169,12 @@ class OptimizerWithMixedPrecision(object): self._to_bf16_var_names) if test_program is not None: if self._use_pure_bf16: - cast_model_to_bf16(test_program, self._amp_lists, - self._use_bf16_guard) + cast_model_to_bf16( + test_program, + amp_lists=self._amp_lists, + use_bf16_guard=self._use_bf16_guard) elif use_bf16_test: - rewrite_program_bf16(test_program, self._amp_lists) + rewrite_program_bf16(test_program, amp_lists=self._amp_lists) def apply_gradients(self, params_grads): """ diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 65b62e7e5ab..16dfb2bd50c 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -157,7 +157,8 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name, return num_cast_ops assert target_var.dtype == src_dtype, \ - "The real dtype({}) is not equal to the src dtype({})".format(_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype)) + "The real dtype({}) is not equal to the src dtype({})".format( + _dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype)) cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype) cast_var = block.vars.get(cast_name) @@ -209,7 +210,7 @@ def find_true_prev_op(ops, cur_op, var_name): return None -def find_true_post_op(ops, cur_op, var_name): +def find_true_post_op(ops, cur_op, var_name, search_all=False): """ if there are post ops, return them, if there is no post op, return None instead. @@ -217,11 +218,22 @@ def find_true_post_op(ops, cur_op, var_name): ops (list): A list of ops. cur_op (Operator): Current operator which has var_name variable. var_name (string): Variable name. + search_all (bool): The type of operator search. Use if \"cur_op\" is not in the \"ops\" set. """ post_op = [] - for idx, op in enumerate(ops): - if op == cur_op: - break + if search_all: + """ + \"cur_op\" do not have to be in list of \"ops\". E.g. \"cur_op\" can come + from startup_prog block and \"ops\" list from main_prog block. + By setting idx to -1, we'll start looking for post-ops from the top of the list. + If search_all is False, assume that \"cur_op\" is in \"ops\" list, + so to reduce the time of search we can start iterating from \"cur_op\" idx. + """ + idx = -1 + else: + for idx, op in enumerate(ops): + if op == cur_op: + break for i in range(idx + 1, len(ops)): op = ops[i] @@ -270,7 +282,7 @@ def _need_keep_fp32(op, unsupported_op_list, use_fp16_guard): if use_fp16_guard: if op.has_attr("op_namescope") and \ - (_fp16_guard_pattern in op.attr("op_namescope")): + (_fp16_guard_pattern in op.attr("op_namescope")): # op in fp16 guard return False else: @@ -496,8 +508,8 @@ def rewrite_program(main_prog, amp_lists): black_op_set = set() for op in ops: - # NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder, - # we don't need to handle reader op and the input of 'create_py_reader' is not + # NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder, + # we don't need to handle reader op and the input of 'create_py_reader' is not # in block, which may result in errors. # See GeneratorLoader._init_non_iterable() for details. if op.type == 'create_py_reader' or op.type == 'read': @@ -612,7 +624,7 @@ def update_role_var_grad(main_prog, params_grads): raise ValueError("The cast op {0}'s output should not be" "used by a non-optimize op, however, it" "is used by {1}".format(op, post_ops[0])) - #add new op in the python and cpp at the same time + # add new op in the python and cpp at the same time new_op_desc = block.desc.append_op() new_op_desc.copy_from(op.desc) new_op = framework.Operator( diff --git a/python/paddle/fluid/contrib/tests/test_bf16_utils.py b/python/paddle/fluid/contrib/tests/test_bf16_utils.py index 2969b7ea11d..41aa5e5412d 100644 --- a/python/paddle/fluid/contrib/tests/test_bf16_utils.py +++ b/python/paddle/fluid/contrib/tests/test_bf16_utils.py @@ -139,6 +139,29 @@ class AMPTest2(unittest.TestCase): res = amp.bf16.amp_utils.find_true_post_op(block.ops, op1, "Y") assert (res == [op2]) + def test_find_true_post_op_with_search_all(self): + program = fluid.Program() + block = program.current_block() + startup_block = fluid.default_startup_program().global_block() + + var1 = block.create_var(name="X", shape=[3], dtype='float32') + var2 = block.create_var(name="Y", shape=[3], dtype='float32') + inititializer_op = startup_block._prepend_op( + type="fill_constant", + outputs={"Out": var1}, + attrs={"shape": var1.shape, + "dtype": var1.dtype, + "value": 1.0}) + + op1 = block.append_op( + type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]}) + result = amp.bf16.amp_utils.find_true_post_op( + block.ops, inititializer_op, "X", search_all=False) + assert (len(result) == 0) + result = amp.bf16.amp_utils.find_true_post_op( + block.ops, inititializer_op, "X", search_all=True) + assert (result == [op1]) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py index af2c42d6b85..470073543c3 100644 --- a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py +++ b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py @@ -53,19 +53,27 @@ class TestModelCastBF16(unittest.TestCase): with fluid.program_guard(prog, startup_prog): yield - def get_static_graph_result(self, feed, fetch_list, amp_fun, - with_lod=False): + def get_static_graph_result(self, + feed, + fetch_list, + amp_fun, + with_lod=False, + startup_prog=None): exe = fluid.Executor(core.CPUPlace()) - exe.run(fluid.default_startup_program()) + exe.run(fluid.default_startup_program() + if startup_prog is None else startup_prog) prog = fluid.default_main_program() if amp_fun is not None: - amp_fun(prog) + if startup_prog is not None: + amp_fun(prog, startup_prog) + else: + amp_fun(prog) return exe.run(prog, feed=feed, fetch_list=fetch_list, return_numpy=(not with_lod)) - def _graph_common(self, _amp_fun): + def _graph_common(self, _amp_fun, startup_prog=None): size = 3 n = np.ones([size, size], dtype='float32') * 3.2 nn = np.ones([size, size], dtype='float32') * -2.7 @@ -122,7 +130,8 @@ class TestModelCastBF16(unittest.TestCase): self.get_static_graph_result( feed={'t': n, 'tt': nn}, fetch_list=[ret], - amp_fun=_amp_fun + amp_fun=_amp_fun, + startup_prog=startup_prog ) self.assertTrue( static_ret_bf16, np.ones( @@ -132,16 +141,17 @@ class TestModelCastBF16(unittest.TestCase): self._graph_common(lambda prog: amp.bf16.rewrite_program_bf16( prog, amp.bf16.AutoMixedPrecisionListsBF16( - custom_fp32_varnames={'elementwise_add_0.tmp_0'}), + custom_fp32_varnames={'elementwise_add_0.tmp_0'}) )) def test_graph_cast(self): - self._graph_common(lambda prog: amp.bf16.cast_model_to_bf16( + self._graph_common(lambda prog, startup_prog: amp.bf16.cast_model_to_bf16( prog, + startup_prog, amp.bf16.AutoMixedPrecisionListsBF16( custom_fp32_list={'elementwise_mul'}), use_bf16_guard=True - )) + ), startup_prog=fluid.default_startup_program()) if __name__ == '__main__': diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 7dcce5efcfc..c0c07f593a3 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -231,13 +231,13 @@ def cast(x, dtype): out = core.ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype) return out - check_variable_and_dtype( - x, 'x', - ['bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], - 'cast') + check_variable_and_dtype(x, 'x', [ + 'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8', + 'uint16' + ], 'cast') check_dtype(dtype, 'dtype', [ 'bool', 'float16', 'float32', 'float64', 'int8', 'int32', 'int64', - 'uint8' + 'uint8', 'uint16' ], 'cast') helper = LayerHelper('cast', **locals()) diff --git a/python/paddle/fluid/tests/book/test_fit_a_line.py b/python/paddle/fluid/tests/book/test_fit_a_line.py index 1172ae0f0ea..12952462270 100644 --- a/python/paddle/fluid/tests/book/test_fit_a_line.py +++ b/python/paddle/fluid/tests/book/test_fit_a_line.py @@ -56,7 +56,8 @@ def train(use_cuda, save_dirname, is_local, use_bf16, pure_bf16): amp_lists=amp.bf16.AutoMixedPrecisionListsBF16(), use_bf16_guard=False, use_pure_bf16=pure_bf16) - sgd_optimizer.minimize(avg_cost) + sgd_optimizer.minimize( + avg_cost, startup_program=fluid.default_startup_program()) BATCH_SIZE = 20 diff --git a/python/paddle/fluid/tests/book/test_word2vec_book.py b/python/paddle/fluid/tests/book/test_word2vec_book.py index f16592a55cf..650ccc0776a 100644 --- a/python/paddle/fluid/tests/book/test_word2vec_book.py +++ b/python/paddle/fluid/tests/book/test_word2vec_book.py @@ -115,7 +115,7 @@ def train(target, use_bf16_guard=False, use_pure_bf16=pure_bf16) - sgd_optimizer.minimize(avg_cost) + sgd_optimizer.minimize(avg_cost, fluid.default_startup_program()) train_reader = paddle.batch( paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE) -- GitLab