From 1a0fc5d8dcab7e3e28c0e3463e8b97e0d90b28b2 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 21 Dec 2017 15:43:47 +0800 Subject: [PATCH] Add the simple support of no_grad_set --- paddle/pybind/pybind.cc | 3 +- python/paddle/v2/fluid/backward.py | 71 +++++++++++++++++++++--------- 2 files changed, 51 insertions(+), 23 deletions(-) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index d84d5efbcfa..b453dfbf891 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -314,7 +314,8 @@ All parameter, weight, gradient are variables in Paddle. InferenceOptimize(*(origin.Proto()), &pruned_desc); return new ProgramDescBind(pruned_desc); }); - m.def("get_empty_var_name", []() { return framework::kEmptyVarName; }); + m.def("empty_var_name", []() { return framework::kEmptyVarName; }); + m.def("grad_var_suffix", []() { return framework::kGradVarSuffix; }); m.def_submodule( "var_names", "The module will return special predefined variable name in Paddle") diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index b24e124e1e6..df2761d8024 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -32,12 +32,27 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): return op_desc -def backward_impl(target, - block, - target_block, - no_grad_set, - grad_info_map, - callback=None): +def _is_all_in_set_(cands, s): + for c in cands: + if not c in s: + return False + return True + + +def _strip_grad_suffix_(name): + return name[:name.find(core.grad_var_suffix())] + + +def _append_grad_suffix_(name): + return name + core.grad_var_suffix() + + +def _backward_impl_(target, + block, + target_block, + no_grad_set, + grad_info_map, + callback=None): grad_op_descs = [] grad_to_var = dict() program = block.program @@ -47,8 +62,8 @@ def backward_impl(target, sub_block_idx = each_op.block_attr("sub_block") sub_block = program.block(sub_block_idx) grad_sub_block = program.create_block(parent_idx=sub_block_idx) - backward_impl(target, sub_block, grad_sub_block, no_grad_set, - grad_info_map, callback) + _backward_impl_(target, sub_block, grad_sub_block, no_grad_set, + grad_info_map, callback) grad_sub_block_list.append(grad_sub_block) grad_op_desc, op_grad_to_var = core.get_grad_op_desc( each_op.desc, no_grad_set[block.idx], grad_sub_block_list) @@ -61,14 +76,14 @@ def backward_impl(target, pending_sum_ops = [] var_rename_count = collections.defaultdict(int) var_inputs = collections.defaultdict(list) - for pos, op_desc in enumerate(grad_op_descs): + for idx, op_desc in enumerate(grad_op_descs): for var_name in op_desc.input_arg_names(): if len(var_inputs[var_name]) > 1: pending_sum_ops.append((_create_op_desc_( op_type="sum_op", inputs=var_inputs[var_name], outputs=[var_name], - attrs={}), pos)) + attrs={}), idx)) var_inputs[var_name] = [var_name] for var_name in op_desc.output_arg_names(): if len(var_inputs[var_name]) == 0: @@ -81,7 +96,7 @@ def backward_impl(target, var_rename_count[var_name] = var_rename_count[var_name] + 1 # rename original var_name var_inputs[var_name][0] = new_name - _rename_arg_(grad_op_descs, var_name, new_name, 0, pos) + _rename_arg_(grad_op_descs, var_name, new_name, 0, idx) _rename_arg_(pending_sum_ops, var_name, new_name) new_name = var_name + "@RENAME@" + \ @@ -96,18 +111,31 @@ def backward_impl(target, inputs={"X": inputs}, outputs={"Out": var_name}, attrs={}), len(grad_op_descs))) - # TODO: remove op in no grad set - # 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的 for p in reversed(pending_sum_ops): grad_op_descs.insert(p[1], p[0]) + # Remove ops whose outputs are all in no_grad_set + grad_op_descs = filter( + lambda op_desc: not _is_all_in_set_(op_desc.output_arg_names(), no_grad_set[block.idx]), + grad_op_descs) + # Insert fill_zeros_like_op + to_insert = [] + for idx, op_desc in enumerate(grad_op_descs): + for arg in op_desc.input_arg_names(): + if arg in no_grad_set[block.idx]: + to_insert.append((arg, idx)) + for ele in reversed(to_insert): + arg = ele[0] + fill_zeros_like_op = _create_op_desc_( + "fill_zeros_like", {"X": [_strip_grad_suffix_(arg)]}, {"Y": [arg]}, + {}) + grad_op_descs.insert(ele[1], fill_zeros_like_op) # create new gradient variables in the target block desc for op_desc in grad_op_descs: for grad_var_name in op_desc.output_arg_names(): grad_var_name = grad_var_name.encode("ascii") if target_block.desc.has_var( - grad_var_name) or grad_var_name == core.get_empty_var_name( - ): + grad_var_name) or grad_var_name == core.empty_var_name(): continue target_block.desc.var(grad_var_name) if not grad_to_var.has_key(grad_var_name): @@ -115,8 +143,8 @@ def backward_impl(target, grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, target_block) if target_block.idx == 0: - grad_target_name = (target.name + "@GRAD") - target_block.desc.var(grad_target_name) + grad_target_name = _append_grad_suffix_(target.name) + target_block.desc.var(grad_target_name.encode("ascii")) grad_op_descs.insert( 0, _create_op_desc_( @@ -134,7 +162,6 @@ def backward_impl(target, op_desc.infer_shape(target_block.desc) target_block.desc.append_allocated_op(op_desc) - pdb.set_trace() target_block.sync_with_cpp() @@ -165,14 +192,14 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None): for var in block.vars.itervalues(): assert isinstance(var, framework.Variable) if var.stop_gradient: - block_no_grad_set.add(var.name) + block_no_grad_set.add(_append_grad_suffix_(var.name)) no_grad_set[block.idx] = block_no_grad_set grad_info_map = dict() root_block = loss.block.program.block(0) - pdb.set_trace() - backward_impl(loss, root_block, root_block, no_grad_set, grad_info_map) - pdb.set_trace() + + _backward_impl_(loss, root_block, root_block, no_grad_set, grad_info_map) + if parameter_list is not None: parameters = parameter_list else: -- GitLab