未验证 提交 d117bbc3 编写于 作者: Y Yan Xu 提交者: GitHub

Merge pull request #13291 from Yancey1989/reset_vars_on_pserver

reset received vars on pserver
......@@ -67,24 +67,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
LOG(FATAL) << "sync: Can not find server side var: " << varname;
return false;
}
if (invar->IsType<framework::SelectedRows>()) {
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
sparse_vars_.push_back(invar);
}
}
}
return true;
}
void RequestSendHandler::ResetSparseVarRecorder() {
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
for (auto* var : sparse_vars_) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
sparse_vars_.clear();
}
bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
......
......@@ -41,11 +41,6 @@ class RequestSendHandler final : public RequestHandler {
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
void ResetSparseVarRecorder();
private:
std::mutex mutex_sparse_vars_;
std::vector<framework::Variable*> sparse_vars_;
};
class RequestGetHandler final : public RequestHandler {
......
......@@ -101,6 +101,8 @@ void RPCServer::Complete() {
{
std::unique_lock<std::mutex> lock(mutex_);
client_num_--;
need_reset_all_vars_ = true;
VLOG(4) << "decrease client_num to: " << client_num_;
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
barrier_counter_[kRequestGet]--;
......@@ -109,6 +111,11 @@ void RPCServer::Complete() {
barrier_cond_.notify_all();
}
bool RPCServer::NeedResetAllVars() {
std::unique_lock<std::mutex> lock(mutex_);
return need_reset_all_vars_;
}
int RPCServer::GetClientNum() {
std::unique_lock<std::mutex> lock(mutex_);
return client_num_;
......@@ -120,6 +127,7 @@ void RPCServer::ResetBarrierCounter() {
for (auto& t : barrier_counter_) {
t.second = 0;
}
need_reset_all_vars_ = false;
}
void RPCServer::RegisterRPC(const std::string& rpc_name,
......
......@@ -49,7 +49,8 @@ class RPCServer {
bind_address_(address),
exit_flag_(false),
selected_port_(0),
client_num_(client_num) {}
client_num_(client_num),
need_reset_all_vars_(false) {}
virtual ~RPCServer() {}
virtual void StartServer() = 0;
......@@ -86,6 +87,8 @@ class RPCServer {
void ResetBarrierCounter();
RPCServerProfiler& Profiler() { return profiler_; }
bool NeedResetAllVars();
protected:
virtual void ShutDownImpl() = 0;
......@@ -104,6 +107,7 @@ class RPCServer {
std::atomic<int> exit_flag_;
int selected_port_;
int client_num_;
bool need_reset_all_vars_;
std::unordered_map<std::string, RequestHandler*> rpc_call_map_;
std::unordered_map<std::string, int> rpc_thread_num_;
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "gflags/gflags.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
......@@ -101,7 +102,7 @@ static int64_t GetTimestamp() {
void ListenAndServOp::RunSyncLoop(
framework::Executor *executor, framework::ProgramDesc *program,
framework::Scope *recv_scope,
framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
const std::vector<int> &prefetch_block_id_list,
const int checkpoint_point_block_id) const {
VLOG(2) << "RunSyncLoop";
......@@ -128,6 +129,7 @@ void ListenAndServOp::RunSyncLoop(
rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(distributed::kRequestGet);
rpc_service_->ResetBarrierCounter();
while (true) {
rpc_service_->Profiler().OneStep();
// Get from multiple trainers, we don't care about the order in which
......@@ -165,9 +167,7 @@ void ListenAndServOp::RunSyncLoop(
recv_scope);
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
// reset received sparse vars to avoid reuse it in the next mini-batch
dynamic_cast<distributed::RequestSendHandler *>(request_send_handler_.get())
->ResetSparseVarRecorder();
ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars());
rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(distributed::kRequestGet);
......@@ -175,6 +175,42 @@ void ListenAndServOp::RunSyncLoop(
} // while(true)
}
void ListenAndServOp::ResetReceivedVars(framework::Scope *recv_scope,
platform::DeviceContext *dev_ctx,
bool reset_all) const {
for (auto &varname : sparse_vars_) {
auto var = recv_scope->FindVar(varname);
if (var == nullptr) {
VLOG(2) << "can not find var " << varname << " in received scope";
continue;
}
if (var->IsType<framework::SelectedRows>()) {
VLOG(3) << "reset sparse var: " << varname;
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
} else {
PADDLE_THROW("The type of sparse var should be SelectedRows");
}
}
if (UNLIKELY(reset_all)) {
for (auto &varname : dense_vars_) {
auto var = recv_scope->FindVar(varname);
if (var == nullptr) {
VLOG(2) << "can not find var " << varname << " in received scope";
continue;
}
if (var->IsType<framework::LoDTensor>()) {
math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(),
static_cast<float>(0));
} else if (var->IsType<framework::Tensor>()) {
math::set_constant(*dev_ctx, var->GetMutable<framework::Tensor>(),
static_cast<float>(0));
} else {
PADDLE_THROW("The type of dense var should be in [LoDTensor, Tensor]");
}
}
}
}
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope) const {
......@@ -248,6 +284,25 @@ static void FillRequestCtx(
h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx);
}
void ListenAndServOp::CacheVarsType(const std::vector<std::string> &varnames,
const framework::Scope &scope) const {
for (const auto &varname : varnames) {
auto var = scope.FindVar(varname);
PADDLE_ENFORCE(var != nullptr,
"Received var should be initialized in the received scope.");
if (var->IsType<framework::SelectedRows>()) {
sparse_vars_.push_back(varname);
} else if (var->IsType<framework::LoDTensor>() ||
var->IsType<framework::Tensor>()) {
dense_vars_.push_back(varname);
} else {
PADDLE_THROW(
"The type of received var should be in [SelectedRows, LoDTensor, "
"Tensor].");
}
}
}
void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const {
// Mark this as PS that it should decide profiling by listening from trainer.
......@@ -258,6 +313,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
bool sync_mode = Attr<bool>("sync_mode");
auto fan_in = Attr<int>("Fanin");
auto inputs = Inputs("X");
PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint");
......@@ -348,11 +404,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal(SIGINT, SignalHandler::StopAndExit);
signal(SIGTERM, SignalHandler::StopAndExit);
// Cache the type of the received vars as `sparse_vars_` and `dense_vars_`
// so that we can reset them at the end of each iteration.
// NOTE: only used in sync update
CacheVarsType(inputs, recv_scope);
// Write to a file of server selected port for python use.
SavePort();
if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list,
checkpoint_block_id);
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
prefetch_block_id_list, checkpoint_block_id);
} else {
RunAsyncLoop(&executor, program, &recv_scope);
}
......
......@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -48,6 +49,7 @@ class ListenAndServOp : public framework::OperatorBase {
void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx,
const std::vector<int>& prefetch_block_id_list,
const int checkpoint_point_block_id) const;
......@@ -64,6 +66,13 @@ class ListenAndServOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override;
void ResetReceivedVars(framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx,
bool reset_all = false) const;
void CacheVarsType(const std::vector<std::string>& varnames,
const framework::Scope& scope) const;
protected:
mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
......@@ -74,6 +83,8 @@ class ListenAndServOp : public framework::OperatorBase {
request_checkpoint_handler_;
mutable std::shared_ptr<std::thread> server_thread_;
mutable std::vector<std::string> sparse_vars_;
mutable std::vector<std::string> dense_vars_;
};
class SignalHandler {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册