提交 50601501 编写于 作者: Q Qiao Longfei

improve communicator

上级 c2cce6ba
...@@ -54,7 +54,7 @@ cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) ...@@ -54,7 +54,7 @@ cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory) cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory)
cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor) cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor simple_threadpool)
if(WITH_GPU) if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc cc_test(collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_rpc executor ${RPC_DEPS} DEPS sendrecvop_rpc executor ${RPC_DEPS}
......
...@@ -25,7 +25,7 @@ namespace paddle { ...@@ -25,7 +25,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
static void MergeVars(const std::string &var_name, static inline void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars, const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope) { Scope *scope) {
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!"); PADDLE_ENFORCE(!vars.empty(), "should have value to merge!");
...@@ -62,7 +62,11 @@ static void MergeVars(const std::string &var_name, ...@@ -62,7 +62,11 @@ static void MergeVars(const std::string &var_name,
} }
void Communicator::SendThread() { void Communicator::SendThread() {
while (running_) {
std::vector<std::future<void>> task_futures;
task_futures.reserve(send_varname_to_ctx_.size());
for (auto &iter : send_varname_to_queue_) { for (auto &iter : send_varname_to_queue_) {
auto send_task = [this, &iter] {
auto &var_name = iter.first; auto &var_name = iter.first;
VLOG(3) << "merge var " << var_name << " and send"; VLOG(3) << "merge var " << var_name << " and send";
auto &var_queue = iter.second; auto &var_queue = iter.second;
...@@ -74,19 +78,37 @@ void Communicator::SendThread() { ...@@ -74,19 +78,37 @@ void Communicator::SendThread() {
merged_var_num++; merged_var_num++;
} }
MergeVars(var_name, vars, send_scope_.get()); MergeVars(var_name, vars, send_scope_.get());
// auto send_functor = distributed::ParameterSend<float>(); auto send_functor = distributed::ParameterSend<float>();
// send_functor(var_name, send_varname_to_ctx_[var_name], exe_ctx, auto &ctx = send_varname_to_ctx_.at(var_name);
// send_scope_, true); send_functor(ctx, *send_scope_, true);
};
task_futures.emplace_back(
send_threadpool_->enqueue(std::move(send_task)));
}
for (auto &task_f : task_futures) {
task_f.wait();
}
} }
} }
void Communicator::RecvThread() { void Communicator::RecvThread() {
while (running_) {
// parallel run recv graph // parallel run recv graph
std::vector<std::future<void>> task_futures;
task_futures.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) { for (auto &iter : recv_varname_to_ctx_) {
auto recv_task = [this, &iter] {
auto &var_name = iter.first; auto &var_name = iter.first;
VLOG(3) << "recv var " << iter.first; VLOG(3) << "recv var " << var_name;
// auto recv_functor = distributed::ParameterRecv<float>(); auto recv_functor = distributed::ParameterRecv<float>();
// recv_functor(var_name, iter.second, exe_ctx, recv_scope_); recv_functor(iter.second, *recv_scope_);
};
task_futures.emplace_back(
recv_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : task_futures) {
task.wait();
}
} }
} }
...@@ -101,6 +123,7 @@ void Communicator::Send(const std::string &var_name, ...@@ -101,6 +123,7 @@ void Communicator::Send(const std::string &var_name,
} }
void Communicator::Start() { void Communicator::Start() {
running_ = true;
// start send and recv thread // start send and recv thread
send_thread_.reset( send_thread_.reset(
new std::thread(std::bind(&Communicator::SendThread, this))); new std::thread(std::bind(&Communicator::SendThread, this)));
......
...@@ -19,6 +19,8 @@ limitations under the License. */ ...@@ -19,6 +19,8 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include <ThreadPool.h>
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/rpc_common.h" #include "paddle/fluid/operators/distributed/rpc_common.h"
...@@ -100,9 +102,18 @@ class Communicator { ...@@ -100,9 +102,18 @@ class Communicator {
send_varname_to_queue_[iter.first] = send_varname_to_queue_[iter.first] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(10); std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(10);
} }
// TODO(qiao): default 5, need to config
send_threadpool_.reset(new ::ThreadPool(5));
recv_threadpool_.reset(new ::ThreadPool(5));
} }
~Communicator() {} ~Communicator() {
VLOG(3) << "~Communicator";
running_ = false;
send_thread_->join();
recv_thread_->join();
VLOG(3) << "~Communicator done";
}
void Start(); void Start();
...@@ -113,6 +124,7 @@ class Communicator { ...@@ -113,6 +124,7 @@ class Communicator {
void SendThread(); void SendThread();
void RecvThread(); void RecvThread();
bool running_ = false;
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>> std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_; send_varname_to_queue_;
...@@ -122,6 +134,8 @@ class Communicator { ...@@ -122,6 +134,8 @@ class Communicator {
std::unique_ptr<std::thread> recv_thread_; std::unique_ptr<std::thread> recv_thread_;
Scope* recv_scope_; // should be global scope Scope* recv_scope_; // should be global scope
std::unique_ptr<Scope> send_scope_; // an independent scope std::unique_ptr<Scope> send_scope_; // an independent scope
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
}; };
} // namespace distributed } // namespace distributed
......
...@@ -29,6 +29,14 @@ struct RpcContext { ...@@ -29,6 +29,14 @@ struct RpcContext {
splited_var_names(names), splited_var_names(names),
epmap(emap), epmap(emap),
height_sections(sections) {} height_sections(sections) {}
RpcContext(const RpcContext& ctx) {
var_name = ctx.var_name;
splited_var_names = ctx.splited_var_names;
epmap = ctx.epmap;
height_sections = ctx.height_sections;
}
std::string var_name; std::string var_name;
std::vector<std::string> splited_var_names; std::vector<std::string> splited_var_names;
std::vector<std::string> epmap; std::vector<std::string> epmap;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册