未验证 提交 435fc4f0 编写于 作者: L liym27 提交者: GitHub

[while grad]Support pruning op in find_op_path about while sub-block when...

[while grad]Support pruning op in find_op_path about while sub-block when appending backward (#25330)

Prune OPs which are not related with loss in while sub-block when constructing backward OP path.
上级 aaa7cbd5
......@@ -45,7 +45,7 @@ class ProgramStats(object):
input_names = []
for name in self.var_op_deps:
if len(self.var_op_deps[name]["var_as_output_ops"]) == 0 and \
len(self.var_op_deps[name]["var_as_input_ops"]) > 0:
len(self.var_op_deps[name]["var_as_input_ops"]) > 0:
if self.block.var(name).persistable:
continue
input_names.append(name)
......@@ -433,7 +433,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx):
] + arg_names[arg_idx:]
new_name = var_name + "@RENAME@block" + str(block_idx) + "@" + \
str(var_rename_count[var_name])
str(var_rename_count[var_name])
var_rename_count[var_name] += 1
arg_names[arg_idx] = new_name
op_desc.set_output(param_name, arg_names)
......@@ -611,7 +611,7 @@ def _find_not_need_ops(grad_op_descs, forward_ops, input_grad_names_set):
not_need_op_descs_set = set(not_need_op_descs)
grad_op_descs_set = set(grad_op_descs)
# If a backward computational graph is simply one sub-graph header, the
# not_need_op_descs will be whole graph, this IF clause avoids it.
# not_need_op_descs will be whole graph, this IF clause avoids it.
if grad_op_descs_set == not_need_op_descs_set:
return set()
return not_need_op_descs_set
......@@ -662,7 +662,7 @@ def _append_backward_ops_with_checkpoints_(
checkpoints_name = list(set(checkpoints_name))
local_block = block.program._create_block()
buffer_block = block.program._create_block()
# 0) deal with forward recomputing program descs
# 0) deal with forward recomputing program descs
program_stat = ProgramStats(block, ops)
program_stat.modify_forward_desc_for_recompute()
program_stat.build_stats()
......@@ -797,32 +797,51 @@ def _append_backward_ops_with_checkpoints_(
return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments
def _get_sub_block_path(sub_block, sub_block_op_desc, no_grad_set):
def _get_sub_block_path(sub_block,
sub_block_op_desc,
no_grad_set,
op_path_dict,
sub_block_target_names=None):
"""
Get output vars in subblock which will be assigned to parent block.
It is used to find the grad path in subblock
It is used to find the grad path in subblock.
Args:
sub_block(Block): The sub-block in which to get op path.
sub_block_op_desc: The op desc of the sub-block op such as 'while', 'conditional_block' and 'recurrent'.
no_grad_set(set): The set of no grad var name. no_grad_set will be changed.
op_path_dict(dict): op_path_dict will be changed.
key(int) block index
val(list) the op path of block(index)
sub_block_target_names(set): Target var names of sub-block.
Return:
The forward op path of sub-block corresponding to backward op.
"""
assert sub_block_op_desc.has_attr(
"sub_block") and sub_block.idx == sub_block_op_desc._block_attr_id(
"sub_block")
# TODO(huihuangzheng): add support for recurrent op and while op
if sub_block_op_desc.type == "conditional_block":
sub_outputs = []
sub_assign_to_out_ops = []
for var in sub_block_op_desc.output_arg_names:
assert isinstance(sub_block_target_names, (set, type(None)))
if sub_block_target_names is None:
sub_block_target_names = sub_block_op_desc.output_arg_names
# TODO(huihuangzheng): add support for recurrent op.
if sub_block_op_desc.type in ["conditional_block", "while"]:
# Step1: get the output vars in sub-block
sub_outputs = [
sub_block._var_recursive(var) for var in sub_block_target_names
]
for var in sub_block_target_names:
for op_desc in sub_block.ops:
if op_desc.type == "assign" and var in op_desc.output_arg_names:
sub_assign_to_out_ops.append(op_desc)
if var in op_desc.output_arg_names:
for name in op_desc.input_arg_names:
if sub_block.has_var(name):
sub_outputs.append(sub_block.var(name))
sub_outputs.append(sub_block._var_recursive(name))
# Step2: find op path of sub-block
is_while = sub_block_op_desc.type in ["while"]
sub_block_op_path = _find_op_path_(sub_block, sub_outputs, [],
no_grad_set)
# TODO better way than finding in list
for op_desc in sub_assign_to_out_ops:
if op_desc not in sub_block_op_path:
sub_block_op_path.append(op_desc)
no_grad_set, op_path_dict, is_while)
return sub_block_op_path
return sub_block.ops
......@@ -846,7 +865,8 @@ def _append_backward_ops_(block,
no_grad_dict,
grad_to_var,
callbacks=None,
input_grad_names_set=None):
input_grad_names_set=None,
op_path_dict=None):
"""
Create all grad ops, and insert them into given block
......@@ -864,6 +884,9 @@ def _append_backward_ops_(block,
input_grad_names_set(set): this set is used to store the gradients' name which is
generated by backward ops, and input_grad_names_set can help to prune the unnecessary
backward ops.
op_path_dict(dict): op_path_dict will be changed.
key(int) block index
val(list) the op path of block(index)
"""
if callbacks is not None:
assert (isinstance(callbacks, list))
......@@ -888,11 +911,10 @@ def _append_backward_ops_(block,
# see follwing comments for why set None here.
pre_input_grad_names_set = copy.copy(input_grad_names_set)
input_grad_names_set = None
sub_block_path = _get_sub_block_path(sub_block, op,
no_grad_dict[sub_block.idx])
sub_block_path = op_path_dict[op._block_attr_id("sub_block")]
_append_backward_ops_(sub_block, sub_block_path, grad_sub_block,
no_grad_dict, grad_to_var, callbacks,
input_grad_names_set)
input_grad_names_set, op_path_dict)
input_grad_names_set = pre_input_grad_names_set
program._rollback()
......@@ -1013,7 +1035,7 @@ def _find_parent_op_(sub_block):
"sub_block") == sub_block_id:
return op
# NOTE(paddle-dev): When optimizer is added in conditional block,
# NOTE(paddle-dev): When optimizer is added in conditional block,
# sub_block may not be found.
return None
......@@ -1072,7 +1094,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
if var != core.empty_var_name()
]
# If the outputs of grad op is empty, just remove it
# If the outputs of grad op is empty, just remove it
if not outputs:
ops_to_remove.append(op_idx)
continue
......@@ -1358,7 +1380,10 @@ def append_backward(loss,
block_no_grad_set = set(
map(_strip_grad_suffix_, no_grad_dict[block_idx]))
op_path = _find_op_path_(block, [loss], [], block_no_grad_set)
op_path_dict = dict()
op_path = _find_op_path_(block, [loss], [], block_no_grad_set,
op_path_dict)
no_grad_vars = _find_no_grad_vars(block, op_path, [loss],
block_no_grad_set)
......@@ -1371,13 +1396,13 @@ def append_backward(loss,
# For double backward, input_grad_names is used for filtering
# some non-used gradients op(s).
# Todo(liym27): need a better design.
# TODO(liym27): need a better design.
# not support double grad in control flow sub-block now.
if not is_in_control_flow:
if program._appending_grad_times > 1:
input_grad_names_set = set([_append_grad_suffix_(loss.name)])
# Todo: support _append_backward_ops_with_checkpoints_ in
# TODO: support _append_backward_ops_with_checkpoints_ in
# sub-block (control flow)
if checkpoints != None and \
isinstance(checkpoints, list) and \
......@@ -1400,7 +1425,8 @@ def append_backward(loss,
no_grad_dict,
grad_to_var,
callbacks,
input_grad_names_set=input_grad_names_set)
input_grad_names_set=input_grad_names_set,
op_path_dict=op_path_dict)
grad_info_map = dict()
......@@ -1508,13 +1534,14 @@ def _get_output_names(cur_block, targets):
"""
block = targets[0].block if targets else cur_block
prog = cur_block.program
if _is_ancestor_block(block, cur_block):
return set()
current_output_names = set([out.name for out in targets])
# if `cur_block` is an ancestor of `targets[0].block`, run while loop
# 1. If `targets` in cur_block or the ancestral block of `cur_block`
if block.idx == cur_block.idx or _is_ancestor_block(block, cur_block):
return current_output_names
# 2. If `cur_block` is an ancestor of `targets[0].block`, run while loop
prog = cur_block.program
while block.idx != cur_block.idx:
assert block.parent_idx != -1
parent_block = prog.block(block.parent_idx)
......@@ -1554,12 +1581,32 @@ def _find_no_grad_vars(block, op_path, targets, no_grad_set):
return set(no_grad_var)
def _find_op_path_(block, outputs, inputs, no_grad_set):
def _find_op_path_(block,
targets,
inputs,
no_grad_set,
op_path_dict=None,
is_while=False):
"""
no_grad_set will also be changed
It is used to find the grad path in `block`.
Args:
block(Block): The block in which to get op path.
targets(list[Variable]): The target variables.
inputs(list[Variable]): The input variables.
no_grad_set(set): The set of no grad var name. no_grad_set will be changed.
op_path_dict(dict): op_path_dict will be changed. op_path_dict will be changed.
key(int) block index
val(list) the op path of block(index)
is_while(bool): Whether or not `block` is while block
Return:
The forward op path of block corresponding to backward op.
"""
input_names = set([inp.name for inp in inputs])
output_names = _get_output_names(block, outputs)
output_names = _get_output_names(block, targets)
if op_path_dict is None:
op_path_dict = dict()
relevant_op_flags = [True] * len(block.ops)
......@@ -1576,6 +1623,15 @@ def _find_op_path_(block, outputs, inputs, no_grad_set):
relevant_op_flags[i] = False
for i, op in reversed(list(enumerate(block.ops))):
if op.has_attr("sub_block"):
sub_block_id = op._block_attr_id("sub_block")
sub_block = block.program.block(sub_block_id)
sub_block_target_names = output_names & set(op.output_arg_names)
sub_block_path = _get_sub_block_path(sub_block, op,
set(), op_path_dict,
sub_block_target_names)
op_path_dict[sub_block_id] = sub_block_path
if _some_in_set_(
op.desc.output_arg_names(),
output_names) and core.has_non_empty_grad_op_maker(op.type):
......@@ -1585,6 +1641,14 @@ def _find_op_path_(block, outputs, inputs, no_grad_set):
else:
relevant_op_flags[i] = False
if is_while:
# If block is while block, dealing with op specifically again.
# TODO(liym27): Consider special types of ops.
for i, op in reversed(list(enumerate(block.ops))):
if relevant_op_flags[i] == False \
and _some_in_set_(op.desc.output_arg_names(),output_names):
relevant_op_flags[i] = True
op_path = [
block.ops[i] for i in range(len(block.ops)) if relevant_op_flags[i]
]
......@@ -1688,7 +1752,10 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
raise "input must be in the same program as targets"
block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))
op_path = _find_op_path_(block, targets, inputs, block_no_grad_set)
op_path_dict = dict()
op_path = _find_op_path_(block, targets, inputs, block_no_grad_set,
op_path_dict)
no_grad_dict[0].update(list(map(_append_grad_suffix_, block_no_grad_set)))
grad_to_var = dict()
grad_info_map = dict()
......@@ -1698,7 +1765,8 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
block,
no_grad_dict,
grad_to_var,
input_grad_names_set=input_grad_names_set)
input_grad_names_set=input_grad_names_set,
op_path_dict=op_path_dict)
# Because calc_gradient may be called multiple times,
# we need rename the internal gradient variables so that they have
......
......@@ -199,10 +199,16 @@ class TestApiWhileLoop_Backward(unittest.TestCase):
def cond(i, x):
return layers.less_than(i, eleven)
def body(i, x):
def body(j, x):
# TODO: In while block, if the var created in parent block
# participates in the calculation of gradient, the result of gradient
# is incorrect because each step scope always returns the same value
# generated by last step.
# Here we call `assign` op in while block to avoid this bug, and working on fixing it in next PR.
i = layers.assign(j)
x = layers.elementwise_mul(x=i, y=i)
i = layers.increment(i)
return [i, x]
j = layers.increment(j)
return [j, x]
main_program = Program()
startup_program = Program()
......@@ -232,7 +238,48 @@ class TestApiWhileLoop_Backward(unittest.TestCase):
'x': feed_x},
fetch_list=[mean.name, i.grad_name])
self.assertTrue(np.allclose(np.asarray(res[0]), data))
self.assertTrue(np.allclose(np.asarray(res[1]), i_grad))
self.assertTrue(
np.allclose(np.asarray(res[1]), i_grad),
msg=" \nres = \n{} \n\n ans = \n{}".format(res[1], i_grad))
def test_while_loop_backward2(self):
def cond(i, x):
return i < 5
def body(i, x):
x = x + i
i = i + 1
return [i, x]
main_program = Program()
startup_program = Program()
with fluid.program_guard(main_program, startup_program):
i = fluid.data(name='i', shape=[1], dtype='float32')
i.stop_gradient = False
x = fluid.data(name='x', shape=[1], dtype='float32')
x.stop_gradient = False
out = layers.while_loop(cond, body, [i, x])
mean = layers.mean(out[1])
append_backward(mean)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
feed_i = np.ones(1).astype('float32')
feed_x = np.ones(1).astype('float32')
data = np.asarray([11]).astype('float32')
i_grad = np.asarray([1]).astype('float32')
res = exe.run(main_program,
feed={'i': feed_i,
'x': feed_x},
fetch_list=[mean.name, i.grad_name])
self.assertTrue(np.allclose(np.asarray(res[0]), data))
self.assertTrue(
np.allclose(np.asarray(res[1]), i_grad),
msg=" \nres = \n{} \n\n ans = \n{}".format(res[1], i_grad))
class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase):
......@@ -410,7 +457,7 @@ class TestApiWhileLoop_Error(unittest.TestCase):
ten = layers.fill_constant(shape=[1], dtype='int64', value=10)
ten_2d = layers.fill_constant(shape=[2, 2], dtype='int64', value=10)
# The type of `cond` in Op(while_loop) must be callable
# The type of `cond` in Op(while_loop) must be callable
def type_error_cond():
out = layers.while_loop(data, body, [data_1d])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册