diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 1933e6a5d0596eb50dc967f71ab2962f3ede14d2..abbb3d06d18ce0d27232adde40531ee3929d5019 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -167,9 +167,8 @@ void ListenAndServOp::RunSyncLoop( recv_scope); VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)"; - // reset received sparse vars to avoid reuse it in the next mini-batch ResetReceivedVars(recv_varnames, recv_scope, dev_ctx, - !rpc_service_->NeedResetAllVars()); + rpc_service_->NeedResetAllVars()); rpc_service_->SetCond(distributed::kRequestGet); rpc_service_->WaitBarrier(distributed::kRequestGet); @@ -179,7 +178,7 @@ void ListenAndServOp::RunSyncLoop( void ListenAndServOp::ResetReceivedVars( const std::vector &recv_varnames, framework::Scope *recv_scope, - platform::DeviceContext *dev_ctx, bool only_sparse_vars) const { + platform::DeviceContext *dev_ctx, bool reset_all) const { for (auto &varname : recv_varnames) { auto var = recv_scope->FindVar(varname); if (var == nullptr) { @@ -187,9 +186,11 @@ void ListenAndServOp::ResetReceivedVars( continue; } if (var->IsType()) { + VLOG(3) << "reset sparse var: " << varname; var->GetMutable()->mutable_rows()->clear(); } - if (!only_sparse_vars) { + if (UNLIKELY(reset_all)) { + VLOG(3) << "reset dense var: " << varname; if (var->IsType()) { math::set_constant(*dev_ctx, var->GetMutable(), static_cast(0)); diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index f84baa36eba2ee59e72c12e8592aa77b8d674ae8..5102c963b9aa3467fbd48d9d1931e9bd387fc85f 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -70,7 +70,7 @@ class ListenAndServOp : public framework::OperatorBase { void ResetReceivedVars(const std::vector& recv_varnames, framework::Scope* recv_scope, platform::DeviceContext* dev_ctx, - bool only_sparse_vars = true) const; + bool reset_all = false) const; protected: mutable std::shared_ptr rpc_service_;