diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 31159a02592a2aff75f7ecf5be924989f0f47071..849e412504eb9180b746db65fd4fa353ed0c05a1 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -67,24 +67,11 @@ bool RequestSendHandler::Handle(const std::string& varname, LOG(FATAL) << "sync: Can not find server side var: " << varname; return false; } - - if (invar->IsType()) { - std::unique_lock lock(mutex_sparse_vars_); - sparse_vars_.push_back(invar); - } } } return true; } -void RequestSendHandler::ResetSparseVarRecorder() { - std::unique_lock lock(mutex_sparse_vars_); - for (auto* var : sparse_vars_) { - var->GetMutable()->mutable_rows()->clear(); - } - sparse_vars_.clear(); -} - bool RequestGetHandler::Handle(const std::string& varname, framework::Scope* scope, framework::Variable* invar, diff --git a/paddle/fluid/operators/distributed/request_handler_impl.h b/paddle/fluid/operators/distributed/request_handler_impl.h index 87185500f2ffc3a8578eea339cc7a1e2b0e46631..8be5b21bb89a580f4091de19186fd2d7e5802478 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.h +++ b/paddle/fluid/operators/distributed/request_handler_impl.h @@ -41,11 +41,6 @@ class RequestSendHandler final : public RequestHandler { bool Handle(const std::string& varname, framework::Scope* scope, framework::Variable* var, framework::Variable** outvar, const std::string& out_var_name = "") override; - void ResetSparseVarRecorder(); - - private: - std::mutex mutex_sparse_vars_; - std::vector sparse_vars_; }; class RequestGetHandler final : public RequestHandler { diff --git a/paddle/fluid/operators/distributed/rpc_server.cc b/paddle/fluid/operators/distributed/rpc_server.cc index 406e7294c190172347d432fb155c2a81c43dda25..084480ae48b8b9267ade1a840f6a70519cb28e48 100644 --- a/paddle/fluid/operators/distributed/rpc_server.cc +++ b/paddle/fluid/operators/distributed/rpc_server.cc @@ -101,6 +101,8 @@ void RPCServer::Complete() { { std::unique_lock lock(mutex_); client_num_--; + need_reset_all_vars_ = true; + VLOG(4) << "decrease client_num to: " << client_num_; if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) { barrier_counter_[kRequestGet]--; @@ -109,6 +111,11 @@ void RPCServer::Complete() { barrier_cond_.notify_all(); } +bool RPCServer::NeedResetAllVars() { + std::unique_lock lock(mutex_); + return need_reset_all_vars_; +} + int RPCServer::GetClientNum() { std::unique_lock lock(mutex_); return client_num_; @@ -120,6 +127,7 @@ void RPCServer::ResetBarrierCounter() { for (auto& t : barrier_counter_) { t.second = 0; } + need_reset_all_vars_ = false; } void RPCServer::RegisterRPC(const std::string& rpc_name, diff --git a/paddle/fluid/operators/distributed/rpc_server.h b/paddle/fluid/operators/distributed/rpc_server.h index d813ba03e2fbec6e808f59f814a9b2f4bfbcd77b..d88e8c640ffb5ea44e88318cc973c9a783862435 100644 --- a/paddle/fluid/operators/distributed/rpc_server.h +++ b/paddle/fluid/operators/distributed/rpc_server.h @@ -49,7 +49,8 @@ class RPCServer { bind_address_(address), exit_flag_(false), selected_port_(0), - client_num_(client_num) {} + client_num_(client_num), + need_reset_all_vars_(false) {} virtual ~RPCServer() {} virtual void StartServer() = 0; @@ -86,6 +87,8 @@ class RPCServer { void ResetBarrierCounter(); RPCServerProfiler& Profiler() { return profiler_; } + bool NeedResetAllVars(); + protected: virtual void ShutDownImpl() = 0; @@ -104,6 +107,7 @@ class RPCServer { std::atomic exit_flag_; int selected_port_; int client_num_; + bool need_reset_all_vars_; std::unordered_map rpc_call_map_; std::unordered_map rpc_thread_num_; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 4cc2159d9f22809a640f82ad19415f3e5a2d9999..1933e6a5d0596eb50dc967f71ab2962f3ede14d2 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "gflags/gflags.h" #include "paddle/fluid/operators/detail/macros.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/listen_and_serv_op.h" @@ -101,9 +102,10 @@ static int64_t GetTimestamp() { void ListenAndServOp::RunSyncLoop( framework::Executor *executor, framework::ProgramDesc *program, - framework::Scope *recv_scope, + framework::Scope *recv_scope, platform::DeviceContext *dev_ctx, const std::vector &prefetch_block_id_list, - const int checkpoint_point_block_id) const { + const int checkpoint_point_block_id, + const std::vector &recv_varnames) const { VLOG(2) << "RunSyncLoop"; size_t num_blocks = program->Size(); auto optimize_blocks = @@ -166,8 +168,8 @@ void ListenAndServOp::RunSyncLoop( VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)"; // reset received sparse vars to avoid reuse it in the next mini-batch - dynamic_cast(request_send_handler_.get()) - ->ResetSparseVarRecorder(); + ResetReceivedVars(recv_varnames, recv_scope, dev_ctx, + !rpc_service_->NeedResetAllVars()); rpc_service_->SetCond(distributed::kRequestGet); rpc_service_->WaitBarrier(distributed::kRequestGet); @@ -175,6 +177,33 @@ void ListenAndServOp::RunSyncLoop( } // while(true) } +void ListenAndServOp::ResetReceivedVars( + const std::vector &recv_varnames, framework::Scope *recv_scope, + platform::DeviceContext *dev_ctx, bool only_sparse_vars) const { + for (auto &varname : recv_varnames) { + auto var = recv_scope->FindVar(varname); + if (var == nullptr) { + VLOG(2) << "can not find var " << varname << " in received scope"; + continue; + } + if (var->IsType()) { + var->GetMutable()->mutable_rows()->clear(); + } + if (!only_sparse_vars) { + if (var->IsType()) { + math::set_constant(*dev_ctx, var->GetMutable(), + static_cast(0)); + } else if (var->IsType()) { + math::set_constant(*dev_ctx, var->GetMutable(), + static_cast(0)); + } else { + PADDLE_THROW( + "received var should be in [SelectedRows, LoDTensor, Tensor]"); + } + } + } +} + void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, framework::ProgramDesc *program, framework::Scope *recv_scope) const { @@ -258,6 +287,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, bool sync_mode = Attr("sync_mode"); auto fan_in = Attr("Fanin"); + auto inputs = Inputs("X"); PADDLE_ENFORCE(!rpc_service_); std::string endpoint = Attr("endpoint"); @@ -351,8 +381,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, // Write to a file of server selected port for python use. SavePort(); if (sync_mode) { - RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list, - checkpoint_block_id); + RunSyncLoop(&executor, program, &recv_scope, &dev_ctx, + prefetch_block_id_list, checkpoint_block_id, inputs); } else { RunAsyncLoop(&executor, program, &recv_scope); } diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index 978969cc515c7954b59f2bf7a4f2c0e1b13f9bc0..f84baa36eba2ee59e72c12e8592aa77b8d674ae8 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/operators/distributed/request_handler.h" #include "paddle/fluid/operators/distributed/rpc_server.h" +#include "paddle/fluid/platform/device_context.h" namespace paddle { namespace operators { @@ -48,8 +49,10 @@ class ListenAndServOp : public framework::OperatorBase { void RunSyncLoop(framework::Executor* executor, framework::ProgramDesc* program, framework::Scope* recv_scope, + platform::DeviceContext* dev_ctx, const std::vector& prefetch_block_id_list, - const int checkpoint_point_block_id) const; + const int checkpoint_point_block_id, + const std::vector& recv_varnames) const; void RunAsyncLoop(framework::Executor* executor, framework::ProgramDesc* program, @@ -64,6 +67,11 @@ class ListenAndServOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override; + void ResetReceivedVars(const std::vector& recv_varnames, + framework::Scope* recv_scope, + platform::DeviceContext* dev_ctx, + bool only_sparse_vars = true) const; + protected: mutable std::shared_ptr rpc_service_; mutable std::shared_ptr request_send_handler_;