提交 fc8ba050 编写于 作者: M malin10

tmp

上级 a87a958b
...@@ -377,9 +377,8 @@ void SyncCommunicator::BarrierSend() { ...@@ -377,9 +377,8 @@ void SyncCommunicator::BarrierSend() {
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
rets[i]->Wait(), 0U, "internal error in RPCClient"));
platform::errors::External("internal error in RPCClient"));
} }
VLOG(4) << "BarrierSend with SyncCommunicator"; VLOG(4) << "BarrierSend with SyncCommunicator";
...@@ -397,9 +396,8 @@ void SyncCommunicator::BarrierRecv() { ...@@ -397,9 +396,8 @@ void SyncCommunicator::BarrierRecv() {
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External(
rets[i]->Wait(), 0U, "internal error in RPCClient"));
platform::errors::External("internal error in RPCClient"));
} }
VLOG(4) << "BarrierRecv with SyncCommunicator"; VLOG(4) << "BarrierRecv with SyncCommunicator";
...@@ -432,7 +430,7 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, ...@@ -432,7 +430,7 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
} }
send_ids_to_queue_[varname] = send_ids_to_queue_[varname] =
std::make_shared<BlockingQueue<std::vector<int64_t>>>( std::make_shared<BlockingQueue<std::shared_ptr<SplitedSparseIds>>>(
send_queue_size_); send_queue_size_);
} }
} }
...@@ -489,12 +487,13 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names, ...@@ -489,12 +487,13 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
"Only LodTensor can be send in GeoCommunicator::Send")); "Only LodTensor can be send in GeoCommunicator::Send"));
} }
auto pserver_num = send_varname_to_ctx_.at[table_name].epmap.size(); auto pserver_num = send_varname_to_ctx_.at(table_name).epmap.size();
auto ids = std::make_shared<SplitedSparseIds>(pserver_num); auto ids = std::make_shared<SplitedSparseIds>(pserver_num);
auto &rows = var->Get<framework::SelectedRows>().rows();
// split rows index into output sparse vars // split rows index into output sparse vars
for (size_t i = 0; i < rows.size(); ++i) { for (size_t i = 0; i < rows.size(); ++i) {
auto ep_idx = rows[i] % pserver_num; auto ep_idx = rows[i] % pserver_num;
ids[ep_idx].add(rows[i]); ids->at(ep_idx).insert(rows[i]);
} }
queue->Push(ids); queue->Push(ids);
} }
...@@ -526,7 +525,8 @@ void GeoCommunicator::SendByCommunicator(int batches) { ...@@ -526,7 +525,8 @@ void GeoCommunicator::SendByCommunicator(int batches) {
for (auto &iter : send_varname_to_ctx_) { for (auto &iter : send_varname_to_ctx_) {
auto &var_name = iter.first; auto &var_name = iter.first;
auto &send_ctx = iter.second; auto &send_ctx = iter.second;
auto &pserver_num = send_ctx.epmap.size(); int pserver_num = static_cast<int>(send_ctx.epmap.size());
auto &ids_queue = send_ids_to_queue_.at(var_name);
splited_ids_vec_.clear(); splited_ids_vec_.clear();
for (int i = 0; i < batches; ++i) { for (int i = 0; i < batches; ++i) {
...@@ -534,7 +534,7 @@ void GeoCommunicator::SendByCommunicator(int batches) { ...@@ -534,7 +534,7 @@ void GeoCommunicator::SendByCommunicator(int batches) {
} }
if (send_ctx.is_sparse) { if (send_ctx.is_sparse) {
for (auto ep_idx = 0; ep_idx < pserver_num; ep_idx++) { for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
auto send_recv_task = [this, ep_idx, &var_name] { auto send_recv_task = [this, ep_idx, &var_name] {
if (var_name == STEP_COUNTER) { if (var_name == STEP_COUNTER) {
return; return;
...@@ -568,29 +568,50 @@ void GeoCommunicator::SendByCommunicator(int batches) { ...@@ -568,29 +568,50 @@ void GeoCommunicator::SendByCommunicator(int batches) {
void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) { void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
std::vector<int64_t> ids; std::vector<int64_t> ids;
auto &ids_queue = send_ids_to_queue_.at(varname);
auto send_varname = send_varname_to_ctx_.at[varname].splited_varnames[ep_idx]; auto &rpc_ctx = send_varname_to_ctx_.at(varname);
auto trainer_id = send_varname_to_ctx_.at[varname].trainer_id; VLOG(1) << rpc_ctx.print();
auto endpoint = send_varname_to_ctx_.at[varname].epmap[ep_idx]; auto send_varname = rpc_ctx.splited_varnames[ep_idx];
auto trainer_id = rpc_ctx.trainer_id;
auto endpoint = rpc_ctx.epmap[ep_idx];
auto pserver_num = rpc_ctx.epmap.size();
for (int i = 0; i < splited_ids_vec_.size(); ++i) { int batches = static_cast<int>(splited_ids_vec_.size());
std::copy((*splited_ids_vec_[i])[ep_idx].begin(), for (int i = 0; i < batches; ++i) {
(*splited_ids_vec_[i])[ep_idx].end(), back_inserter(ids)); std::copy(splited_ids_vec_[i].at(ep_idx).begin(),
splited_ids_vec_[i].at(ep_idx).end(), back_inserter(ids));
} }
auto size = ids.size(); auto size = ids.size();
std::set<int64_t> st(ids.begin(), ids.end()); std::set<int64_t> st(ids.begin(), ids.end());
ids.assign(st.begin(), st.end()); ids.assign(st.begin(), st.end());
VLOG(1) << "SendSparse receive var: " << varname << " unset: " << size
<< " set: " << ids.size(); std::stringstream list_str;
for (uint64_t i = 0; i < ids.size(); i++) {
list_str << ids[i] << ",";
}
VLOG(1) << "SendSparse receive var: " << send_varname << " unset: " << size
<< " set: " << ids.size() << ": " << list_str.str();
if (ids.empty()) { if (ids.empty()) {
LOG(WARNING) << "WARNING: GEO has nothing to send, return directly "; LOG(WARNING) << "WARNING: GEO has nothing to send, return directly ";
return; return;
} }
std::vector<size_t> outs_rows_idx;
if (!rpc_ctx.is_distributed) {
for (size_t i = 0; i < ids.size(); ++i) {
auto id = ids[i] / pserver_num;
outs_rows_idx.push_back(id);
}
} else {
for (size_t i = 0; i < ids.size(); ++i) {
outs_rows_idx.push_back(ids[i]);
}
}
auto *var_latest = recv_scope_->FindVar(varname); auto *var_latest = recv_scope_->FindVar(varname);
PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true, PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true,
...@@ -603,8 +624,10 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) { ...@@ -603,8 +624,10 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
auto cpu_ctx = paddle::platform::CPUDeviceContext(); auto cpu_ctx = paddle::platform::CPUDeviceContext();
auto *var_delta = delta_scope_->Var(send_varname); auto *var_delta = delta_scope_->Var(send_varname);
auto *t_delta = var_delta->GetMutable<framework::SelectedRows>(); auto *t_delta = var_delta->GetMutable<framework::SelectedRows>();
t_delta->set_height(ids.size());
t_delta->mutable_rows()->assign(ids.begin(), ids.end()); t_delta->set_height(rpc_ctx.height_sections[ep_idx]);
t_delta->mutable_rows()->assign(outs_rows_idx.begin(), outs_rows_idx.end());
auto *t_value = t_delta->mutable_value(); auto *t_value = t_delta->mutable_value();
t_value->mutable_data<float>( t_value->mutable_data<float>(
framework::make_ddim({static_cast<int64_t>(ids.size()), dims1}), framework::make_ddim({static_cast<int64_t>(ids.size()), dims1}),
...@@ -625,6 +648,7 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) { ...@@ -625,6 +648,7 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
values[j][0]->data()); values[j][0]->data());
} }
VLOG(1) << "begin to real send " << send_varname;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx_send = *pool.Get(platform::CPUPlace()); auto &cpu_ctx_send = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client = distributed::RPCClient *rpc_client =
...@@ -632,7 +656,9 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) { ...@@ -632,7 +656,9 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
auto ret = rpc_client->AsyncSendVar(endpoint, cpu_ctx_send, auto ret = rpc_client->AsyncSendVar(endpoint, cpu_ctx_send,
*delta_scope_.get(), send_varname); *delta_scope_.get(), send_varname);
ret.wait(); VLOG(1) << "need to wait for send " << send_varname;
ret->Wait();
VLOG(1) << "finish to send " << send_varname;
} }
void GeoCommunicator::SendDense(const std::string &varname) { void GeoCommunicator::SendDense(const std::string &varname) {
...@@ -672,10 +698,11 @@ void GeoCommunicator::SendDense(const std::string &varname) { ...@@ -672,10 +698,11 @@ void GeoCommunicator::SendDense(const std::string &varname) {
void GeoCommunicator::RecvByCommunicator() { return; } void GeoCommunicator::RecvByCommunicator() { return; }
void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) { void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
auto train_id = recv_varname_to_ctx_.at(var_name).trainer_id; auto train_id = recv_varname_to_ctx_.at(varname).trainer_id;
auto endpoint = recv_varname_to_ctx_.at(var_name).epmap[ep_idx]; auto endpoint = recv_varname_to_ctx_.at(varname).epmap[ep_idx];
auto splited_var_name = auto splited_var_name =
send_varname_to_ctx_.at(varname).splited_varnames[ep_idx]; DeltaVarToVar(send_varname_to_ctx_.at(varname).splited_varnames[ep_idx]);
auto pserver_num = recv_varname_to_ctx_.at(varname).epmap.size();
VLOG(1) << "Begin to RecvSparse receive var: " << splited_var_name; VLOG(1) << "Begin to RecvSparse receive var: " << splited_var_name;
...@@ -683,6 +710,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) { ...@@ -683,6 +710,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
auto &cpu_ctx_recv = *pool.Get(platform::CPUPlace()); auto &cpu_ctx_recv = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client = distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(train_id); distributed::RPCClient::GetInstance<RPCCLIENT_T>(train_id);
auto *var_psrever = pserver_scope_->Var(splited_var_name); auto *var_psrever = pserver_scope_->Var(splited_var_name);
auto handle = rpc_client->AsyncGetVar(endpoint, cpu_ctx_recv, auto handle = rpc_client->AsyncGetVar(endpoint, cpu_ctx_recv,
*pserver_scope_.get(), splited_var_name, *pserver_scope_.get(), splited_var_name,
...@@ -724,14 +752,15 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) { ...@@ -724,14 +752,15 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx); auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
for (auto j = 0; j < static_cast<int>(ids.size()); ++j) { for (auto j = 0; j < static_cast<int>(ids.size()); ++j) {
auto id = ids[j] * pserver_num + ep_idx;
blas.VSUB(dims1, t_psrever.data<float>() + j * dims1, blas.VSUB(dims1, t_psrever.data<float>() + j * dims1,
old_values[j][0]->data(), v_delta.data() + j * dims1); old_values[j][0]->data(), v_delta.data() + j * dims1);
blas.VADD(dims1, t_latest->data<float>() + ids[j] * dims1, blas.VADD(dims1, t_latest->data<float>() + id * dims1,
v_delta.data() + j * dims1, v_delta.data() + j * dims1, t_latest->data<float>() + id * dims1);
t_latest->data<float>() + ids[j] * dims1);
blas.VCOPY(dims1, t_psrever.data<float>() + j * dims1, blas.VCOPY(dims1, t_psrever.data<float>() + j * dims1,
old_values[j][0]->data()); old_values[j][0]->data());
} }
VLOG(1) << "receive finish";
} }
void GeoCommunicator::RecvDense(const std::string &varname) { void GeoCommunicator::RecvDense(const std::string &varname) {
......
...@@ -282,7 +282,7 @@ class AsyncCommunicator : public Communicator { ...@@ -282,7 +282,7 @@ class AsyncCommunicator : public Communicator {
const RpcCtxMap &recv_varname_to_ctx, const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override; Scope *recv_scope) override;
void MainThread(); virtual void MainThread();
void Send(const std::vector<std::string> &var_names, void Send(const std::vector<std::string> &var_names,
const std::vector<std::string> &var_tables, const std::vector<std::string> &var_tables,
...@@ -406,7 +406,7 @@ class GeoCommunicator : public AsyncCommunicator { ...@@ -406,7 +406,7 @@ class GeoCommunicator : public AsyncCommunicator {
void InitImpl(const RpcCtxMap &send_varname_to_ctx, void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RpcCtxMap &recv_varname_to_ctx, const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override; Scope *recv_scope) override;
void MainThread() override;
void InitEnvs() { void InitEnvs() {
min_send_grad_num_before_recv_ = 0; min_send_grad_num_before_recv_ = 0;
...@@ -444,6 +444,19 @@ class GeoCommunicator : public AsyncCommunicator { ...@@ -444,6 +444,19 @@ class GeoCommunicator : public AsyncCommunicator {
void InitDense(const std::string varname); void InitDense(const std::string varname);
const std::string VarToDeltaVar(const std::string var_name) {
std::string delta_name = var_name;
const std::string send_name = delta_name.append(".delta");
return send_name;
}
const std::string DeltaVarToVar(const std::string var_name) {
std::string origin_name = var_name;
origin_name.erase(origin_name.find(".delta"), 6);
const std::string param_name = origin_name;
return param_name;
}
private: private:
int trainers_; int trainers_;
std::string sparse_attrs_; std::string sparse_attrs_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册