From e684575f662c13fd0f8c732671c77420c2aedefe Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 22 Jun 2018 14:55:16 +0800 Subject: [PATCH] checkpoint feature optimized --- paddle/fluid/operators/checkpoint_notify_op.cc | 13 +++++++------ paddle/fluid/operators/detail/macros.h | 4 ++++ paddle/fluid/operators/distributed/grpc_server.cc | 11 ++++------- .../operators/distributed/request_handler_impl.cc | 5 +++-- paddle/fluid/operators/listen_and_serv_op.cc | 12 ++++++------ paddle/fluid/operators/load_op.cc | 6 ++++-- paddle/fluid/operators/save_op.cc | 10 +++++----- python/paddle/fluid/io.py | 3 ++- .../fluid/transpiler/distribute_transpiler.py | 7 +++++-- 9 files changed, 40 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/checkpoint_notify_op.cc index e7a65b76a49..7fc5b5e6221 100644 --- a/paddle/fluid/operators/checkpoint_notify_op.cc +++ b/paddle/fluid/operators/checkpoint_notify_op.cc @@ -42,10 +42,11 @@ class CheckpointNotifyOp : public framework::OperatorBase { distributed::RPCClient* rpc_client = distributed::RPCClient::GetInstance(); for (size_t i = 0; i < epmap.size(); i++) { - VLOG(3) << "checkpoint notify sending " << dir << " to " << epmap[i]; - auto serial_looku_table = + auto lookup_table_save_dir = string::Sprintf("%s/%s_%d", dir, lookup_table_name, i); - rpc_client->AsyncCheckpointNotify(epmap[i], serial_looku_table); + rpc_client->AsyncCheckpointNotify(epmap[i], lookup_table_save_dir); + VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name + << " and dir:" << dir << " to " << epmap[i]; } rpc_client->Wait(); } @@ -64,10 +65,10 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("lookup_table", "(string, default '') the lookup table name"); AddComment(R"DOC( -Prefetch operator +CheckpointNotify operator -This operator will send Ids variables to listen_and_serve op at -the parameter server and fetch result back. +This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at +the parameter server. )DOC"); } }; diff --git a/paddle/fluid/operators/detail/macros.h b/paddle/fluid/operators/detail/macros.h index b9e385994ef..6e9f7beb93b 100644 --- a/paddle/fluid/operators/detail/macros.h +++ b/paddle/fluid/operators/detail/macros.h @@ -25,3 +25,7 @@ #define RPCSERVER_T distributed::AsyncBRPCServer #define RPCCLIENT_T distributed::BRPCClient #endif + +// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables +// to directory specified. +constexpr char LOOKUP_TABLE_PATH[] = "lookup_table_path"; diff --git a/paddle/fluid/operators/distributed/grpc_server.cc b/paddle/fluid/operators/distributed/grpc_server.cc index 218a1f85625..363614df4f9 100644 --- a/paddle/fluid/operators/distributed/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc_server.cc @@ -194,7 +194,7 @@ class RequestCheckpointNotify final : public RequestBase { RequestHandler* request_handler, int req_id) : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { request_.reset(new VariableResponse(request_handler->scope(), - request_handler->dev_ctx(), true)); + request_handler->dev_ctx())); int method_id = static_cast(distributed::GrpcMethod::kCheckpointNotify); service_->RequestAsyncUnary( @@ -212,13 +212,10 @@ class RequestCheckpointNotify final : public RequestBase { std::string checkpoint_notify = request_->Varname(); std::string checkpoint_dir = request_->OutVarname(); - framework::Variable* invar = nullptr; - framework::Variable* outvar = nullptr; - VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify << ", dir: " << checkpoint_dir; - request_handler_->Handle(checkpoint_notify, scope, invar, &outvar, + request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr, checkpoint_dir); Finish(reply_, &responder_); } @@ -320,8 +317,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, return; } - LOG(INFO) << "TryToRegisterNewOne on RPC NAME: " << rpc_name - << " REQ ID: " << req_id; + VLOG(4) << "TryToRegisterNewOne on RPC NAME: " << rpc_name + << " REQ ID: " << req_id; auto& reqs = rpc_reqs_[rpc_name]; auto& handler = rpc_call_map_[rpc_name]; diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index b6e4e156080..cd8059a96d3 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/string/printf.h" @@ -129,10 +130,10 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, checkpoint_notify_id != -1, "when checkpoint_notify_id = -1, there should be no RPC invoke."); - auto* lt_var = scope->FindVar("loopup_table_path")->GetMutable(); + auto* lt_var = scope->FindVar(LOOKUP_TABLE_PATH)->GetMutable(); lt_var->clear(); lt_var->append(out_var_name); - VLOG(4) << "RequestCheckpointHandler update loopup_table_path to: " + VLOG(4) << "RequestCheckpointHandler update var lookup_table_path to: " << out_var_name; executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); return true; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index df9cdae97df..87a501eaa25 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -247,11 +247,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, PADDLE_ENFORCE(!rpc_service_); std::string endpoint = Attr("endpoint"); - int checkpoint_notify_id = Attr(kCheckpointBlockId); + int checkpoint_notify_block_id = Attr(kCheckpointBlockId); LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in << ", end_point:" << endpoint - << ", CheckpointNotify Id: " << checkpoint_notify_id; + << ", CheckpointNotify Id: " << checkpoint_notify_block_id; rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); @@ -260,7 +260,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, request_prefetch_handler_.reset( new distributed::RequestPrefetchHandler(sync_mode)); request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler( - sync_mode, checkpoint_notify_id)); + sync_mode, checkpoint_notify_block_id)); rpc_service_->RegisterRPC(distributed::kRequestSend, request_send_handler_.get()); @@ -276,8 +276,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, framework::Executor executor(dev_place); std::shared_ptr ckpt_pre_context = nullptr; - if (checkpoint_notify_id != -1) { - auto ctx = executor.Prepare(*program, checkpoint_notify_id); + if (checkpoint_notify_block_id != -1) { + auto ctx = executor.Prepare(*program, checkpoint_notify_block_id); ckpt_pre_context = std::move(ctx); } @@ -334,7 +334,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, SavePort(); if (sync_mode) { RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list, - checkpoint_notify_id); + checkpoint_notify_block_id); } else { RunAsyncLoop(&executor, program); } diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index 764e3428ec4..ac35cf0b89b 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -101,7 +101,7 @@ class LoadOp : public framework::OperatorBase { class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddOutput("Out", "The tensor need to be loaded"); + AddOutput("Out", "The LoDTensor / SelectedRows need to be loaded"); AddAttr( "load_as_fp16", "If true, the tensor will be first loaded and then " @@ -112,7 +112,9 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { R"(Variable will be loaded from "file_path")") .AddCustomChecker( [](const std::string &path) { return !path.empty(); }); - AddComment("Load operator will load a tensor variable from disk file."); + AddComment( + "Load operator will load a LoDTensor / SelectedRows variable from disk " + "file."); } }; } // namespace operators diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 941bca10477..bf8553ed557 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { @@ -131,11 +132,10 @@ class SaveOp : public framework::OperatorBase { void SaveSelectedRows(const framework::Scope &scope, const platform::Place &place, framework::Variable *var) const { - auto *lt_var = - scope.FindVar("loopup_table_path")->GetMutable(); + auto *lt_var = scope.FindVar(LOOKUP_TABLE_PATH)->GetMutable(); PADDLE_ENFORCE( lt_var != nullptr, - "Can not find variable loopup_table_path for SaveSelectedRows"); + "Can not find variable lookup_table_path for SaveSelectedRows"); std::string filename = lt_var->data(); VLOG(4) << "SaveSelectedRows get File name: " << filename; @@ -162,7 +162,7 @@ class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Save operator -This operator will serialize and write a tensor/selected rows variable to file on disk. +This operator will serialize and write LoDTensor / SelectedRows variable to file on disk. )DOC"); AddAttr("overwrite", "(boolean, default true)" @@ -186,7 +186,7 @@ class SaveOpVarTypeInference : public framework::VarTypeInference { public: void operator()(const framework::OpDesc &op_desc, framework::BlockDesc *block) const override { - auto out_var_name = op_desc.Output("loopup_table_path").front(); + auto out_var_name = op_desc.Output(LOOKUP_TABLE_PATH).front(); auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); auto var_type = framework::proto::VarType::RAW; out_var.SetType(var_type); diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 8cc25e86237..d7b42ef3515 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -1042,6 +1042,7 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): main_program(Program): Find the variable named table_name in main_program pserver_id(int): the serial number in pserver_endpoints list table_name(str): lookup table name + Returns: None @@ -1188,7 +1189,7 @@ def save_trainer_args(dirname, trainer_id, trainer_args): def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): """ - trainer will load some args from it's independent directory, + trainer will load some args from it's independent directory, such as epoch_id and step_id. Args: diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index a3f0a4ffe28..d9578af2a93 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -914,7 +914,7 @@ class DistributeTranspiler(object): import os pserver_program.global_block().create_var( - name="loopup_table_path", + name="lookup_table_path", persistable=True, type=core.VarDesc.VarType.RAW) @@ -923,7 +923,10 @@ class DistributeTranspiler(object): type='save', inputs={'X': [self.table_name]}, outputs={}, - attrs={'file_path': self.table_name}) + attrs={ + 'file_path': + "this 'file_path' do not be used in save lookup table variable" + }) return checkpoint_save_block.idx -- GitLab