From 6ad6b47a3677147aa767374488906c606c4d4fab Mon Sep 17 00:00:00 2001 From: seiriosPlus Date: Thu, 20 Aug 2020 20:54:48 +0800 Subject: [PATCH] fix parame recv for sparse --- paddle/fluid/operators/distributed/communicator.cc | 2 -- paddle/fluid/operators/distributed/parameter_recv.cc | 11 ++++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index 194731d631c..23920e1233b 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -628,9 +628,7 @@ void GeoCommunicator::RecvByCommunicator() { if (recv_ctx.is_sparse) { RecvSparse(var_name); } else { - VLOG(1) << "recv dense " << var_name << " begin"; RecvDense(var_name); - VLOG(1) << "recv dense " << var_name << " done"; } }; tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task))); diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc index 5409ec54987..376ab6f402d 100644 --- a/paddle/fluid/operators/distributed/parameter_recv.cc +++ b/paddle/fluid/operators/distributed/parameter_recv.cc @@ -84,9 +84,14 @@ void RecvSelectedRows(const CommContext &rpc_ctx, ids_num += recv_t.rows().size(); width = recv_t.value().dims()[1]; - std::transform(recv_t.rows().begin(), recv_t.rows().end(), - std::back_inserter(all_ids), - [&](int64_t id) { return id * pserver_num + i; }); + if (rpc_ctx.is_distributed) { + std::copy(recv_t.rows().begin(), recv_t.rows().end(), + std::back_inserter(all_ids)); + } else { + std::transform(recv_t.rows().begin(), recv_t.rows().end(), + std::back_inserter(all_ids), + [&](int64_t id) { return i * pserver_num + id; }); + } } auto *var = scope.FindVar(rpc_ctx.var_name); -- GitLab