提交 039d783d 编写于 作者: Q Qiao Longfei

change communicator_recv_wait_ms to communicator_max_send_grad_num_before_recv

上级 ea0df4e8
...@@ -29,7 +29,8 @@ DEFINE_bool(communicator_independent_recv_thread, true, ...@@ -29,7 +29,8 @@ DEFINE_bool(communicator_independent_recv_thread, true,
"use an independent to recv vars from parameter server"); "use an independent to recv vars from parameter server");
DEFINE_int32(communicator_send_queue_size, 20, DEFINE_int32(communicator_send_queue_size, 20,
"queue size to recv gradient before send"); "queue size to recv gradient before send");
DEFINE_int32(communicator_recv_wait_ms, 200, "wait time between each recv"); DEFINE_int32(communicator_max_send_grad_num_before_recv, 20,
"max grad num to send before recv parameters");
DEFINE_int32(communicator_thread_pool_size, 5, "thread num to do send or recv"); DEFINE_int32(communicator_thread_pool_size, 5, "thread num to do send or recv");
DEFINE_int32(communicator_max_merge_var_num, 20, DEFINE_int32(communicator_max_merge_var_num, 20,
"max var num to merge and send"); "max var num to merge and send");
...@@ -60,7 +61,8 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx, ...@@ -60,7 +61,8 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
<< FLAGS_communicator_independent_recv_thread; << FLAGS_communicator_independent_recv_thread;
VLOG(0) << "communicator_send_queue_size: " VLOG(0) << "communicator_send_queue_size: "
<< FLAGS_communicator_send_queue_size; << FLAGS_communicator_send_queue_size;
VLOG(0) << "communicator_recv_wait_ms: " << FLAGS_communicator_recv_wait_ms; VLOG(0) << "communicator_max_send_grad_num_before_recv: "
<< FLAGS_communicator_max_send_grad_num_before_recv;
VLOG(0) << "communicator_thread_pool_size: " VLOG(0) << "communicator_thread_pool_size: "
<< FLAGS_communicator_thread_pool_size; << FLAGS_communicator_thread_pool_size;
VLOG(0) << "communicator_max_merge_var_num: " VLOG(0) << "communicator_max_merge_var_num: "
...@@ -102,6 +104,10 @@ void Communicator::SendThread() { ...@@ -102,6 +104,10 @@ void Communicator::SendThread() {
while (var_queue->Size() > 0 && while (var_queue->Size() > 0 &&
merged_var_num < FLAGS_communicator_max_merge_var_num) { merged_var_num < FLAGS_communicator_max_merge_var_num) {
vars.push_back(var_queue->Pop()); vars.push_back(var_queue->Pop());
// only count the send number of the first var
if (var_name == send_varname_to_queue_.begin()->first) {
grad_num_.fetch_add(1, std::memory_order_relaxed);
}
merged_var_num++; merged_var_num++;
} }
auto before_merge = GetCurrentUS(); auto before_merge = GetCurrentUS();
...@@ -129,7 +135,7 @@ void Communicator::SendThread() { ...@@ -129,7 +135,7 @@ void Communicator::SendThread() {
} }
auto after_run_send_graph = GetCurrentUS(); auto after_run_send_graph = GetCurrentUS();
auto send_graph_use_time = after_run_send_graph - before_run_send_graph; auto send_graph_use_time = after_run_send_graph - before_run_send_graph;
if (send_graph_use_time > 10) { if (send_graph_use_time > 100) {
VLOG(1) << "run send graph use time " VLOG(1) << "run send graph use time "
<< after_run_send_graph - before_run_send_graph; << after_run_send_graph - before_run_send_graph;
} }
...@@ -165,9 +171,14 @@ void Communicator::RecvAll() { ...@@ -165,9 +171,14 @@ void Communicator::RecvAll() {
void Communicator::RecvThread() { void Communicator::RecvThread() {
VLOG(3) << "RecvThread start!"; VLOG(3) << "RecvThread start!";
while (running_) { while (running_) {
auto grad_num = grad_num_.load();
if (grad_num > FLAGS_communicator_max_send_grad_num_before_recv) {
VLOG(1) << "current grad num " << grad_num;
RecvAll(); RecvAll();
std::this_thread::sleep_for( grad_num_.store(0);
std::chrono::milliseconds(FLAGS_communicator_recv_wait_ms)); } else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
} }
} }
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <atomic>
#include <deque> #include <deque>
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -184,6 +185,7 @@ class Communicator { ...@@ -184,6 +185,7 @@ class Communicator {
std::unique_ptr<Scope> send_scope_; // an independent scope std::unique_ptr<Scope> send_scope_; // an independent scope
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr}; std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr}; std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv
// the following code is for initialize the commnunicator // the following code is for initialize the commnunicator
public: public:
......
...@@ -155,7 +155,7 @@ def __bootstrap__(): ...@@ -155,7 +155,7 @@ def __bootstrap__():
# env for communicator # env for communicator
read_env_flags.append('communicator_independent_recv_thread') read_env_flags.append('communicator_independent_recv_thread')
read_env_flags.append('communicator_send_queue_size') read_env_flags.append('communicator_send_queue_size')
read_env_flags.append('communicator_recv_wait_ms') read_env_flags.append('communicator_max_send_grad_num_before_recv')
read_env_flags.append('communicator_thread_pool_size') read_env_flags.append('communicator_thread_pool_size')
read_env_flags.append('communicator_max_merge_var_num') read_env_flags.append('communicator_max_merge_var_num')
read_env_flags.append('communicator_fake_rpc') read_env_flags.append('communicator_fake_rpc')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册