提交 1a0fc5d8 编写于 作者: F fengjiayi

Add the simple support of no_grad_set

上级 278ac7be
......@@ -314,7 +314,8 @@ All parameter, weight, gradient are variables in Paddle.
InferenceOptimize(*(origin.Proto()), &pruned_desc);
return new ProgramDescBind(pruned_desc);
});
m.def("get_empty_var_name", []() { return framework::kEmptyVarName; });
m.def("empty_var_name", []() { return framework::kEmptyVarName; });
m.def("grad_var_suffix", []() { return framework::kGradVarSuffix; });
m.def_submodule(
"var_names",
"The module will return special predefined variable name in Paddle")
......
......@@ -32,7 +32,22 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
return op_desc
def backward_impl(target,
def _is_all_in_set_(cands, s):
for c in cands:
if not c in s:
return False
return True
def _strip_grad_suffix_(name):
return name[:name.find(core.grad_var_suffix())]
def _append_grad_suffix_(name):
return name + core.grad_var_suffix()
def _backward_impl_(target,
block,
target_block,
no_grad_set,
......@@ -47,7 +62,7 @@ def backward_impl(target,
sub_block_idx = each_op.block_attr("sub_block")
sub_block = program.block(sub_block_idx)
grad_sub_block = program.create_block(parent_idx=sub_block_idx)
backward_impl(target, sub_block, grad_sub_block, no_grad_set,
_backward_impl_(target, sub_block, grad_sub_block, no_grad_set,
grad_info_map, callback)
grad_sub_block_list.append(grad_sub_block)
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
......@@ -61,14 +76,14 @@ def backward_impl(target,
pending_sum_ops = []
var_rename_count = collections.defaultdict(int)
var_inputs = collections.defaultdict(list)
for pos, op_desc in enumerate(grad_op_descs):
for idx, op_desc in enumerate(grad_op_descs):
for var_name in op_desc.input_arg_names():
if len(var_inputs[var_name]) > 1:
pending_sum_ops.append((_create_op_desc_(
op_type="sum_op",
inputs=var_inputs[var_name],
outputs=[var_name],
attrs={}), pos))
attrs={}), idx))
var_inputs[var_name] = [var_name]
for var_name in op_desc.output_arg_names():
if len(var_inputs[var_name]) == 0:
......@@ -81,7 +96,7 @@ def backward_impl(target,
var_rename_count[var_name] = var_rename_count[var_name] + 1
# rename original var_name
var_inputs[var_name][0] = new_name
_rename_arg_(grad_op_descs, var_name, new_name, 0, pos)
_rename_arg_(grad_op_descs, var_name, new_name, 0, idx)
_rename_arg_(pending_sum_ops, var_name, new_name)
new_name = var_name + "@RENAME@" + \
......@@ -96,18 +111,31 @@ def backward_impl(target,
inputs={"X": inputs},
outputs={"Out": var_name},
attrs={}), len(grad_op_descs)))
# TODO: remove op in no grad set
# 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的
for p in reversed(pending_sum_ops):
grad_op_descs.insert(p[1], p[0])
# Remove ops whose outputs are all in no_grad_set
grad_op_descs = filter(
lambda op_desc: not _is_all_in_set_(op_desc.output_arg_names(), no_grad_set[block.idx]),
grad_op_descs)
# Insert fill_zeros_like_op
to_insert = []
for idx, op_desc in enumerate(grad_op_descs):
for arg in op_desc.input_arg_names():
if arg in no_grad_set[block.idx]:
to_insert.append((arg, idx))
for ele in reversed(to_insert):
arg = ele[0]
fill_zeros_like_op = _create_op_desc_(
"fill_zeros_like", {"X": [_strip_grad_suffix_(arg)]}, {"Y": [arg]},
{})
grad_op_descs.insert(ele[1], fill_zeros_like_op)
# create new gradient variables in the target block desc
for op_desc in grad_op_descs:
for grad_var_name in op_desc.output_arg_names():
grad_var_name = grad_var_name.encode("ascii")
if target_block.desc.has_var(
grad_var_name) or grad_var_name == core.get_empty_var_name(
):
grad_var_name) or grad_var_name == core.empty_var_name():
continue
target_block.desc.var(grad_var_name)
if not grad_to_var.has_key(grad_var_name):
......@@ -115,8 +143,8 @@ def backward_impl(target,
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name,
target_block)
if target_block.idx == 0:
grad_target_name = (target.name + "@GRAD")
target_block.desc.var(grad_target_name)
grad_target_name = _append_grad_suffix_(target.name)
target_block.desc.var(grad_target_name.encode("ascii"))
grad_op_descs.insert(
0,
_create_op_desc_(
......@@ -134,7 +162,6 @@ def backward_impl(target,
op_desc.infer_shape(target_block.desc)
target_block.desc.append_allocated_op(op_desc)
pdb.set_trace()
target_block.sync_with_cpp()
......@@ -165,14 +192,14 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
for var in block.vars.itervalues():
assert isinstance(var, framework.Variable)
if var.stop_gradient:
block_no_grad_set.add(var.name)
block_no_grad_set.add(_append_grad_suffix_(var.name))
no_grad_set[block.idx] = block_no_grad_set
grad_info_map = dict()
root_block = loss.block.program.block(0)
pdb.set_trace()
backward_impl(loss, root_block, root_block, no_grad_set, grad_info_map)
pdb.set_trace()
_backward_impl_(loss, root_block, root_block, no_grad_set, grad_info_map)
if parameter_list is not None:
parameters = parameter_list
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册