diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index 145ee53107a89def3d017d076e2d4c005665a1fc..b5ee3ab51ec5e685b41057ba60d6701d61cbc09c 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -63,6 +63,9 @@ bool RequestSendHandler::Handle(const std::string& varname, PADDLE_THROW("sync: Can not find server side var"); return false; } + if (invar->IsType()) { + rpc_server_->RecordSparseVar(invar); + } } return true; diff --git a/paddle/fluid/operators/detail/rpc_server.cc b/paddle/fluid/operators/detail/rpc_server.cc index 448763372a8c224cc68319a4a444915896b68234..7feddbeca89ee97769f5598b0111e50a3f572ce8 100644 --- a/paddle/fluid/operators/detail/rpc_server.cc +++ b/paddle/fluid/operators/detail/rpc_server.cc @@ -73,6 +73,19 @@ void RPCServer::ResetBarrierCounter() { t.second = 0; } } +void RPCServer::RecordSparseVar(framework::Variable* sparse_var) { + std::unique_lock lock(mutex_sparse_var_recorder_); + sparse_vars_.push_back(sparse_var); +} + +void RPCServer::ResetSparseVarsRecorder() { + VLOG(3) << "RPCServer reset sparse vars recorder."; + std::unique_lock lock(mutex_sparse_var_recorder_); + for (auto* var : sparse_vars_) { + var->GetMutable()->mutable_rows()->clear(); + } + sparse_vars_.clear(); +} void RPCServer::RegisterRPC(const std::string& rpc_name, RequestHandler* handler, int thread_num) { diff --git a/paddle/fluid/operators/detail/rpc_server.h b/paddle/fluid/operators/detail/rpc_server.h index c2e7ae706c9dc6776e09b25e424b30f110c3855d..94a21ef8d04e0fc8f4e22bf00cf89dc5f41f294b 100644 --- a/paddle/fluid/operators/detail/rpc_server.h +++ b/paddle/fluid/operators/detail/rpc_server.h @@ -60,7 +60,10 @@ class RPCServer { void SetCond(const std::string& rpc_name); void WaitCond(const std::string& rpc_name); void IncreaseBatchBarrier(const std::string rpc_name); + void ResetBarrierCounter(); + void RecordSparseVar(framework::Variable* sparse_var); + void ResetSparseVarsRecorder(); protected: virtual void ShutDownImpl() = 0; @@ -74,6 +77,9 @@ class RPCServer { std::atomic cur_cond_; std::condition_variable rpc_cond_; + std::vector sparse_vars_; + std::mutex mutex_sparse_var_recorder_; + protected: std::string bind_address_; std::atomic exit_flag_; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 0c9d2b5a74c0078cb703615d8e0d4cb582dc04f9..ee7b01a54ca96e7685b96f58ebfd1a454b19a976 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -146,6 +146,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, rpc_service_->SetCond(detail::kRequestGet); rpc_service_->WaitBarrier(detail::kRequestGet); rpc_service_->ResetBarrierCounter(); + rpc_service_->ResetSparseVarsRecorder(); } // while(true) }