From 32b94a7d13233aba6f077dac43071e54f43fd489 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Mon, 10 Sep 2018 15:09:47 +0800 Subject: [PATCH] cache var types --- paddle/fluid/operators/listen_and_serv_op.cc | 56 +++++++++++++++----- paddle/fluid/operators/listen_and_serv_op.h | 11 ++-- 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index abbb3d06d18..966d78b8413 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -104,8 +104,7 @@ void ListenAndServOp::RunSyncLoop( framework::Executor *executor, framework::ProgramDesc *program, framework::Scope *recv_scope, platform::DeviceContext *dev_ctx, const std::vector &prefetch_block_id_list, - const int checkpoint_point_block_id, - const std::vector &recv_varnames) const { + const int checkpoint_point_block_id) const { VLOG(2) << "RunSyncLoop"; size_t num_blocks = program->Size(); auto optimize_blocks = @@ -130,6 +129,7 @@ void ListenAndServOp::RunSyncLoop( rpc_service_->SetCond(distributed::kRequestGet); rpc_service_->WaitBarrier(distributed::kRequestGet); rpc_service_->ResetBarrierCounter(); + while (true) { rpc_service_->Profiler().OneStep(); // Get from multiple trainers, we don't care about the order in which @@ -167,8 +167,7 @@ void ListenAndServOp::RunSyncLoop( recv_scope); VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)"; - ResetReceivedVars(recv_varnames, recv_scope, dev_ctx, - rpc_service_->NeedResetAllVars()); + ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars()); rpc_service_->SetCond(distributed::kRequestGet); rpc_service_->WaitBarrier(distributed::kRequestGet); @@ -176,10 +175,10 @@ void ListenAndServOp::RunSyncLoop( } // while(true) } -void ListenAndServOp::ResetReceivedVars( - const std::vector &recv_varnames, framework::Scope *recv_scope, - platform::DeviceContext *dev_ctx, bool reset_all) const { - for (auto &varname : recv_varnames) { +void ListenAndServOp::ResetReceivedVars(framework::Scope *recv_scope, + platform::DeviceContext *dev_ctx, + bool reset_all) const { + for (auto &varname : sparse_vars_) { auto var = recv_scope->FindVar(varname); if (var == nullptr) { VLOG(2) << "can not find var " << varname << " in received scope"; @@ -188,9 +187,17 @@ void ListenAndServOp::ResetReceivedVars( if (var->IsType()) { VLOG(3) << "reset sparse var: " << varname; var->GetMutable()->mutable_rows()->clear(); + } else { + PADDLE_THROW("The type of sparse var should be SelectedRows"); } - if (UNLIKELY(reset_all)) { - VLOG(3) << "reset dense var: " << varname; + } + if (UNLIKELY(reset_all)) { + for (auto &varname : dense_vars_) { + auto var = recv_scope->FindVar(varname); + if (var == nullptr) { + VLOG(2) << "can not find var " << varname << " in received scope"; + continue; + } if (var->IsType()) { math::set_constant(*dev_ctx, var->GetMutable(), static_cast(0)); @@ -198,8 +205,7 @@ void ListenAndServOp::ResetReceivedVars( math::set_constant(*dev_ctx, var->GetMutable(), static_cast(0)); } else { - PADDLE_THROW( - "received var should be in [SelectedRows, LoDTensor, Tensor]"); + PADDLE_THROW("The type of dense var should be in [LoDTensor, Tensor]"); } } } @@ -278,6 +284,25 @@ static void FillRequestCtx( h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx); } +void ListenAndServOp::CacheVarsType(const std::vector &varnames, + const framework::Scope &scope) const { + for (const auto &varname : varnames) { + auto var = scope.FindVar(varname); + PADDLE_ENFORCE(var != nullptr, + "Received var should be initialized in the received scope."); + if (var->IsType()) { + sparse_vars_.push_back(varname); + } else if (var->IsType() || + var->IsType()) { + dense_vars_.push_back(varname); + } else { + PADDLE_THROW( + "The type of received var should be in [SelectedRows, LoDTensor, " + "Tensor]."); + } + } +} + void ListenAndServOp::RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const { // Mark this as PS that it should decide profiling by listening from trainer. @@ -379,11 +404,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, signal(SIGINT, SignalHandler::StopAndExit); signal(SIGTERM, SignalHandler::StopAndExit); + // Cache the type of the received vars as `sparse_vars_` and `dense_vars_` + // so that we can reset them at the end of each iteration. + // NOTE: only used in sync update + CacheVarsType(inputs, recv_scope); + // Write to a file of server selected port for python use. SavePort(); if (sync_mode) { RunSyncLoop(&executor, program, &recv_scope, &dev_ctx, - prefetch_block_id_list, checkpoint_block_id, inputs); + prefetch_block_id_list, checkpoint_block_id); } else { RunAsyncLoop(&executor, program, &recv_scope); } diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index 5102c963b9a..5f889793ab1 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -51,8 +51,7 @@ class ListenAndServOp : public framework::OperatorBase { framework::Scope* recv_scope, platform::DeviceContext* dev_ctx, const std::vector& prefetch_block_id_list, - const int checkpoint_point_block_id, - const std::vector& recv_varnames) const; + const int checkpoint_point_block_id) const; void RunAsyncLoop(framework::Executor* executor, framework::ProgramDesc* program, @@ -67,11 +66,13 @@ class ListenAndServOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override; - void ResetReceivedVars(const std::vector& recv_varnames, - framework::Scope* recv_scope, + void ResetReceivedVars(framework::Scope* recv_scope, platform::DeviceContext* dev_ctx, bool reset_all = false) const; + void CacheVarsType(const std::vector& varnames, + const framework::Scope& scope) const; + protected: mutable std::shared_ptr rpc_service_; mutable std::shared_ptr request_send_handler_; @@ -82,6 +83,8 @@ class ListenAndServOp : public framework::OperatorBase { request_checkpoint_handler_; mutable std::shared_ptr server_thread_; + mutable std::vector sparse_vars_; + mutable std::vector dense_vars_; }; class SignalHandler { -- GitLab