提交 6edfae42 编写于 作者: Y Yancey1989

reset received vars on pserver

上级 f76f42c2
...@@ -67,24 +67,11 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -67,24 +67,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
LOG(FATAL) << "sync: Can not find server side var: " << varname; LOG(FATAL) << "sync: Can not find server side var: " << varname;
return false; return false;
} }
if (invar->IsType<framework::SelectedRows>()) {
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
sparse_vars_.push_back(invar);
}
} }
} }
return true; return true;
} }
void RequestSendHandler::ResetSparseVarRecorder() {
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
for (auto* var : sparse_vars_) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
sparse_vars_.clear();
}
bool RequestGetHandler::Handle(const std::string& varname, bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
......
...@@ -41,11 +41,6 @@ class RequestSendHandler final : public RequestHandler { ...@@ -41,11 +41,6 @@ class RequestSendHandler final : public RequestHandler {
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override; const std::string& out_var_name = "") override;
void ResetSparseVarRecorder();
private:
std::mutex mutex_sparse_vars_;
std::vector<framework::Variable*> sparse_vars_;
}; };
class RequestGetHandler final : public RequestHandler { class RequestGetHandler final : public RequestHandler {
......
...@@ -101,6 +101,8 @@ void RPCServer::Complete() { ...@@ -101,6 +101,8 @@ void RPCServer::Complete() {
{ {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
client_num_--; client_num_--;
need_reset_all_vars_ = true;
VLOG(4) << "decrease client_num to: " << client_num_; VLOG(4) << "decrease client_num to: " << client_num_;
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) { if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
barrier_counter_[kRequestGet]--; barrier_counter_[kRequestGet]--;
...@@ -109,6 +111,11 @@ void RPCServer::Complete() { ...@@ -109,6 +111,11 @@ void RPCServer::Complete() {
barrier_cond_.notify_all(); barrier_cond_.notify_all();
} }
bool RPCServer::NeedResetAllVars() {
std::unique_lock<std::mutex> lock(mutex_);
return need_reset_all_vars_;
}
int RPCServer::GetClientNum() { int RPCServer::GetClientNum() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
return client_num_; return client_num_;
...@@ -120,6 +127,7 @@ void RPCServer::ResetBarrierCounter() { ...@@ -120,6 +127,7 @@ void RPCServer::ResetBarrierCounter() {
for (auto& t : barrier_counter_) { for (auto& t : barrier_counter_) {
t.second = 0; t.second = 0;
} }
need_reset_all_vars_ = false;
} }
void RPCServer::RegisterRPC(const std::string& rpc_name, void RPCServer::RegisterRPC(const std::string& rpc_name,
......
...@@ -49,7 +49,8 @@ class RPCServer { ...@@ -49,7 +49,8 @@ class RPCServer {
bind_address_(address), bind_address_(address),
exit_flag_(false), exit_flag_(false),
selected_port_(0), selected_port_(0),
client_num_(client_num) {} client_num_(client_num),
need_reset_all_vars_(false) {}
virtual ~RPCServer() {} virtual ~RPCServer() {}
virtual void StartServer() = 0; virtual void StartServer() = 0;
...@@ -86,6 +87,8 @@ class RPCServer { ...@@ -86,6 +87,8 @@ class RPCServer {
void ResetBarrierCounter(); void ResetBarrierCounter();
RPCServerProfiler& Profiler() { return profiler_; } RPCServerProfiler& Profiler() { return profiler_; }
bool NeedResetAllVars();
protected: protected:
virtual void ShutDownImpl() = 0; virtual void ShutDownImpl() = 0;
...@@ -104,6 +107,7 @@ class RPCServer { ...@@ -104,6 +107,7 @@ class RPCServer {
std::atomic<int> exit_flag_; std::atomic<int> exit_flag_;
int selected_port_; int selected_port_;
int client_num_; int client_num_;
bool need_reset_all_vars_;
std::unordered_map<std::string, RequestHandler*> rpc_call_map_; std::unordered_map<std::string, RequestHandler*> rpc_call_map_;
std::unordered_map<std::string, int> rpc_thread_num_; std::unordered_map<std::string, int> rpc_thread_num_;
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
...@@ -101,9 +102,10 @@ static int64_t GetTimestamp() { ...@@ -101,9 +102,10 @@ static int64_t GetTimestamp() {
void ListenAndServOp::RunSyncLoop( void ListenAndServOp::RunSyncLoop(
framework::Executor *executor, framework::ProgramDesc *program, framework::Executor *executor, framework::ProgramDesc *program,
framework::Scope *recv_scope, framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
const std::vector<int> &prefetch_block_id_list, const std::vector<int> &prefetch_block_id_list,
const int checkpoint_point_block_id) const { const int checkpoint_point_block_id,
const std::vector<std::string> &recv_varnames) const {
VLOG(2) << "RunSyncLoop"; VLOG(2) << "RunSyncLoop";
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
auto optimize_blocks = auto optimize_blocks =
...@@ -166,8 +168,8 @@ void ListenAndServOp::RunSyncLoop( ...@@ -166,8 +168,8 @@ void ListenAndServOp::RunSyncLoop(
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)"; VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
// reset received sparse vars to avoid reuse it in the next mini-batch // reset received sparse vars to avoid reuse it in the next mini-batch
dynamic_cast<distributed::RequestSendHandler *>(request_send_handler_.get()) ResetReceivedVars(recv_varnames, recv_scope, dev_ctx,
->ResetSparseVarRecorder(); !rpc_service_->NeedResetAllVars());
rpc_service_->SetCond(distributed::kRequestGet); rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(distributed::kRequestGet); rpc_service_->WaitBarrier(distributed::kRequestGet);
...@@ -175,6 +177,33 @@ void ListenAndServOp::RunSyncLoop( ...@@ -175,6 +177,33 @@ void ListenAndServOp::RunSyncLoop(
} // while(true) } // while(true)
} }
void ListenAndServOp::ResetReceivedVars(
const std::vector<std::string> &recv_varnames, framework::Scope *recv_scope,
platform::DeviceContext *dev_ctx, bool only_sparse_vars) const {
for (auto &varname : recv_varnames) {
auto var = recv_scope->FindVar(varname);
if (var == nullptr) {
VLOG(2) << "can not find var " << varname << " in received scope";
continue;
}
if (var->IsType<framework::SelectedRows>()) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
if (!only_sparse_vars) {
if (var->IsType<framework::LoDTensor>()) {
math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(),
static_cast<float>(0));
} else if (var->IsType<framework::Tensor>()) {
math::set_constant(*dev_ctx, var->GetMutable<framework::Tensor>(),
static_cast<float>(0));
} else {
PADDLE_THROW(
"received var should be in [SelectedRows, LoDTensor, Tensor]");
}
}
}
}
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program, framework::ProgramDesc *program,
framework::Scope *recv_scope) const { framework::Scope *recv_scope) const {
...@@ -258,6 +287,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -258,6 +287,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
bool sync_mode = Attr<bool>("sync_mode"); bool sync_mode = Attr<bool>("sync_mode");
auto fan_in = Attr<int>("Fanin"); auto fan_in = Attr<int>("Fanin");
auto inputs = Inputs("X");
PADDLE_ENFORCE(!rpc_service_); PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
...@@ -351,8 +381,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -351,8 +381,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// Write to a file of server selected port for python use. // Write to a file of server selected port for python use.
SavePort(); SavePort();
if (sync_mode) { if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list, RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
checkpoint_block_id); prefetch_block_id_list, checkpoint_block_id, inputs);
} else { } else {
RunAsyncLoop(&executor, program, &recv_scope); RunAsyncLoop(&executor, program, &recv_scope);
} }
......
...@@ -26,6 +26,7 @@ limitations under the License. */ ...@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -48,8 +49,10 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -48,8 +49,10 @@ class ListenAndServOp : public framework::OperatorBase {
void RunSyncLoop(framework::Executor* executor, void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program, framework::ProgramDesc* program,
framework::Scope* recv_scope, framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx,
const std::vector<int>& prefetch_block_id_list, const std::vector<int>& prefetch_block_id_list,
const int checkpoint_point_block_id) const; const int checkpoint_point_block_id,
const std::vector<std::string>& recv_varnames) const;
void RunAsyncLoop(framework::Executor* executor, void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program, framework::ProgramDesc* program,
...@@ -64,6 +67,11 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -64,6 +67,11 @@ class ListenAndServOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override; const platform::Place& dev_place) const override;
void ResetReceivedVars(const std::vector<std::string>& recv_varnames,
framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx,
bool only_sparse_vars = true) const;
protected: protected:
mutable std::shared_ptr<distributed::RPCServer> rpc_service_; mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_; mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册