提交 bb0427ad 编写于 作者: F fengjiayi

Add comments for functions in backward.py

上级 18311767
...@@ -5,14 +5,17 @@ import collections ...@@ -5,14 +5,17 @@ import collections
__all__ = ['append_backward'] __all__ = ['append_backward']
def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None, def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None):
end_idx=None): """
Traverse all ops in op_descs[begin_idx : end_idx],
if any op has inputs/outputs named "old_name", rename it as 'new_name'
"""
if begin_idx is None: if begin_idx is None:
begin_idx = 0 begin_idx = 0
if end_idx is None: if end_idx is None:
end_idx = len(op_desc_list) end_idx = len(op_descs)
for i in range(begin_idx, end_idx): for i in range(begin_idx, end_idx):
op_desc = op_desc_list[i] op_desc = op_descs[i]
if isinstance(op_desc, tuple): if isinstance(op_desc, tuple):
op_desc = op_desc[0] op_desc = op_desc[0]
op_desc.rename_input(old_name, new_name) op_desc.rename_input(old_name, new_name)
...@@ -20,6 +23,9 @@ def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None, ...@@ -20,6 +23,9 @@ def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None,
def _create_op_desc_(op_type, inputs, outputs, attrs): def _create_op_desc_(op_type, inputs, outputs, attrs):
"""
Create a C++ OpDesc object with specified inputs, outputs and attributes.
"""
op_desc = core.OpDesc() op_desc = core.OpDesc()
op_desc.set_type(op_type) op_desc.set_type(op_type)
for para, args in inputs.iteritems(): for para, args in inputs.iteritems():
...@@ -34,9 +40,12 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): ...@@ -34,9 +40,12 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
return op_desc return op_desc
def _infer_var_data_type_(var_name, block): def _infer_var_data_type_(grad_var_name, block):
grad_var = block.desc.find_var(var_name.encode("ascii")) """
fwd_name = _strip_grad_suffix_(var_name.encode("ascii")) Infer the data type of given grad variable
"""
grad_var = block.desc.find_var(grad_var_name.encode("ascii"))
fwd_name = _strip_grad_suffix_(grad_var_name.encode("ascii"))
if block.desc.has_var_recursive(fwd_name): if block.desc.has_var_recursive(fwd_name):
fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii")) fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii"))
grad_var.set_dtype(fwd_var.dtype()) grad_var.set_dtype(fwd_var.dtype())
...@@ -45,6 +54,9 @@ def _infer_var_data_type_(var_name, block): ...@@ -45,6 +54,9 @@ def _infer_var_data_type_(var_name, block):
def _all_in_set_(cands, s): def _all_in_set_(cands, s):
"""
Test if all elements of 'cands' are in set 's'
"""
for c in cands: for c in cands:
if not c in s: if not c in s:
return False return False
...@@ -52,18 +64,29 @@ def _all_in_set_(cands, s): ...@@ -52,18 +64,29 @@ def _all_in_set_(cands, s):
def _strip_grad_suffix_(name): def _strip_grad_suffix_(name):
"""
Strip the grad suffix from the given varibale name
e.g. x@GRAD ==> x
y@GRAD@RENAME@1 ==> y
"""
pos = name.find(core.grad_var_suffix()) pos = name.find(core.grad_var_suffix())
return name[:pos] if pos != -1 else name return name[:pos] if pos != -1 else name
def _append_grad_suffix_(name): def _append_grad_suffix_(name):
"""
Append grad suffix to the given variable name
e.g. x ==> x@GRAD
"""
return name + core.grad_var_suffix() return name + core.grad_var_suffix()
def _addup_repetitive_outputs_(op_descs): def _addup_repetitive_outputs_(op_descs):
# In backward part, an variable my be the output of more than one ops. """
# In this case, the variable should be the accumulation of all the outputs. In backward part, an variable may be the output of more than one ops.
# We adopt adding `sum_op`s to implement the accumulate. In this case, the variable should be the accumulation of all the outputs.
`sum_op`s are added to implement the accumulate.
"""
pending_sum_ops = [] pending_sum_ops = []
var_rename_count = collections.defaultdict(int) var_rename_count = collections.defaultdict(int)
renamed_vars = collections.defaultdict(list) renamed_vars = collections.defaultdict(list)
...@@ -109,6 +132,12 @@ def _addup_repetitive_outputs_(op_descs): ...@@ -109,6 +132,12 @@ def _addup_repetitive_outputs_(op_descs):
def _remove_no_grad_branch_(op_descs, no_grad_set): def _remove_no_grad_branch_(op_descs, no_grad_set):
"""
Remove unnecessary grad ops
A grad op can be removed in two cases:
1. all outputs of the grad op are in 'no_grad_set'
2. (TODO) all grad inputs of the grad op are in 'no_grad_set'
"""
# Remove ops whose outputs are all in no_grad_dict # Remove ops whose outputs are all in no_grad_dict
op_descs = filter( op_descs = filter(
lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set), lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set),
...@@ -133,6 +162,20 @@ def _append_backward_ops_(target, ...@@ -133,6 +162,20 @@ def _append_backward_ops_(target,
no_grad_dict, no_grad_dict,
grad_to_var, grad_to_var,
callback=None): callback=None):
"""
Create all grad ops, and insert them into given block
Args:
target(Variable): the target variable of forward pass
block(Block): the block where forward ops are
target_block(Block): the block which is going to hold new generated grad ops
no_grad_dict(dict):
key(int) block index
val(set) a set of varibale names. These varibales have no gradient
grad_to_var(dict)(output argument):
key(str): grad variable name
val(str): corresponding forward variable name
"""
grad_op_descs = [] grad_op_descs = []
program = block.program program = block.program
for op in reversed(block.ops): for op in reversed(block.ops):
...@@ -170,6 +213,20 @@ def _append_backward_ops_(target, ...@@ -170,6 +213,20 @@ def _append_backward_ops_(target,
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
"""
Create new variables required by backward pass.
Args:
block(Block): the block where new variables will be created
start_op_idx(int): Only variables required by ops in block.ops[start_op_idx : ] will be created
grad_to_var(dict):
key(str): grad variable name
val(str): corresponding forward variable name
In most cases, this dict is generated by _append_backward_ops_()
grad_info_map(dict)(output argument):
key(str): forward variable name
val(tuple): a tuple of (str, int), str is the corresponding grad name, int is the block index
"""
for op_idx in range(start_op_idx, block.desc.op_size()): for op_idx in range(start_op_idx, block.desc.op_size()):
op_desc = block.desc.op(op_idx) op_desc = block.desc.op(op_idx)
if op_desc.has_attr("sub_block"): if op_desc.has_attr("sub_block"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册