diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index 29243c90e872ca4a7d1ce6f84f6297b865655da1..2946ef19678d0e4dfdeae8f5431973921b871436 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -199,6 +199,47 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): return op_descs +def _callback_lookup_(op): + """ + Only used in _append_backward_ops_ + Build and returns a callback function for certain op. For example + + parallel_do: AllReduce + + :param op: + :return: callback function + """ + print(op.type) + if op.type == 'parallel_do': + param_names = set(op.input('parameters')) + param_grad_names = [n + "@GRAD" for n in param_names] + + class ParallelDoCallBack(object): + def __init__(self, param_grad_names): + self.has_inserted_nccl_init = False + self.param_grad_names = param_grad_names + + def __call__(self, block, context): + # TODO(tonyyang-svail): insert nccl init + + for o_param in context.output_names(): + for o_argu in context.output(o_param): + if o_argu in self.param_grad_names: + print("reduce", o_argu) + op_desc = block.desc.append_op() + framework.Operator( + block, + type='fill_constant', + desc=op_desc, + inputs={}, + attrs={'shape': [1], }, + outputs={'Out': [block.create_var()]}) + + return ParallelDoCallBack(param_grad_names) + else: + return None + + def _append_backward_ops_(block, ops, target_block, @@ -239,7 +280,8 @@ def _append_backward_ops_(block, sub_block = program.block(op.block_attr("sub_block")) grad_sub_block = program.create_block(parent_idx=sub_block.idx) _append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, - no_grad_dict, grad_to_var) + no_grad_dict, grad_to_var, + _callback_lookup_(op)) grad_sub_block_list.append(grad_sub_block.desc) # Getting op's corresponding grad_op @@ -258,7 +300,7 @@ def _append_backward_ops_(block, for op_desc in grad_op_descs: new_op_desc = target_block.desc.append_op() new_op_desc.copy_from(op_desc) - callback(block=target_block, context=grad_to_var) + callback(block=target_block, context=new_op_desc) def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):