From 039d783db5ed14a5eabadb3177c800697afec39d Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 18 Mar 2019 13:35:37 +0800 Subject: [PATCH] change communicator_recv_wait_ms to communicator_max_send_grad_num_before_recv --- .../operators/distributed/communicator.cc | 23 ++++++++++++++----- .../operators/distributed/communicator.h | 2 ++ python/paddle/fluid/__init__.py | 2 +- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index 3661c2763d..eba18c6777 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -29,7 +29,8 @@ DEFINE_bool(communicator_independent_recv_thread, true, "use an independent to recv vars from parameter server"); DEFINE_int32(communicator_send_queue_size, 20, "queue size to recv gradient before send"); -DEFINE_int32(communicator_recv_wait_ms, 200, "wait time between each recv"); +DEFINE_int32(communicator_max_send_grad_num_before_recv, 20, + "max grad num to send before recv parameters"); DEFINE_int32(communicator_thread_pool_size, 5, "thread num to do send or recv"); DEFINE_int32(communicator_max_merge_var_num, 20, "max var num to merge and send"); @@ -60,7 +61,8 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx, << FLAGS_communicator_independent_recv_thread; VLOG(0) << "communicator_send_queue_size: " << FLAGS_communicator_send_queue_size; - VLOG(0) << "communicator_recv_wait_ms: " << FLAGS_communicator_recv_wait_ms; + VLOG(0) << "communicator_max_send_grad_num_before_recv: " + << FLAGS_communicator_max_send_grad_num_before_recv; VLOG(0) << "communicator_thread_pool_size: " << FLAGS_communicator_thread_pool_size; VLOG(0) << "communicator_max_merge_var_num: " @@ -102,6 +104,10 @@ void Communicator::SendThread() { while (var_queue->Size() > 0 && merged_var_num < FLAGS_communicator_max_merge_var_num) { vars.push_back(var_queue->Pop()); + // only count the send number of the first var + if (var_name == send_varname_to_queue_.begin()->first) { + grad_num_.fetch_add(1, std::memory_order_relaxed); + } merged_var_num++; } auto before_merge = GetCurrentUS(); @@ -129,7 +135,7 @@ void Communicator::SendThread() { } auto after_run_send_graph = GetCurrentUS(); auto send_graph_use_time = after_run_send_graph - before_run_send_graph; - if (send_graph_use_time > 10) { + if (send_graph_use_time > 100) { VLOG(1) << "run send graph use time " << after_run_send_graph - before_run_send_graph; } @@ -165,9 +171,14 @@ void Communicator::RecvAll() { void Communicator::RecvThread() { VLOG(3) << "RecvThread start!"; while (running_) { - RecvAll(); - std::this_thread::sleep_for( - std::chrono::milliseconds(FLAGS_communicator_recv_wait_ms)); + auto grad_num = grad_num_.load(); + if (grad_num > FLAGS_communicator_max_send_grad_num_before_recv) { + VLOG(1) << "current grad num " << grad_num; + RecvAll(); + grad_num_.store(0); + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } } } diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index 3fe2a21232..859c0a7f51 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -184,6 +185,7 @@ class Communicator { std::unique_ptr send_scope_; // an independent scope std::unique_ptr<::ThreadPool> send_threadpool_{nullptr}; std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr}; + std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv // the following code is for initialize the commnunicator public: diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index c478c8ceee..97ac7fd97b 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -155,7 +155,7 @@ def __bootstrap__(): # env for communicator read_env_flags.append('communicator_independent_recv_thread') read_env_flags.append('communicator_send_queue_size') - read_env_flags.append('communicator_recv_wait_ms') + read_env_flags.append('communicator_max_send_grad_num_before_recv') read_env_flags.append('communicator_thread_pool_size') read_env_flags.append('communicator_max_merge_var_num') read_env_flags.append('communicator_fake_rpc') -- GitLab