提交 77c6b71e 编写于 作者: T tangwei12

add ckpt to sync loop

上级 1fabbbad
...@@ -101,6 +101,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -101,6 +101,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
framework::Scope *recv_scope, framework::Scope *recv_scope,
framework::BlockDesc *prefetch_block) const { framework::BlockDesc *prefetch_block) const {
auto fan_in = Attr<int>("Fanin"); auto fan_in = Attr<int>("Fanin");
auto checkpoint = Attr<std::string>("Checkpoint");
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
...@@ -188,6 +189,18 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -188,6 +189,18 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
for (auto &var : sparse_vars) { for (auto &var : sparse_vars) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear(); var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
} }
/******************** CHECK POINT ***********************/
std::vector<std::string> all_vars = recv_scope.LocalVarNames();
std::vector<std::string>::iterator it;
for (it = all_vars.begin(); it != all_vars.end(); it++) {
VLOG(2) << "Checkpoint Var: " << *it;
break;
}
/******************** CHECK POINT ***********************/
rpc_service_->SetCond(1); rpc_service_->SetCond(1);
// FIXME(typhoonzero): use another condition to sync wait clients get. // FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_->WaitClientGet(fan_in); rpc_service_->WaitClientGet(fan_in);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册