提交 32b94a7d 编写于 作者: Y Yancey1989

cache var types

上级 e5a93539
......@@ -104,8 +104,7 @@ void ListenAndServOp::RunSyncLoop(
framework::Executor *executor, framework::ProgramDesc *program,
framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
const std::vector<int> &prefetch_block_id_list,
const int checkpoint_point_block_id,
const std::vector<std::string> &recv_varnames) const {
const int checkpoint_point_block_id) const {
VLOG(2) << "RunSyncLoop";
size_t num_blocks = program->Size();
auto optimize_blocks =
......@@ -130,6 +129,7 @@ void ListenAndServOp::RunSyncLoop(
rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(distributed::kRequestGet);
rpc_service_->ResetBarrierCounter();
while (true) {
rpc_service_->Profiler().OneStep();
// Get from multiple trainers, we don't care about the order in which
......@@ -167,8 +167,7 @@ void ListenAndServOp::RunSyncLoop(
recv_scope);
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
ResetReceivedVars(recv_varnames, recv_scope, dev_ctx,
rpc_service_->NeedResetAllVars());
ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars());
rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(distributed::kRequestGet);
......@@ -176,10 +175,10 @@ void ListenAndServOp::RunSyncLoop(
} // while(true)
}
void ListenAndServOp::ResetReceivedVars(
const std::vector<std::string> &recv_varnames, framework::Scope *recv_scope,
platform::DeviceContext *dev_ctx, bool reset_all) const {
for (auto &varname : recv_varnames) {
void ListenAndServOp::ResetReceivedVars(framework::Scope *recv_scope,
platform::DeviceContext *dev_ctx,
bool reset_all) const {
for (auto &varname : sparse_vars_) {
auto var = recv_scope->FindVar(varname);
if (var == nullptr) {
VLOG(2) << "can not find var " << varname << " in received scope";
......@@ -188,9 +187,17 @@ void ListenAndServOp::ResetReceivedVars(
if (var->IsType<framework::SelectedRows>()) {
VLOG(3) << "reset sparse var: " << varname;
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
} else {
PADDLE_THROW("The type of sparse var should be SelectedRows");
}
}
if (UNLIKELY(reset_all)) {
VLOG(3) << "reset dense var: " << varname;
for (auto &varname : dense_vars_) {
auto var = recv_scope->FindVar(varname);
if (var == nullptr) {
VLOG(2) << "can not find var " << varname << " in received scope";
continue;
}
if (var->IsType<framework::LoDTensor>()) {
math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(),
static_cast<float>(0));
......@@ -198,8 +205,7 @@ void ListenAndServOp::ResetReceivedVars(
math::set_constant(*dev_ctx, var->GetMutable<framework::Tensor>(),
static_cast<float>(0));
} else {
PADDLE_THROW(
"received var should be in [SelectedRows, LoDTensor, Tensor]");
PADDLE_THROW("The type of dense var should be in [LoDTensor, Tensor]");
}
}
}
......@@ -278,6 +284,25 @@ static void FillRequestCtx(
h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx);
}
void ListenAndServOp::CacheVarsType(const std::vector<std::string> &varnames,
const framework::Scope &scope) const {
for (const auto &varname : varnames) {
auto var = scope.FindVar(varname);
PADDLE_ENFORCE(var != nullptr,
"Received var should be initialized in the received scope.");
if (var->IsType<framework::SelectedRows>()) {
sparse_vars_.push_back(varname);
} else if (var->IsType<framework::LoDTensor>() ||
var->IsType<framework::Tensor>()) {
dense_vars_.push_back(varname);
} else {
PADDLE_THROW(
"The type of received var should be in [SelectedRows, LoDTensor, "
"Tensor].");
}
}
}
void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const {
// Mark this as PS that it should decide profiling by listening from trainer.
......@@ -379,11 +404,16 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal(SIGINT, SignalHandler::StopAndExit);
signal(SIGTERM, SignalHandler::StopAndExit);
// Cache the type of the received vars as `sparse_vars_` and `dense_vars_`
// so that we can reset them at the end of each iteration.
// NOTE: only used in sync update
CacheVarsType(inputs, recv_scope);
// Write to a file of server selected port for python use.
SavePort();
if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
prefetch_block_id_list, checkpoint_block_id, inputs);
prefetch_block_id_list, checkpoint_block_id);
} else {
RunAsyncLoop(&executor, program, &recv_scope);
}
......
......@@ -51,8 +51,7 @@ class ListenAndServOp : public framework::OperatorBase {
framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx,
const std::vector<int>& prefetch_block_id_list,
const int checkpoint_point_block_id,
const std::vector<std::string>& recv_varnames) const;
const int checkpoint_point_block_id) const;
void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
......@@ -67,11 +66,13 @@ class ListenAndServOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override;
void ResetReceivedVars(const std::vector<std::string>& recv_varnames,
framework::Scope* recv_scope,
void ResetReceivedVars(framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx,
bool reset_all = false) const;
void CacheVarsType(const std::vector<std::string>& varnames,
const framework::Scope& scope) const;
protected:
mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
......@@ -82,6 +83,8 @@ class ListenAndServOp : public framework::OperatorBase {
request_checkpoint_handler_;
mutable std::shared_ptr<std::thread> server_thread_;
mutable std::vector<std::string> sparse_vars_;
mutable std::vector<std::string> dense_vars_;
};
class SignalHandler {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册