提交 5b9dbbb9 编写于 作者: F fengjiayi

code clean

上级 febb7251
from paddle.v2.fluid import framework as framework from paddle.v2.fluid import framework as framework
from . import core from . import core
import collections import collections
import pdb
__all__ = ['append_backward'] __all__ = ['append_backward']
...@@ -45,7 +44,7 @@ def _infer_var_data_type_(var_name, block): ...@@ -45,7 +44,7 @@ def _infer_var_data_type_(var_name, block):
grad_var.set_dtype(core.DataType.FP32) grad_var.set_dtype(core.DataType.FP32)
def _is_all_in_set_(cands, s): def _all_in_set_(cands, s):
for c in cands: for c in cands:
if not c in s: if not c in s:
return False return False
...@@ -61,112 +60,114 @@ def _append_grad_suffix_(name): ...@@ -61,112 +60,114 @@ def _append_grad_suffix_(name):
return name + core.grad_var_suffix() return name + core.grad_var_suffix()
def _append_backward_ops_(target, def _addup_repetitive_outputs_(op_descs):
block, # In backward part, an variable my be the output of more than one ops.
target_block, # In this case, the variable should be the accumulation of all the outputs.
no_grad_set, # We adopt adding `sum_op`s to implement the accumulate.
callback=None):
grad_op_descs = []
grad_to_var = dict()
program = block.program
for each_op in reversed(block.ops):
grad_sub_block_list = []
if each_op.has_attr("sub_block"):
sub_block_idx = each_op.block_attr("sub_block")
sub_block = program.block(sub_block_idx)
grad_sub_block = program.create_block(parent_idx=sub_block_idx)
sub_grad_to_var = _append_backward_ops_(
target, sub_block, grad_sub_block, no_grad_set, callback)
grad_to_var = dict(grad_to_var, **sub_grad_to_var)
grad_sub_block_list.append(grad_sub_block.desc)
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
each_op.desc, no_grad_set[block.idx], grad_sub_block_list)
grad_op_descs.append(grad_op_desc)
grad_to_var = dict(grad_to_var, **op_grad_to_var)
# grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...]
# flatten grad_op_descs
grad_op_descs = [op for sublist in grad_op_descs for op in sublist] # ?????
pending_sum_ops = [] pending_sum_ops = []
var_rename_count = collections.defaultdict(int) var_rename_count = collections.defaultdict(int)
var_inputs = collections.defaultdict(list) renamed_vars = collections.defaultdict(list)
for idx, op_desc in enumerate(grad_op_descs): for idx, op_desc in enumerate(op_descs):
for var_name in op_desc.input_arg_names(): for var_name in op_desc.input_arg_names():
if len(var_inputs[var_name]) > 1: if len(renamed_vars[var_name]) > 1:
pending_sum_ops.append((_create_op_desc_( pending_sum_ops.append(
op_type="sum", (_create_op_desc_("sum", {"X": renamed_vars[var_name]},
inputs={"X": var_inputs[var_name]}, {"Out": [var_name]}, {}), idx))
outputs={"Out": [var_name]}, renamed_vars[var_name] = [var_name]
attrs={}), idx))
var_inputs[var_name] = [var_name]
for var_name in op_desc.output_arg_names(): for var_name in op_desc.output_arg_names():
if var_name in op_desc.input_arg_names(): if var_name == core.empty_var_name(
# in place operator ) or var_name in op_desc.input_arg_names():
# empty variable or inplace op
continue continue
if var_name == core.empty_var_name() or len(var_inputs[ if len(renamed_vars[var_name]) == 0:
var_name]) == 0:
# it's the first time we get the variable # it's the first time we get the variable
var_inputs[var_name] = [var_name] renamed_vars[var_name] = [var_name]
else: else:
if len(var_inputs[var_name]) == 1: if len(renamed_vars[var_name]) == 1:
new_name = var_name + "@RENAME@" + \ new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name]) str(var_rename_count[var_name])
var_rename_count[var_name] = var_rename_count[var_name] + 1 var_rename_count[var_name] += 1
# rename original var_name # rename original var_name
var_inputs[var_name][0] = new_name renamed_vars[var_name][0] = new_name
_rename_arg_(grad_op_descs, var_name, new_name, 0, idx) _rename_arg_(op_descs, var_name, new_name, 0, idx)
_rename_arg_(pending_sum_ops, var_name, new_name) _rename_arg_(pending_sum_ops, var_name, new_name)
new_name = var_name + "@RENAME@" + \ new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name]) str(var_rename_count[var_name])
var_rename_count[var_name] = var_rename_count[var_name] + 1 var_rename_count[var_name] += 1
op_desc.rename_output(var_name, new_name) op_desc.rename_output(var_name, new_name)
var_inputs[var_name].append(new_name) renamed_vars[var_name].append(new_name)
for var_name, inputs in var_inputs.iteritems(): for var_name, inputs in renamed_vars.iteritems():
if len(inputs) > 1: if len(inputs) > 1:
pending_sum_ops.append((_create_op_desc_( pending_sum_ops.append((_create_op_desc_(
op_type="sum", "sum", {"X": inputs}, {"Out": [var_name]}, {}), len(op_descs)))
inputs={"X": inputs},
outputs={"Out": [var_name]},
attrs={}), len(grad_op_descs)))
# sum_op descs are sorted according to their insert position # sum_op descs are sorted according to their insert position
for p in reversed(pending_sum_ops): for p in reversed(pending_sum_ops):
grad_op_descs.insert(p[1], p[0]) op_descs.insert(p[1], p[0])
# Remove ops whose outputs are all in no_grad_set
grad_op_descs = filter( return op_descs
lambda op_desc: not _is_all_in_set_(op_desc.output_arg_names(), no_grad_set[block.idx]),
grad_op_descs)
def _remove_no_grad_branch_(op_descs, no_grad_set):
# Remove ops whose outputs are all in no_grad_dict
op_descs = filter(
lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set),
op_descs)
# Insert fill_zeros_like_op # Insert fill_zeros_like_op
to_insert = [] to_insert = []
for idx, op_desc in enumerate(grad_op_descs): for idx, op_desc in enumerate(op_descs):
for arg in op_desc.input_arg_names(): for arg in op_desc.input_arg_names():
if core.grad_var_suffix() in arg and arg in no_grad_set[block.idx]: if core.grad_var_suffix() in arg and arg in no_grad_set:
to_insert.append((arg, idx)) to_insert.append((_create_op_desc_("fill_zeros_like", {
for ele in reversed(to_insert): "X": [_strip_grad_suffix_(arg)]
arg = ele[0] }, {"Y": [arg]}, {}), idx))
fill_zeros_like_op = _create_op_desc_(
"fill_zeros_like", {"X": [_strip_grad_suffix_(arg)]}, {"Y": [arg]}, map(lambda p: op_descs.insert(p[1], p[0]), reversed(to_insert))
{})
grad_op_descs.insert(ele[1], fill_zeros_like_op) return op_descs
def _append_backward_ops_(target,
block,
target_block,
no_grad_dict,
grad_to_var,
callback=None):
grad_op_descs = []
program = block.program
for op in reversed(block.ops):
grad_sub_block_list = []
# If the op has its own sub-block, deal with the sub-block first
if op.has_attr("sub_block"):
sub_block = program.block(op.block_attr("sub_block"))
grad_sub_block = program.create_block(parent_idx=sub_block.idx)
_append_backward_ops_(target, sub_block, grad_sub_block,
no_grad_dict, grad_to_var, callback)
grad_sub_block_list.append(grad_sub_block.desc)
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, no_grad_dict[block.idx], grad_sub_block_list)
grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var)
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs)
grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
no_grad_dict[block.idx])
if target_block.idx == 0: if target_block.idx == 0:
grad_target_name = _append_grad_suffix_(target.name)
# target_block.desc.var(grad_target_name.encode("ascii"))
grad_op_descs.insert( grad_op_descs.insert(
0, 0,
_create_op_desc_( _create_op_desc_("fill_constant", {}, {
op_type="fill_constant", "Out": [_append_grad_suffix_(target.name)]
inputs={}, }, {"shape": [1],
outputs={"Out": [grad_target_name]}, "value": 1.0,
attrs={"shape": [1], "dtype": target.dtype}))
"value": 1.0, # append op_desc in grad_op_descs to target_block
"dtype": target.dtype}))
for op_desc in grad_op_descs: for op_desc in grad_op_descs:
new_op_desc = target_block.desc.append_op() new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc) new_op_desc.copy_from(op_desc)
return grad_to_var
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
for op_idx in range(start_op_idx, block.desc.op_size()): for op_idx in range(start_op_idx, block.desc.op_size()):
...@@ -194,15 +195,15 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): ...@@ -194,15 +195,15 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
_infer_var_data_type_(arg, block) _infer_var_data_type_(arg, block)
def append_backward(loss, parameter_list=None, no_grad_set=None): def append_backward(loss, parameter_list=None, no_grad_dict=None):
""" """
Create and add gradient Operators in BlockDesc to compute Create and add gradient Operators in BlockDesc to compute
gradients of `loss` for parameters in parameter_list gradients of `loss` for parameters in parameter_list
:param loss: an variable generated by cost function. :param loss: an variable generated by cost function.
:type loss: Variable :type loss: Variable
:param no_grad_set: variable that should not create gradient :param no_grad_dict: variable that should not create gradient
:type no_grad_set: set :type no_grad_dict: set
:param parameter_list: parameters that need to compute gradient and :param parameter_list: parameters that need to compute gradient and
update to optimize the lost. update to optimize the lost.
:type: list :type: list
...@@ -212,8 +213,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None): ...@@ -212,8 +213,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None):
assert isinstance(loss, framework.Variable) assert isinstance(loss, framework.Variable)
program = loss.block.program program = loss.block.program
if no_grad_set is None: if no_grad_dict is None:
no_grad_set = dict() no_grad_dict = dict()
assert isinstance(program, framework.Program) assert isinstance(program, framework.Program)
for block in program.blocks: for block in program.blocks:
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
...@@ -222,19 +223,21 @@ def append_backward(loss, parameter_list=None, no_grad_set=None): ...@@ -222,19 +223,21 @@ def append_backward(loss, parameter_list=None, no_grad_set=None):
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
if var.stop_gradient: if var.stop_gradient:
block_no_grad_set.add(_append_grad_suffix_(var.name)) block_no_grad_set.add(_append_grad_suffix_(var.name))
no_grad_set[block.idx] = block_no_grad_set no_grad_dict[block.idx] = block_no_grad_set
else: elif isinstance(no_grad_dict, set):
# FIX ME no_grad_dict = {0: no_grad_dict}
no_grad_set = {0: no_grad_set}
grad_info_map = dict() grad_info_map = dict()
root_block = program.block(0) root_block = program.block(0)
fwd_op_num = root_block.desc.op_size() fwd_op_num = root_block.desc.op_size()
current_block_idx = program.current_block_idx current_block_idx = program.current_block_idx
grad_to_var = _append_backward_ops_(loss, root_block, root_block, grad_to_var = dict()
no_grad_set)
_append_backward_ops_(loss, root_block, root_block, no_grad_dict,
grad_to_var)
_append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map) _append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map)
program.current_block_idx = current_block_idx program.current_block_idx = current_block_idx
program.sync_with_cpp() program.sync_with_cpp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册