diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index d3988ae16d7d4ceccaf01503c6200066f2fa4073..4c338c67d34fa229de17019ce97e8b8dc39ea737 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -195,7 +195,7 @@ if(WITH_DISTRIBUTE) endif() set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") - foreach(dist_op "prefetch_op" "listen_and_serv_op" "send_op" "recv_op" "send_barrier_op" "fetch_barrier_op") + foreach(dist_op "prefetch_op" "checkpoint_notify_op" "listen_and_serv_op" "send_op" "recv_op" "send_barrier_op" "fetch_barrier_op") op_library(${dist_op} DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(${dist_op}.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) endforeach() @@ -216,7 +216,7 @@ if(WITH_DISTRIBUTE) set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op) endif() else() - set(DEPS_OPS ${DEPS_OPS} prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op) + set(DEPS_OPS ${DEPS_OPS} checkpoint_notify_op prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op) endif() op_library(cross_entropy_op DEPS cross_entropy) diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/checkpoint_notify_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c4219a429a53eb4869426a2674109555fb784b85 --- /dev/null +++ b/paddle/fluid/operators/checkpoint_notify_op.cc @@ -0,0 +1,88 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include // NOLINT +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/macros.h" +#include "paddle/fluid/operators/send_recv_util.h" +#include "paddle/fluid/string/printf.h" + +namespace paddle { +namespace operators { + +class CheckpointNotifyOp : public framework::OperatorBase { + public: + CheckpointNotifyOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { + std::vector epmap = Attr>("epmap"); + std::string dir = Attr("dir"); + std::string lookup_table_name = Attr("lookup_table"); + + distributed::RPCClient* rpc_client = + distributed::RPCClient::GetInstance(); + for (size_t i = 0; i < epmap.size(); i++) { + auto lookup_table_save_dir = + string::Sprintf("%s/%s_%d", dir, lookup_table_name, i); + rpc_client->AsyncCheckpointNotify(epmap[i], lookup_table_save_dir); + VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name + << " and dir:" << dir << " to " << epmap[i]; + } + rpc_client->Wait(); + } +}; + +class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddAttr>("epmap", + "(string vector, default 127.0.0.1:6164)" + "Parameter Server endpoints in the order") + .SetDefault({"127.0.0.1:6164"}); + AddAttr( + "dir", "(string, default '') indicate the folder checkpoint will use"); + AddAttr("lookup_table", + "(string, default '') the lookup table name"); + AddComment(R"DOC( +CheckpointNotify operator + +This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at +the parameter server. +)DOC"); + } +}; + +class CheckpointNotifyOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override {} +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(checkpoint_notify, ops::CheckpointNotifyOp, + paddle::framework::EmptyGradOpMaker, + ops::CheckpointNotifyOpMaker, + ops::CheckpointNotifyOpShapeInference); diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index cf10565d48a30b2d67e8171ef3199130992a3b81..8228a8c5a3eae73fe82551c8aad55290b0d54ef0 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -239,6 +239,23 @@ void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { req_count_++; } +void GRPCClient::AsyncCheckpointNotify(const std::string& ep, + const std::string& dir, + int64_t time_out) { + const auto ch = GetChannel(ep); + + CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch); + s->Prepare(time_out); + + sendrecv::VariableMessage req; + req.set_varname(CHECKPOINT_SAVE_MESSAGE); + req.set_out_varname(dir); + + auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + req_count_++; +} + void GRPCClient::Wait() { std::unique_lock lk(sync_mutex_); sync_cond_.wait(lk, [this] { return req_count_ == 0; }); diff --git a/paddle/fluid/operators/distributed/grpc_client.h b/paddle/fluid/operators/distributed/grpc_client.h index 940e9d879b3c1361035988245f07b1fa7afa603a..69705d72106488015b538c0758e6af48c448164c 100644 --- a/paddle/fluid/operators/distributed/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc_client.h @@ -171,6 +171,20 @@ class FetchBarrierProcessor : public BaseProcessor { std::unique_ptr stub_; }; +class CheckpointNotifyProcessor : public BaseProcessor { + public: + explicit CheckpointNotifyProcessor(std::shared_ptr ch) + : BaseProcessor(ch) { + stub_ = sendrecv::SendRecvService::NewStub(ch); + } + + virtual ~CheckpointNotifyProcessor() {} + + virtual void Process() {} + sendrecv::VoidMessage reply_; + std::unique_ptr stub_; +}; + class GRPCClient : public RPCClient { public: GRPCClient() {} @@ -197,6 +211,9 @@ class GRPCClient : public RPCClient { void AsyncSendFetchBarrier(const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; + void AsyncCheckpointNotify(const std::string& ep, const std::string& dir, + int64_t time_out = FLAGS_grpc_deadline) override; + void Wait() override; void SendComplete() override; diff --git a/paddle/fluid/operators/distributed/grpc_server.cc b/paddle/fluid/operators/distributed/grpc_server.cc index 8ec29d0a90429cc05dae5b7a575254428881a175..f35e268f6ad36da02f17db2feb3fbf1fdf6c1e41 100644 --- a/paddle/fluid/operators/distributed/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc_server.cc @@ -200,6 +200,45 @@ class RequestPrefetch final : public RequestBase { framework::Scope* local_scope_; }; +class RequestCheckpointNotify final : public RequestBase { + public: + explicit RequestCheckpointNotify(GrpcService::AsyncService* service, + ::grpc::ServerCompletionQueue* cq, + RequestHandler* request_handler, int req_id) + : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { + request_.reset(new VariableResponse(request_handler->scope(), + request_handler->dev_ctx())); + int method_id = + static_cast(distributed::GrpcMethod::kCheckpointNotify); + service_->RequestAsyncUnary( + method_id, &ctx_, request_.get(), &responder_, cq_, cq_, + reinterpret_cast(static_cast(req_id))); + } + + virtual ~RequestCheckpointNotify() {} + + std::string GetReqName() override { return request_->Varname(); } + + void Process() override { + auto scope = request_->GetMutableLocalScope(); + + std::string checkpoint_notify = request_->Varname(); + std::string checkpoint_dir = request_->OutVarname(); + + VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify + << ", dir: " << checkpoint_dir; + + request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr, + checkpoint_dir); + Finish(reply_, &responder_); + } + + protected: + std::shared_ptr request_; + sendrecv::VoidMessage reply_; + ServerAsyncResponseWriter responder_; +}; + void AsyncGRPCServer::WaitServerReady() { VLOG(4) << "AsyncGRPCServer is wait server ready"; std::unique_lock lock(this->mutex_ready_); @@ -237,6 +276,7 @@ void AsyncGRPCServer::StartServer() { reqs.reserve(kRequestBufSize); for (int i = 0; i < kRequestBufSize; i++) { + VLOG(6) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << " I: " << i; TryToRegisterNewOne(rpc_name, i); } @@ -289,8 +329,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, return; } - VLOG(4) << "register send rpc_name:" << rpc_name - << ", handler:" << rpc_call_map_[kRequestSend]; + VLOG(4) << "TryToRegisterNewOne on RPC NAME: " << rpc_name + << " REQ ID: " << req_id; auto& reqs = rpc_reqs_[rpc_name]; auto& handler = rpc_call_map_[rpc_name]; @@ -303,6 +343,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, b = new RequestGet(&service_, cq.get(), handler, req_id); } else if (rpc_name == kRequestPrefetch) { b = new RequestPrefetch(&service_, cq.get(), handler, req_id); + } else if (rpc_name == kRequestCheckpoint) { + b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id); } else { PADDLE_ENFORCE(false, "not supported rpc"); } @@ -321,7 +363,7 @@ void AsyncGRPCServer::HandleRequest( while (true) { VLOG(4) << "HandleRequest " << rpc_name << " wait next"; if (!cq->Next(&tag, &ok)) { - LOG(INFO) << "CompletionQueue " << rpc_name << " shutdown!"; + VLOG(3) << "CompletionQueue " << rpc_name << " shutdown!"; break; } diff --git a/paddle/fluid/operators/distributed/grpc_service.h b/paddle/fluid/operators/distributed/grpc_service.h index 141be3e68012743a32e4df5de148a55717f8e9a2..cdc4e7b79276d6aac55aeac8ac121ca28d2cc1f0 100644 --- a/paddle/fluid/operators/distributed/grpc_service.h +++ b/paddle/fluid/operators/distributed/grpc_service.h @@ -80,10 +80,11 @@ enum class GrpcMethod { kSendVariable, kGetVariable, kPrefetchVariable, + kCheckpointNotify, }; static const int kGrpcNumMethods = - static_cast(GrpcMethod::kPrefetchVariable) + 1; + static_cast(GrpcMethod::kCheckpointNotify) + 1; inline const char* GrpcMethodName(GrpcMethod id) { switch (id) { @@ -93,6 +94,8 @@ inline const char* GrpcMethodName(GrpcMethod id) { return "/sendrecv.SendRecvService/GetVariable"; case GrpcMethod::kPrefetchVariable: return "/sendrecv.SendRecvService/PrefetchVariable"; + case GrpcMethod::kCheckpointNotify: + return "/sendrecv.SendRecvService/CheckpointNotify"; } // Shouldn't be reached. diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index cf106656aa56c2130d8be8dbe7478c3397f9b9ad..90742a201ad46447d6fbbe2137aa40fabc2f9983 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -36,12 +36,16 @@ namespace distributed { constexpr char kRequestSend[] = "RequestSend"; constexpr char kRequestGet[] = "RequestGet"; constexpr char kRequestPrefetch[] = "RequestPrefetch"; +constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV" +#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" +#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" + class RPCServer; class RequestHandler { @@ -69,6 +73,11 @@ class RequestHandler { prefetch_var_name_to_prepared_ctx_ = g; } + void SetCheckpointNotifyPreparedCtx( + std::shared_ptr g) { + checkpoint_prepared_ctx_ = g; + } + // Used for async. void SetGradToPreparedCtx( std::unordered_map< @@ -115,6 +124,8 @@ class RequestHandler { std::unordered_map>* prefetch_var_name_to_prepared_ctx_; + // used for checkpoint notify + std::shared_ptr checkpoint_prepared_ctx_; // Used for async. std::unordered_mapFindVar(LOOKUP_TABLE_PATH)->GetMutable(); + lt_var->clear(); + lt_var->append(out_var_name); + VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: " + << out_var_name; + executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); + return true; +} + } // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/distributed/request_handler_impl.h b/paddle/fluid/operators/distributed/request_handler_impl.h index abbe8778911a21ece3090bc9790d51a3cb31b6d7..87185500f2ffc3a8578eea339cc7a1e2b0e46631 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.h +++ b/paddle/fluid/operators/distributed/request_handler_impl.h @@ -66,6 +66,21 @@ class RequestPrefetchHandler final : public RequestHandler { const std::string& out_var_name = "") override; }; +class RequestCheckpointHandler final : public RequestHandler { + public: + explicit RequestCheckpointHandler(bool sync_mode, int checkpoint_notify_id) + : RequestHandler(sync_mode) { + this->checkpoint_notify_id = checkpoint_notify_id; + } + virtual ~RequestCheckpointHandler() {} + bool Handle(const std::string& varname, framework::Scope* scope, + framework::Variable* var, framework::Variable** outvar, + const std::string& out_var_name = "") override; + + private: + int checkpoint_notify_id; +}; + } // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 151b60d6f032906cc418cd6477403d3759d15d2a..cc3ae96787a1c29cbc3df72233aaef70d6d94e03 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -56,6 +56,10 @@ class RPCClient { virtual void AsyncSendFetchBarrier(const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0; + virtual void AsyncCheckpointNotify( + const std::string& ep, const std::string& dir, + int64_t time_out = FLAGS_grpc_deadline) = 0; + // SendComplete tells all the server that current trainer have no more data // to train, so that the pserver can reduce it's barrier count, and continue // to train with other trainers. diff --git a/paddle/fluid/operators/distributed/send_recv.proto b/paddle/fluid/operators/distributed/send_recv.proto index 54cb93e04d18b3784be187c9c8885bbccc55488b..e0902320cff003797b12ed0204f7f99c44554b62 100644 --- a/paddle/fluid/operators/distributed/send_recv.proto +++ b/paddle/fluid/operators/distributed/send_recv.proto @@ -25,6 +25,8 @@ service SendRecvService { rpc GetVariable(VariableMessage) returns (VariableMessage) {} // pre-fetch variable by given variable name and Ids rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {} + + rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {} } // VariableMessage is serialized paddle variable message. diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 27708ffbe44d38aaa6c504c766e6a72917cf45cc..56e39649b409f7eed108027f6df58c19dd3c8ab8 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -99,7 +99,8 @@ static int64_t GetTimestamp() { void ListenAndServOp::RunSyncLoop( framework::Executor *executor, framework::ProgramDesc *program, framework::Scope *recv_scope, - const std::vector &prefetch_block_id_list) const { + const std::vector &prefetch_block_id_list, + const int checkpoint_point_block_id) const { size_t num_blocks = program->Size(); auto optimize_blocks = Attr>(kOptimizeBlocks); @@ -208,7 +209,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, while (true) { if (rpc_service_->IsExit()) { - LOG(INFO) << "get exit!rpc_processor break!"; + VLOG(4) << "get exit!rpc_processor break!"; break; } @@ -223,6 +224,7 @@ static void FillRequestCtx( std::unordered_map> *prefetch_ctx, + std::shared_ptr checkpoint_ctx, distributed::RPCServer *rpc_server) { h->SetScope(scope); h->SetDevCtx(dev_ctx); @@ -230,6 +232,7 @@ static void FillRequestCtx( h->SetProgram(program); h->SetPrefetchPreparedCtx(prefetch_ctx); h->SetRPCServer(rpc_server); + h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx); } void ListenAndServOp::RunImpl(const framework::Scope &scope, @@ -245,9 +248,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, PADDLE_ENFORCE(!rpc_service_); std::string endpoint = Attr("endpoint"); + int checkpoint_block_id = Attr(kCheckpointBlockId); - LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in - << ", end_point:" << endpoint; + VLOG(4) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in + << ", end_point:" << endpoint + << ", checkpoint_block_id: " << checkpoint_block_id; rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); @@ -255,6 +260,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, request_get_handler_.reset(new distributed::RequestGetHandler(sync_mode)); request_prefetch_handler_.reset( new distributed::RequestPrefetchHandler(sync_mode)); + request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler( + sync_mode, checkpoint_block_id)); rpc_service_->RegisterRPC(distributed::kRequestSend, request_send_handler_.get()); @@ -262,6 +269,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, request_get_handler_.get()); rpc_service_->RegisterRPC(distributed::kRequestPrefetch, request_prefetch_handler_.get()); + rpc_service_->RegisterRPC(distributed::kRequestCheckpoint, + request_checkpoint_handler_.get()); auto optimize_blocks = Attr>(kOptimizeBlocks); @@ -270,6 +279,13 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, auto *program = optimize_blocks[0]->Program(); framework::Executor executor(dev_place); + std::shared_ptr ckpt_pre_context = nullptr; + if (checkpoint_block_id != -1) { + auto ctx = executor.Prepare(*program, checkpoint_block_id); + // see: https://stackoverflow.com/a/14856553 + ckpt_pre_context = std::move(ctx); + } + // prepare for prefetch std::vector prefetch_block_id_list; std::unordered_map block_id_to_prefetch_var_name; @@ -300,13 +316,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i]; } - auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, - &dev_ctx, &executor, program, - &prefetch_var_name_to_prepared_ctx, rpc_service_.get()); + auto f = + std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, + &executor, program, &prefetch_var_name_to_prepared_ctx, + ckpt_pre_context, rpc_service_.get()); f(request_send_handler_.get()); f(request_get_handler_.get()); f(request_prefetch_handler_.get()); + f(request_checkpoint_handler_.get()); // start the server listening after all member initialized. server_thread_.reset(new std::thread(RunServer, rpc_service_)); @@ -320,7 +338,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, // Write to a file of server selected port for python use. SavePort(); if (sync_mode) { - RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list); + RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list, + checkpoint_block_id); } else { RunAsyncLoop(&executor, program, &recv_scope); } @@ -352,6 +371,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault({}); AddAttr("Fanin", "How many clients send to this server.") .SetDefault(1); + AddAttr(kCheckpointBlockId, + "BolckID to run save checkpoint on pserer.") + .SetDefault(-1); } }; diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index f856fdc376196c736593eb8f941c1197d043f7c5..978969cc515c7954b59f2bf7a4f2c0e1b13f9bc0 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -32,6 +32,7 @@ namespace operators { constexpr char kOptimizeBlocks[] = "optimize_blocks"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id"; +constexpr char kCheckpointBlockId[] = "checkpint_block_id"; void RunServer(std::shared_ptr service); @@ -47,7 +48,8 @@ class ListenAndServOp : public framework::OperatorBase { void RunSyncLoop(framework::Executor* executor, framework::ProgramDesc* program, framework::Scope* recv_scope, - const std::vector& prefetch_block_id_list) const; + const std::vector& prefetch_block_id_list, + const int checkpoint_point_block_id) const; void RunAsyncLoop(framework::Executor* executor, framework::ProgramDesc* program, @@ -68,6 +70,8 @@ class ListenAndServOp : public framework::OperatorBase { mutable std::shared_ptr request_get_handler_; mutable std::shared_ptr request_prefetch_handler_; + mutable std::shared_ptr + request_checkpoint_handler_; mutable std::shared_ptr server_thread_; }; diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index 8f4b5049271c9592d2db268ea7ff2f5c8abc28b6..ac35cf0b89bfaa0c0f8e64445f18a3bbd478e70a 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -34,6 +34,8 @@ class LoadOp : public framework::OperatorBase { auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place); platform::RecordEvent record_event(Type(), dev_ctx); + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. auto filename = Attr("file_path"); std::ifstream fin(filename); PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s for load op", @@ -44,9 +46,25 @@ class LoadOp : public framework::OperatorBase { PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found", out_var_name); - auto *tensor = out_var->GetMutable(); + if (out_var->IsType()) { + LoadLodTensor(fin, place, out_var); + } else if (out_var->IsType()) { + LoadSelectedRows(fin, place, out_var); + } else { + PADDLE_ENFORCE( + false, + "Load only support LoDTensor and SelectedRows, %s has wrong type", + out_var_name); + } + } - DeserializeFromStream(fin, tensor, *dev_ctx); + void LoadLodTensor(std::istream &fin, const platform::Place &place, + framework::Variable *var) const { + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + auto *tensor = var->GetMutable(); + DeserializeFromStream(fin, tensor, dev_ctx); auto load_as_fp16 = Attr("load_as_fp16"); auto in_dtype = framework::ToDataType(tensor->type()); @@ -63,18 +81,27 @@ class LoadOp : public framework::OperatorBase { &fp16_tensor); // reset output tensor - out_var->Clear(); - tensor = out_var->GetMutable(); + var->Clear(); + tensor = var->GetMutable(); tensor->set_lod(fp16_tensor.lod()); tensor->ShareDataWith(fp16_tensor); } } + + void LoadSelectedRows(std::istream &fin, const platform::Place &place, + framework::Variable *var) const { + auto *selectedRows = var->GetMutable(); + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + framework::DeserializeFromStream(fin, selectedRows, dev_ctx); + } }; class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddOutput("Out", "The tensor need to be loaded"); + AddOutput("Out", "The LoDTensor / SelectedRows need to be loaded"); AddAttr( "load_as_fp16", "If true, the tensor will be first loaded and then " @@ -85,7 +112,9 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { R"(Variable will be loaded from "file_path")") .AddCustomChecker( [](const std::string &path) { return !path.empty(); }); - AddComment("Load operator will load a tensor variable from disk file."); + AddComment( + "Load operator will load a LoDTensor / SelectedRows variable from disk " + "file."); } }; } // namespace operators diff --git a/paddle/fluid/operators/save_load_op_test.cc b/paddle/fluid/operators/save_load_op_test.cc index c4fcc61af4b75e6dc7d5c31e20c5fff358637af5..ccaea0eef2906953d922e097348b6c0a86dad6f1 100644 --- a/paddle/fluid/operators/save_load_op_test.cc +++ b/paddle/fluid/operators/save_load_op_test.cc @@ -139,6 +139,7 @@ TEST(LoadFP16Op, CPU) { save_op->Run(scope, place); auto load_var = scope.Var("out_var"); + load_var->GetMutable(); auto load_op = paddle::framework::OpRegistry::CreateOp( "load", {}, {{"Out", {"out_var"}}}, attrs); load_op->Run(scope, place); diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index e6d27e2dedd7668b93bd8ddc330a897d1c6fa732..201a51130d6b6f94104e2dabf9e7facffa672ae0 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -22,11 +22,17 @@ limitations under the License. */ #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { namespace operators { +// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables +// to directory specified. +constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath"; + // TODO(yuyang18): If the functions below are needed by other files, move them // to paddle::filesystem namespace. constexpr char kSEP = '/'; @@ -67,9 +73,27 @@ class SaveOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { + auto iname = Input("X"); + auto *var = scope.FindVar(iname); + PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op", + iname); + + if (var->IsType()) { + SaveLodTensor(place, var); + } else if (var->IsType()) { + SaveSelectedRows(scope, place, var); + } else { + PADDLE_ENFORCE( + false, + "SaveOp only support LoDTensor and SelectedRows, %s has wrong type", + iname); + } + } + + void SaveLodTensor(const platform::Place &place, + framework::Variable *var) const { auto filename = Attr("file_path"); auto overwrite = Attr("overwrite"); - auto save_as_fp16 = Attr("save_as_fp16"); if (FileExists(filename) && !overwrite) { PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", @@ -78,26 +102,19 @@ class SaveOp : public framework::OperatorBase { MkDirRecursively(DirName(filename).c_str()); - // FIXME(yuyang18): We save variable to local file now, but we should change - // it to save an output stream. - std::ofstream fout(filename); - PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", - filename); - - auto iname = Input("X"); - auto *var = scope.FindVar(iname); - PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op", - iname); - - PADDLE_ENFORCE(var->IsType(), - "SaveOp only support LoDTensor, %s has wrong type", iname); - auto &tensor = var->Get(); // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. + std::ofstream fout(filename); + PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", + filename); + + auto save_as_fp16 = Attr("save_as_fp16"); auto in_dtype = framework::ToDataType(tensor.type()); auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; @@ -112,17 +129,43 @@ class SaveOp : public framework::OperatorBase { } else { framework::SerializeToStream(fout, tensor, dev_ctx); } + fout.close(); + } + + void SaveSelectedRows(const framework::Scope &scope, + const platform::Place &place, + framework::Variable *var) const { + auto *lt_var = scope.FindVar(LOOKUP_TABLE_PATH)->GetMutable(); + PADDLE_ENFORCE( + lt_var != nullptr, + "Can not find variable kLookupTablePath for SaveSelectedRows"); + std::string filename = lt_var->data(); + VLOG(4) << "SaveSelectedRows get File name: " << filename; + + auto &selectedRows = var->Get(); + + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. + std::ofstream fout(filename); + PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", + filename); + framework::SerializeToStream(fout, selectedRows, dev_ctx); + fout.close(); } }; class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "(Tensor ) Input tensor to be saved"); + AddInput("X", "(Tensor ) Input LoDTensor and SelectedRows to be saved"); AddComment(R"DOC( Save operator -This operator will serialize and write a tensor variable to file on disk. +This operator will serialize and write LoDTensor / SelectedRows variable to file on disk. )DOC"); AddAttr("overwrite", "(boolean, default true)" @@ -142,9 +185,26 @@ This operator will serialize and write a tensor variable to file on disk. } }; +class SaveOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + auto out_var_name = op_desc.Output(LOOKUP_TABLE_PATH).front(); + auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); + auto var_type = framework::proto::VarType::RAW; + out_var.SetType(var_type); + } +}; + +class SaveOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override {} +}; } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker); +REGISTER_OPERATOR(save, ops::SaveOp, paddle::framework::EmptyGradOpMaker, + ops::SaveOpProtoMaker, ops::SaveOpVarTypeInference, + ops::SaveOpShapeInference); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index eee64fc050908d5ee6ac4c8e46be004377cb98e8..2b2462b771a3801bf220ad6e09ee0c44f7b367b2 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -454,7 +454,7 @@ class Operator(object): 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv', 'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine', 'ncclInit', 'channel_create', 'channel_close', 'channel_send', - 'channel_recv', 'select', 'gen_nccl_id' + 'channel_recv', 'select', 'checkpoint_notify', 'gen_nccl_id' } def __init__(self, @@ -1214,6 +1214,9 @@ class Block(object): if var.type == core.VarDesc.VarType.STEP_SCOPES: ret_var = self.create_var( name=var.name, persistable=var.persistable, type=var.type) + elif var.type == core.VarDesc.VarType.RAW: + ret_var = self.create_var( + name=var.name, persistable=var.persistable, type=var.type) elif var.type == core.VarDesc.VarType.SELECTED_ROWS: ret_var = self.create_var( name=var.name, @@ -1923,7 +1926,7 @@ def get_var(name, program=None): Args: name(str): name of the variable program(Program|None): program object. - If None, default_global_program() will be used. + If None, default_global_program() will be used. Returns: Variable diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 6e527572f1ca77be9fe069654db00d16ad5c21ef..d94564e11f982575dd9c065deb20d29396203227 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import errno import time import shutil @@ -25,7 +26,8 @@ __all__ = [ 'load_persistables', 'save_inference_model', 'load_inference_model', 'get_inference_program', 'save_checkpoint', 'load_checkpoint', 'clean_checkpoint', 'load_persist_vars_without_grad', - 'save_persist_vars_without_grad', 'get_latest_checkpoint_serial' + 'load_lookup_table_vars', 'save_persist_vars_without_grad', + 'get_latest_checkpoint_serial' ] @@ -795,6 +797,7 @@ def get_parameter_value_by_name(name, executor, program=None): SUCCESS_MARK_FILENAME = "_SUCCESS" CHECKPOINT_PREFIX = "checkpoint" MODEL_DIR = "__model__" +LOOKUP_TABLE_DIR = "__lookup_table__" TRAINER_PREFIX = "trainer" CHECKPOINT_SEPARATOR = "_" @@ -804,7 +807,9 @@ def save_checkpoint(executor, trainer_id, trainer_args=None, main_program=None, - max_num_checkpoints=3): + max_num_checkpoints=3, + lookup_table=None, + ps_endpoint_list=None): """ This function filters out all checkpoint variables from the give main_program and then saves these variables to the `checkpoint_dir` @@ -836,6 +841,12 @@ def save_checkpoint(executor, max_num_checkpoints(int): The max number of total number of existing checkpoints. Default: 3 + lookup_table(string|None): the lookup table name, when use distribute + lookup table, we can get lookup table name by DistributeTranspiler. + table_name + ps_endpoint_list(list|None): the parameter server ip:port list. + when use distribute lookup table, we can get ps_endpoint_list by + distribute arguments. Returns: None @@ -852,30 +863,40 @@ def save_checkpoint(executor, prog = fluid.default_main_program() trainer_args = {"epoch_id": 200, "step_id": 20} # just an example + table_name = "share_w" + ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] + fluid.io.save_checkpoint(executor=exe, checkpoint_dir=path, trainer_id=0, trainer_args=trainer_args, main_program=prog, - max_num_checkpoints=3) + max_num_checkpoints=3, + lookup_table=table_name, + ps_endpoint_list = ps_endpoints) """ if checkpoint_dir is None: raise ValueError("'checkpoint_dir' should not be None") + assert checkpoint_dir if trainer_args: assert isinstance(trainer_args, dict) - if not os.path.isdir(checkpoint_dir): - os.makedirs(checkpoint_dir) + is_chief = trainer_id == 0 + _make_chekcpoint_dirs(checkpoint_dir) serial = get_latest_checkpoint_serial(checkpoint_dir) + 1 cur_dir = _get_serial_dir(checkpoint_dir, serial) save_trainer_args(cur_dir, trainer_id, trainer_args) - if trainer_id == 0: + if is_chief: save_persist_vars_without_grad(executor, cur_dir, main_program) + if is_chief and lookup_table and ps_endpoint_list: + save_pserver_vars_by_notify(executor, cur_dir, lookup_table, + ps_endpoint_list) + _scroll_delete(checkpoint_dir, max_num_checkpoints) @@ -942,8 +963,9 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): def clean_checkpoint(checkpoint_dir, delete_dir=False): """ - clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before. - delete_dir only works when the directory is empty, otherwise, OSError is raised. + clean the checkpoint dir, when the train exits normally, + the trainer will call clean_checkpoint to delete checkpoint directory saved before. + delete_dir only works when the directory is empty, otherwise, OSError is raised. : param checkpoint_dir : param delete_dir @@ -1009,6 +1031,56 @@ def load_persist_vars_without_grad(executor, filename=None) +def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): + """ + The parameter server will load lookup table's local file in + selectedrows variable. + + Args: + executor(Executor): The executor to run for loading persistable variables + dirname(str): The directory path + main_program(Program): Find the variable named table_name in main_program + pserver_id(int): the serial number in pserver_endpoints list + table_name(str): lookup table name + + Returns: + None + + Examples: + .. code-block:: python + + exe = fluid.Executor(fluid.CPUPlace()) + dirname = "./checkpoints/checkpoint_9/__model__" + prog = fluid.default_main_program() + pserver_id = 1 + table_name = "share_w" + fluid.io.load_lookup_table_vars(executor=exe, + dirname=dirname, program=prog, pserver_id=pserver_id, + table_name=table_name) + """ + + for var in program.list_vars(): + if var.name == table_name: + lookup_table_var = var + break + + assert lookup_table_var is not None + + lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) + table_file = table_name + CHECKPOINT_SEPARATOR + str(pserver_id) + + load_prog = Program() + load_block = load_prog.global_block() + + load_block.append_op( + type='load', + inputs={}, + outputs={'Out': [lookup_table_var]}, + attrs={'file_path': os.path.join(lookup_table_dir, table_file)}) + + executor.run(load_prog) + + def save_persist_vars_without_grad(executor, dirname, program): """ This function filters out all checkpoint variables from the give @@ -1055,6 +1127,54 @@ def save_persist_vars_without_grad(executor, dirname, program): _write_success(cur_dir) +def save_pserver_vars_by_notify(executor, dirname, lookup_table, + ps_endpoint_list): + """ + This function will send checkpoint notify message from Trainer 0 + to all the pservers. + The checkpoint notify message contains lookup table name, + the absolute path on pserver to save lookup_table. + + Args: + executor(Executor): The executor to run for send checkpoint notify. + dirname(str): The folder where to save checkpoints. + lookup_table(string): the lookup table name, when use distribute + lookup table, we can get lookup table name by DistributeTranspiler. + table_name + ps_endpoint_list(list): the parameter server ip:port list. + when use distribute lookup table, we can get ps_endpoint_list by + distribute arguments. + Return: + None + + Examples: + .. code-block:: python + + exe = fluid.Executor(fluid.CPUPlace()) + param_path = "./my_paddle_model" + prog = fluid.default_main_program() + table_name = "share_w" + ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] + + fluid.io.save_pserver_vars_by_notify(executor=exe, + dirname=param_path, lookup_table=table_name, + ps_endpoint_list=ps_endpoints) + """ + cur_dir = _get_lookuptable_dir(dirname) + + checkpoint_notify_program = Program() + checkpoint_notify_block = checkpoint_notify_program.global_block() + + attrs = {} + attrs['epmap'] = ps_endpoint_list + attrs['dir'] = cur_dir + attrs['lookup_table'] = lookup_table + + checkpoint_notify_block.append_op( + type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs) + executor.run(checkpoint_notify_program) + + def save_trainer_args(dirname, trainer_id, trainer_args): assert isinstance(trainer_args, dict) @@ -1068,6 +1188,29 @@ def save_trainer_args(dirname, trainer_id, trainer_args): def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): + """ + trainer will load some args from it's independent directory, + such as epoch_id and step_id. + + Args: + checkpoint_dir(str): The folder where all checkpoints are. + serial(int): The serial of checkpoint you would like to load. + trainer_id(int): current trainer id. + trainer_args(list): list about load trainer args + Return: + None + + Examples: + .. code-block:: python + + param_path = "./checkpoint/" + serial = 7 + trainer_id = 2 + trainer_args = ["epoch_id", "step_id"] + + fluid.io.load_trainer_args(checkpoint_dir=param_path, serial=serial, + trainer_id=trainer_id, trainer_args=trainer_args) + """ assert isinstance(trainer_args, list) cur_dir = _get_serial_dir(checkpoint_dir, serial) @@ -1088,7 +1231,7 @@ def _is_checkpoint_var(var): the checkpoint will not save or load all the variables. var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded. - : param var + : param var(Variable) """ if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ @@ -1108,6 +1251,23 @@ def _is_checkpoint_var(var): return var.persistable +def _make_chekcpoint_dirs(dirs): + """ + _make_chekcpoint_dirs will makdir local directory directly, when the directory is exist, it will igore it. + """ + assert dirs is not None + + if os.path.isfile(dirs): + raise OSError(errno.ENOTDIR, "dirs path shoule be a Directory.", dirs) + + if not os.path.isdir(dirs): + try: + os.makedirs(dirs) + except OSError as err: + if err.errno != errno.EEXIST: + raise err + + def _get_dir_serial(dirname): _, serial = dirname.split(CHECKPOINT_SEPARATOR) @@ -1121,29 +1281,27 @@ def _get_dir_serial(dirname): def _get_serial_dir(dirname, serial): serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) serial_dir = os.path.join(dirname, serial_folder) - - if not os.path.isdir(serial_dir): - os.makedirs(serial_dir) + _make_chekcpoint_dirs(serial_dir) return serial_dir def _get_model_dir(dirname): model_dir = os.path.join(dirname, MODEL_DIR) + _make_chekcpoint_dirs(model_dir) + return model_dir - if not os.path.isdir(model_dir): - os.makedirs(model_dir) - return model_dir +def _get_lookuptable_dir(dirname): + lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) + _make_chekcpoint_dirs(lookuptable_dir) + return lookuptable_dir def _get_trainer_dir(dirname, trainer_id): trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id) trainer_dir = os.path.join(dirname, trainer_folder) - - if not os.path.isdir(trainer_dir): - os.makedirs(trainer_dir) - + _make_chekcpoint_dirs(trainer_dir) return trainer_dir @@ -1162,7 +1320,11 @@ def _scroll_delete(dirname, max_num_checkpoints=3): serials = serials[max_num_checkpoints:] for serial in serials: cur_dir = _get_serial_dir(dirname, serial) - shutil.rmtree(cur_dir) + try: + shutil.rmtree(cur_dir) + except OSError as err: + if err.errno != errno.ENOENT: + raise err def _write_success(dirname): diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 45ab889beaa1355d0e1e2922aedf0340f70809ba..f191ef7df5caa04537e69ad9a0e018d161cd59ad 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -119,27 +119,20 @@ class CheckpointConfig(object): max_num_checkpoints=3, epoch_interval=1, step_interval=10): - if checkpoint_dir is None: - self.checkpoint_dir = os.getcwd() - else: - self.checkpoint_dir = checkpoint_dir - - self.max_num_checkpoints = max_num_checkpoints - - if epoch_interval < 1: - self.epoch_interval = 1 - else: - self.epoch_interval = epoch_interval - if step_interval < 1: - self.step_interval = 10 - else: - self.step_interval = step_interval + assert epoch_interval >= 1 + assert step_interval >= 1 + self.checkpoint_dir = checkpoint_dir \ + if checkpoint_dir is not None else os.getcwd() + self.max_num_checkpoints = max_num_checkpoints + self.epoch_interval = epoch_interval + self.step_interval = step_interval self.epoch_id = 0 self.step_id = 0 self.load_serial = None - self.is_pserver = False + self.pserver_id = None + self.lookup_table_name = None def check_and_get_place(place): @@ -290,13 +283,20 @@ class Trainer(object): self.checkpoint_cfg.load_serial, self.startup_program) - if not self.checkpoint_cfg.is_pserver: - epoch_id, step_id = io.load_trainer_args( - self.checkpoint_cfg.checkpoint_dir, - self.checkpoint_cfg.load_serial, self.trainer_id, - self._get_checkpoint_load_args()) - self.checkpoint_cfg.epoch_id = int(epoch_id) - self.checkpoint_cfg.step_id = int(step_id) + if not self.checkpoint_cfg.pserver_id: + epoch_id, step_id = io.load_trainer_args( + self.checkpoint_cfg.checkpoint_dir, + self.checkpoint_cfg.load_serial, self.trainer_id, + self._get_checkpoint_load_args()) + self.checkpoint_cfg.epoch_id = int(epoch_id) + self.checkpoint_cfg.step_id = int(step_id) + else: + if self.checkpoint_cfg.lookup_table_name: + io.load_lookup_table_vars( + exe, self.checkpoint_cfg.checkpoint_dir, + self.startup_program, + self.checkpoint_cfg.pserver_id, + self.checkpoint_cfg.lookup_table_name) if param_path and os.path.isdir(param_path): # load params from param_path into scope @@ -366,7 +366,10 @@ class Trainer(object): self.trainer_id, pservers=pserver_endpoints, trainers=trainers) if training_role == "PSERVER": if self.checkpoint_cfg: - self.is_pserver = True + pserver_id = eplist.index(current_endpoint) + self.checkpoint_cfg.pserver_id = pserver_id + if t.has_distributed_lookup_table: + self.checkpoint_cfg.lookup_table_name = t.table_name self.train_program = t.get_pserver_program(current_endpoint) self.startup_program = t.get_startup_program(current_endpoint, @@ -566,7 +569,8 @@ class Trainer(object): def _save_checkpoint(self, epoch_id, step_id): assert self.checkpoint_cfg - if epoch_id % self.checkpoint_cfg.epoch_interval == 0 and step_id % self.checkpoint_cfg.step_interval == 0: + if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \ + and step_id % self.checkpoint_cfg.step_interval == 0: exe = executor.Executor(self.place) io.save_checkpoint( executor=exe, diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 68b0b1ff4c41efe8ae50254fe9f343be4ac4b8a2..05b7beec001d6a9b1b0b23a5e996db0179f21c03 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -471,6 +471,8 @@ class DistributeTranspiler(object): pserver_index, pserver_program, pre_block_idx, grad_to_block_id) prefetch_var_name_to_block_id = self._create_prefetch_block( pserver_index, pserver_program, table_opt_block) + checkpoint_block_id = self._create_checkpoint_save_block( + pserver_program, table_opt_block.idx) # NOTE: if has_distributed_lookup_table is False, then prefetch_block will # not be executed, so it's safe to use optimize_block to hold the place @@ -489,6 +491,7 @@ class DistributeTranspiler(object): if len(prefetch_var_name_to_block_id) > 0: attrs['prefetch_var_name_to_block_id'] \ = prefetch_var_name_to_block_id + attrs['checkpint_block_id'] = checkpoint_block_id # step5 append the listen_and_serv op pserver_program.global_block().append_op( @@ -910,6 +913,27 @@ class DistributeTranspiler(object): return table_opt_block + def _create_checkpoint_save_block(self, pserver_program, pre_block_idx): + """ + create a new block to handle save checkpoint. + """ + import os + + pserver_program.global_block().create_var( + name="kLookupTablePath", + persistable=True, + type=core.VarDesc.VarType.RAW) + + checkpoint_save_block = pserver_program.create_block(pre_block_idx) + # this 'file_path' do not be used in save lookup table variable + checkpoint_save_block.append_op( + type='save', + inputs={'X': [self.table_name]}, + outputs={}, + attrs={'file_path': "none"}) + + return checkpoint_save_block.idx + def _create_vars_from_blocklist(self, program, block_list,