From bb0427add03ce29b8013511f9cebf509e9de3585 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 28 Dec 2017 17:57:17 +0800 Subject: [PATCH] Add comments for functions in backward.py --- python/paddle/v2/fluid/backward.py | 77 ++++++++++++++++++++++++++---- 1 file changed, 67 insertions(+), 10 deletions(-) diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index 6966cc75804..b3c1bab298f 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -5,14 +5,17 @@ import collections __all__ = ['append_backward'] -def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None, - end_idx=None): +def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None): + """ + Traverse all ops in op_descs[begin_idx : end_idx], + if any op has inputs/outputs named "old_name", rename it as 'new_name' + """ if begin_idx is None: begin_idx = 0 if end_idx is None: - end_idx = len(op_desc_list) + end_idx = len(op_descs) for i in range(begin_idx, end_idx): - op_desc = op_desc_list[i] + op_desc = op_descs[i] if isinstance(op_desc, tuple): op_desc = op_desc[0] op_desc.rename_input(old_name, new_name) @@ -20,6 +23,9 @@ def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None, def _create_op_desc_(op_type, inputs, outputs, attrs): + """ + Create a C++ OpDesc object with specified inputs, outputs and attributes. + """ op_desc = core.OpDesc() op_desc.set_type(op_type) for para, args in inputs.iteritems(): @@ -34,9 +40,12 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): return op_desc -def _infer_var_data_type_(var_name, block): - grad_var = block.desc.find_var(var_name.encode("ascii")) - fwd_name = _strip_grad_suffix_(var_name.encode("ascii")) +def _infer_var_data_type_(grad_var_name, block): + """ + Infer the data type of given grad variable + """ + grad_var = block.desc.find_var(grad_var_name.encode("ascii")) + fwd_name = _strip_grad_suffix_(grad_var_name.encode("ascii")) if block.desc.has_var_recursive(fwd_name): fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii")) grad_var.set_dtype(fwd_var.dtype()) @@ -45,6 +54,9 @@ def _infer_var_data_type_(var_name, block): def _all_in_set_(cands, s): + """ + Test if all elements of 'cands' are in set 's' + """ for c in cands: if not c in s: return False @@ -52,18 +64,29 @@ def _all_in_set_(cands, s): def _strip_grad_suffix_(name): + """ + Strip the grad suffix from the given varibale name + e.g. x@GRAD ==> x + y@GRAD@RENAME@1 ==> y + """ pos = name.find(core.grad_var_suffix()) return name[:pos] if pos != -1 else name def _append_grad_suffix_(name): + """ + Append grad suffix to the given variable name + e.g. x ==> x@GRAD + """ return name + core.grad_var_suffix() def _addup_repetitive_outputs_(op_descs): - # In backward part, an variable my be the output of more than one ops. - # In this case, the variable should be the accumulation of all the outputs. - # We adopt adding `sum_op`s to implement the accumulate. + """ + In backward part, an variable may be the output of more than one ops. + In this case, the variable should be the accumulation of all the outputs. + `sum_op`s are added to implement the accumulate. + """ pending_sum_ops = [] var_rename_count = collections.defaultdict(int) renamed_vars = collections.defaultdict(list) @@ -109,6 +132,12 @@ def _addup_repetitive_outputs_(op_descs): def _remove_no_grad_branch_(op_descs, no_grad_set): + """ + Remove unnecessary grad ops + A grad op can be removed in two cases: + 1. all outputs of the grad op are in 'no_grad_set' + 2. (TODO) all grad inputs of the grad op are in 'no_grad_set' + """ # Remove ops whose outputs are all in no_grad_dict op_descs = filter( lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set), @@ -133,6 +162,20 @@ def _append_backward_ops_(target, no_grad_dict, grad_to_var, callback=None): + """ + Create all grad ops, and insert them into given block + + Args: + target(Variable): the target variable of forward pass + block(Block): the block where forward ops are + target_block(Block): the block which is going to hold new generated grad ops + no_grad_dict(dict): + key(int) block index + val(set) a set of varibale names. These varibales have no gradient + grad_to_var(dict)(output argument): + key(str): grad variable name + val(str): corresponding forward variable name + """ grad_op_descs = [] program = block.program for op in reversed(block.ops): @@ -170,6 +213,20 @@ def _append_backward_ops_(target, def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): + """ + Create new variables required by backward pass. + + Args: + block(Block): the block where new variables will be created + start_op_idx(int): Only variables required by ops in block.ops[start_op_idx : ] will be created + grad_to_var(dict): + key(str): grad variable name + val(str): corresponding forward variable name + In most cases, this dict is generated by _append_backward_ops_() + grad_info_map(dict)(output argument): + key(str): forward variable name + val(tuple): a tuple of (str, int), str is the corresponding grad name, int is the block index + """ for op_idx in range(start_op_idx, block.desc.op_size()): op_desc = block.desc.op(op_idx) if op_desc.has_attr("sub_block"): -- GitLab