diff --git a/paddle/fluid/operators/detail/request_handler_impl.cc b/paddle/fluid/operators/detail/request_handler_impl.cc index b5ee3ab51ec5e685b41057ba60d6701d61cbc09c..9473dce55029f2a4e0987ab8f6f5e7205d7fff47 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.cc +++ b/paddle/fluid/operators/detail/request_handler_impl.cc @@ -64,13 +64,21 @@ bool RequestSendHandler::Handle(const std::string& varname, return false; } if (invar->IsType()) { - rpc_server_->RecordSparseVar(invar); + std::unique_lock lock(mutex_sparse_vars_); + sparse_vars_.push_back(invar); } } - return true; } +void RequestSendHandler::ResetSparseVarRecorder() { + std::unique_lock lock(mutex_sparse_vars_); + for (auto* var : sparse_vars_) { + var->GetMutable()->mutable_rows()->clear(); + } + sparse_vars_.clear(); +} + bool RequestGetHandler::Handle(const std::string& varname, framework::Scope* scope, framework::Variable* invar, diff --git a/paddle/fluid/operators/detail/request_handler_impl.h b/paddle/fluid/operators/detail/request_handler_impl.h index 8d0c62232b68ad6c05e751c25103802ee12db57e..443d951914dd0f40e8831abc637848363d9fef16 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.h +++ b/paddle/fluid/operators/detail/request_handler_impl.h @@ -41,6 +41,11 @@ class RequestSendHandler final : public RequestHandler { virtual ~RequestSendHandler() {} bool Handle(const std::string& varname, framework::Scope* scope, framework::Variable* var, framework::Variable** outvar) override; + void ResetSparseVarRecorder(); + + private: + std::mutex mutex_sparse_vars_; + std::vector sparse_vars_; }; class RequestGetHandler final : public RequestHandler { diff --git a/paddle/fluid/operators/detail/rpc_server.cc b/paddle/fluid/operators/detail/rpc_server.cc index 7feddbeca89ee97769f5598b0111e50a3f572ce8..448763372a8c224cc68319a4a444915896b68234 100644 --- a/paddle/fluid/operators/detail/rpc_server.cc +++ b/paddle/fluid/operators/detail/rpc_server.cc @@ -73,19 +73,6 @@ void RPCServer::ResetBarrierCounter() { t.second = 0; } } -void RPCServer::RecordSparseVar(framework::Variable* sparse_var) { - std::unique_lock lock(mutex_sparse_var_recorder_); - sparse_vars_.push_back(sparse_var); -} - -void RPCServer::ResetSparseVarsRecorder() { - VLOG(3) << "RPCServer reset sparse vars recorder."; - std::unique_lock lock(mutex_sparse_var_recorder_); - for (auto* var : sparse_vars_) { - var->GetMutable()->mutable_rows()->clear(); - } - sparse_vars_.clear(); -} void RPCServer::RegisterRPC(const std::string& rpc_name, RequestHandler* handler, int thread_num) { diff --git a/paddle/fluid/operators/detail/rpc_server.h b/paddle/fluid/operators/detail/rpc_server.h index 94a21ef8d04e0fc8f4e22bf00cf89dc5f41f294b..f809c13c726ac2f1c60e8cf84848c4138f631b44 100644 --- a/paddle/fluid/operators/detail/rpc_server.h +++ b/paddle/fluid/operators/detail/rpc_server.h @@ -62,8 +62,6 @@ class RPCServer { void IncreaseBatchBarrier(const std::string rpc_name); void ResetBarrierCounter(); - void RecordSparseVar(framework::Variable* sparse_var); - void ResetSparseVarsRecorder(); protected: virtual void ShutDownImpl() = 0; @@ -77,9 +75,6 @@ class RPCServer { std::atomic cur_cond_; std::condition_variable rpc_cond_; - std::vector sparse_vars_; - std::mutex mutex_sparse_var_recorder_; - protected: std::string bind_address_; std::atomic exit_flag_; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index ee7b01a54ca96e7685b96f58ebfd1a454b19a976..66d31c88951926a6dd9b7262942a69bb1564a416 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -146,7 +146,9 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, rpc_service_->SetCond(detail::kRequestGet); rpc_service_->WaitBarrier(detail::kRequestGet); rpc_service_->ResetBarrierCounter(); - rpc_service_->ResetSparseVarsRecorder(); + // reset received sparse vars to avoid reuse it in the next mini-batch + dynamic_cast(request_send_handler_.get()) + ->ResetSparseVarRecorder(); } // while(true) }