diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index e05750c5bd5a4e253c123e04b37613a5470034c6..b90949838ea5ccaae3111164214f15a8b5579e87 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -1,7 +1,6 @@ from paddle.v2.fluid import framework as framework from . import core import collections -import pdb __all__ = ['append_backward'] @@ -45,7 +44,7 @@ def _infer_var_data_type_(var_name, block): grad_var.set_dtype(core.DataType.FP32) -def _is_all_in_set_(cands, s): +def _all_in_set_(cands, s): for c in cands: if not c in s: return False @@ -61,112 +60,114 @@ def _append_grad_suffix_(name): return name + core.grad_var_suffix() -def _append_backward_ops_(target, - block, - target_block, - no_grad_set, - callback=None): - grad_op_descs = [] - grad_to_var = dict() - program = block.program - for each_op in reversed(block.ops): - grad_sub_block_list = [] - if each_op.has_attr("sub_block"): - sub_block_idx = each_op.block_attr("sub_block") - sub_block = program.block(sub_block_idx) - grad_sub_block = program.create_block(parent_idx=sub_block_idx) - sub_grad_to_var = _append_backward_ops_( - target, sub_block, grad_sub_block, no_grad_set, callback) - grad_to_var = dict(grad_to_var, **sub_grad_to_var) - grad_sub_block_list.append(grad_sub_block.desc) - grad_op_desc, op_grad_to_var = core.get_grad_op_desc( - each_op.desc, no_grad_set[block.idx], grad_sub_block_list) - grad_op_descs.append(grad_op_desc) - grad_to_var = dict(grad_to_var, **op_grad_to_var) - # grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...] - # flatten grad_op_descs - grad_op_descs = [op for sublist in grad_op_descs for op in sublist] # ????? - +def _addup_repetitive_outputs_(op_descs): + # In backward part, an variable my be the output of more than one ops. + # In this case, the variable should be the accumulation of all the outputs. + # We adopt adding `sum_op`s to implement the accumulate. pending_sum_ops = [] var_rename_count = collections.defaultdict(int) - var_inputs = collections.defaultdict(list) - for idx, op_desc in enumerate(grad_op_descs): + renamed_vars = collections.defaultdict(list) + for idx, op_desc in enumerate(op_descs): for var_name in op_desc.input_arg_names(): - if len(var_inputs[var_name]) > 1: - pending_sum_ops.append((_create_op_desc_( - op_type="sum", - inputs={"X": var_inputs[var_name]}, - outputs={"Out": [var_name]}, - attrs={}), idx)) - var_inputs[var_name] = [var_name] + if len(renamed_vars[var_name]) > 1: + pending_sum_ops.append( + (_create_op_desc_("sum", {"X": renamed_vars[var_name]}, + {"Out": [var_name]}, {}), idx)) + renamed_vars[var_name] = [var_name] for var_name in op_desc.output_arg_names(): - if var_name in op_desc.input_arg_names(): - # in place operator + if var_name == core.empty_var_name( + ) or var_name in op_desc.input_arg_names(): + # empty variable or inplace op continue - if var_name == core.empty_var_name() or len(var_inputs[ - var_name]) == 0: + if len(renamed_vars[var_name]) == 0: # it's the first time we get the variable - var_inputs[var_name] = [var_name] + renamed_vars[var_name] = [var_name] else: - if len(var_inputs[var_name]) == 1: + if len(renamed_vars[var_name]) == 1: new_name = var_name + "@RENAME@" + \ str(var_rename_count[var_name]) - var_rename_count[var_name] = var_rename_count[var_name] + 1 + var_rename_count[var_name] += 1 # rename original var_name - var_inputs[var_name][0] = new_name - _rename_arg_(grad_op_descs, var_name, new_name, 0, idx) + renamed_vars[var_name][0] = new_name + _rename_arg_(op_descs, var_name, new_name, 0, idx) _rename_arg_(pending_sum_ops, var_name, new_name) new_name = var_name + "@RENAME@" + \ str(var_rename_count[var_name]) - var_rename_count[var_name] = var_rename_count[var_name] + 1 + var_rename_count[var_name] += 1 op_desc.rename_output(var_name, new_name) - var_inputs[var_name].append(new_name) - for var_name, inputs in var_inputs.iteritems(): + renamed_vars[var_name].append(new_name) + for var_name, inputs in renamed_vars.iteritems(): if len(inputs) > 1: pending_sum_ops.append((_create_op_desc_( - op_type="sum", - inputs={"X": inputs}, - outputs={"Out": [var_name]}, - attrs={}), len(grad_op_descs))) + "sum", {"X": inputs}, {"Out": [var_name]}, {}), len(op_descs))) # sum_op descs are sorted according to their insert position for p in reversed(pending_sum_ops): - grad_op_descs.insert(p[1], p[0]) - # Remove ops whose outputs are all in no_grad_set - grad_op_descs = filter( - lambda op_desc: not _is_all_in_set_(op_desc.output_arg_names(), no_grad_set[block.idx]), - grad_op_descs) + op_descs.insert(p[1], p[0]) + + return op_descs + + +def _remove_no_grad_branch_(op_descs, no_grad_set): + # Remove ops whose outputs are all in no_grad_dict + op_descs = filter( + lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set), + op_descs) # Insert fill_zeros_like_op to_insert = [] - for idx, op_desc in enumerate(grad_op_descs): + for idx, op_desc in enumerate(op_descs): for arg in op_desc.input_arg_names(): - if core.grad_var_suffix() in arg and arg in no_grad_set[block.idx]: - to_insert.append((arg, idx)) - for ele in reversed(to_insert): - arg = ele[0] - fill_zeros_like_op = _create_op_desc_( - "fill_zeros_like", {"X": [_strip_grad_suffix_(arg)]}, {"Y": [arg]}, - {}) - grad_op_descs.insert(ele[1], fill_zeros_like_op) + if core.grad_var_suffix() in arg and arg in no_grad_set: + to_insert.append((_create_op_desc_("fill_zeros_like", { + "X": [_strip_grad_suffix_(arg)] + }, {"Y": [arg]}, {}), idx)) + + map(lambda p: op_descs.insert(p[1], p[0]), reversed(to_insert)) + + return op_descs + + +def _append_backward_ops_(target, + block, + target_block, + no_grad_dict, + grad_to_var, + callback=None): + grad_op_descs = [] + program = block.program + for op in reversed(block.ops): + grad_sub_block_list = [] + # If the op has its own sub-block, deal with the sub-block first + if op.has_attr("sub_block"): + sub_block = program.block(op.block_attr("sub_block")) + grad_sub_block = program.create_block(parent_idx=sub_block.idx) + _append_backward_ops_(target, sub_block, grad_sub_block, + no_grad_dict, grad_to_var, callback) + grad_sub_block_list.append(grad_sub_block.desc) + + grad_op_desc, op_grad_to_var = core.get_grad_op_desc( + op.desc, no_grad_dict[block.idx], grad_sub_block_list) + grad_op_descs.extend(grad_op_desc) + grad_to_var.update(op_grad_to_var) + + grad_op_descs = _addup_repetitive_outputs_(grad_op_descs) + + grad_op_descs = _remove_no_grad_branch_(grad_op_descs, + no_grad_dict[block.idx]) if target_block.idx == 0: - grad_target_name = _append_grad_suffix_(target.name) - # target_block.desc.var(grad_target_name.encode("ascii")) grad_op_descs.insert( 0, - _create_op_desc_( - op_type="fill_constant", - inputs={}, - outputs={"Out": [grad_target_name]}, - attrs={"shape": [1], - "value": 1.0, - "dtype": target.dtype})) + _create_op_desc_("fill_constant", {}, { + "Out": [_append_grad_suffix_(target.name)] + }, {"shape": [1], + "value": 1.0, + "dtype": target.dtype})) + # append op_desc in grad_op_descs to target_block for op_desc in grad_op_descs: new_op_desc = target_block.desc.append_op() new_op_desc.copy_from(op_desc) - return grad_to_var - def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): for op_idx in range(start_op_idx, block.desc.op_size()): @@ -194,15 +195,15 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): _infer_var_data_type_(arg, block) -def append_backward(loss, parameter_list=None, no_grad_set=None): +def append_backward(loss, parameter_list=None, no_grad_dict=None): """ Create and add gradient Operators in BlockDesc to compute gradients of `loss` for parameters in parameter_list :param loss: an variable generated by cost function. :type loss: Variable - :param no_grad_set: variable that should not create gradient - :type no_grad_set: set + :param no_grad_dict: variable that should not create gradient + :type no_grad_dict: set :param parameter_list: parameters that need to compute gradient and update to optimize the lost. :type: list @@ -212,8 +213,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None): assert isinstance(loss, framework.Variable) program = loss.block.program - if no_grad_set is None: - no_grad_set = dict() + if no_grad_dict is None: + no_grad_dict = dict() assert isinstance(program, framework.Program) for block in program.blocks: assert isinstance(block, framework.Block) @@ -222,19 +223,21 @@ def append_backward(loss, parameter_list=None, no_grad_set=None): assert isinstance(var, framework.Variable) if var.stop_gradient: block_no_grad_set.add(_append_grad_suffix_(var.name)) - no_grad_set[block.idx] = block_no_grad_set - else: - # FIX ME - no_grad_set = {0: no_grad_set} + no_grad_dict[block.idx] = block_no_grad_set + elif isinstance(no_grad_dict, set): + no_grad_dict = {0: no_grad_dict} grad_info_map = dict() root_block = program.block(0) fwd_op_num = root_block.desc.op_size() current_block_idx = program.current_block_idx - grad_to_var = _append_backward_ops_(loss, root_block, root_block, - no_grad_set) + grad_to_var = dict() + + _append_backward_ops_(loss, root_block, root_block, no_grad_dict, + grad_to_var) _append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map) + program.current_block_idx = current_block_idx program.sync_with_cpp()