提交 f2129b19 编写于 作者: Y Yang Yang

pass run time

上级 e9ddaaba
......@@ -199,6 +199,15 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
return op_descs
import proto.framework_pb2 as framework_pb2
def serialize_op_decs(op_desc):
protostr = op_desc.serialize_to_string()
proto = framework_pb2.OpDesc.FromString(str(protostr))
return proto.__str__()
def _callback_lookup_(op):
"""
Only used in _append_backward_ops_
......@@ -209,7 +218,6 @@ def _callback_lookup_(op):
:param op:
:return: callback function
"""
print(op.type)
if op.type == 'parallel_do':
param_names = set(op.input('parameters'))
param_grad_names = [n + "@GRAD" for n in param_names]
......@@ -220,20 +228,38 @@ def _callback_lookup_(op):
self.param_grad_names = param_grad_names
def __call__(self, block, context):
# TODO(tonyyang-svail): insert nccl init
for o_param in context.output_names():
for o_argu in context.output(o_param):
# move to parallel_do.py
# # TODO(tonyyang-svail): insert nccl init
if not self.has_inserted_nccl_init:
global_block = block.program.global_block()
op_desc = global_block.desc.append_op()
var_desc = global_block.desc.var('nccl_com')
var_desc.set_type(core.VarDesc.VarType.NCCL_COM)
self.nccl_com = global_block.create_var(
name='nccl_com', type=core.VarDesc.VarType.NCCL_COM)
framework.Operator(
global_block,
type='ncclInit',
desc=op_desc,
inputs={},
outputs={'Communicator': [self.nccl_com]})
self.has_inserted_nccl_init = True
current_op_desc = context["__current_op_desc__"]
# print(serialize_op_decs(context))
for o_param in current_op_desc.output_names():
for o_argu in current_op_desc.output(o_param):
if o_argu in self.param_grad_names:
print("reduce", o_argu)
# print("reduce", o_argu)
op_desc = block.desc.append_op()
framework.Operator(
block,
type='fill_constant',
desc=op_desc,
inputs={},
attrs={'shape': [1], },
outputs={'Out': [block.create_var()]})
op_desc.set_type("ncclAllReduce")
op_desc.set_input("X", [o_argu])
# FIXME(tonyyang-svail):
# Looks like nccl_com has been changed to nccl_com_0
op_desc.set_input("Communicator", ['nccl_com_0'])
out_var = block.create_var()
op_desc.set_output("Out", [out_var.name])
op_desc.set_attr("reduction", "ncclSum")
return ParallelDoCallBack(param_grad_names)
else:
......@@ -300,7 +326,8 @@ def _append_backward_ops_(block,
for op_desc in grad_op_descs:
new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc)
callback(block=target_block, context=new_op_desc)
grad_to_var["__current_op_desc__"] = new_op_desc
callback(block=target_block, context=grad_to_var)
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
......@@ -336,6 +363,8 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
continue
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block)
# infer_shape and infer_type
if op_desc.type() == 'ncclInit':
continue
op_desc.infer_var_type(block.desc)
op_desc.infer_shape(block.desc)
for arg in op_desc.output_arg_names():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册