提交 0815c0f1 编写于 作者: Y Yang Yang

add assign op

上级 f2129b19
...@@ -228,8 +228,6 @@ def _callback_lookup_(op): ...@@ -228,8 +228,6 @@ def _callback_lookup_(op):
self.param_grad_names = param_grad_names self.param_grad_names = param_grad_names
def __call__(self, block, context): def __call__(self, block, context):
# move to parallel_do.py
# # TODO(tonyyang-svail): insert nccl init
if not self.has_inserted_nccl_init: if not self.has_inserted_nccl_init:
global_block = block.program.global_block() global_block = block.program.global_block()
op_desc = global_block.desc.append_op() op_desc = global_block.desc.append_op()
...@@ -250,16 +248,30 @@ def _callback_lookup_(op): ...@@ -250,16 +248,30 @@ def _callback_lookup_(op):
for o_param in current_op_desc.output_names(): for o_param in current_op_desc.output_names():
for o_argu in current_op_desc.output(o_param): for o_argu in current_op_desc.output(o_param):
if o_argu in self.param_grad_names: if o_argu in self.param_grad_names:
# print("reduce", o_argu) # # print("reduce", o_argu)
op_desc = block.desc.append_op() # op_desc = block.desc.append_op()
op_desc.set_type("ncclAllReduce") # op_desc.set_type("ncclAllReduce")
op_desc.set_input("X", [o_argu]) # op_desc.set_input("X", [o_argu])
# FIXME(tonyyang-svail): #
# Looks like nccl_com has been changed to nccl_com_0 # # FIXME(tonyyang-svail):
op_desc.set_input("Communicator", ['nccl_com_0']) # # Looks like nccl_com has been changed to nccl_com_0
out_var = block.create_var() # op_desc.set_input("Communicator", ['nccl_com_0'])
op_desc.set_output("Out", [out_var.name]) # out_var = block.create_var()
op_desc.set_attr("reduction", "ncclSum") # op_desc.set_output("Out", [out_var.name])
# op_desc.set_attr("reduction", "ncclSum")
allreduce_out_name = o_argu + "__nccl_all_reduce__"
op_desc = _create_op_desc_(
"ncclAllReduce", {
"X": [o_argu],
"Communicator": ['nccl_com_0']
}, {"Out": [allreduce_out_name]},
{"reduction": "ncclSum"})
block.desc.append_op().copy_from(op_desc)
op_desc = _create_op_desc_(
"assign", {"X": [allreduce_out_name]},
{"Out": [o_argu]}, {})
block.desc.append_op().copy_from(op_desc)
return ParallelDoCallBack(param_grad_names) return ParallelDoCallBack(param_grad_names)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册