提交 fc8ba050 编写于 作者: M malin10

tmp

上级 a87a958b
......@@ -377,9 +377,8 @@ void SyncCommunicator::BarrierSend() {
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U,
platform::errors::External("internal error in RPCClient"));
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
"internal error in RPCClient"));
}
VLOG(4) << "BarrierSend with SyncCommunicator";
......@@ -397,9 +396,8 @@ void SyncCommunicator::BarrierRecv() {
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U,
platform::errors::External("internal error in RPCClient"));
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
"internal error in RPCClient"));
}
VLOG(4) << "BarrierRecv with SyncCommunicator";
......@@ -432,7 +430,7 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
}
send_ids_to_queue_[varname] =
std::make_shared<BlockingQueue<std::vector<int64_t>>>(
std::make_shared<BlockingQueue<std::shared_ptr<SplitedSparseIds>>>(
send_queue_size_);
}
}
......@@ -489,12 +487,13 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
"Only LodTensor can be send in GeoCommunicator::Send"));
}
auto pserver_num = send_varname_to_ctx_.at[table_name].epmap.size();
auto pserver_num = send_varname_to_ctx_.at(table_name).epmap.size();
auto ids = std::make_shared<SplitedSparseIds>(pserver_num);
auto &rows = var->Get<framework::SelectedRows>().rows();
// split rows index into output sparse vars
for (size_t i = 0; i < rows.size(); ++i) {
auto ep_idx = rows[i] % pserver_num;
ids[ep_idx].add(rows[i]);
ids->at(ep_idx).insert(rows[i]);
}
queue->Push(ids);
}
......@@ -526,7 +525,8 @@ void GeoCommunicator::SendByCommunicator(int batches) {
for (auto &iter : send_varname_to_ctx_) {
auto &var_name = iter.first;
auto &send_ctx = iter.second;
auto &pserver_num = send_ctx.epmap.size();
int pserver_num = static_cast<int>(send_ctx.epmap.size());
auto &ids_queue = send_ids_to_queue_.at(var_name);
splited_ids_vec_.clear();
for (int i = 0; i < batches; ++i) {
......@@ -534,7 +534,7 @@ void GeoCommunicator::SendByCommunicator(int batches) {
}
if (send_ctx.is_sparse) {
for (auto ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
auto send_recv_task = [this, ep_idx, &var_name] {
if (var_name == STEP_COUNTER) {
return;
......@@ -568,29 +568,50 @@ void GeoCommunicator::SendByCommunicator(int batches) {
void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
std::vector<int64_t> ids;
auto &ids_queue = send_ids_to_queue_.at(varname);
auto send_varname = send_varname_to_ctx_.at[varname].splited_varnames[ep_idx];
auto trainer_id = send_varname_to_ctx_.at[varname].trainer_id;
auto endpoint = send_varname_to_ctx_.at[varname].epmap[ep_idx];
auto &rpc_ctx = send_varname_to_ctx_.at(varname);
VLOG(1) << rpc_ctx.print();
auto send_varname = rpc_ctx.splited_varnames[ep_idx];
auto trainer_id = rpc_ctx.trainer_id;
auto endpoint = rpc_ctx.epmap[ep_idx];
auto pserver_num = rpc_ctx.epmap.size();
for (int i = 0; i < splited_ids_vec_.size(); ++i) {
std::copy((*splited_ids_vec_[i])[ep_idx].begin(),
(*splited_ids_vec_[i])[ep_idx].end(), back_inserter(ids));
int batches = static_cast<int>(splited_ids_vec_.size());
for (int i = 0; i < batches; ++i) {
std::copy(splited_ids_vec_[i].at(ep_idx).begin(),
splited_ids_vec_[i].at(ep_idx).end(), back_inserter(ids));
}
auto size = ids.size();
std::set<int64_t> st(ids.begin(), ids.end());
ids.assign(st.begin(), st.end());
VLOG(1) << "SendSparse receive var: " << varname << " unset: " << size
<< " set: " << ids.size();
std::stringstream list_str;
for (uint64_t i = 0; i < ids.size(); i++) {
list_str << ids[i] << ",";
}
VLOG(1) << "SendSparse receive var: " << send_varname << " unset: " << size
<< " set: " << ids.size() << ": " << list_str.str();
if (ids.empty()) {
LOG(WARNING) << "WARNING: GEO has nothing to send, return directly ";
return;
}
std::vector<size_t> outs_rows_idx;
if (!rpc_ctx.is_distributed) {
for (size_t i = 0; i < ids.size(); ++i) {
auto id = ids[i] / pserver_num;
outs_rows_idx.push_back(id);
}
} else {
for (size_t i = 0; i < ids.size(); ++i) {
outs_rows_idx.push_back(ids[i]);
}
}
auto *var_latest = recv_scope_->FindVar(varname);
PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
......@@ -603,8 +624,10 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
auto cpu_ctx = paddle::platform::CPUDeviceContext();
auto *var_delta = delta_scope_->Var(send_varname);
auto *t_delta = var_delta->GetMutable<framework::SelectedRows>();
t_delta->set_height(ids.size());
t_delta->mutable_rows()->assign(ids.begin(), ids.end());
t_delta->set_height(rpc_ctx.height_sections[ep_idx]);
t_delta->mutable_rows()->assign(outs_rows_idx.begin(), outs_rows_idx.end());
auto *t_value = t_delta->mutable_value();
t_value->mutable_data<float>(
framework::make_ddim({static_cast<int64_t>(ids.size()), dims1}),
......@@ -625,6 +648,7 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
values[j][0]->data());
}
VLOG(1) << "begin to real send " << send_varname;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx_send = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client =
......@@ -632,7 +656,9 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
auto ret = rpc_client->AsyncSendVar(endpoint, cpu_ctx_send,
*delta_scope_.get(), send_varname);
ret.wait();
VLOG(1) << "need to wait for send " << send_varname;
ret->Wait();
VLOG(1) << "finish to send " << send_varname;
}
void GeoCommunicator::SendDense(const std::string &varname) {
......@@ -672,10 +698,11 @@ void GeoCommunicator::SendDense(const std::string &varname) {
void GeoCommunicator::RecvByCommunicator() { return; }
void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
auto train_id = recv_varname_to_ctx_.at(var_name).trainer_id;
auto endpoint = recv_varname_to_ctx_.at(var_name).epmap[ep_idx];
auto train_id = recv_varname_to_ctx_.at(varname).trainer_id;
auto endpoint = recv_varname_to_ctx_.at(varname).epmap[ep_idx];
auto splited_var_name =
send_varname_to_ctx_.at(varname).splited_varnames[ep_idx];
DeltaVarToVar(send_varname_to_ctx_.at(varname).splited_varnames[ep_idx]);
auto pserver_num = recv_varname_to_ctx_.at(varname).epmap.size();
VLOG(1) << "Begin to RecvSparse receive var: " << splited_var_name;
......@@ -683,6 +710,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
auto &cpu_ctx_recv = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(train_id);
auto *var_psrever = pserver_scope_->Var(splited_var_name);
auto handle = rpc_client->AsyncGetVar(endpoint, cpu_ctx_recv,
*pserver_scope_.get(), splited_var_name,
......@@ -724,14 +752,15 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
for (auto j = 0; j < static_cast<int>(ids.size()); ++j) {
auto id = ids[j] * pserver_num + ep_idx;
blas.VSUB(dims1, t_psrever.data<float>() + j * dims1,
old_values[j][0]->data(), v_delta.data() + j * dims1);
blas.VADD(dims1, t_latest->data<float>() + ids[j] * dims1,
v_delta.data() + j * dims1,
t_latest->data<float>() + ids[j] * dims1);
blas.VADD(dims1, t_latest->data<float>() + id * dims1,
v_delta.data() + j * dims1, t_latest->data<float>() + id * dims1);
blas.VCOPY(dims1, t_psrever.data<float>() + j * dims1,
old_values[j][0]->data());
}
VLOG(1) << "receive finish";
}
void GeoCommunicator::RecvDense(const std::string &varname) {
......
......@@ -282,7 +282,7 @@ class AsyncCommunicator : public Communicator {
const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override;
void MainThread();
virtual void MainThread();
void Send(const std::vector<std::string> &var_names,
const std::vector<std::string> &var_tables,
......@@ -406,7 +406,7 @@ class GeoCommunicator : public AsyncCommunicator {
void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override;
void MainThread() override;
void InitEnvs() {
min_send_grad_num_before_recv_ = 0;
......@@ -444,6 +444,19 @@ class GeoCommunicator : public AsyncCommunicator {
void InitDense(const std::string varname);
const std::string VarToDeltaVar(const std::string var_name) {
std::string delta_name = var_name;
const std::string send_name = delta_name.append(".delta");
return send_name;
}
const std::string DeltaVarToVar(const std::string var_name) {
std::string origin_name = var_name;
origin_name.erase(origin_name.find(".delta"), 6);
const std::string param_name = origin_name;
return param_name;
}
private:
int trainers_;
std::string sparse_attrs_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册