提交 a87a958b 编写于 作者: M malin10

test=develop, bug fix

上级 9ded7565
...@@ -13,13 +13,16 @@ See the License for the specific language governing permissions and ...@@ -13,13 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/distributed/communicator.h" #include "paddle/fluid/operators/distributed/communicator.h"
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <paddle/fluid/framework/program_desc.h> #include <paddle/fluid/framework/program_desc.h>
#include <algorithm> #include <algorithm>
#include <chrono> // NOLINT #include <chrono> // NOLINT
#include <map> #include <map>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
...@@ -374,8 +377,9 @@ void SyncCommunicator::BarrierSend() { ...@@ -374,8 +377,9 @@ void SyncCommunicator::BarrierSend() {
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External( PADDLE_ENFORCE_NE(
"internal error in RPCClient")); rets[i]->Wait(), 0U,
platform::errors::External("internal error in RPCClient"));
} }
VLOG(4) << "BarrierSend with SyncCommunicator"; VLOG(4) << "BarrierSend with SyncCommunicator";
...@@ -393,8 +397,9 @@ void SyncCommunicator::BarrierRecv() { ...@@ -393,8 +397,9 @@ void SyncCommunicator::BarrierRecv() {
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::External( PADDLE_ENFORCE_NE(
"internal error in RPCClient")); rets[i]->Wait(), 0U,
platform::errors::External("internal error in RPCClient"));
} }
VLOG(4) << "BarrierRecv with SyncCommunicator"; VLOG(4) << "BarrierRecv with SyncCommunicator";
...@@ -484,13 +489,36 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names, ...@@ -484,13 +489,36 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
"Only LodTensor can be send in GeoCommunicator::Send")); "Only LodTensor can be send in GeoCommunicator::Send"));
} }
std::vector<int64_t> ids; auto pserver_num = send_varname_to_ctx_.at[table_name].epmap.size();
auto &rows = var->Get<framework::SelectedRows>().rows(); auto ids = std::make_shared<SplitedSparseIds>(pserver_num);
ids.assign(rows.begin(), rows.end()); // split rows index into output sparse vars
for (size_t i = 0; i < rows.size(); ++i) {
auto ep_idx = rows[i] % pserver_num;
ids[ep_idx].add(rows[i]);
}
queue->Push(ids); queue->Push(ids);
} }
} }
void GeoCommunicator::MainThread() {
VLOG(3) << "MainThread start and wait";
while (waiting_ && running_) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
VLOG(3) << "wait for running";
}
while (running_) {
int meet = Meet();
VLOG(1) << "async_meet: " << meet;
SendGlobalStep(meet);
SendByCommunicator(meet);
}
VLOG(1) << "geo-communicator stopped, send thread exit";
}
void GeoCommunicator::SendByCommunicator(int batches) { void GeoCommunicator::SendByCommunicator(int batches) {
std::vector<std::future<void>> tasks; std::vector<std::future<void>> tasks;
tasks.reserve(send_varname_to_ctx_.size()); tasks.reserve(send_varname_to_ctx_.size());
...@@ -498,21 +526,39 @@ void GeoCommunicator::SendByCommunicator(int batches) { ...@@ -498,21 +526,39 @@ void GeoCommunicator::SendByCommunicator(int batches) {
for (auto &iter : send_varname_to_ctx_) { for (auto &iter : send_varname_to_ctx_) {
auto &var_name = iter.first; auto &var_name = iter.first;
auto &send_ctx = iter.second; auto &send_ctx = iter.second;
auto &pserver_num = send_ctx.epmap.size();
auto send_task = [this, batches, &var_name, &send_ctx] { splited_ids_vec_.clear();
if (var_name == STEP_COUNTER) { for (int i = 0; i < batches; ++i) {
return; splited_ids_vec_.push_back(*(ids_queue->Pop()));
} }
if (send_ctx.is_sparse) { if (send_ctx.is_sparse) {
SendSparse(var_name, batches); for (auto ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
} else { auto send_recv_task = [this, ep_idx, &var_name] {
if (var_name == STEP_COUNTER) {
return;
}
SendSparse(var_name, ep_idx);
RecvSparse(var_name, ep_idx);
};
tasks.emplace_back(
send_threadpool_->enqueue(std::move(send_recv_task)));
}
} else {
auto send_recv_task = [this, &var_name, &send_ctx] {
if (var_name == STEP_COUNTER) {
return;
}
VLOG(1) << "send dense " << var_name << " begin"; VLOG(1) << "send dense " << var_name << " begin";
SendDense(var_name); SendDense(var_name);
VLOG(1) << "send dense " << var_name << " done"; VLOG(1) << "send dense " << var_name << " done";
} VLOG(1) << "recv dense " << var_name << " begin";
}; RecvDense(var_name);
tasks.emplace_back(send_threadpool_->enqueue(std::move(send_task))); VLOG(1) << "recv dense " << var_name << " done";
};
tasks.emplace_back(send_threadpool_->enqueue(std::move(send_recv_task)));
}
} }
for (auto &task : tasks) { for (auto &task : tasks) {
...@@ -520,13 +566,17 @@ void GeoCommunicator::SendByCommunicator(int batches) { ...@@ -520,13 +566,17 @@ void GeoCommunicator::SendByCommunicator(int batches) {
} }
} }
void GeoCommunicator::SendSparse(const std::string &varname, int batches) { void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
std::vector<int64_t> ids; std::vector<int64_t> ids;
auto &ids_queue = send_ids_to_queue_.at(varname); auto &ids_queue = send_ids_to_queue_.at(varname);
for (int i = 0; i < batches; ++i) { auto send_varname = send_varname_to_ctx_.at[varname].splited_varnames[ep_idx];
auto pop_ids = ids_queue->Pop(); auto trainer_id = send_varname_to_ctx_.at[varname].trainer_id;
std::copy(pop_ids.begin(), pop_ids.end(), back_inserter(ids)); auto endpoint = send_varname_to_ctx_.at[varname].epmap[ep_idx];
for (int i = 0; i < splited_ids_vec_.size(); ++i) {
std::copy((*splited_ids_vec_[i])[ep_idx].begin(),
(*splited_ids_vec_[i])[ep_idx].end(), back_inserter(ids));
} }
auto size = ids.size(); auto size = ids.size();
...@@ -551,7 +601,7 @@ void GeoCommunicator::SendSparse(const std::string &varname, int batches) { ...@@ -551,7 +601,7 @@ void GeoCommunicator::SendSparse(const std::string &varname, int batches) {
auto dims1 = t_latest.dims()[1]; auto dims1 = t_latest.dims()[1];
auto cpu_ctx = paddle::platform::CPUDeviceContext(); auto cpu_ctx = paddle::platform::CPUDeviceContext();
auto *var_delta = delta_scope_->Var(varname); auto *var_delta = delta_scope_->Var(send_varname);
auto *t_delta = var_delta->GetMutable<framework::SelectedRows>(); auto *t_delta = var_delta->GetMutable<framework::SelectedRows>();
t_delta->set_height(ids.size()); t_delta->set_height(ids.size());
t_delta->mutable_rows()->assign(ids.begin(), ids.end()); t_delta->mutable_rows()->assign(ids.begin(), ids.end());
...@@ -575,9 +625,14 @@ void GeoCommunicator::SendSparse(const std::string &varname, int batches) { ...@@ -575,9 +625,14 @@ void GeoCommunicator::SendSparse(const std::string &varname, int batches) {
values[j][0]->data()); values[j][0]->data());
} }
auto &ctx = send_varname_to_ctx_.at(varname); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto send = distributed::ParameterSend<float>(); auto &cpu_ctx_send = *pool.Get(platform::CPUPlace());
send(ctx, *delta_scope_, true, 1); distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
auto ret = rpc_client->AsyncSendVar(endpoint, cpu_ctx_send,
*delta_scope_.get(), send_varname);
ret.wait();
} }
void GeoCommunicator::SendDense(const std::string &varname) { void GeoCommunicator::SendDense(const std::string &varname) {
...@@ -614,39 +669,29 @@ void GeoCommunicator::SendDense(const std::string &varname) { ...@@ -614,39 +669,29 @@ void GeoCommunicator::SendDense(const std::string &varname) {
send(ctx, *delta_scope_, true, 1); send(ctx, *delta_scope_, true, 1);
} }
void GeoCommunicator::RecvByCommunicator() { void GeoCommunicator::RecvByCommunicator() { return; }
std::vector<std::future<void>> tasks;
tasks.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) { void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
auto &var_name = iter.first; auto train_id = recv_varname_to_ctx_.at(var_name).trainer_id;
auto &recv_ctx = iter.second; auto endpoint = recv_varname_to_ctx_.at(var_name).epmap[ep_idx];
auto splited_var_name =
send_varname_to_ctx_.at(varname).splited_varnames[ep_idx];
auto recv_task = [this, &var_name, &recv_ctx] { VLOG(1) << "Begin to RecvSparse receive var: " << splited_var_name;
if (recv_ctx.is_sparse) {
RecvSparse(var_name);
} else {
VLOG(1) << "recv dense " << var_name << " begin";
RecvDense(var_name);
VLOG(1) << "recv dense " << var_name << " done";
}
};
tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : tasks) {
task.wait();
}
}
void GeoCommunicator::RecvSparse(const std::string &varname) { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
VLOG(1) << "RecvSparse receive var: " << varname; auto &cpu_ctx_recv = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(train_id);
auto *var_psrever = pserver_scope_->Var(splited_var_name);
auto handle = rpc_client->AsyncGetVar(endpoint, cpu_ctx_recv,
*pserver_scope_.get(), splited_var_name,
splited_var_name, splited_var_name);
handle->Wait();
auto *var_latest = recv_scope_->FindVar(varname); VLOG(1) << "Finish to RecvSparse receive var: " << splited_var_name;
auto *var_psrever = pserver_scope_->Var(varname);
auto &ctx = recv_varname_to_ctx_.at(varname); auto *var_latest = recv_scope_->FindVar(varname);
auto recv = distributed::ParameterRecv<float>();
recv(ctx, *pserver_scope_, true);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
var_psrever->IsInitialized(), true, var_psrever->IsInitialized(), true,
...@@ -657,7 +702,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname) { ...@@ -657,7 +702,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname) {
ids.assign(var_psrever->Get<framework::SelectedRows>().rows().begin(), ids.assign(var_psrever->Get<framework::SelectedRows>().rows().begin(),
var_psrever->Get<framework::SelectedRows>().rows().end()); var_psrever->Get<framework::SelectedRows>().rows().end());
VLOG(1) << "RecvSparse receive var: " << varname VLOG(1) << "RecvSparse receive var: " << splited_var_name
<< " ids Size: " << ids.size(); << " ids Size: " << ids.size();
auto t_psrever = var_psrever->Get<framework::SelectedRows>().value(); auto t_psrever = var_psrever->Get<framework::SelectedRows>().value();
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <ThreadPool.h> #include <ThreadPool.h>
#include <atomic> #include <atomic>
#include <deque> #include <deque>
#include <map> #include <map>
...@@ -25,8 +26,8 @@ limitations under the License. */ ...@@ -25,8 +26,8 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "gflags/gflags.h"
#include "gflags/gflags.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/communicator_common.h" #include "paddle/fluid/operators/distributed/communicator_common.h"
...@@ -250,6 +251,8 @@ class Communicator { ...@@ -250,6 +251,8 @@ class Communicator {
std::unordered_map<std::string, std::string> envs; std::unordered_map<std::string, std::string> envs;
}; };
using SplitedSparseIds = std::vector<std::unordered_set<int64_t>>;
class AsyncCommunicator : public Communicator { class AsyncCommunicator : public Communicator {
public: public:
AsyncCommunicator() : Communicator() {} AsyncCommunicator() : Communicator() {}
...@@ -423,7 +426,7 @@ class GeoCommunicator : public AsyncCommunicator { ...@@ -423,7 +426,7 @@ class GeoCommunicator : public AsyncCommunicator {
void SendByCommunicator(int batches) override; void SendByCommunicator(int batches) override;
void SendSparse(const std::string &varname, int batches); void SendSparse(const std::string &varname, int ep_idx);
void SendDense(const std::string &varname); void SendDense(const std::string &varname);
...@@ -431,7 +434,7 @@ class GeoCommunicator : public AsyncCommunicator { ...@@ -431,7 +434,7 @@ class GeoCommunicator : public AsyncCommunicator {
void RecvByCommunicator() override; void RecvByCommunicator() override;
void RecvSparse(const std::string &varname); void RecvSparse(const std::string &varname, int ep_idx);
void RecvDense(const std::string &varname); void RecvDense(const std::string &varname);
...@@ -454,11 +457,13 @@ class GeoCommunicator : public AsyncCommunicator { ...@@ -454,11 +457,13 @@ class GeoCommunicator : public AsyncCommunicator {
// parameter on pserver // parameter on pserver
std::shared_ptr<Scope> pserver_scope_; std::shared_ptr<Scope> pserver_scope_;
std::unordered_map<std::string, std::unordered_map<
std::shared_ptr<BlockingQueue<std::vector<int64_t>>>> std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<SplitedSparseIds>>>>
send_ids_to_queue_; send_ids_to_queue_;
std::unordered_map<std::string, std::shared_ptr<SparseValue>> old_sparses_; std::unordered_map<std::string, std::shared_ptr<SparseValue>> old_sparses_;
std::vector<SplitedSparseIds> splited_ids_vec_;
}; };
} // namespace distributed } // namespace distributed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册