diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index 3a185667e7a70d315bc14ca018f181c3de6ca421..af277d69c18670e31cb8fd9991b33b915261778e 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -40,6 +40,8 @@ DEFINE_int32(communicator_max_merge_var_num, 20, "max var num to merge and send"); DEFINE_bool(communicator_fake_rpc, false, "fake mode does not really send any thing"); +DEFINE_bool(communicator_merge_sparse_grad, true, + "merge sparse gradient before sending"); namespace paddle { namespace operators { @@ -73,6 +75,8 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx, VLOG(0) << "communicator_max_merge_var_num: " << FLAGS_communicator_max_merge_var_num; VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc; + VLOG(0) << "communicator_merge_sparse_grad: " + << FLAGS_communicator_merge_sparse_grad; send_scope_.reset(new Scope()); for (auto &iter : send_varname_to_ctx_) { send_varname_to_queue_[iter.first] = @@ -214,11 +218,20 @@ void Communicator::Send(const std::string &var_name, // push var into send queue by var_name auto *grad_var = scope.FindVar(var_name); PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited"); - auto tmp_grad_var = std::make_shared(); - framework::CopyVariable(*grad_var, tmp_grad_var.get()); - auto &queue = send_varname_to_queue_.at(var_name); - VLOG(3) << "send " << var_name << " queue size " << queue->Size(); - queue->Push(tmp_grad_var); + if (grad_var->IsType() && + !FLAGS_communicator_merge_sparse_grad) { + auto send_functor = distributed::ParameterSend(); + auto &ctx = send_varname_to_ctx_.at(var_name); + if (!FLAGS_communicator_fake_rpc) { + send_functor(ctx, scope, true); + } + } else { + auto tmp_grad_var = std::make_shared(); + framework::CopyVariable(*grad_var, tmp_grad_var.get()); + auto &queue = send_varname_to_queue_.at(var_name); + VLOG(3) << "send " << var_name << " queue size " << queue->Size(); + queue->Push(tmp_grad_var); + } } void Communicator::Init(const paddle::framework::ProgramDesc &program, diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 3b19e9f210d6179460a57464178f999d2d1afc42..304643ea9a10ab017dad14030e2f402aeeb4e8a9 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -192,6 +192,7 @@ def __bootstrap__(): read_env_flags.append('communicator_max_merge_var_num') read_env_flags.append('communicator_fake_rpc') read_env_flags.append('communicator_send_wait_times') + read_env_flags.append('communicator_merge_sparse_grad') if core.is_compiled_with_brpc(): read_env_flags.append('max_body_size') #set brpc max body size