提交 f2129b19 编写于 作者: Y Yang Yang

pass run time

上级 e9ddaaba
...@@ -199,6 +199,15 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): ...@@ -199,6 +199,15 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
return op_descs return op_descs
import proto.framework_pb2 as framework_pb2
def serialize_op_decs(op_desc):
protostr = op_desc.serialize_to_string()
proto = framework_pb2.OpDesc.FromString(str(protostr))
return proto.__str__()
def _callback_lookup_(op): def _callback_lookup_(op):
""" """
Only used in _append_backward_ops_ Only used in _append_backward_ops_
...@@ -209,7 +218,6 @@ def _callback_lookup_(op): ...@@ -209,7 +218,6 @@ def _callback_lookup_(op):
:param op: :param op:
:return: callback function :return: callback function
""" """
print(op.type)
if op.type == 'parallel_do': if op.type == 'parallel_do':
param_names = set(op.input('parameters')) param_names = set(op.input('parameters'))
param_grad_names = [n + "@GRAD" for n in param_names] param_grad_names = [n + "@GRAD" for n in param_names]
...@@ -220,20 +228,38 @@ def _callback_lookup_(op): ...@@ -220,20 +228,38 @@ def _callback_lookup_(op):
self.param_grad_names = param_grad_names self.param_grad_names = param_grad_names
def __call__(self, block, context): def __call__(self, block, context):
# TODO(tonyyang-svail): insert nccl init # move to parallel_do.py
# # TODO(tonyyang-svail): insert nccl init
for o_param in context.output_names(): if not self.has_inserted_nccl_init:
for o_argu in context.output(o_param): global_block = block.program.global_block()
op_desc = global_block.desc.append_op()
var_desc = global_block.desc.var('nccl_com')
var_desc.set_type(core.VarDesc.VarType.NCCL_COM)
self.nccl_com = global_block.create_var(
name='nccl_com', type=core.VarDesc.VarType.NCCL_COM)
framework.Operator(
global_block,
type='ncclInit',
desc=op_desc,
inputs={},
outputs={'Communicator': [self.nccl_com]})
self.has_inserted_nccl_init = True
current_op_desc = context["__current_op_desc__"]
# print(serialize_op_decs(context))
for o_param in current_op_desc.output_names():
for o_argu in current_op_desc.output(o_param):
if o_argu in self.param_grad_names: if o_argu in self.param_grad_names:
print("reduce", o_argu) # print("reduce", o_argu)
op_desc = block.desc.append_op() op_desc = block.desc.append_op()
framework.Operator( op_desc.set_type("ncclAllReduce")
block, op_desc.set_input("X", [o_argu])
type='fill_constant', # FIXME(tonyyang-svail):
desc=op_desc, # Looks like nccl_com has been changed to nccl_com_0
inputs={}, op_desc.set_input("Communicator", ['nccl_com_0'])
attrs={'shape': [1], }, out_var = block.create_var()
outputs={'Out': [block.create_var()]}) op_desc.set_output("Out", [out_var.name])
op_desc.set_attr("reduction", "ncclSum")
return ParallelDoCallBack(param_grad_names) return ParallelDoCallBack(param_grad_names)
else: else:
...@@ -300,7 +326,8 @@ def _append_backward_ops_(block, ...@@ -300,7 +326,8 @@ def _append_backward_ops_(block,
for op_desc in grad_op_descs: for op_desc in grad_op_descs:
new_op_desc = target_block.desc.append_op() new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc) new_op_desc.copy_from(op_desc)
callback(block=target_block, context=new_op_desc) grad_to_var["__current_op_desc__"] = new_op_desc
callback(block=target_block, context=grad_to_var)
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
...@@ -336,6 +363,8 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): ...@@ -336,6 +363,8 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
continue continue
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block) grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block)
# infer_shape and infer_type # infer_shape and infer_type
if op_desc.type() == 'ncclInit':
continue
op_desc.infer_var_type(block.desc) op_desc.infer_var_type(block.desc)
op_desc.infer_shape(block.desc) op_desc.infer_shape(block.desc)
for arg in op_desc.output_arg_names(): for arg in op_desc.output_arg_names():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册