提交 9ba2ae10 编写于 作者: S seiriosPlus

optimize init from pserver

上级 4b4e558a
......@@ -74,8 +74,12 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
} else {
recv_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
}
InitParams();
}
void AsyncCommunicator::InitParams() { RecvNoBarrier(); }
AsyncCommunicator::~AsyncCommunicator() {
running_ = false;
if (main_thread_) main_thread_->join();
......@@ -721,7 +725,7 @@ void GeoCommunicator::RecvDense(const std::string &varname) {
t_timestamp->data<float>());
}
void GeoCommunicator::Init() {
void GeoCommunicator::InitParams() {
std::vector<std::future<void>> tasks;
tasks.reserve(recv_varname_to_ctx_.size());
......@@ -744,12 +748,17 @@ void GeoCommunicator::Init() {
}
void GeoCommunicator::InitDense(const std::string varname) {
auto *var = old_scope_->Var(varname);
var->GetMutable<framework::LoDTensor>();
auto &ctx = recv_varname_to_ctx_.at(varname);
auto recv = distributed::ParameterRecv<float>();
recv(ctx, *old_scope_);
recv(ctx, *recv_scope_);
auto *global_var = recv_scope_->FindVar(varname);
global_var->GetMutable<framework::LoDTensor>();
auto *old_var = old_scope_->Var(varname);
old_var->GetMutable<framework::LoDTensor>();
framework::CopyVariable(*global_var, old_var);
VLOG(1) << "init dense variable " << varname << " done";
}
......@@ -781,68 +790,43 @@ void GeoCommunicator::InitSparse() {
LargeScaleKV::Init(metas);
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id_);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
framework::Scope &local_scope = send_scope_->NewScope();
for (auto &meta : metas) {
auto &ctx = recv_varname_to_ctx_.at(meta.name);
auto pserver_num = ctx.splited_varnames.size();
for (size_t i = 0; i < ctx.splited_varnames.size(); i++) {
auto &recv_var_name = ctx.splited_varnames[i];
auto *var = local_scope.Var(recv_var_name);
var->GetMutable<framework::LoDTensor>();
distributed::VarHandlePtr ret;
auto recv = distributed::ParameterRecv<float>();
ret = rpc_client->AsyncGetVarNoBarrier(ctx.epmap[i], cpu_ctx, local_scope,
recv_var_name, recv_var_name);
auto *global_var = recv_scope_->FindVar(meta.name);
auto global_value = global_var->Get<framework::LoDTensor>();
auto rows = global_value.dims()[0];
auto dim1 = global_value.dims()[1];
auto *recv_var = local_scope.FindVar(recv_var_name);
auto &recv_t = recv_var->Get<framework::LoDTensor>();
recv(ctx, *recv_scope_);
VLOG(1) << "recv " << meta.name << " with global scope for init";
auto width = recv_t.dims()[1];
auto rows = recv_t.dims()[0];
auto n_rows = global_var->Get<framework::LoDTensor>().dims()[0];
PADDLE_ENFORCE_EQ(
width, meta.value_dims[0],
platform::errors::InvalidArgument("sparse params do not match"));
std::vector<int64_t> ids;
rows, n_rows,
platform::errors::InvalidArgument(
"global var: %s origin dim must equal recved rows", meta.name));
for (int x = 0; x < rows; x++) {
ids.push_back(x * pserver_num + i);
}
std::vector<int64_t> ids(rows);
std::iota(ids.begin(), ids.end(), 0);
std::vector<std::vector<std::vector<float> *>> values;
auto *ins = distributed::LargeScaleKV::GetInstance();
std::vector<std::vector<std::vector<float> *>> values;
ins->Get(meta.name)->Init(ids);
ins->Get(meta.name)->Get(ids, {"Param"}, &values);
PADDLE_ENFORCE_NE(ret->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
auto blas = math::GetBlas<platform::CPUDeviceContext, float>(
paddle::platform::CPUDeviceContext());
for (size_t k = 0; k < ids.size(); ++k) {
blas.VCOPY(width, recv_t.data<float>() + k * width,
values[k][0]->data());
}
local_scope.EraseVars({recv_var_name});
for (auto &id : ids) {
blas.VCOPY(dim1, global_value.data<float>() + k * width,
values[id][0]->data());
}
}
send_scope_->DeleteScope(&local_scope);
VLOG(3) << "init sparse variable done";
}
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <deque>
#include <map>
#include <memory>
#include <numeric>
#include <set>
#include <string>
#include <unordered_map>
......@@ -29,6 +30,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
......@@ -279,6 +281,8 @@ class AsyncCommunicator : public Communicator {
const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override;
void InitParams();
void MainThread();
void Send(const std::vector<std::string> &var_names,
......@@ -435,7 +439,7 @@ class GeoCommunicator : public AsyncCommunicator {
void RecvDense(const std::string &varname);
void Init();
void InitParams();
void InitSparse();
......
......@@ -65,7 +65,6 @@ class AsyncMetaOptimizer(MetaOptimizerBase):
# for startup program
_startup = worker.fake_init_ops_pass(_startup, compiled_config)
_startup = worker.init_from_server_pass(_startup, compiled_config)
_startup = worker.delet_extra_optimizes_pass(_startup,
compiled_config)
else:
......
......@@ -771,7 +771,6 @@ class ParameterServerOptimizer(DistributedOptimizer):
# for startup program
_startup = worker.fake_init_ops_pass(_startup, compiled_config)
_startup = worker.init_from_server_pass(_startup, compiled_config)
_startup = worker.delet_extra_optimizes_pass(_startup,
compiled_config)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册