提交 50601501 编写于 作者: Q Qiao Longfei

improve communicator

上级 c2cce6ba
......@@ -54,7 +54,7 @@ cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory)
cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor)
cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor simple_threadpool)
if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_rpc executor ${RPC_DEPS}
......
......@@ -25,9 +25,9 @@ namespace paddle {
namespace operators {
namespace distributed {
static void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope) {
static inline void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope) {
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!");
auto cpu_place = platform::CPUPlace();
auto &var0 = vars[0];
......@@ -62,31 +62,53 @@ static void MergeVars(const std::string &var_name,
}
void Communicator::SendThread() {
for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first;
VLOG(3) << "merge var " << var_name << " and send";
auto &var_queue = iter.second;
std::vector<std::shared_ptr<Variable>> vars;
const size_t max_merge_var_num = 20;
size_t merged_var_num = 0;
while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) {
vars.push_back(var_queue->Pop());
merged_var_num++;
while (running_) {
std::vector<std::future<void>> task_futures;
task_futures.reserve(send_varname_to_ctx_.size());
for (auto &iter : send_varname_to_queue_) {
auto send_task = [this, &iter] {
auto &var_name = iter.first;
VLOG(3) << "merge var " << var_name << " and send";
auto &var_queue = iter.second;
std::vector<std::shared_ptr<Variable>> vars;
const size_t max_merge_var_num = 20;
size_t merged_var_num = 0;
while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) {
vars.push_back(var_queue->Pop());
merged_var_num++;
}
MergeVars(var_name, vars, send_scope_.get());
auto send_functor = distributed::ParameterSend<float>();
auto &ctx = send_varname_to_ctx_.at(var_name);
send_functor(ctx, *send_scope_, true);
};
task_futures.emplace_back(
send_threadpool_->enqueue(std::move(send_task)));
}
for (auto &task_f : task_futures) {
task_f.wait();
}
MergeVars(var_name, vars, send_scope_.get());
// auto send_functor = distributed::ParameterSend<float>();
// send_functor(var_name, send_varname_to_ctx_[var_name], exe_ctx,
// send_scope_, true);
}
}
void Communicator::RecvThread() {
// parallel run recv graph
for (auto &iter : recv_varname_to_ctx_) {
auto &var_name = iter.first;
VLOG(3) << "recv var " << iter.first;
// auto recv_functor = distributed::ParameterRecv<float>();
// recv_functor(var_name, iter.second, exe_ctx, recv_scope_);
while (running_) {
// parallel run recv graph
std::vector<std::future<void>> task_futures;
task_futures.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) {
auto recv_task = [this, &iter] {
auto &var_name = iter.first;
VLOG(3) << "recv var " << var_name;
auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(iter.second, *recv_scope_);
};
task_futures.emplace_back(
recv_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : task_futures) {
task.wait();
}
}
}
......@@ -101,6 +123,7 @@ void Communicator::Send(const std::string &var_name,
}
void Communicator::Start() {
running_ = true;
// start send and recv thread
send_thread_.reset(
new std::thread(std::bind(&Communicator::SendThread, this)));
......
......@@ -19,6 +19,8 @@ limitations under the License. */
#include <string>
#include <vector>
#include <ThreadPool.h>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
......@@ -100,9 +102,18 @@ class Communicator {
send_varname_to_queue_[iter.first] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(10);
}
// TODO(qiao): default 5, need to config
send_threadpool_.reset(new ::ThreadPool(5));
recv_threadpool_.reset(new ::ThreadPool(5));
}
~Communicator() {}
~Communicator() {
VLOG(3) << "~Communicator";
running_ = false;
send_thread_->join();
recv_thread_->join();
VLOG(3) << "~Communicator done";
}
void Start();
......@@ -113,6 +124,7 @@ class Communicator {
void SendThread();
void RecvThread();
bool running_ = false;
std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_;
......@@ -122,6 +134,8 @@ class Communicator {
std::unique_ptr<std::thread> recv_thread_;
Scope* recv_scope_; // should be global scope
std::unique_ptr<Scope> send_scope_; // an independent scope
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
};
} // namespace distributed
......
......@@ -29,6 +29,14 @@ struct RpcContext {
splited_var_names(names),
epmap(emap),
height_sections(sections) {}
RpcContext(const RpcContext& ctx) {
var_name = ctx.var_name;
splited_var_names = ctx.splited_var_names;
epmap = ctx.epmap;
height_sections = ctx.height_sections;
}
std::string var_name;
std::vector<std::string> splited_var_names;
std::vector<std::string> epmap;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册