From c0b848742bcaf9b7e6b91f6f516e452acc0cfdff Mon Sep 17 00:00:00 2001 From: seiriosPlus Date: Fri, 21 Aug 2020 11:32:20 +0800 Subject: [PATCH] geo sparse init from pserver --- .../operators/distributed/communicator.cc | 58 +++++++++++++++---- 1 file changed, 47 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index 23920e1233b..1ebb1ef0592 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -781,24 +781,60 @@ void GeoCommunicator::InitSparse() { LargeScaleKV::Init(metas); + distributed::RPCClient *rpc_client = + distributed::RPCClient::GetInstance(trainer_id_); + + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto cpu_place = platform::CPUPlace(); + auto &cpu_ctx = *pool.Get(cpu_place); + + framework::Scope &local_scope = send_scope_->NewScope(); + for (size_t i = 0; i < metas.size(); i++) { - auto &varname = metas[i].name; - auto &dict = dicts[i]; + auto &meta = metas[i]; + auto &ctx = recv_varname_to_ctx_.at(meta.name); + auto pserver_num = ctx.splited_varnames.size(); + for (size_t j = 0; j < ctx.splited_varnames.size(); j++) { + auto &recv_var_name = ctx.splited_varnames[i]; - std::vector ids; - ids.reserve(dict); + distributed::VarHandlePtr ret; + ret = rpc_client->AsyncGetVarNoBarrier(endpoints[i], cpu_ctx, local_scope, + recv_var_name, recv_var_name); + width = recv_t.value().dims()[1]; - for (auto j = 0; j < dict; ++j) { - ids.push_back(j); - } + PADDLE_ENFORCE_EQ( + width, meta.value_dims[0], + platform::errors::InvalidArgument("sparse params do not match")); + + auto *recv_var = local_scope->FindVar(recv_var_name); + auto &recv_t = recv_var->Get(); - auto *ins = distributed::LargeScaleKV::GetInstance(); - ins->Get(varname)->Init(ids); + std::vector ids; + std::transform(recv_t.rows().begin(), recv_t.rows().end(), + std::back_inserter(ids), + [&](int64_t id) { return id * pserver_num + i; }); - VLOG(3) << "GeoCommunicator init sparse " << varname << " with size " - << ids.size(); + std::vector *>> values; + auto *ins = distributed::LargeScaleKV::GetInstance(); + + ins->Get(meta.name)->Init(ids); + ins->Get(meta.name)->Get(ids, {"Param"}, &values); + + PADDLE_ENFORCE_NE(ret->Wait(), 0U, platform::errors::ExecutionTimeout( + "internal error in RPCClient")); + + auto blas = math::GetBlas(cpu_ctx); + + for (size_t k = 0; k < ids.size(); ++k) { + blas.VCOPY(width, recv_t.value().data() + k * width, + values[k][0]->data()); + } + } } + send_scope_->DeleteScope(&local_scope); + + VLOG(3) << "GeoCommunicator init sparse " << varname << " done "; VLOG(3) << "init sparse variable done"; } -- GitLab