提交 50601501 编写于 作者: Q Qiao Longfei

improve communicator

上级 c2cce6ba
...@@ -54,7 +54,7 @@ cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) ...@@ -54,7 +54,7 @@ cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory) cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory)
cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor) cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor simple_threadpool)
if(WITH_GPU) if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc cc_test(collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_rpc executor ${RPC_DEPS} DEPS sendrecvop_rpc executor ${RPC_DEPS}
......
...@@ -25,9 +25,9 @@ namespace paddle { ...@@ -25,9 +25,9 @@ namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
static void MergeVars(const std::string &var_name, static inline void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars, const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope) { Scope *scope) {
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!"); PADDLE_ENFORCE(!vars.empty(), "should have value to merge!");
auto cpu_place = platform::CPUPlace(); auto cpu_place = platform::CPUPlace();
auto &var0 = vars[0]; auto &var0 = vars[0];
...@@ -62,31 +62,53 @@ static void MergeVars(const std::string &var_name, ...@@ -62,31 +62,53 @@ static void MergeVars(const std::string &var_name,
} }
void Communicator::SendThread() { void Communicator::SendThread() {
for (auto &iter : send_varname_to_queue_) { while (running_) {
auto &var_name = iter.first; std::vector<std::future<void>> task_futures;
VLOG(3) << "merge var " << var_name << " and send"; task_futures.reserve(send_varname_to_ctx_.size());
auto &var_queue = iter.second; for (auto &iter : send_varname_to_queue_) {
std::vector<std::shared_ptr<Variable>> vars; auto send_task = [this, &iter] {
const size_t max_merge_var_num = 20; auto &var_name = iter.first;
size_t merged_var_num = 0; VLOG(3) << "merge var " << var_name << " and send";
while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) { auto &var_queue = iter.second;
vars.push_back(var_queue->Pop()); std::vector<std::shared_ptr<Variable>> vars;
merged_var_num++; const size_t max_merge_var_num = 20;
size_t merged_var_num = 0;
while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) {
vars.push_back(var_queue->Pop());
merged_var_num++;
}
MergeVars(var_name, vars, send_scope_.get());
auto send_functor = distributed::ParameterSend<float>();
auto &ctx = send_varname_to_ctx_.at(var_name);
send_functor(ctx, *send_scope_, true);
};
task_futures.emplace_back(
send_threadpool_->enqueue(std::move(send_task)));
}
for (auto &task_f : task_futures) {
task_f.wait();
} }
MergeVars(var_name, vars, send_scope_.get());
// auto send_functor = distributed::ParameterSend<float>();
// send_functor(var_name, send_varname_to_ctx_[var_name], exe_ctx,
// send_scope_, true);
} }
} }
void Communicator::RecvThread() { void Communicator::RecvThread() {
// parallel run recv graph while (running_) {
for (auto &iter : recv_varname_to_ctx_) { // parallel run recv graph
auto &var_name = iter.first; std::vector<std::future<void>> task_futures;
VLOG(3) << "recv var " << iter.first; task_futures.reserve(recv_varname_to_ctx_.size());
// auto recv_functor = distributed::ParameterRecv<float>(); for (auto &iter : recv_varname_to_ctx_) {
// recv_functor(var_name, iter.second, exe_ctx, recv_scope_); auto recv_task = [this, &iter] {
auto &var_name = iter.first;
VLOG(3) << "recv var " << var_name;
auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(iter.second, *recv_scope_);
};
task_futures.emplace_back(
recv_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : task_futures) {
task.wait();
}
} }
} }
...@@ -101,6 +123,7 @@ void Communicator::Send(const std::string &var_name, ...@@ -101,6 +123,7 @@ void Communicator::Send(const std::string &var_name,
} }
void Communicator::Start() { void Communicator::Start() {
running_ = true;
// start send and recv thread // start send and recv thread
send_thread_.reset( send_thread_.reset(
new std::thread(std::bind(&Communicator::SendThread, this))); new std::thread(std::bind(&Communicator::SendThread, this)));
......
...@@ -19,6 +19,8 @@ limitations under the License. */ ...@@ -19,6 +19,8 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include <ThreadPool.h>
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/rpc_common.h" #include "paddle/fluid/operators/distributed/rpc_common.h"
...@@ -100,9 +102,18 @@ class Communicator { ...@@ -100,9 +102,18 @@ class Communicator {
send_varname_to_queue_[iter.first] = send_varname_to_queue_[iter.first] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(10); std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(10);
} }
// TODO(qiao): default 5, need to config
send_threadpool_.reset(new ::ThreadPool(5));
recv_threadpool_.reset(new ::ThreadPool(5));
} }
~Communicator() {} ~Communicator() {
VLOG(3) << "~Communicator";
running_ = false;
send_thread_->join();
recv_thread_->join();
VLOG(3) << "~Communicator done";
}
void Start(); void Start();
...@@ -113,6 +124,7 @@ class Communicator { ...@@ -113,6 +124,7 @@ class Communicator {
void SendThread(); void SendThread();
void RecvThread(); void RecvThread();
bool running_ = false;
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>> std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_; send_varname_to_queue_;
...@@ -122,6 +134,8 @@ class Communicator { ...@@ -122,6 +134,8 @@ class Communicator {
std::unique_ptr<std::thread> recv_thread_; std::unique_ptr<std::thread> recv_thread_;
Scope* recv_scope_; // should be global scope Scope* recv_scope_; // should be global scope
std::unique_ptr<Scope> send_scope_; // an independent scope std::unique_ptr<Scope> send_scope_; // an independent scope
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
}; };
} // namespace distributed } // namespace distributed
......
...@@ -29,6 +29,14 @@ struct RpcContext { ...@@ -29,6 +29,14 @@ struct RpcContext {
splited_var_names(names), splited_var_names(names),
epmap(emap), epmap(emap),
height_sections(sections) {} height_sections(sections) {}
RpcContext(const RpcContext& ctx) {
var_name = ctx.var_name;
splited_var_names = ctx.splited_var_names;
epmap = ctx.epmap;
height_sections = ctx.height_sections;
}
std::string var_name; std::string var_name;
std::vector<std::string> splited_var_names; std::vector<std::string> splited_var_names;
std::vector<std::string> epmap; std::vector<std::string> epmap;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册