提交 0a828fef 编写于 作者: Q Qiao Longfei

add some flags for communicator

上级 63cd70a8
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/communicator.h" #include "paddle/fluid/operators/distributed/communicator.h"
#include <gflags/gflags.h>
#include <chrono> // NOLINT #include <chrono> // NOLINT
#include <thread> // NOLINT #include <thread> // NOLINT
...@@ -24,6 +25,13 @@ limitations under the License. */ ...@@ -24,6 +25,13 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/parameter_send.h" #include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
DEFINE_bool(communicator_independent_recv_thread, true,
"use an independent to recv vars from parameter server");
DEFINE_int32(communicator_send_queue_size, 20,
"queue size to recv gradient before send");
DEFINE_int32(communicator_recv_wait_ms, 200, "wait time between each recv");
DEFINE_int32(communicator_thread_pool_size, 5, "wait time between each recv");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
...@@ -70,6 +78,38 @@ static inline void MergeVars(const std::string &var_name, ...@@ -70,6 +78,38 @@ static inline void MergeVars(const std::string &var_name,
std::unique_ptr<Communicator> Communicator::communicator_(nullptr); std::unique_ptr<Communicator> Communicator::communicator_(nullptr);
std::once_flag Communicator::init_flag_; std::once_flag Communicator::init_flag_;
Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope)
: send_varname_to_ctx_(send_varname_to_ctx),
recv_varname_to_ctx_(recv_varname_to_ctx),
recv_scope_(recv_scope) {
// get all send information from graph, build vars_to_send
VLOG(0) << "communicator_independent_recv_thread: "
<< FLAGS_communicator_independent_recv_thread;
VLOG(0) << "communicator_send_queue_size: "
<< FLAGS_communicator_send_queue_size;
VLOG(0) << "communicator_recv_wait_ms: " << FLAGS_communicator_recv_wait_ms;
VLOG(0) << "communicator_thread_pool_size: "
<< FLAGS_communicator_thread_pool_size;
send_scope_.reset(new Scope());
for (auto &iter : send_varname_to_ctx_) {
send_varname_to_queue_[iter.first] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
FLAGS_communicator_send_queue_size);
}
send_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
recv_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
}
Communicator::~Communicator() {
VLOG(3) << "~Communicator";
running_ = false;
if (send_thread_) send_thread_->join();
if (recv_thread_) recv_thread_->join();
VLOG(3) << "~Communicator done";
}
void Communicator::SendThread() { void Communicator::SendThread() {
VLOG(3) << "SendThread start!"; VLOG(3) << "SendThread start!";
while (running_) { while (running_) {
...@@ -105,8 +145,10 @@ void Communicator::SendThread() { ...@@ -105,8 +145,10 @@ void Communicator::SendThread() {
task_f.wait(); task_f.wait();
} }
VLOG(3) << "run send graph done"; VLOG(3) << "run send graph done";
if (!FLAGS_communicator_independent_recv_thread) {
RecvAll(); RecvAll();
} }
}
} }
void Communicator::RecvAll() { void Communicator::RecvAll() {
...@@ -132,8 +174,8 @@ void Communicator::RecvThread() { ...@@ -132,8 +174,8 @@ void Communicator::RecvThread() {
VLOG(3) << "RecvThread start!"; VLOG(3) << "RecvThread start!";
while (running_) { while (running_) {
RecvAll(); RecvAll();
// TODO(qiao) need to be configuable std::this_thread::sleep_for(
std::this_thread::sleep_for(std::chrono::milliseconds(200)); std::chrono::milliseconds(FLAGS_communicator_recv_wait_ms));
} }
} }
...@@ -157,8 +199,10 @@ void Communicator::Start() { ...@@ -157,8 +199,10 @@ void Communicator::Start() {
// start send and recv thread // start send and recv thread
send_thread_.reset( send_thread_.reset(
new std::thread(std::bind(&Communicator::SendThread, this))); new std::thread(std::bind(&Communicator::SendThread, this)));
// recv_thread_.reset( if (FLAGS_communicator_independent_recv_thread) {
// new std::thread(std::bind(&Communicator::RecvThread, this))); recv_thread_.reset(
new std::thread(std::bind(&Communicator::RecvThread, this)));
}
} }
} // namespace distributed } // namespace distributed
......
...@@ -96,28 +96,9 @@ using RpcCtxMap = std::unordered_map<std::string, RpcContext>; ...@@ -96,28 +96,9 @@ using RpcCtxMap = std::unordered_map<std::string, RpcContext>;
class Communicator { class Communicator {
public: public:
Communicator(const RpcCtxMap& send_varname_to_ctx, Communicator(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope) const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope);
: send_varname_to_ctx_(send_varname_to_ctx),
recv_varname_to_ctx_(recv_varname_to_ctx),
recv_scope_(recv_scope) {
// get all send information from graph, build vars_to_send
send_scope_.reset(new Scope());
for (auto& iter : send_varname_to_ctx_) {
send_varname_to_queue_[iter.first] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(10);
}
// TODO(qiao): default 5, need to config
send_threadpool_.reset(new ::ThreadPool(5));
recv_threadpool_.reset(new ::ThreadPool(5));
}
~Communicator() { ~Communicator();
VLOG(3) << "~Communicator";
running_ = false;
send_thread_->join();
recv_thread_->join();
VLOG(3) << "~Communicator done";
}
void Start(); void Start();
......
...@@ -150,6 +150,10 @@ def __bootstrap__(): ...@@ -150,6 +150,10 @@ def __bootstrap__():
read_env_flags.append('rpc_get_thread_num') read_env_flags.append('rpc_get_thread_num')
read_env_flags.append('rpc_prefetch_thread_num') read_env_flags.append('rpc_prefetch_thread_num')
read_env_flags.append('rpc_disable_reuse_port') read_env_flags.append('rpc_disable_reuse_port')
read_env_flags.append('communicator_independent_recv_thread')
read_env_flags.append('communicator_send_queue_size')
read_env_flags.append('communicator_recv_wait_ms')
read_env_flags.append('communicator_thread_pool_size')
if core.is_compiled_with_brpc(): if core.is_compiled_with_brpc():
read_env_flags.append('max_body_size') read_env_flags.append('max_body_size')
#set brpc max body size #set brpc max body size
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册