未验证 提交 0e08e91c 编写于 作者: Q Qiao Longfei 提交者: GitHub

optimize communicator merge sparse gradient test=develop (#18159)

* optimize communicator merge sparse gradient test=develop

* revert multithread selected rows merge add test=develop

* follow comment test=develop
上级 172c2fac
...@@ -40,6 +40,8 @@ DEFINE_int32(communicator_max_merge_var_num, 20, ...@@ -40,6 +40,8 @@ DEFINE_int32(communicator_max_merge_var_num, 20,
"max var num to merge and send"); "max var num to merge and send");
DEFINE_bool(communicator_fake_rpc, false, DEFINE_bool(communicator_fake_rpc, false,
"fake mode does not really send any thing"); "fake mode does not really send any thing");
DEFINE_bool(communicator_merge_sparse_grad, true,
"merge sparse gradient before sending");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -73,6 +75,8 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx, ...@@ -73,6 +75,8 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
VLOG(0) << "communicator_max_merge_var_num: " VLOG(0) << "communicator_max_merge_var_num: "
<< FLAGS_communicator_max_merge_var_num; << FLAGS_communicator_max_merge_var_num;
VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc; VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc;
VLOG(0) << "communicator_merge_sparse_grad: "
<< FLAGS_communicator_merge_sparse_grad;
send_scope_.reset(new Scope()); send_scope_.reset(new Scope());
for (auto &iter : send_varname_to_ctx_) { for (auto &iter : send_varname_to_ctx_) {
send_varname_to_queue_[iter.first] = send_varname_to_queue_[iter.first] =
...@@ -214,11 +218,20 @@ void Communicator::Send(const std::string &var_name, ...@@ -214,11 +218,20 @@ void Communicator::Send(const std::string &var_name,
// push var into send queue by var_name // push var into send queue by var_name
auto *grad_var = scope.FindVar(var_name); auto *grad_var = scope.FindVar(var_name);
PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited"); PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited");
if (grad_var->IsType<framework::SelectedRows>() &&
!FLAGS_communicator_merge_sparse_grad) {
auto send_functor = distributed::ParameterSend<float>();
auto &ctx = send_varname_to_ctx_.at(var_name);
if (!FLAGS_communicator_fake_rpc) {
send_functor(ctx, scope, true);
}
} else {
auto tmp_grad_var = std::make_shared<Variable>(); auto tmp_grad_var = std::make_shared<Variable>();
framework::CopyVariable(*grad_var, tmp_grad_var.get()); framework::CopyVariable(*grad_var, tmp_grad_var.get());
auto &queue = send_varname_to_queue_.at(var_name); auto &queue = send_varname_to_queue_.at(var_name);
VLOG(3) << "send " << var_name << " queue size " << queue->Size(); VLOG(3) << "send " << var_name << " queue size " << queue->Size();
queue->Push(tmp_grad_var); queue->Push(tmp_grad_var);
}
} }
void Communicator::Init(const paddle::framework::ProgramDesc &program, void Communicator::Init(const paddle::framework::ProgramDesc &program,
......
...@@ -192,6 +192,7 @@ def __bootstrap__(): ...@@ -192,6 +192,7 @@ def __bootstrap__():
read_env_flags.append('communicator_max_merge_var_num') read_env_flags.append('communicator_max_merge_var_num')
read_env_flags.append('communicator_fake_rpc') read_env_flags.append('communicator_fake_rpc')
read_env_flags.append('communicator_send_wait_times') read_env_flags.append('communicator_send_wait_times')
read_env_flags.append('communicator_merge_sparse_grad')
if core.is_compiled_with_brpc(): if core.is_compiled_with_brpc():
read_env_flags.append('max_body_size') read_env_flags.append('max_body_size')
#set brpc max body size #set brpc max body size
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册