提交 7e70802b 编写于 作者: S seiriosPlus

optimize init from pserver

上级 9ba2ae10
......@@ -193,7 +193,7 @@ void AsyncCommunicator::RecvNoBarrier() {
auto &var_name = iter.first;
VLOG(4) << "recv var " << var_name;
auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(iter.second, *recv_scope_, false);
recv_functor(iter.second, *recv_scope_);
};
task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
}
......@@ -700,7 +700,7 @@ void GeoCommunicator::RecvDense(const std::string &varname) {
auto &ctx = recv_varname_to_ctx_.at(varname);
auto recv = distributed::ParameterRecv<float>();
recv(ctx, *pserver_scope_, true);
recv(ctx, *pserver_scope_);
PADDLE_ENFORCE_EQ(
var_psrever->IsInitialized(), true,
......
......@@ -41,8 +41,67 @@ using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
template <typename T>
void RecvSelectedRows(const CommContext &rpc_ctx,
const framework::Scope &scope) {
void RecvSparseLodTensor(const CommContext &rpc_ctx,
const framework::Scope &scope) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
std::vector<const float *> tensors;
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i];
auto *local_var = local_scope->Var(recv_var_name);
VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
// sparse param in recv_scope is LoDTensor
rets.push_back(rpc_client->AsyncGetVarNoBarrier(
rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name,
recv_var_name, recv_var_name));
const auto *value = local_var->Get<framework::LoDTensor>().data<float>();
tensors.push_back(value);
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
}
auto *merged_var = scope.FindVar(rpc_ctx.var_name);
if (merged_var == nullptr || !merged_var->IsInitialized()) {
PADDLE_THROW(
platform::errors::InvalidArgument("%s must initialized at first."));
}
auto dims1 = merged_var->Get<framework::LoDTensor>().dims()[1];
int64_t height = 0;
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto *splited_var = local_scope->FindVar(rpc_ctx.splited_varnames[i]);
height += splited_var->Get<framework::LoDTensor>().dims()[0];
}
PADDLE_ENFORCE_EQ(merged_var->Get<framework::LoDTensor>().dims()[0], height,
"recved var must has same dims with local var");
auto *merged_t = merged_var->GetMutable<framework::LoDTensor>();
auto *merged_d = merged_t->mutable_data<float>(place)();
auto pserver_num = rpc_ctx.splited_varnames.size();
for (int x = 0; x < height; ++x) {
auto id = x % pserver_num;
auto idx = x / pserver_num;
std::memcpy(merged_d + x * dims1, tensors[id] + idx * dims1,
sizeof(float) * dims1);
}
}
template <typename T>
void RecvGeoSparseRecords(const CommContext &rpc_ctx,
const framework::Scope &scope) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
......@@ -151,7 +210,8 @@ void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) {
template <typename T>
void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope, bool barrier) {
const framework::Scope &scope,
bool geo_records) {
VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name;
PADDLE_ENFORCE_GE(rpc_ctx.origin_varnames.size(), 1,
......@@ -159,18 +219,21 @@ void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
"origin_varnames.size() >= 1 is permitted"));
if (rpc_ctx.is_sparse) {
RecvSelectedRows<T>(rpc_ctx, scope);
if (geo_records) {
RecvGeoSparseRecords()<T>(rpc_ctx, scope);
} else {
RecvSparseLodTensor()<T>(rpc_ctx, scope);
}
} else {
RecvLodTensor<T>(rpc_ctx, scope);
}
VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name;
}
template <typename T>
void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope) {
this->operator()(rpc_ctx, scope, true);
this->operator()(rpc_ctx, scope, false);
}
template struct ParameterRecv<float>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册