提交 32b94a7d 编写于 作者: Y Yancey1989

cache var types

上级 e5a93539
...@@ -104,8 +104,7 @@ void ListenAndServOp::RunSyncLoop( ...@@ -104,8 +104,7 @@ void ListenAndServOp::RunSyncLoop(
framework::Executor *executor, framework::ProgramDesc *program, framework::Executor *executor, framework::ProgramDesc *program,
framework::Scope *recv_scope, platform::DeviceContext *dev_ctx, framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
const std::vector<int> &prefetch_block_id_list, const std::vector<int> &prefetch_block_id_list,
const int checkpoint_point_block_id, const int checkpoint_point_block_id) const {
const std::vector<std::string> &recv_varnames) const {
VLOG(2) << "RunSyncLoop"; VLOG(2) << "RunSyncLoop";
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
auto optimize_blocks = auto optimize_blocks =
...@@ -130,6 +129,7 @@ void ListenAndServOp::RunSyncLoop( ...@@ -130,6 +129,7 @@ void ListenAndServOp::RunSyncLoop(
rpc_service_->SetCond(distributed::kRequestGet); rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(distributed::kRequestGet); rpc_service_->WaitBarrier(distributed::kRequestGet);
rpc_service_->ResetBarrierCounter(); rpc_service_->ResetBarrierCounter();
while (true) { while (true) {
rpc_service_->Profiler().OneStep(); rpc_service_->Profiler().OneStep();
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
...@@ -167,8 +167,7 @@ void ListenAndServOp::RunSyncLoop( ...@@ -167,8 +167,7 @@ void ListenAndServOp::RunSyncLoop(
recv_scope); recv_scope);
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)"; VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
ResetReceivedVars(recv_varnames, recv_scope, dev_ctx, ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars());
rpc_service_->NeedResetAllVars());
rpc_service_->SetCond(distributed::kRequestGet); rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(distributed::kRequestGet); rpc_service_->WaitBarrier(distributed::kRequestGet);
...@@ -176,10 +175,10 @@ void ListenAndServOp::RunSyncLoop( ...@@ -176,10 +175,10 @@ void ListenAndServOp::RunSyncLoop(
} // while(true) } // while(true)
} }
void ListenAndServOp::ResetReceivedVars( void ListenAndServOp::ResetReceivedVars(framework::Scope *recv_scope,
const std::vector<std::string> &recv_varnames, framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
platform::DeviceContext *dev_ctx, bool reset_all) const { bool reset_all) const {
for (auto &varname : recv_varnames) { for (auto &varname : sparse_vars_) {
auto var = recv_scope->FindVar(varname); auto var = recv_scope->FindVar(varname);
if (var == nullptr) { if (var == nullptr) {
VLOG(2) << "can not find var " << varname << " in received scope"; VLOG(2) << "can not find var " << varname << " in received scope";
...@@ -188,9 +187,17 @@ void ListenAndServOp::ResetReceivedVars( ...@@ -188,9 +187,17 @@ void ListenAndServOp::ResetReceivedVars(
if (var->IsType<framework::SelectedRows>()) { if (var->IsType<framework::SelectedRows>()) {
VLOG(3) << "reset sparse var: " << varname; VLOG(3) << "reset sparse var: " << varname;
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear(); var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
} else {
PADDLE_THROW("The type of sparse var should be SelectedRows");
}
} }
if (UNLIKELY(reset_all)) { if (UNLIKELY(reset_all)) {
VLOG(3) << "reset dense var: " << varname; for (auto &varname : dense_vars_) {
auto var = recv_scope->FindVar(varname);
if (var == nullptr) {
VLOG(2) << "can not find var " << varname << " in received scope";
continue;
}
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(), math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(),
static_cast<float>(0)); static_cast<float>(0));
...@@ -198,8 +205,7 @@ void ListenAndServOp::ResetReceivedVars( ...@@ -198,8 +205,7 @@ void ListenAndServOp::ResetReceivedVars(
math::set_constant(*dev_ctx, var->GetMutable<framework::Tensor>(), math::set_constant(*dev_ctx, var->GetMutable<framework::Tensor>(),
static_cast<float>(0)); static_cast<float>(0));
} else { } else {
PADDLE_THROW( PADDLE_THROW("The type of dense var should be in [LoDTensor, Tensor]");
"received var should be in [SelectedRows, LoDTensor, Tensor]");
} }
} }
} }
...@@ -278,6 +284,25 @@ static void FillRequestCtx( ...@@ -278,6 +284,25 @@ static void FillRequestCtx(
h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx); h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx);
} }
void ListenAndServOp::CacheVarsType(const std::vector<std::string> &varnames,
const framework::Scope &scope) const {
for (const auto &varname : varnames) {
auto var = scope.FindVar(varname);
PADDLE_ENFORCE(var != nullptr,
"Received var should be initialized in the received scope.");
if (var->IsType<framework::SelectedRows>()) {
sparse_vars_.push_back(varname);
} else if (var->IsType<framework::LoDTensor>() ||
var->IsType<framework::Tensor>()) {
dense_vars_.push_back(varname);
} else {
PADDLE_THROW(
"The type of received var should be in [SelectedRows, LoDTensor, "
"Tensor].");
}
}
}
void ListenAndServOp::RunImpl(const framework::Scope &scope, void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place) const {
// Mark this as PS that it should decide profiling by listening from trainer. // Mark this as PS that it should decide profiling by listening from trainer.
...@@ -379,11 +404,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -379,11 +404,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal(SIGINT, SignalHandler::StopAndExit); signal(SIGINT, SignalHandler::StopAndExit);
signal(SIGTERM, SignalHandler::StopAndExit); signal(SIGTERM, SignalHandler::StopAndExit);
// Cache the type of the received vars as `sparse_vars_` and `dense_vars_`
// so that we can reset them at the end of each iteration.
// NOTE: only used in sync update
CacheVarsType(inputs, recv_scope);
// Write to a file of server selected port for python use. // Write to a file of server selected port for python use.
SavePort(); SavePort();
if (sync_mode) { if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx, RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
prefetch_block_id_list, checkpoint_block_id, inputs); prefetch_block_id_list, checkpoint_block_id);
} else { } else {
RunAsyncLoop(&executor, program, &recv_scope); RunAsyncLoop(&executor, program, &recv_scope);
} }
......
...@@ -51,8 +51,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -51,8 +51,7 @@ class ListenAndServOp : public framework::OperatorBase {
framework::Scope* recv_scope, framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx, platform::DeviceContext* dev_ctx,
const std::vector<int>& prefetch_block_id_list, const std::vector<int>& prefetch_block_id_list,
const int checkpoint_point_block_id, const int checkpoint_point_block_id) const;
const std::vector<std::string>& recv_varnames) const;
void RunAsyncLoop(framework::Executor* executor, void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program, framework::ProgramDesc* program,
...@@ -67,11 +66,13 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -67,11 +66,13 @@ class ListenAndServOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override; const platform::Place& dev_place) const override;
void ResetReceivedVars(const std::vector<std::string>& recv_varnames, void ResetReceivedVars(framework::Scope* recv_scope,
framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx, platform::DeviceContext* dev_ctx,
bool reset_all = false) const; bool reset_all = false) const;
void CacheVarsType(const std::vector<std::string>& varnames,
const framework::Scope& scope) const;
protected: protected:
mutable std::shared_ptr<distributed::RPCServer> rpc_service_; mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_; mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
...@@ -82,6 +83,8 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -82,6 +83,8 @@ class ListenAndServOp : public framework::OperatorBase {
request_checkpoint_handler_; request_checkpoint_handler_;
mutable std::shared_ptr<std::thread> server_thread_; mutable std::shared_ptr<std::thread> server_thread_;
mutable std::vector<std::string> sparse_vars_;
mutable std::vector<std::string> dense_vars_;
}; };
class SignalHandler { class SignalHandler {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册