From 7e70802bb28a2cdbeb1d8056dbfc1904c3e4fbc0 Mon Sep 17 00:00:00 2001 From: seiriosPlus Date: Fri, 21 Aug 2020 18:01:26 +0800 Subject: [PATCH] optimize init from pserver --- .../operators/distributed/communicator.cc | 4 +- .../operators/distributed/parameter_recv.cc | 75 +++++++++++++++++-- 2 files changed, 71 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index 817071d3a57..993bf6c988e 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -193,7 +193,7 @@ void AsyncCommunicator::RecvNoBarrier() { auto &var_name = iter.first; VLOG(4) << "recv var " << var_name; auto recv_functor = distributed::ParameterRecv(); - recv_functor(iter.second, *recv_scope_, false); + recv_functor(iter.second, *recv_scope_); }; task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task))); } @@ -700,7 +700,7 @@ void GeoCommunicator::RecvDense(const std::string &varname) { auto &ctx = recv_varname_to_ctx_.at(varname); auto recv = distributed::ParameterRecv(); - recv(ctx, *pserver_scope_, true); + recv(ctx, *pserver_scope_); PADDLE_ENFORCE_EQ( var_psrever->IsInitialized(), true, diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc index 03b9000999b..fa68562b08c 100644 --- a/paddle/fluid/operators/distributed/parameter_recv.cc +++ b/paddle/fluid/operators/distributed/parameter_recv.cc @@ -41,8 +41,67 @@ using SelectedRows = framework::SelectedRows; using DDim = framework::DDim; template -void RecvSelectedRows(const CommContext &rpc_ctx, - const framework::Scope &scope) { +void RecvSparseLodTensor(const CommContext &rpc_ctx, + const framework::Scope &scope) { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto cpu_place = platform::CPUPlace(); + auto &cpu_ctx = *pool.Get(cpu_place); + + distributed::RPCClient *rpc_client = + distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); + + std::unique_ptr local_scope = scope.NewTmpScope(); + std::vector tensors; + std::vector rets; + for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) { + auto &recv_var_name = rpc_ctx.splited_varnames[i]; + auto *local_var = local_scope->Var(recv_var_name); + VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i]; + // sparse param in recv_scope is LoDTensor + rets.push_back(rpc_client->AsyncGetVarNoBarrier( + rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name, + recv_var_name, recv_var_name)); + + const auto *value = local_var->Get().data(); + tensors.push_back(value); + } + + for (size_t i = 0; i < rets.size(); i++) { + PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout( + "internal error in RPCClient")); + } + + auto *merged_var = scope.FindVar(rpc_ctx.var_name); + + if (merged_var == nullptr || !merged_var->IsInitialized()) { + PADDLE_THROW( + platform::errors::InvalidArgument("%s must initialized at first.")); + } + auto dims1 = merged_var->Get().dims()[1]; + int64_t height = 0; + for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) { + auto *splited_var = local_scope->FindVar(rpc_ctx.splited_varnames[i]); + height += splited_var->Get().dims()[0]; + } + + PADDLE_ENFORCE_EQ(merged_var->Get().dims()[0], height, + "recved var must has same dims with local var"); + + auto *merged_t = merged_var->GetMutable(); + auto *merged_d = merged_t->mutable_data(place)(); + + auto pserver_num = rpc_ctx.splited_varnames.size(); + for (int x = 0; x < height; ++x) { + auto id = x % pserver_num; + auto idx = x / pserver_num; + std::memcpy(merged_d + x * dims1, tensors[id] + idx * dims1, + sizeof(float) * dims1); + } +} + +template +void RecvGeoSparseRecords(const CommContext &rpc_ctx, + const framework::Scope &scope) { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto cpu_place = platform::CPUPlace(); auto &cpu_ctx = *pool.Get(cpu_place); @@ -151,7 +210,8 @@ void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) { template void ParameterRecv::operator()(const CommContext &rpc_ctx, - const framework::Scope &scope, bool barrier) { + const framework::Scope &scope, + bool geo_records) { VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name; PADDLE_ENFORCE_GE(rpc_ctx.origin_varnames.size(), 1, @@ -159,18 +219,21 @@ void ParameterRecv::operator()(const CommContext &rpc_ctx, "origin_varnames.size() >= 1 is permitted")); if (rpc_ctx.is_sparse) { - RecvSelectedRows(rpc_ctx, scope); + if (geo_records) { + RecvGeoSparseRecords()(rpc_ctx, scope); + } else { + RecvSparseLodTensor()(rpc_ctx, scope); + } } else { RecvLodTensor(rpc_ctx, scope); } VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name; } - template void ParameterRecv::operator()(const CommContext &rpc_ctx, const framework::Scope &scope) { - this->operator()(rpc_ctx, scope, true); + this->operator()(rpc_ctx, scope, false); } template struct ParameterRecv; -- GitLab