提交 5b9dbbb9 编写于 作者: F fengjiayi

code clean

上级 febb7251
from paddle.v2.fluid import framework as framework
from . import core
import collections
import pdb
__all__ = ['append_backward']
......@@ -45,7 +44,7 @@ def _infer_var_data_type_(var_name, block):
grad_var.set_dtype(core.DataType.FP32)
def _is_all_in_set_(cands, s):
def _all_in_set_(cands, s):
for c in cands:
if not c in s:
return False
......@@ -61,112 +60,114 @@ def _append_grad_suffix_(name):
return name + core.grad_var_suffix()
def _append_backward_ops_(target,
block,
target_block,
no_grad_set,
callback=None):
grad_op_descs = []
grad_to_var = dict()
program = block.program
for each_op in reversed(block.ops):
grad_sub_block_list = []
if each_op.has_attr("sub_block"):
sub_block_idx = each_op.block_attr("sub_block")
sub_block = program.block(sub_block_idx)
grad_sub_block = program.create_block(parent_idx=sub_block_idx)
sub_grad_to_var = _append_backward_ops_(
target, sub_block, grad_sub_block, no_grad_set, callback)
grad_to_var = dict(grad_to_var, **sub_grad_to_var)
grad_sub_block_list.append(grad_sub_block.desc)
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
each_op.desc, no_grad_set[block.idx], grad_sub_block_list)
grad_op_descs.append(grad_op_desc)
grad_to_var = dict(grad_to_var, **op_grad_to_var)
# grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...]
# flatten grad_op_descs
grad_op_descs = [op for sublist in grad_op_descs for op in sublist] # ?????
def _addup_repetitive_outputs_(op_descs):
# In backward part, an variable my be the output of more than one ops.
# In this case, the variable should be the accumulation of all the outputs.
# We adopt adding `sum_op`s to implement the accumulate.
pending_sum_ops = []
var_rename_count = collections.defaultdict(int)
var_inputs = collections.defaultdict(list)
for idx, op_desc in enumerate(grad_op_descs):
renamed_vars = collections.defaultdict(list)
for idx, op_desc in enumerate(op_descs):
for var_name in op_desc.input_arg_names():
if len(var_inputs[var_name]) > 1:
pending_sum_ops.append((_create_op_desc_(
op_type="sum",
inputs={"X": var_inputs[var_name]},
outputs={"Out": [var_name]},
attrs={}), idx))
var_inputs[var_name] = [var_name]
if len(renamed_vars[var_name]) > 1:
pending_sum_ops.append(
(_create_op_desc_("sum", {"X": renamed_vars[var_name]},
{"Out": [var_name]}, {}), idx))
renamed_vars[var_name] = [var_name]
for var_name in op_desc.output_arg_names():
if var_name in op_desc.input_arg_names():
# in place operator
if var_name == core.empty_var_name(
) or var_name in op_desc.input_arg_names():
# empty variable or inplace op
continue
if var_name == core.empty_var_name() or len(var_inputs[
var_name]) == 0:
if len(renamed_vars[var_name]) == 0:
# it's the first time we get the variable
var_inputs[var_name] = [var_name]
renamed_vars[var_name] = [var_name]
else:
if len(var_inputs[var_name]) == 1:
if len(renamed_vars[var_name]) == 1:
new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name])
var_rename_count[var_name] = var_rename_count[var_name] + 1
var_rename_count[var_name] += 1
# rename original var_name
var_inputs[var_name][0] = new_name
_rename_arg_(grad_op_descs, var_name, new_name, 0, idx)
renamed_vars[var_name][0] = new_name
_rename_arg_(op_descs, var_name, new_name, 0, idx)
_rename_arg_(pending_sum_ops, var_name, new_name)
new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name])
var_rename_count[var_name] = var_rename_count[var_name] + 1
var_rename_count[var_name] += 1
op_desc.rename_output(var_name, new_name)
var_inputs[var_name].append(new_name)
for var_name, inputs in var_inputs.iteritems():
renamed_vars[var_name].append(new_name)
for var_name, inputs in renamed_vars.iteritems():
if len(inputs) > 1:
pending_sum_ops.append((_create_op_desc_(
op_type="sum",
inputs={"X": inputs},
outputs={"Out": [var_name]},
attrs={}), len(grad_op_descs)))
"sum", {"X": inputs}, {"Out": [var_name]}, {}), len(op_descs)))
# sum_op descs are sorted according to their insert position
for p in reversed(pending_sum_ops):
grad_op_descs.insert(p[1], p[0])
# Remove ops whose outputs are all in no_grad_set
grad_op_descs = filter(
lambda op_desc: not _is_all_in_set_(op_desc.output_arg_names(), no_grad_set[block.idx]),
grad_op_descs)
op_descs.insert(p[1], p[0])
return op_descs
def _remove_no_grad_branch_(op_descs, no_grad_set):
# Remove ops whose outputs are all in no_grad_dict
op_descs = filter(
lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set),
op_descs)
# Insert fill_zeros_like_op
to_insert = []
for idx, op_desc in enumerate(grad_op_descs):
for idx, op_desc in enumerate(op_descs):
for arg in op_desc.input_arg_names():
if core.grad_var_suffix() in arg and arg in no_grad_set[block.idx]:
to_insert.append((arg, idx))
for ele in reversed(to_insert):
arg = ele[0]
fill_zeros_like_op = _create_op_desc_(
"fill_zeros_like", {"X": [_strip_grad_suffix_(arg)]}, {"Y": [arg]},
{})
grad_op_descs.insert(ele[1], fill_zeros_like_op)
if core.grad_var_suffix() in arg and arg in no_grad_set:
to_insert.append((_create_op_desc_("fill_zeros_like", {
"X": [_strip_grad_suffix_(arg)]
}, {"Y": [arg]}, {}), idx))
map(lambda p: op_descs.insert(p[1], p[0]), reversed(to_insert))
return op_descs
def _append_backward_ops_(target,
block,
target_block,
no_grad_dict,
grad_to_var,
callback=None):
grad_op_descs = []
program = block.program
for op in reversed(block.ops):
grad_sub_block_list = []
# If the op has its own sub-block, deal with the sub-block first
if op.has_attr("sub_block"):
sub_block = program.block(op.block_attr("sub_block"))
grad_sub_block = program.create_block(parent_idx=sub_block.idx)
_append_backward_ops_(target, sub_block, grad_sub_block,
no_grad_dict, grad_to_var, callback)
grad_sub_block_list.append(grad_sub_block.desc)
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, no_grad_dict[block.idx], grad_sub_block_list)
grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var)
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs)
grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
no_grad_dict[block.idx])
if target_block.idx == 0:
grad_target_name = _append_grad_suffix_(target.name)
# target_block.desc.var(grad_target_name.encode("ascii"))
grad_op_descs.insert(
0,
_create_op_desc_(
op_type="fill_constant",
inputs={},
outputs={"Out": [grad_target_name]},
attrs={"shape": [1],
"value": 1.0,
"dtype": target.dtype}))
_create_op_desc_("fill_constant", {}, {
"Out": [_append_grad_suffix_(target.name)]
}, {"shape": [1],
"value": 1.0,
"dtype": target.dtype}))
# append op_desc in grad_op_descs to target_block
for op_desc in grad_op_descs:
new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc)
return grad_to_var
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
for op_idx in range(start_op_idx, block.desc.op_size()):
......@@ -194,15 +195,15 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
_infer_var_data_type_(arg, block)
def append_backward(loss, parameter_list=None, no_grad_set=None):
def append_backward(loss, parameter_list=None, no_grad_dict=None):
"""
Create and add gradient Operators in BlockDesc to compute
gradients of `loss` for parameters in parameter_list
:param loss: an variable generated by cost function.
:type loss: Variable
:param no_grad_set: variable that should not create gradient
:type no_grad_set: set
:param no_grad_dict: variable that should not create gradient
:type no_grad_dict: set
:param parameter_list: parameters that need to compute gradient and
update to optimize the lost.
:type: list
......@@ -212,8 +213,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None):
assert isinstance(loss, framework.Variable)
program = loss.block.program
if no_grad_set is None:
no_grad_set = dict()
if no_grad_dict is None:
no_grad_dict = dict()
assert isinstance(program, framework.Program)
for block in program.blocks:
assert isinstance(block, framework.Block)
......@@ -222,19 +223,21 @@ def append_backward(loss, parameter_list=None, no_grad_set=None):
assert isinstance(var, framework.Variable)
if var.stop_gradient:
block_no_grad_set.add(_append_grad_suffix_(var.name))
no_grad_set[block.idx] = block_no_grad_set
else:
# FIX ME
no_grad_set = {0: no_grad_set}
no_grad_dict[block.idx] = block_no_grad_set
elif isinstance(no_grad_dict, set):
no_grad_dict = {0: no_grad_dict}
grad_info_map = dict()
root_block = program.block(0)
fwd_op_num = root_block.desc.op_size()
current_block_idx = program.current_block_idx
grad_to_var = _append_backward_ops_(loss, root_block, root_block,
no_grad_set)
grad_to_var = dict()
_append_backward_ops_(loss, root_block, root_block, no_grad_dict,
grad_to_var)
_append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map)
program.current_block_idx = current_block_idx
program.sync_with_cpp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册