From 0815c0f141d1df2088ed3c5a5391662bb4484e3d Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Sat, 10 Feb 2018 03:16:02 +0000 Subject: [PATCH] add assign op --- python/paddle/v2/fluid/backward.py | 36 ++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index 34383827f..40c54bf22 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -228,8 +228,6 @@ def _callback_lookup_(op): self.param_grad_names = param_grad_names def __call__(self, block, context): - # move to parallel_do.py - # # TODO(tonyyang-svail): insert nccl init if not self.has_inserted_nccl_init: global_block = block.program.global_block() op_desc = global_block.desc.append_op() @@ -250,16 +248,30 @@ def _callback_lookup_(op): for o_param in current_op_desc.output_names(): for o_argu in current_op_desc.output(o_param): if o_argu in self.param_grad_names: - # print("reduce", o_argu) - op_desc = block.desc.append_op() - op_desc.set_type("ncclAllReduce") - op_desc.set_input("X", [o_argu]) - # FIXME(tonyyang-svail): - # Looks like nccl_com has been changed to nccl_com_0 - op_desc.set_input("Communicator", ['nccl_com_0']) - out_var = block.create_var() - op_desc.set_output("Out", [out_var.name]) - op_desc.set_attr("reduction", "ncclSum") + # # print("reduce", o_argu) + # op_desc = block.desc.append_op() + # op_desc.set_type("ncclAllReduce") + # op_desc.set_input("X", [o_argu]) + # + # # FIXME(tonyyang-svail): + # # Looks like nccl_com has been changed to nccl_com_0 + # op_desc.set_input("Communicator", ['nccl_com_0']) + # out_var = block.create_var() + # op_desc.set_output("Out", [out_var.name]) + # op_desc.set_attr("reduction", "ncclSum") + allreduce_out_name = o_argu + "__nccl_all_reduce__" + op_desc = _create_op_desc_( + "ncclAllReduce", { + "X": [o_argu], + "Communicator": ['nccl_com_0'] + }, {"Out": [allreduce_out_name]}, + {"reduction": "ncclSum"}) + block.desc.append_op().copy_from(op_desc) + + op_desc = _create_op_desc_( + "assign", {"X": [allreduce_out_name]}, + {"Out": [o_argu]}, {}) + block.desc.append_op().copy_from(op_desc) return ParallelDoCallBack(param_grad_names) else: -- GitLab