提交 c0b84874 编写于 作者: S seiriosPlus

geo sparse init from pserver

上级 426c3e13
......@@ -781,24 +781,60 @@ void GeoCommunicator::InitSparse() {
LargeScaleKV::Init(metas);
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id_);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
framework::Scope &local_scope = send_scope_->NewScope();
for (size_t i = 0; i < metas.size(); i++) {
auto &varname = metas[i].name;
auto &dict = dicts[i];
auto &meta = metas[i];
auto &ctx = recv_varname_to_ctx_.at(meta.name);
auto pserver_num = ctx.splited_varnames.size();
for (size_t j = 0; j < ctx.splited_varnames.size(); j++) {
auto &recv_var_name = ctx.splited_varnames[i];
std::vector<int64_t> ids;
ids.reserve(dict);
distributed::VarHandlePtr ret;
ret = rpc_client->AsyncGetVarNoBarrier(endpoints[i], cpu_ctx, local_scope,
recv_var_name, recv_var_name);
width = recv_t.value().dims()[1];
for (auto j = 0; j < dict; ++j) {
ids.push_back(j);
}
PADDLE_ENFORCE_EQ(
width, meta.value_dims[0],
platform::errors::InvalidArgument("sparse params do not match"));
auto *recv_var = local_scope->FindVar(recv_var_name);
auto &recv_t = recv_var->Get<framework::SelectedRows>();
auto *ins = distributed::LargeScaleKV::GetInstance();
ins->Get(varname)->Init(ids);
std::vector<int64_t> ids;
std::transform(recv_t.rows().begin(), recv_t.rows().end(),
std::back_inserter(ids),
[&](int64_t id) { return id * pserver_num + i; });
VLOG(3) << "GeoCommunicator init sparse " << varname << " with size "
<< ids.size();
std::vector<std::vector<std::vector<float> *>> values;
auto *ins = distributed::LargeScaleKV::GetInstance();
ins->Get(meta.name)->Init(ids);
ins->Get(meta.name)->Get(ids, {"Param"}, &values);
PADDLE_ENFORCE_NE(ret->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
for (size_t k = 0; k < ids.size(); ++k) {
blas.VCOPY(width, recv_t.value().data<float>() + k * width,
values[k][0]->data());
}
}
}
send_scope_->DeleteScope(&local_scope);
VLOG(3) << "GeoCommunicator init sparse " << varname << " done ";
VLOG(3) << "init sparse variable done";
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册