未验证 提交 d117bbc3 编写于 作者: Y Yan Xu 提交者: GitHub

Merge pull request #13291 from Yancey1989/reset_vars_on_pserver

reset received vars on pserver
...@@ -67,24 +67,11 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -67,24 +67,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
LOG(FATAL) << "sync: Can not find server side var: " << varname; LOG(FATAL) << "sync: Can not find server side var: " << varname;
return false; return false;
} }
if (invar->IsType<framework::SelectedRows>()) {
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
sparse_vars_.push_back(invar);
}
} }
} }
return true; return true;
} }
void RequestSendHandler::ResetSparseVarRecorder() {
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
for (auto* var : sparse_vars_) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
sparse_vars_.clear();
}
bool RequestGetHandler::Handle(const std::string& varname, bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
......
...@@ -41,11 +41,6 @@ class RequestSendHandler final : public RequestHandler { ...@@ -41,11 +41,6 @@ class RequestSendHandler final : public RequestHandler {
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override; const std::string& out_var_name = "") override;
void ResetSparseVarRecorder();
private:
std::mutex mutex_sparse_vars_;
std::vector<framework::Variable*> sparse_vars_;
}; };
class RequestGetHandler final : public RequestHandler { class RequestGetHandler final : public RequestHandler {
......
...@@ -101,6 +101,8 @@ void RPCServer::Complete() { ...@@ -101,6 +101,8 @@ void RPCServer::Complete() {
{ {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
client_num_--; client_num_--;
need_reset_all_vars_ = true;
VLOG(4) << "decrease client_num to: " << client_num_; VLOG(4) << "decrease client_num to: " << client_num_;
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) { if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
barrier_counter_[kRequestGet]--; barrier_counter_[kRequestGet]--;
...@@ -109,6 +111,11 @@ void RPCServer::Complete() { ...@@ -109,6 +111,11 @@ void RPCServer::Complete() {
barrier_cond_.notify_all(); barrier_cond_.notify_all();
} }
bool RPCServer::NeedResetAllVars() {
std::unique_lock<std::mutex> lock(mutex_);
return need_reset_all_vars_;
}
int RPCServer::GetClientNum() { int RPCServer::GetClientNum() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
return client_num_; return client_num_;
...@@ -120,6 +127,7 @@ void RPCServer::ResetBarrierCounter() { ...@@ -120,6 +127,7 @@ void RPCServer::ResetBarrierCounter() {
for (auto& t : barrier_counter_) { for (auto& t : barrier_counter_) {
t.second = 0; t.second = 0;
} }
need_reset_all_vars_ = false;
} }
void RPCServer::RegisterRPC(const std::string& rpc_name, void RPCServer::RegisterRPC(const std::string& rpc_name,
......
...@@ -49,7 +49,8 @@ class RPCServer { ...@@ -49,7 +49,8 @@ class RPCServer {
bind_address_(address), bind_address_(address),
exit_flag_(false), exit_flag_(false),
selected_port_(0), selected_port_(0),
client_num_(client_num) {} client_num_(client_num),
need_reset_all_vars_(false) {}
virtual ~RPCServer() {} virtual ~RPCServer() {}
virtual void StartServer() = 0; virtual void StartServer() = 0;
...@@ -86,6 +87,8 @@ class RPCServer { ...@@ -86,6 +87,8 @@ class RPCServer {
void ResetBarrierCounter(); void ResetBarrierCounter();
RPCServerProfiler& Profiler() { return profiler_; } RPCServerProfiler& Profiler() { return profiler_; }
bool NeedResetAllVars();
protected: protected:
virtual void ShutDownImpl() = 0; virtual void ShutDownImpl() = 0;
...@@ -104,6 +107,7 @@ class RPCServer { ...@@ -104,6 +107,7 @@ class RPCServer {
std::atomic<int> exit_flag_; std::atomic<int> exit_flag_;
int selected_port_; int selected_port_;
int client_num_; int client_num_;
bool need_reset_all_vars_;
std::unordered_map<std::string, RequestHandler*> rpc_call_map_; std::unordered_map<std::string, RequestHandler*> rpc_call_map_;
std::unordered_map<std::string, int> rpc_thread_num_; std::unordered_map<std::string, int> rpc_thread_num_;
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
...@@ -101,7 +102,7 @@ static int64_t GetTimestamp() { ...@@ -101,7 +102,7 @@ static int64_t GetTimestamp() {
void ListenAndServOp::RunSyncLoop( void ListenAndServOp::RunSyncLoop(
framework::Executor *executor, framework::ProgramDesc *program, framework::Executor *executor, framework::ProgramDesc *program,
framework::Scope *recv_scope, framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
const std::vector<int> &prefetch_block_id_list, const std::vector<int> &prefetch_block_id_list,
const int checkpoint_point_block_id) const { const int checkpoint_point_block_id) const {
VLOG(2) << "RunSyncLoop"; VLOG(2) << "RunSyncLoop";
...@@ -128,6 +129,7 @@ void ListenAndServOp::RunSyncLoop( ...@@ -128,6 +129,7 @@ void ListenAndServOp::RunSyncLoop(
rpc_service_->SetCond(distributed::kRequestGet); rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(distributed::kRequestGet); rpc_service_->WaitBarrier(distributed::kRequestGet);
rpc_service_->ResetBarrierCounter(); rpc_service_->ResetBarrierCounter();
while (true) { while (true) {
rpc_service_->Profiler().OneStep(); rpc_service_->Profiler().OneStep();
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
...@@ -165,9 +167,7 @@ void ListenAndServOp::RunSyncLoop( ...@@ -165,9 +167,7 @@ void ListenAndServOp::RunSyncLoop(
recv_scope); recv_scope);
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)"; VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
// reset received sparse vars to avoid reuse it in the next mini-batch ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars());
dynamic_cast<distributed::RequestSendHandler *>(request_send_handler_.get())
->ResetSparseVarRecorder();
rpc_service_->SetCond(distributed::kRequestGet); rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(distributed::kRequestGet); rpc_service_->WaitBarrier(distributed::kRequestGet);
...@@ -175,6 +175,42 @@ void ListenAndServOp::RunSyncLoop( ...@@ -175,6 +175,42 @@ void ListenAndServOp::RunSyncLoop(
} // while(true) } // while(true)
} }
void ListenAndServOp::ResetReceivedVars(framework::Scope *recv_scope,
platform::DeviceContext *dev_ctx,
bool reset_all) const {
for (auto &varname : sparse_vars_) {
auto var = recv_scope->FindVar(varname);
if (var == nullptr) {
VLOG(2) << "can not find var " << varname << " in received scope";
continue;
}
if (var->IsType<framework::SelectedRows>()) {
VLOG(3) << "reset sparse var: " << varname;
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
} else {
PADDLE_THROW("The type of sparse var should be SelectedRows");
}
}
if (UNLIKELY(reset_all)) {
for (auto &varname : dense_vars_) {
auto var = recv_scope->FindVar(varname);
if (var == nullptr) {
VLOG(2) << "can not find var " << varname << " in received scope";
continue;
}
if (var->IsType<framework::LoDTensor>()) {
math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(),
static_cast<float>(0));
} else if (var->IsType<framework::Tensor>()) {
math::set_constant(*dev_ctx, var->GetMutable<framework::Tensor>(),
static_cast<float>(0));
} else {
PADDLE_THROW("The type of dense var should be in [LoDTensor, Tensor]");
}
}
}
}
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program, framework::ProgramDesc *program,
framework::Scope *recv_scope) const { framework::Scope *recv_scope) const {
...@@ -248,6 +284,25 @@ static void FillRequestCtx( ...@@ -248,6 +284,25 @@ static void FillRequestCtx(
h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx); h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx);
} }
void ListenAndServOp::CacheVarsType(const std::vector<std::string> &varnames,
const framework::Scope &scope) const {
for (const auto &varname : varnames) {
auto var = scope.FindVar(varname);
PADDLE_ENFORCE(var != nullptr,
"Received var should be initialized in the received scope.");
if (var->IsType<framework::SelectedRows>()) {
sparse_vars_.push_back(varname);
} else if (var->IsType<framework::LoDTensor>() ||
var->IsType<framework::Tensor>()) {
dense_vars_.push_back(varname);
} else {
PADDLE_THROW(
"The type of received var should be in [SelectedRows, LoDTensor, "
"Tensor].");
}
}
}
void ListenAndServOp::RunImpl(const framework::Scope &scope, void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place) const {
// Mark this as PS that it should decide profiling by listening from trainer. // Mark this as PS that it should decide profiling by listening from trainer.
...@@ -258,6 +313,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -258,6 +313,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
bool sync_mode = Attr<bool>("sync_mode"); bool sync_mode = Attr<bool>("sync_mode");
auto fan_in = Attr<int>("Fanin"); auto fan_in = Attr<int>("Fanin");
auto inputs = Inputs("X");
PADDLE_ENFORCE(!rpc_service_); PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
...@@ -348,11 +404,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -348,11 +404,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal(SIGINT, SignalHandler::StopAndExit); signal(SIGINT, SignalHandler::StopAndExit);
signal(SIGTERM, SignalHandler::StopAndExit); signal(SIGTERM, SignalHandler::StopAndExit);
// Cache the type of the received vars as `sparse_vars_` and `dense_vars_`
// so that we can reset them at the end of each iteration.
// NOTE: only used in sync update
CacheVarsType(inputs, recv_scope);
// Write to a file of server selected port for python use. // Write to a file of server selected port for python use.
SavePort(); SavePort();
if (sync_mode) { if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list, RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
checkpoint_block_id); prefetch_block_id_list, checkpoint_block_id);
} else { } else {
RunAsyncLoop(&executor, program, &recv_scope); RunAsyncLoop(&executor, program, &recv_scope);
} }
......
...@@ -26,6 +26,7 @@ limitations under the License. */ ...@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -48,6 +49,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -48,6 +49,7 @@ class ListenAndServOp : public framework::OperatorBase {
void RunSyncLoop(framework::Executor* executor, void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program, framework::ProgramDesc* program,
framework::Scope* recv_scope, framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx,
const std::vector<int>& prefetch_block_id_list, const std::vector<int>& prefetch_block_id_list,
const int checkpoint_point_block_id) const; const int checkpoint_point_block_id) const;
...@@ -64,6 +66,13 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -64,6 +66,13 @@ class ListenAndServOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override; const platform::Place& dev_place) const override;
void ResetReceivedVars(framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx,
bool reset_all = false) const;
void CacheVarsType(const std::vector<std::string>& varnames,
const framework::Scope& scope) const;
protected: protected:
mutable std::shared_ptr<distributed::RPCServer> rpc_service_; mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_; mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
...@@ -74,6 +83,8 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -74,6 +83,8 @@ class ListenAndServOp : public framework::OperatorBase {
request_checkpoint_handler_; request_checkpoint_handler_;
mutable std::shared_ptr<std::thread> server_thread_; mutable std::shared_ptr<std::thread> server_thread_;
mutable std::vector<std::string> sparse_vars_;
mutable std::vector<std::string> dense_vars_;
}; };
class SignalHandler { class SignalHandler {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册