From ef28e7deba4e9516815229e9c47e1f09b16a1f1e Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 14 Mar 2018 14:21:01 +0800 Subject: [PATCH] refine parallel_do_grad --- paddle/fluid/operators/assign_op.cc | 1 + paddle/fluid/operators/nccl_op.cu.cc | 2 ++ python/paddle/fluid/backward.py | 9 ++++++--- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index 39ae3c0040d..d372213e1b6 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -56,6 +56,7 @@ class AssignFunctor { private: void copy_tensor(const framework::LoDTensor &lod_tensor, framework::LoDTensor *out) const { + if (lod_tensor.numel() == 0) return; auto &out_tensor = *out; TensorCopy(lod_tensor, lod_tensor.place(), dev_ctx_, &out_tensor); out_tensor.set_lod(lod_tensor.lod()); diff --git a/paddle/fluid/operators/nccl_op.cu.cc b/paddle/fluid/operators/nccl_op.cu.cc index 4d83a70e733..ad623e1fe0f 100644 --- a/paddle/fluid/operators/nccl_op.cu.cc +++ b/paddle/fluid/operators/nccl_op.cu.cc @@ -106,6 +106,8 @@ class NCCLReduceKernel : public framework::OpKernel { T* recvbuffer = nullptr; if (root == gpu_id) { recvbuffer = out->mutable_data(ctx.GetPlace()); + } else { + out->Resize(framework::make_ddim({0})); } VLOG(3) << "gpu : " << gpu_id << " invoke reduce. send " << x->numel() << " recv " << out->numel(); diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index b6f20daee3a..7af6ed1463a 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -248,12 +248,15 @@ def _callback_lookup_(op): if o_argu in self.param_grad_names: allreduce_out_name = o_argu + "__nccl_all_reduce__" op_desc = _create_op_desc_( - "ncclAllReduce", { + "ncclReduce", + { "X": [o_argu], "Communicator": ['nccl_com__do_not_change_'] - }, {"Out": [allreduce_out_name]}, - {"reduction": "ncclSum"}) + }, + {"Out": [allreduce_out_name]}, + {"reduction": "ncclSum", + "root": 0}, ) block.desc.append_op().copy_from(op_desc) op_desc = _create_op_desc_( -- GitLab