提交 63cd70a8 编写于 作者: Q Qiao Longfei

fix blocking problem

上级 c0e5941e
......@@ -75,10 +75,11 @@ void Communicator::SendThread() {
while (running_) {
std::vector<std::future<void>> task_futures;
task_futures.reserve(send_varname_to_ctx_.size());
VLOG(3) << "run send graph";
for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first;
auto &var_queue = iter.second;
if (var_queue->NotEmpty()) { // will block if queue is empty
if (var_queue->Size() > 0) {
auto send_task = [this, &var_name, &var_queue] {
VLOG(3) << "merge var " << var_name << " and send";
std::vector<std::shared_ptr<Variable>> vars;
......@@ -96,33 +97,41 @@ void Communicator::SendThread() {
};
task_futures.emplace_back(
send_threadpool_->enqueue(std::move(send_task)));
} else {
VLOG(3) << var_name << " queue empty";
}
}
for (auto &task_f : task_futures) {
task_f.wait();
}
VLOG(3) << "run send graph done";
RecvAll();
}
}
void Communicator::RecvAll() {
VLOG(3) << "parallel run recv graph";
std::vector<std::future<void>> task_futures;
task_futures.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) {
auto recv_task = [this, &iter] {
auto &var_name = iter.first;
VLOG(3) << "recv var " << var_name;
auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(iter.second, *recv_scope_);
};
task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : task_futures) {
task.wait();
}
VLOG(3) << "run recv graph done";
}
void Communicator::RecvThread() {
VLOG(3) << "RecvThread start!";
while (running_) {
// parallel run recv graph
std::vector<std::future<void>> task_futures;
task_futures.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) {
auto recv_task = [this, &iter] {
auto &var_name = iter.first;
VLOG(3) << "recv var " << var_name;
auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(iter.second, *recv_scope_);
};
task_futures.emplace_back(
recv_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : task_futures) {
task.wait();
}
RecvAll();
// TODO(qiao) need to be configuable
std::this_thread::sleep_for(std::chrono::milliseconds(200));
}
......@@ -136,7 +145,9 @@ void Communicator::Send(const std::string &var_name,
PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited");
auto tmp_grad_var = std::make_shared<Variable>();
framework::CopyVariable(*grad_var, tmp_grad_var.get());
send_varname_to_queue_[var_name]->Push(tmp_grad_var);
auto &queue = send_varname_to_queue_.at(var_name);
VLOG(3) << "send " << var_name << " queue size " << queue->Size();
queue->Push(tmp_grad_var);
}
Communicator *Communicator::GetInstance() { return communicator_.get(); }
......@@ -146,8 +157,8 @@ void Communicator::Start() {
// start send and recv thread
send_thread_.reset(
new std::thread(std::bind(&Communicator::SendThread, this)));
recv_thread_.reset(
new std::thread(std::bind(&Communicator::RecvThread, this)));
// recv_thread_.reset(
// new std::thread(std::bind(&Communicator::RecvThread, this)));
}
} // namespace distributed
......
......@@ -43,37 +43,36 @@ class BlockingQueue {
}
bool Push(const T& elem) {
std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.push_back(elem);
recv_cv_.notify_one();
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.push_back(elem);
}
cv_.notify_one();
return true;
}
bool Push(T&& elem) {
std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.emplace_back(std::move(elem));
recv_cv_.notify_one();
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.emplace_back(std::move(elem));
}
cv_.notify_one();
return true;
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
recv_cv_.wait(lock, [=] { return !queue_.empty(); });
cv_.wait(lock, [=] { return !queue_.empty(); });
T rc(std::move(queue_.front()));
queue_.pop_front();
cv_.notify_one();
return rc;
}
bool NotEmpty() {
std::unique_lock<std::mutex> lock(mutex_);
recv_cv_.wait(lock, [=] { return !queue_.empty(); });
return true;
}
size_t Cap() const {
std::lock_guard<std::mutex> lock(mutex_);
return capacity_;
......@@ -89,8 +88,7 @@ class BlockingQueue {
std::deque<T> queue_;
mutable std::mutex mutex_;
std::condition_variable recv_cv_;
std::condition_variable send_cv_;
std::condition_variable cv_;
};
using RpcCtxMap = std::unordered_map<std::string, RpcContext>;
......@@ -127,6 +125,8 @@ class Communicator {
void Send(const std::string& var_name, const framework::Scope& scope);
private:
// recv all parameter
void RecvAll();
void SendThread();
void RecvThread();
......
......@@ -41,6 +41,7 @@ using DDim = framework::DDim;
template <typename T>
void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
const framework::Scope &scope) {
VLOG(3) << "ParameterRecv in";
framework::Scope *local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......@@ -90,6 +91,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
}
delete local_scope;
VLOG(3) << "ParameterRecv out";
}
template struct ParameterRecv<float>;
......
......@@ -48,12 +48,15 @@ class SendOp : public framework::OperatorBase {
if (send_varnames.size() > 0) {
PADDLE_ENFORCE_EQ(ins.size(), 1, "");
// auto send_functor = distributed::ParameterSend<float>();
// auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames,
// epmap,
// height_sections);
// send_functor(rpc_ctx, scope, static_cast<bool>(sync_send));
/*
auto send_functor = distributed::ParameterSend<float>();
auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap,
height_sections);
send_functor(rpc_ctx, scope, static_cast<bool>(sync_send));
*/
VLOG(3) << "send " << ins[0];
distributed::Communicator::GetInstance()->Send(ins[0], scope);
VLOG(3) << "send " << ins[0] << " done";
} else {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册