提交 ad5a2b3e 编写于 作者: Q Qiao Longfei

add some debug flags for communicator

上级 eb6af305
......@@ -30,7 +30,11 @@ DEFINE_bool(communicator_independent_recv_thread, true,
DEFINE_int32(communicator_send_queue_size, 20,
"queue size to recv gradient before send");
DEFINE_int32(communicator_recv_wait_ms, 200, "wait time between each recv");
DEFINE_int32(communicator_thread_pool_size, 5, "wait time between each recv");
DEFINE_int32(communicator_thread_pool_size, 5, "thread num to do send or recv");
DEFINE_int32(communicator_max_merge_var_num, 20,
"max var num to merge and send");
DEFINE_bool(communicator_fake_rpc, false,
"fake mode does not really send any thing");
namespace paddle {
namespace operators {
......@@ -92,6 +96,9 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
VLOG(0) << "communicator_recv_wait_ms: " << FLAGS_communicator_recv_wait_ms;
VLOG(0) << "communicator_thread_pool_size: "
<< FLAGS_communicator_thread_pool_size;
VLOG(0) << "communicator_max_merge_var_num"
<< FLAGS_communicator_max_merge_var_num;
VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc;
send_scope_.reset(new Scope());
for (auto &iter : send_varname_to_ctx_) {
send_varname_to_queue_[iter.first] =
......@@ -123,17 +130,18 @@ void Communicator::SendThread() {
auto send_task = [this, &var_name, &var_queue] {
VLOG(3) << "merge var " << var_name << " and send";
std::vector<std::shared_ptr<Variable>> vars;
// TODO(qiao): need to be configurable
const size_t max_merge_var_num = 20;
size_t merged_var_num = 0;
while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) {
while (var_queue->Size() > 0 &&
merged_var_num < FLAGS_communicator_max_merge_var_num) {
vars.push_back(var_queue->Pop());
merged_var_num++;
}
MergeVars(var_name, vars, send_scope_.get());
auto send_functor = distributed::ParameterSend<float>();
auto &ctx = send_varname_to_ctx_.at(var_name);
if (!FLAGS_communicator_fake_rpc) {
send_functor(ctx, *send_scope_, true);
}
};
task_futures.emplace_back(
send_threadpool_->enqueue(std::move(send_task)));
......@@ -160,7 +168,9 @@ void Communicator::RecvAll() {
auto &var_name = iter.first;
VLOG(3) << "recv var " << var_name;
auto recv_functor = distributed::ParameterRecv<float>();
if (!FLAGS_communicator_fake_rpc) {
recv_functor(iter.second, *recv_scope_);
}
};
task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册