提交 e684575f 编写于 作者: T tangwei12

checkpoint feature optimized

上级 2229db52
...@@ -42,10 +42,11 @@ class CheckpointNotifyOp : public framework::OperatorBase { ...@@ -42,10 +42,11 @@ class CheckpointNotifyOp : public framework::OperatorBase {
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < epmap.size(); i++) { for (size_t i = 0; i < epmap.size(); i++) {
VLOG(3) << "checkpoint notify sending " << dir << " to " << epmap[i]; auto lookup_table_save_dir =
auto serial_looku_table =
string::Sprintf("%s/%s_%d", dir, lookup_table_name, i); string::Sprintf("%s/%s_%d", dir, lookup_table_name, i);
rpc_client->AsyncCheckpointNotify(epmap[i], serial_looku_table); rpc_client->AsyncCheckpointNotify(epmap[i], lookup_table_save_dir);
VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name
<< " and dir:" << dir << " to " << epmap[i];
} }
rpc_client->Wait(); rpc_client->Wait();
} }
...@@ -64,10 +65,10 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -64,10 +65,10 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::string>("lookup_table", AddAttr<std::string>("lookup_table",
"(string, default '') the lookup table name"); "(string, default '') the lookup table name");
AddComment(R"DOC( AddComment(R"DOC(
Prefetch operator CheckpointNotify operator
This operator will send Ids variables to listen_and_serve op at This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at
the parameter server and fetch result back. the parameter server.
)DOC"); )DOC");
} }
}; };
......
...@@ -25,3 +25,7 @@ ...@@ -25,3 +25,7 @@
#define RPCSERVER_T distributed::AsyncBRPCServer #define RPCSERVER_T distributed::AsyncBRPCServer
#define RPCCLIENT_T distributed::BRPCClient #define RPCCLIENT_T distributed::BRPCClient
#endif #endif
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "lookup_table_path";
...@@ -194,7 +194,7 @@ class RequestCheckpointNotify final : public RequestBase { ...@@ -194,7 +194,7 @@ class RequestCheckpointNotify final : public RequestBase {
RequestHandler* request_handler, int req_id) RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new VariableResponse(request_handler->scope(), request_.reset(new VariableResponse(request_handler->scope(),
request_handler->dev_ctx(), true)); request_handler->dev_ctx()));
int method_id = int method_id =
static_cast<int>(distributed::GrpcMethod::kCheckpointNotify); static_cast<int>(distributed::GrpcMethod::kCheckpointNotify);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
...@@ -212,13 +212,10 @@ class RequestCheckpointNotify final : public RequestBase { ...@@ -212,13 +212,10 @@ class RequestCheckpointNotify final : public RequestBase {
std::string checkpoint_notify = request_->Varname(); std::string checkpoint_notify = request_->Varname();
std::string checkpoint_dir = request_->OutVarname(); std::string checkpoint_dir = request_->OutVarname();
framework::Variable* invar = nullptr;
framework::Variable* outvar = nullptr;
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir; << ", dir: " << checkpoint_dir;
request_handler_->Handle(checkpoint_notify, scope, invar, &outvar, request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
checkpoint_dir); checkpoint_dir);
Finish(reply_, &responder_); Finish(reply_, &responder_);
} }
...@@ -320,8 +317,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, ...@@ -320,8 +317,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
return; return;
} }
LOG(INFO) << "TryToRegisterNewOne on RPC NAME: " << rpc_name VLOG(4) << "TryToRegisterNewOne on RPC NAME: " << rpc_name
<< " REQ ID: " << req_id; << " REQ ID: " << req_id;
auto& reqs = rpc_reqs_[rpc_name]; auto& reqs = rpc_reqs_[rpc_name];
auto& handler = rpc_call_map_[rpc_name]; auto& handler = rpc_call_map_[rpc_name];
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
...@@ -129,10 +130,10 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, ...@@ -129,10 +130,10 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
checkpoint_notify_id != -1, checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke."); "when checkpoint_notify_id = -1, there should be no RPC invoke.");
auto* lt_var = scope->FindVar("loopup_table_path")->GetMutable<std::string>(); auto* lt_var = scope->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
lt_var->clear(); lt_var->clear();
lt_var->append(out_var_name); lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update loopup_table_path to: " VLOG(4) << "RequestCheckpointHandler update var lookup_table_path to: "
<< out_var_name; << out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope);
return true; return true;
......
...@@ -247,11 +247,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -247,11 +247,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
PADDLE_ENFORCE(!rpc_service_); PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
int checkpoint_notify_id = Attr<int>(kCheckpointBlockId); int checkpoint_notify_block_id = Attr<int>(kCheckpointBlockId);
LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint << ", end_point:" << endpoint
<< ", CheckpointNotify Id: " << checkpoint_notify_id; << ", CheckpointNotify Id: " << checkpoint_notify_block_id;
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
...@@ -260,7 +260,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -260,7 +260,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_prefetch_handler_.reset( request_prefetch_handler_.reset(
new distributed::RequestPrefetchHandler(sync_mode)); new distributed::RequestPrefetchHandler(sync_mode));
request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler( request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler(
sync_mode, checkpoint_notify_id)); sync_mode, checkpoint_notify_block_id));
rpc_service_->RegisterRPC(distributed::kRequestSend, rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get()); request_send_handler_.get());
...@@ -276,8 +276,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -276,8 +276,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
std::shared_ptr<framework::ExecutorPrepareContext> ckpt_pre_context = nullptr; std::shared_ptr<framework::ExecutorPrepareContext> ckpt_pre_context = nullptr;
if (checkpoint_notify_id != -1) { if (checkpoint_notify_block_id != -1) {
auto ctx = executor.Prepare(*program, checkpoint_notify_id); auto ctx = executor.Prepare(*program, checkpoint_notify_block_id);
ckpt_pre_context = std::move(ctx); ckpt_pre_context = std::move(ctx);
} }
...@@ -334,7 +334,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -334,7 +334,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
SavePort(); SavePort();
if (sync_mode) { if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list, RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list,
checkpoint_notify_id); checkpoint_notify_block_id);
} else { } else {
RunAsyncLoop(&executor, program); RunAsyncLoop(&executor, program);
} }
......
...@@ -101,7 +101,7 @@ class LoadOp : public framework::OperatorBase { ...@@ -101,7 +101,7 @@ class LoadOp : public framework::OperatorBase {
class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddOutput("Out", "The tensor need to be loaded"); AddOutput("Out", "The LoDTensor / SelectedRows need to be loaded");
AddAttr<bool>( AddAttr<bool>(
"load_as_fp16", "load_as_fp16",
"If true, the tensor will be first loaded and then " "If true, the tensor will be first loaded and then "
...@@ -112,7 +112,9 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -112,7 +112,9 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
R"(Variable will be loaded from "file_path")") R"(Variable will be loaded from "file_path")")
.AddCustomChecker( .AddCustomChecker(
[](const std::string &path) { return !path.empty(); }); [](const std::string &path) { return !path.empty(); });
AddComment("Load operator will load a tensor variable from disk file."); AddComment(
"Load operator will load a LoDTensor / SelectedRows variable from disk "
"file.");
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
...@@ -131,11 +132,10 @@ class SaveOp : public framework::OperatorBase { ...@@ -131,11 +132,10 @@ class SaveOp : public framework::OperatorBase {
void SaveSelectedRows(const framework::Scope &scope, void SaveSelectedRows(const framework::Scope &scope,
const platform::Place &place, const platform::Place &place,
framework::Variable *var) const { framework::Variable *var) const {
auto *lt_var = auto *lt_var = scope.FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
scope.FindVar("loopup_table_path")->GetMutable<std::string>();
PADDLE_ENFORCE( PADDLE_ENFORCE(
lt_var != nullptr, lt_var != nullptr,
"Can not find variable loopup_table_path for SaveSelectedRows"); "Can not find variable lookup_table_path for SaveSelectedRows");
std::string filename = lt_var->data(); std::string filename = lt_var->data();
VLOG(4) << "SaveSelectedRows get File name: " << filename; VLOG(4) << "SaveSelectedRows get File name: " << filename;
...@@ -162,7 +162,7 @@ class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -162,7 +162,7 @@ class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Save operator Save operator
This operator will serialize and write a tensor/selected rows variable to file on disk. This operator will serialize and write LoDTensor / SelectedRows variable to file on disk.
)DOC"); )DOC");
AddAttr<bool>("overwrite", AddAttr<bool>("overwrite",
"(boolean, default true)" "(boolean, default true)"
...@@ -186,7 +186,7 @@ class SaveOpVarTypeInference : public framework::VarTypeInference { ...@@ -186,7 +186,7 @@ class SaveOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("loopup_table_path").front(); auto out_var_name = op_desc.Output(LOOKUP_TABLE_PATH).front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW; auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type); out_var.SetType(var_type);
......
...@@ -1042,6 +1042,7 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): ...@@ -1042,6 +1042,7 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
main_program(Program): Find the variable named table_name in main_program main_program(Program): Find the variable named table_name in main_program
pserver_id(int): the serial number in pserver_endpoints list pserver_id(int): the serial number in pserver_endpoints list
table_name(str): lookup table name table_name(str): lookup table name
Returns: Returns:
None None
...@@ -1188,7 +1189,7 @@ def save_trainer_args(dirname, trainer_id, trainer_args): ...@@ -1188,7 +1189,7 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
""" """
trainer will load some args from it's independent directory, trainer will load some args from it's independent directory,
such as epoch_id and step_id. such as epoch_id and step_id.
Args: Args:
......
...@@ -914,7 +914,7 @@ class DistributeTranspiler(object): ...@@ -914,7 +914,7 @@ class DistributeTranspiler(object):
import os import os
pserver_program.global_block().create_var( pserver_program.global_block().create_var(
name="loopup_table_path", name="lookup_table_path",
persistable=True, persistable=True,
type=core.VarDesc.VarType.RAW) type=core.VarDesc.VarType.RAW)
...@@ -923,7 +923,10 @@ class DistributeTranspiler(object): ...@@ -923,7 +923,10 @@ class DistributeTranspiler(object):
type='save', type='save',
inputs={'X': [self.table_name]}, inputs={'X': [self.table_name]},
outputs={}, outputs={},
attrs={'file_path': self.table_name}) attrs={
'file_path':
"this 'file_path' do not be used in save lookup table variable"
})
return checkpoint_save_block.idx return checkpoint_save_block.idx
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册