diff --git a/paddle/fluid/operators/detail/request_handler.h b/paddle/fluid/operators/detail/request_handler.h index 4bc5e7f10ee2a8939d230fe96517bd9f56c13933..d74206aaba6a79ee06475985e642221bd84d9382 100644 --- a/paddle/fluid/operators/detail/request_handler.h +++ b/paddle/fluid/operators/detail/request_handler.h @@ -80,7 +80,6 @@ class RequestHandler { } framework::ProgramDesc* program() { return program_; } framework::Executor* executor() { return executor_; } - std::vector& sparse_vars() { return sparse_vars_; } // This function processes user's rpc request. // The implemention is in request_handler_impl. @@ -113,13 +112,7 @@ class RequestHandler { std::unordered_map>* grad_to_prepared_ctx_; - - // Record received sparse variables, so that - // we could reset those after execute optimize program - std::vector sparse_vars_; RPCServer* rpc_server_; - - std::mutex sparse_var_mutex_; }; } // namespace detail diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index f16c06d52f4fb86d51083a8b3b98d05a64c1af74..b5ee3ab51ec5e685b41057ba60d6701d61cbc09c 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -63,10 +63,8 @@ bool RequestSendHandler::Handle(const std::string& varname, PADDLE_THROW("sync: Can not find server side var"); return false; } - if (invar->IsType()) { - std::unique_lock lock(sparse_var_mutex_); - sparse_vars_.push_back(invar); + rpc_server_->RecordSparseVar(invar); } } diff --git a/paddle/fluid/operators/detail/rpc_server.cc b/paddle/fluid/operators/detail/rpc_server.cc index 448763372a8c224cc68319a4a444915896b68234..7feddbeca89ee97769f5598b0111e50a3f572ce8 100644 --- a/paddle/fluid/operators/detail/rpc_server.cc +++ b/paddle/fluid/operators/detail/rpc_server.cc @@ -73,6 +73,19 @@ void RPCServer::ResetBarrierCounter() { t.second = 0; } } +void RPCServer::RecordSparseVar(framework::Variable* sparse_var) { + std::unique_lock lock(mutex_sparse_var_recorder_); + sparse_vars_.push_back(sparse_var); +} + +void RPCServer::ResetSparseVarsRecorder() { + VLOG(3) << "RPCServer reset sparse vars recorder."; + std::unique_lock lock(mutex_sparse_var_recorder_); + for (auto* var : sparse_vars_) { + var->GetMutable()->mutable_rows()->clear(); + } + sparse_vars_.clear(); +} void RPCServer::RegisterRPC(const std::string& rpc_name, RequestHandler* handler, int thread_num) { diff --git a/paddle/fluid/operators/detail/rpc_server.h b/paddle/fluid/operators/detail/rpc_server.h index c2e7ae706c9dc6776e09b25e424b30f110c3855d..94a21ef8d04e0fc8f4e22bf00cf89dc5f41f294b 100644 --- a/paddle/fluid/operators/detail/rpc_server.h +++ b/paddle/fluid/operators/detail/rpc_server.h @@ -60,7 +60,10 @@ class RPCServer { void SetCond(const std::string& rpc_name); void WaitCond(const std::string& rpc_name); void IncreaseBatchBarrier(const std::string rpc_name); + void ResetBarrierCounter(); + void RecordSparseVar(framework::Variable* sparse_var); + void ResetSparseVarsRecorder(); protected: virtual void ShutDownImpl() = 0; @@ -74,6 +77,9 @@ class RPCServer { std::atomic cur_cond_; std::condition_variable rpc_cond_; + std::vector sparse_vars_; + std::mutex mutex_sparse_var_recorder_; + protected: std::string bind_address_; std::atomic exit_flag_; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 66a0f87b46c6447bac7e42f0f61e3170cb1f2fdb..ee7b01a54ca96e7685b96f58ebfd1a454b19a976 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -108,9 +108,6 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, std::shared_ptr(nullptr)); rpc_service_->ResetBarrierCounter(); - // Record received sparse variables, so that - // we could reset those after execute optimize program - std::vector sparse_vars; while (true) { // Get from multiple trainers, we don't care about the order in which // the gradients arrives, just add suffix 0~n and merge the gradient. @@ -146,18 +143,10 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, recv_scope); VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)"; - // Reset the received sparse variables, the sum operator would not - // sum the input sparse variables which rows is empty at the next - // mini-batch. - // TODO(Yancey1989): move the reset action into an operator, we couldn't - // have any hide logic in the operator. - for (framework::Variable *var : sparse_vars) { - var->GetMutable()->mutable_rows()->clear(); - } - rpc_service_->SetCond(detail::kRequestGet); rpc_service_->WaitBarrier(detail::kRequestGet); rpc_service_->ResetBarrierCounter(); + rpc_service_->ResetSparseVarsRecorder(); } // while(true) }