提交 1a0fc5d8 编写于 作者: F fengjiayi

Add the simple support of no_grad_set

上级 278ac7be
...@@ -314,7 +314,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -314,7 +314,8 @@ All parameter, weight, gradient are variables in Paddle.
InferenceOptimize(*(origin.Proto()), &pruned_desc); InferenceOptimize(*(origin.Proto()), &pruned_desc);
return new ProgramDescBind(pruned_desc); return new ProgramDescBind(pruned_desc);
}); });
m.def("get_empty_var_name", []() { return framework::kEmptyVarName; }); m.def("empty_var_name", []() { return framework::kEmptyVarName; });
m.def("grad_var_suffix", []() { return framework::kGradVarSuffix; });
m.def_submodule( m.def_submodule(
"var_names", "var_names",
"The module will return special predefined variable name in Paddle") "The module will return special predefined variable name in Paddle")
......
...@@ -32,12 +32,27 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): ...@@ -32,12 +32,27 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
return op_desc return op_desc
def backward_impl(target, def _is_all_in_set_(cands, s):
block, for c in cands:
target_block, if not c in s:
no_grad_set, return False
grad_info_map, return True
callback=None):
def _strip_grad_suffix_(name):
return name[:name.find(core.grad_var_suffix())]
def _append_grad_suffix_(name):
return name + core.grad_var_suffix()
def _backward_impl_(target,
block,
target_block,
no_grad_set,
grad_info_map,
callback=None):
grad_op_descs = [] grad_op_descs = []
grad_to_var = dict() grad_to_var = dict()
program = block.program program = block.program
...@@ -47,8 +62,8 @@ def backward_impl(target, ...@@ -47,8 +62,8 @@ def backward_impl(target,
sub_block_idx = each_op.block_attr("sub_block") sub_block_idx = each_op.block_attr("sub_block")
sub_block = program.block(sub_block_idx) sub_block = program.block(sub_block_idx)
grad_sub_block = program.create_block(parent_idx=sub_block_idx) grad_sub_block = program.create_block(parent_idx=sub_block_idx)
backward_impl(target, sub_block, grad_sub_block, no_grad_set, _backward_impl_(target, sub_block, grad_sub_block, no_grad_set,
grad_info_map, callback) grad_info_map, callback)
grad_sub_block_list.append(grad_sub_block) grad_sub_block_list.append(grad_sub_block)
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
each_op.desc, no_grad_set[block.idx], grad_sub_block_list) each_op.desc, no_grad_set[block.idx], grad_sub_block_list)
...@@ -61,14 +76,14 @@ def backward_impl(target, ...@@ -61,14 +76,14 @@ def backward_impl(target,
pending_sum_ops = [] pending_sum_ops = []
var_rename_count = collections.defaultdict(int) var_rename_count = collections.defaultdict(int)
var_inputs = collections.defaultdict(list) var_inputs = collections.defaultdict(list)
for pos, op_desc in enumerate(grad_op_descs): for idx, op_desc in enumerate(grad_op_descs):
for var_name in op_desc.input_arg_names(): for var_name in op_desc.input_arg_names():
if len(var_inputs[var_name]) > 1: if len(var_inputs[var_name]) > 1:
pending_sum_ops.append((_create_op_desc_( pending_sum_ops.append((_create_op_desc_(
op_type="sum_op", op_type="sum_op",
inputs=var_inputs[var_name], inputs=var_inputs[var_name],
outputs=[var_name], outputs=[var_name],
attrs={}), pos)) attrs={}), idx))
var_inputs[var_name] = [var_name] var_inputs[var_name] = [var_name]
for var_name in op_desc.output_arg_names(): for var_name in op_desc.output_arg_names():
if len(var_inputs[var_name]) == 0: if len(var_inputs[var_name]) == 0:
...@@ -81,7 +96,7 @@ def backward_impl(target, ...@@ -81,7 +96,7 @@ def backward_impl(target,
var_rename_count[var_name] = var_rename_count[var_name] + 1 var_rename_count[var_name] = var_rename_count[var_name] + 1
# rename original var_name # rename original var_name
var_inputs[var_name][0] = new_name var_inputs[var_name][0] = new_name
_rename_arg_(grad_op_descs, var_name, new_name, 0, pos) _rename_arg_(grad_op_descs, var_name, new_name, 0, idx)
_rename_arg_(pending_sum_ops, var_name, new_name) _rename_arg_(pending_sum_ops, var_name, new_name)
new_name = var_name + "@RENAME@" + \ new_name = var_name + "@RENAME@" + \
...@@ -96,18 +111,31 @@ def backward_impl(target, ...@@ -96,18 +111,31 @@ def backward_impl(target,
inputs={"X": inputs}, inputs={"X": inputs},
outputs={"Out": var_name}, outputs={"Out": var_name},
attrs={}), len(grad_op_descs))) attrs={}), len(grad_op_descs)))
# TODO: remove op in no grad set
# 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的 # 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的
for p in reversed(pending_sum_ops): for p in reversed(pending_sum_ops):
grad_op_descs.insert(p[1], p[0]) grad_op_descs.insert(p[1], p[0])
# Remove ops whose outputs are all in no_grad_set
grad_op_descs = filter(
lambda op_desc: not _is_all_in_set_(op_desc.output_arg_names(), no_grad_set[block.idx]),
grad_op_descs)
# Insert fill_zeros_like_op
to_insert = []
for idx, op_desc in enumerate(grad_op_descs):
for arg in op_desc.input_arg_names():
if arg in no_grad_set[block.idx]:
to_insert.append((arg, idx))
for ele in reversed(to_insert):
arg = ele[0]
fill_zeros_like_op = _create_op_desc_(
"fill_zeros_like", {"X": [_strip_grad_suffix_(arg)]}, {"Y": [arg]},
{})
grad_op_descs.insert(ele[1], fill_zeros_like_op)
# create new gradient variables in the target block desc # create new gradient variables in the target block desc
for op_desc in grad_op_descs: for op_desc in grad_op_descs:
for grad_var_name in op_desc.output_arg_names(): for grad_var_name in op_desc.output_arg_names():
grad_var_name = grad_var_name.encode("ascii") grad_var_name = grad_var_name.encode("ascii")
if target_block.desc.has_var( if target_block.desc.has_var(
grad_var_name) or grad_var_name == core.get_empty_var_name( grad_var_name) or grad_var_name == core.empty_var_name():
):
continue continue
target_block.desc.var(grad_var_name) target_block.desc.var(grad_var_name)
if not grad_to_var.has_key(grad_var_name): if not grad_to_var.has_key(grad_var_name):
...@@ -115,8 +143,8 @@ def backward_impl(target, ...@@ -115,8 +143,8 @@ def backward_impl(target,
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name,
target_block) target_block)
if target_block.idx == 0: if target_block.idx == 0:
grad_target_name = (target.name + "@GRAD") grad_target_name = _append_grad_suffix_(target.name)
target_block.desc.var(grad_target_name) target_block.desc.var(grad_target_name.encode("ascii"))
grad_op_descs.insert( grad_op_descs.insert(
0, 0,
_create_op_desc_( _create_op_desc_(
...@@ -134,7 +162,6 @@ def backward_impl(target, ...@@ -134,7 +162,6 @@ def backward_impl(target,
op_desc.infer_shape(target_block.desc) op_desc.infer_shape(target_block.desc)
target_block.desc.append_allocated_op(op_desc) target_block.desc.append_allocated_op(op_desc)
pdb.set_trace()
target_block.sync_with_cpp() target_block.sync_with_cpp()
...@@ -165,14 +192,14 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None): ...@@ -165,14 +192,14 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
for var in block.vars.itervalues(): for var in block.vars.itervalues():
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
if var.stop_gradient: if var.stop_gradient:
block_no_grad_set.add(var.name) block_no_grad_set.add(_append_grad_suffix_(var.name))
no_grad_set[block.idx] = block_no_grad_set no_grad_set[block.idx] = block_no_grad_set
grad_info_map = dict() grad_info_map = dict()
root_block = loss.block.program.block(0) root_block = loss.block.program.block(0)
pdb.set_trace()
backward_impl(loss, root_block, root_block, no_grad_set, grad_info_map) _backward_impl_(loss, root_block, root_block, no_grad_set, grad_info_map)
pdb.set_trace()
if parameter_list is not None: if parameter_list is not None:
parameters = parameter_list parameters = parameter_list
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册