提交 d93dc81c 编写于 作者: T tangwei12

add handle when checkpoint_notify_id = -1

上级 1571c25a
......@@ -125,11 +125,15 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Variable* invar,
framework::Variable** outvar,
const std::string& out_var_name) {
PADDLE_ENFORCE(
checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke.");
auto *lt_var = scope->FindVar("loopup_table_path")->GetMutable<std::string>();
auto* lt_var = scope->FindVar("loopup_table_path")->GetMutable<std::string>();
lt_var->clear();
lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update loopup_table_path to: " << out_var_name;
VLOG(4) << "RequestCheckpointHandler update loopup_table_path to: "
<< out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope);
return true;
}
......
......@@ -68,12 +68,17 @@ class RequestPrefetchHandler final : public RequestHandler {
class RequestCheckpointHandler final : public RequestHandler {
public:
explicit RequestCheckpointHandler(bool sync_mode)
: RequestHandler(sync_mode) {}
explicit RequestCheckpointHandler(bool sync_mode, int checkpoint_notify_id)
: RequestHandler(sync_mode) {
this.checkpoint_notify_id = checkpoint_notify_id;
}
virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
private:
int checkpoint_notify_id;
};
} // namespace detail
......
......@@ -247,9 +247,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint");
int checkpoint_point_block_id = Attr<int>(kCheckpointBlockId);
LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint;
<< ", end_point:" << endpoint
<< ", CheckpointNotify Id: " << checkpoint_notify_id;
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
......@@ -258,7 +260,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_prefetch_handler_.reset(
new detail::RequestPrefetchHandler(sync_mode));
request_checkpoint_handler_.reset(
new detail::RequestCheckpointHandler(sync_mode));
new detail::RequestCheckpointHandler(sync_mode, checkpoint_notify_id));
rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get());
......@@ -267,6 +269,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_->RegisterRPC(detail::kRequestCheckpoint,
request_checkpoint_handler_.get());
std::shared_ptr<framework::ExecutorPrepareContext> ckpt_pre_context = nullptr;
if (checkpoint_notify_id != -1) {
auto ctx = executor.Prepare(*program, checkpoint_point_block_id);
ckpt_pre_context = std::move(ctx);
}
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = optimize_block->Program();
framework::Executor executor(dev_place);
......@@ -301,12 +309,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i];
}
int checkpoint_point_block_id = Attr<int>(kCheckpointBlockId);
auto ctx = executor.Prepare(*program, checkpoint_point_block_id);
std::shared_ptr<framework::ExecutorPrepareContext> ckpt_pre_context =
std::move(ctx);
auto f =
std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx,
&executor, program, &prefetch_var_name_to_prepared_ctx,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册