diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index 330a766c8cbae5a38de2f5cd1de99bb03438f37b..c144686a353cfe706e3960f12f26656a404e1df0 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -676,8 +676,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname) { v_delta.resize(numel); auto cpu_ctx = paddle::platform::CPUDeviceContext(); - auto blas = math::GetBlas( - paddle::platform::CPUDeviceContext()); + auto blas = math::GetBlas(cpu_ctx); for (auto j = 0; j < static_cast(ids.size()); ++j) { blas.VSUB(dims1, t_psrever.data() + j * dims1, @@ -791,29 +790,35 @@ void GeoCommunicator::InitSparse() { framework::Scope &local_scope = send_scope_->NewScope(); - for (size_t i = 0; i < metas.size(); i++) { - auto &meta = metas[i]; + for (auto &meta : metas) { auto &ctx = recv_varname_to_ctx_.at(meta.name); auto pserver_num = ctx.splited_varnames.size(); - for (size_t j = 0; j < ctx.splited_varnames.size(); j++) { + + for (size_t i = 0; i < ctx.splited_varnames.size(); i++) { auto &recv_var_name = ctx.splited_varnames[i]; + auto *var = local_scope.Var(recv_var_name); + var->GetMutable(); distributed::VarHandlePtr ret; + ret = rpc_client->AsyncGetVarNoBarrier(ctx.epmap[i], cpu_ctx, local_scope, recv_var_name, recv_var_name); - width = recv_t.value().dims()[1]; + + auto *recv_var = local_scope.FindVar(recv_var_name); + auto &recv_t = recv_var->Get(); + + auto width = recv_t.dims()[1]; + auto rows = recv_t.dims()[0]; PADDLE_ENFORCE_EQ( width, meta.value_dims[0], platform::errors::InvalidArgument("sparse params do not match")); - auto *recv_var = local_scope->FindVar(recv_var_name); - auto &recv_t = recv_var->Get(); - std::vector ids; - std::transform(recv_t.rows().begin(), recv_t.rows().end(), - std::back_inserter(ids), - [&](int64_t id) { return id * pserver_num + i; }); + + for (int x = 0; x < rows; x++) { + ids.push_back(x * pserver_num + i); + } std::vector *>> values; auto *ins = distributed::LargeScaleKV::GetInstance(); @@ -824,18 +829,20 @@ void GeoCommunicator::InitSparse() { PADDLE_ENFORCE_NE(ret->Wait(), 0U, platform::errors::ExecutionTimeout( "internal error in RPCClient")); - auto blas = math::GetBlas(cpu_ctx); + auto blas = math::GetBlas( + paddle::platform::CPUDeviceContext()); for (size_t k = 0; k < ids.size(); ++k) { - blas.VCOPY(width, recv_t.value().data() + k * width, + blas.VCOPY(width, recv_t.data() + k * width, values[k][0]->data()); } + + local_scope.EraseVars({recv_var_name}); } } send_scope_->DeleteScope(&local_scope); - VLOG(3) << "GeoCommunicator init sparse " << varname << " done "; VLOG(3) << "init sparse variable done"; }