提交 4b4e558a 编写于 作者: S seiriosPlus

geo sparse init from pserver

上级 69c349a0
...@@ -676,8 +676,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname) { ...@@ -676,8 +676,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname) {
v_delta.resize(numel); v_delta.resize(numel);
auto cpu_ctx = paddle::platform::CPUDeviceContext(); auto cpu_ctx = paddle::platform::CPUDeviceContext();
auto blas = math::GetBlas<platform::CPUDeviceContext, float>( auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
paddle::platform::CPUDeviceContext());
for (auto j = 0; j < static_cast<int>(ids.size()); ++j) { for (auto j = 0; j < static_cast<int>(ids.size()); ++j) {
blas.VSUB(dims1, t_psrever.data<float>() + j * dims1, blas.VSUB(dims1, t_psrever.data<float>() + j * dims1,
...@@ -791,29 +790,35 @@ void GeoCommunicator::InitSparse() { ...@@ -791,29 +790,35 @@ void GeoCommunicator::InitSparse() {
framework::Scope &local_scope = send_scope_->NewScope(); framework::Scope &local_scope = send_scope_->NewScope();
for (size_t i = 0; i < metas.size(); i++) { for (auto &meta : metas) {
auto &meta = metas[i];
auto &ctx = recv_varname_to_ctx_.at(meta.name); auto &ctx = recv_varname_to_ctx_.at(meta.name);
auto pserver_num = ctx.splited_varnames.size(); auto pserver_num = ctx.splited_varnames.size();
for (size_t j = 0; j < ctx.splited_varnames.size(); j++) {
for (size_t i = 0; i < ctx.splited_varnames.size(); i++) {
auto &recv_var_name = ctx.splited_varnames[i]; auto &recv_var_name = ctx.splited_varnames[i];
auto *var = local_scope.Var(recv_var_name);
var->GetMutable<framework::LoDTensor>();
distributed::VarHandlePtr ret; distributed::VarHandlePtr ret;
ret = rpc_client->AsyncGetVarNoBarrier(ctx.epmap[i], cpu_ctx, local_scope, ret = rpc_client->AsyncGetVarNoBarrier(ctx.epmap[i], cpu_ctx, local_scope,
recv_var_name, recv_var_name); recv_var_name, recv_var_name);
width = recv_t.value().dims()[1];
auto *recv_var = local_scope.FindVar(recv_var_name);
auto &recv_t = recv_var->Get<framework::LoDTensor>();
auto width = recv_t.dims()[1];
auto rows = recv_t.dims()[0];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
width, meta.value_dims[0], width, meta.value_dims[0],
platform::errors::InvalidArgument("sparse params do not match")); platform::errors::InvalidArgument("sparse params do not match"));
auto *recv_var = local_scope->FindVar(recv_var_name);
auto &recv_t = recv_var->Get<framework::SelectedRows>();
std::vector<int64_t> ids; std::vector<int64_t> ids;
std::transform(recv_t.rows().begin(), recv_t.rows().end(),
std::back_inserter(ids), for (int x = 0; x < rows; x++) {
[&](int64_t id) { return id * pserver_num + i; }); ids.push_back(x * pserver_num + i);
}
std::vector<std::vector<std::vector<float> *>> values; std::vector<std::vector<std::vector<float> *>> values;
auto *ins = distributed::LargeScaleKV::GetInstance(); auto *ins = distributed::LargeScaleKV::GetInstance();
...@@ -824,18 +829,20 @@ void GeoCommunicator::InitSparse() { ...@@ -824,18 +829,20 @@ void GeoCommunicator::InitSparse() {
PADDLE_ENFORCE_NE(ret->Wait(), 0U, platform::errors::ExecutionTimeout( PADDLE_ENFORCE_NE(ret->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient")); "internal error in RPCClient"));
auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx); auto blas = math::GetBlas<platform::CPUDeviceContext, float>(
paddle::platform::CPUDeviceContext());
for (size_t k = 0; k < ids.size(); ++k) { for (size_t k = 0; k < ids.size(); ++k) {
blas.VCOPY(width, recv_t.value().data<float>() + k * width, blas.VCOPY(width, recv_t.data<float>() + k * width,
values[k][0]->data()); values[k][0]->data());
} }
local_scope.EraseVars({recv_var_name});
} }
} }
send_scope_->DeleteScope(&local_scope); send_scope_->DeleteScope(&local_scope);
VLOG(3) << "GeoCommunicator init sparse " << varname << " done ";
VLOG(3) << "init sparse variable done"; VLOG(3) << "init sparse variable done";
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册