From ad5a2b3edfb437a225d7f42ab5c35b65a3b9d49e Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 11 Mar 2019 11:02:54 +0800 Subject: [PATCH] add some debug flags for communicator --- .../operators/distributed/communicator.cc | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index 73b9800d437..06f7859f4f8 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -30,7 +30,11 @@ DEFINE_bool(communicator_independent_recv_thread, true, DEFINE_int32(communicator_send_queue_size, 20, "queue size to recv gradient before send"); DEFINE_int32(communicator_recv_wait_ms, 200, "wait time between each recv"); -DEFINE_int32(communicator_thread_pool_size, 5, "wait time between each recv"); +DEFINE_int32(communicator_thread_pool_size, 5, "thread num to do send or recv"); +DEFINE_int32(communicator_max_merge_var_num, 20, + "max var num to merge and send"); +DEFINE_bool(communicator_fake_rpc, false, + "fake mode does not really send any thing"); namespace paddle { namespace operators { @@ -92,6 +96,9 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx, VLOG(0) << "communicator_recv_wait_ms: " << FLAGS_communicator_recv_wait_ms; VLOG(0) << "communicator_thread_pool_size: " << FLAGS_communicator_thread_pool_size; + VLOG(0) << "communicator_max_merge_var_num" + << FLAGS_communicator_max_merge_var_num; + VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc; send_scope_.reset(new Scope()); for (auto &iter : send_varname_to_ctx_) { send_varname_to_queue_[iter.first] = @@ -123,17 +130,18 @@ void Communicator::SendThread() { auto send_task = [this, &var_name, &var_queue] { VLOG(3) << "merge var " << var_name << " and send"; std::vector> vars; - // TODO(qiao): need to be configurable - const size_t max_merge_var_num = 20; size_t merged_var_num = 0; - while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) { + while (var_queue->Size() > 0 && + merged_var_num < FLAGS_communicator_max_merge_var_num) { vars.push_back(var_queue->Pop()); merged_var_num++; } MergeVars(var_name, vars, send_scope_.get()); auto send_functor = distributed::ParameterSend(); auto &ctx = send_varname_to_ctx_.at(var_name); - send_functor(ctx, *send_scope_, true); + if (!FLAGS_communicator_fake_rpc) { + send_functor(ctx, *send_scope_, true); + } }; task_futures.emplace_back( send_threadpool_->enqueue(std::move(send_task))); @@ -160,7 +168,9 @@ void Communicator::RecvAll() { auto &var_name = iter.first; VLOG(3) << "recv var " << var_name; auto recv_functor = distributed::ParameterRecv(); - recv_functor(iter.second, *recv_scope_); + if (!FLAGS_communicator_fake_rpc) { + recv_functor(iter.second, *recv_scope_); + } }; task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task))); } -- GitLab