提交 1c91574b 编写于 作者: Y Yang Yang

backward insert callback pass compile

上级 36da5295
...@@ -199,6 +199,47 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): ...@@ -199,6 +199,47 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
return op_descs return op_descs
def _callback_lookup_(op):
"""
Only used in _append_backward_ops_
Build and returns a callback function for certain op. For example
parallel_do: AllReduce
:param op:
:return: callback function
"""
print(op.type)
if op.type == 'parallel_do':
param_names = set(op.input('parameters'))
param_grad_names = [n + "@GRAD" for n in param_names]
class ParallelDoCallBack(object):
def __init__(self, param_grad_names):
self.has_inserted_nccl_init = False
self.param_grad_names = param_grad_names
def __call__(self, block, context):
# TODO(tonyyang-svail): insert nccl init
for o_param in context.output_names():
for o_argu in context.output(o_param):
if o_argu in self.param_grad_names:
print("reduce", o_argu)
op_desc = block.desc.append_op()
framework.Operator(
block,
type='fill_constant',
desc=op_desc,
inputs={},
attrs={'shape': [1], },
outputs={'Out': [block.create_var()]})
return ParallelDoCallBack(param_grad_names)
else:
return None
def _append_backward_ops_(block, def _append_backward_ops_(block,
ops, ops,
target_block, target_block,
...@@ -239,7 +280,8 @@ def _append_backward_ops_(block, ...@@ -239,7 +280,8 @@ def _append_backward_ops_(block,
sub_block = program.block(op.block_attr("sub_block")) sub_block = program.block(op.block_attr("sub_block"))
grad_sub_block = program.create_block(parent_idx=sub_block.idx) grad_sub_block = program.create_block(parent_idx=sub_block.idx)
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, _append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
no_grad_dict, grad_to_var) no_grad_dict, grad_to_var,
_callback_lookup_(op))
grad_sub_block_list.append(grad_sub_block.desc) grad_sub_block_list.append(grad_sub_block.desc)
# Getting op's corresponding grad_op # Getting op's corresponding grad_op
...@@ -258,7 +300,7 @@ def _append_backward_ops_(block, ...@@ -258,7 +300,7 @@ def _append_backward_ops_(block,
for op_desc in grad_op_descs: for op_desc in grad_op_descs:
new_op_desc = target_block.desc.append_op() new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc) new_op_desc.copy_from(op_desc)
callback(block=target_block, context=grad_to_var) callback(block=target_block, context=new_op_desc)
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册