未验证 提交 70e0e3d5 编写于 作者: L lidanqing 提交者: GitHub

[cherry-pick] Mechanism that converts startup_program initializers to BF16 (#32720) (#32764)

* Add casting initializers for bf16 training

* Changes after review

* Correct test and add comment
Co-authored-by: Njoanna.wozna.intel <joanna.wozna@intel.com>
上级 5fdd85ba
......@@ -49,6 +49,7 @@ class AutoMixedPrecisionListsBF16(object):
self.bf16_list = copy.copy(bf16_list)
self.fp32_list = copy.copy(fp32_list)
self.gray_list = copy.copy(gray_list)
self.bf16_initializer_list = copy.copy(bf16_initializer_list)
self.unsupported_list = copy.copy(unsupported_list)
self.fp32_varnames = copy.copy(custom_fp32_varnames)
self._update_list()
......@@ -79,6 +80,8 @@ class AutoMixedPrecisionListsBF16(object):
self.unsupported_list.add(op_name)
bf16_initializer_list = {'fill_constant', 'uniform_random'}
# always bf16
bf16_list = {'elementwise_add', }
......
......@@ -232,7 +232,52 @@ def bf16_guard():
yield
def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True):
def are_post_ops_bf16(post_ops, keep_fp32_ops):
for post_op in post_ops:
for op in post_op:
if op.type in keep_fp32_ops:
return False
return True
def cast_initializers_to_bf16(startup_prog,
amp_lists,
block,
all_ops,
keep_fp32_ops,
to_bf16_var_names=None):
prepend_ops = startup_prog.global_block().ops
for op in prepend_ops:
if str(op.type) in amp_lists.bf16_initializer_list:
change_op = True
op_post_ops = []
op_out_vars = []
for out_name in op.output_names:
for out_var_name in op.output(out_name):
out_var = block.var(out_var_name)
post_op = find_true_post_op(all_ops, op, out_var_name, True)
if out_var is None or out_var.type not in _valid_types:
change_op = False
break
op_post_ops.append(post_op)
op_out_vars.append(out_var)
if change_op and are_post_ops_bf16(op_post_ops, keep_fp32_ops):
for out_var in op_out_vars:
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.BF16)
if to_bf16_var_names is not None and out_var.name in to_bf16_var_names:
to_bf16_var_names.remove(out_var.name)
if op.has_attr('dtype') and op.attr(
'dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.BF16)
def cast_model_to_bf16(program,
startup_prog=None,
amp_lists=None,
use_bf16_guard=True):
"""
Traverse all ops in the whole model and set their inputs and outputs
to the bf16 data type. This function will do some special processing for
......@@ -329,6 +374,10 @@ def cast_model_to_bf16(program, amp_lists=None, use_bf16_guard=True):
if op.has_attr('mkldnn_data_type'):
op._set_attr('mkldnn_data_type', 'bfloat16')
if startup_prog is not None:
cast_initializers_to_bf16(startup_prog, amp_lists, global_block,
ops, keep_fp32_ops, to_bf16_var_names)
# process ops in keep_fp32_ops
op_var_rename_map = [
collections.OrderedDict() for _ in range(len(program.blocks))
......
......@@ -94,7 +94,8 @@ class OptimizerWithMixedPrecision(object):
if self._use_pure_bf16:
self._to_bf16_var_names = cast_model_to_bf16(
self._train_program, self._amp_lists, self._use_bf16_guard)
self._train_program, startup_program, self._amp_lists,
self._use_bf16_guard)
else:
rewrite_program_bf16(self._train_program, self._amp_lists)
......@@ -168,10 +169,12 @@ class OptimizerWithMixedPrecision(object):
self._to_bf16_var_names)
if test_program is not None:
if self._use_pure_bf16:
cast_model_to_bf16(test_program, self._amp_lists,
self._use_bf16_guard)
cast_model_to_bf16(
test_program,
amp_lists=self._amp_lists,
use_bf16_guard=self._use_bf16_guard)
elif use_bf16_test:
rewrite_program_bf16(test_program, self._amp_lists)
rewrite_program_bf16(test_program, amp_lists=self._amp_lists)
def apply_gradients(self, params_grads):
"""
......
......@@ -157,7 +157,8 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
return num_cast_ops
assert target_var.dtype == src_dtype, \
"The real dtype({}) is not equal to the src dtype({})".format(_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype))
"The real dtype({}) is not equal to the src dtype({})".format(
_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype))
cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype)
cast_var = block.vars.get(cast_name)
......@@ -209,7 +210,7 @@ def find_true_prev_op(ops, cur_op, var_name):
return None
def find_true_post_op(ops, cur_op, var_name):
def find_true_post_op(ops, cur_op, var_name, search_all=False):
"""
if there are post ops, return them, if there is no post op,
return None instead.
......@@ -217,11 +218,22 @@ def find_true_post_op(ops, cur_op, var_name):
ops (list): A list of ops.
cur_op (Operator): Current operator which has var_name variable.
var_name (string): Variable name.
search_all (bool): The type of operator search. Use if \"cur_op\" is not in the \"ops\" set.
"""
post_op = []
for idx, op in enumerate(ops):
if op == cur_op:
break
if search_all:
"""
\"cur_op\" do not have to be in list of \"ops\". E.g. \"cur_op\" can come
from startup_prog block and \"ops\" list from main_prog block.
By setting idx to -1, we'll start looking for post-ops from the top of the list.
If search_all is False, assume that \"cur_op\" is in \"ops\" list,
so to reduce the time of search we can start iterating from \"cur_op\" idx.
"""
idx = -1
else:
for idx, op in enumerate(ops):
if op == cur_op:
break
for i in range(idx + 1, len(ops)):
op = ops[i]
......@@ -270,7 +282,7 @@ def _need_keep_fp32(op, unsupported_op_list, use_fp16_guard):
if use_fp16_guard:
if op.has_attr("op_namescope") and \
(_fp16_guard_pattern in op.attr("op_namescope")):
(_fp16_guard_pattern in op.attr("op_namescope")):
# op in fp16 guard
return False
else:
......@@ -496,8 +508,8 @@ def rewrite_program(main_prog, amp_lists):
black_op_set = set()
for op in ops:
# NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
# we don't need to handle reader op and the input of 'create_py_reader' is not
# NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
# we don't need to handle reader op and the input of 'create_py_reader' is not
# in block, which may result in errors.
# See GeneratorLoader._init_non_iterable() for details.
if op.type == 'create_py_reader' or op.type == 'read':
......@@ -612,7 +624,7 @@ def update_role_var_grad(main_prog, params_grads):
raise ValueError("The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0]))
#add new op in the python and cpp at the same time
# add new op in the python and cpp at the same time
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(op.desc)
new_op = framework.Operator(
......
......@@ -139,6 +139,29 @@ class AMPTest2(unittest.TestCase):
res = amp.bf16.amp_utils.find_true_post_op(block.ops, op1, "Y")
assert (res == [op2])
def test_find_true_post_op_with_search_all(self):
program = fluid.Program()
block = program.current_block()
startup_block = fluid.default_startup_program().global_block()
var1 = block.create_var(name="X", shape=[3], dtype='float32')
var2 = block.create_var(name="Y", shape=[3], dtype='float32')
inititializer_op = startup_block._prepend_op(
type="fill_constant",
outputs={"Out": var1},
attrs={"shape": var1.shape,
"dtype": var1.dtype,
"value": 1.0})
op1 = block.append_op(
type="abs", inputs={"X": [var1]}, outputs={"Out": [var2]})
result = amp.bf16.amp_utils.find_true_post_op(
block.ops, inititializer_op, "X", search_all=False)
assert (len(result) == 0)
result = amp.bf16.amp_utils.find_true_post_op(
block.ops, inititializer_op, "X", search_all=True)
assert (result == [op1])
if __name__ == '__main__':
unittest.main()
......@@ -53,19 +53,27 @@ class TestModelCastBF16(unittest.TestCase):
with fluid.program_guard(prog, startup_prog):
yield
def get_static_graph_result(self, feed, fetch_list, amp_fun,
with_lod=False):
def get_static_graph_result(self,
feed,
fetch_list,
amp_fun,
with_lod=False,
startup_prog=None):
exe = fluid.Executor(core.CPUPlace())
exe.run(fluid.default_startup_program())
exe.run(fluid.default_startup_program()
if startup_prog is None else startup_prog)
prog = fluid.default_main_program()
if amp_fun is not None:
amp_fun(prog)
if startup_prog is not None:
amp_fun(prog, startup_prog)
else:
amp_fun(prog)
return exe.run(prog,
feed=feed,
fetch_list=fetch_list,
return_numpy=(not with_lod))
def _graph_common(self, _amp_fun):
def _graph_common(self, _amp_fun, startup_prog=None):
size = 3
n = np.ones([size, size], dtype='float32') * 3.2
nn = np.ones([size, size], dtype='float32') * -2.7
......@@ -122,7 +130,8 @@ class TestModelCastBF16(unittest.TestCase):
self.get_static_graph_result(
feed={'t': n, 'tt': nn},
fetch_list=[ret],
amp_fun=_amp_fun
amp_fun=_amp_fun,
startup_prog=startup_prog
)
self.assertTrue(
static_ret_bf16, np.ones(
......@@ -132,16 +141,17 @@ class TestModelCastBF16(unittest.TestCase):
self._graph_common(lambda prog: amp.bf16.rewrite_program_bf16(
prog,
amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_varnames={'elementwise_add_0.tmp_0'}),
custom_fp32_varnames={'elementwise_add_0.tmp_0'})
))
def test_graph_cast(self):
self._graph_common(lambda prog: amp.bf16.cast_model_to_bf16(
self._graph_common(lambda prog, startup_prog: amp.bf16.cast_model_to_bf16(
prog,
startup_prog,
amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_list={'elementwise_mul'}),
use_bf16_guard=True
))
), startup_prog=fluid.default_startup_program())
if __name__ == '__main__':
......
......@@ -231,13 +231,13 @@ def cast(x, dtype):
out = core.ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
return out
check_variable_and_dtype(
x, 'x',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
'cast')
check_variable_and_dtype(x, 'x', [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8',
'uint16'
], 'cast')
check_dtype(dtype, 'dtype', [
'bool', 'float16', 'float32', 'float64', 'int8', 'int32', 'int64',
'uint8'
'uint8', 'uint16'
], 'cast')
helper = LayerHelper('cast', **locals())
......
......@@ -56,7 +56,8 @@ def train(use_cuda, save_dirname, is_local, use_bf16, pure_bf16):
amp_lists=amp.bf16.AutoMixedPrecisionListsBF16(),
use_bf16_guard=False,
use_pure_bf16=pure_bf16)
sgd_optimizer.minimize(avg_cost)
sgd_optimizer.minimize(
avg_cost, startup_program=fluid.default_startup_program())
BATCH_SIZE = 20
......
......@@ -115,7 +115,7 @@ def train(target,
use_bf16_guard=False,
use_pure_bf16=pure_bf16)
sgd_optimizer.minimize(avg_cost)
sgd_optimizer.minimize(avg_cost, fluid.default_startup_program())
train_reader = paddle.batch(
paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册