diff --git a/paddle/fluid/operators/detail/request_handler.h b/paddle/fluid/operators/detail/request_handler.h index a2d08747d59220d30a5b8fd56074fd2739ae3bab..cb480accb4ea234cd4ec2edd1ddcc0861984a248 100644 --- a/paddle/fluid/operators/detail/request_handler.h +++ b/paddle/fluid/operators/detail/request_handler.h @@ -36,6 +36,7 @@ namespace detail { constexpr char kRequestSend[] = "RequestSend"; constexpr char kRequestGet[] = "RequestGet"; constexpr char kRequestPrefetch[] = "RequestPrefetch"; +constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" diff --git a/paddle/fluid/operators/detail/request_handler_impl.h b/paddle/fluid/operators/detail/request_handler_impl.h index 3f77c09a9598b431d747f1b824615e49d939098e..643eae4d314383d525c6c8bbfaa71d2da2c524f2 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.h +++ b/paddle/fluid/operators/detail/request_handler_impl.h @@ -66,6 +66,16 @@ class RequestPrefetchHandler final : public RequestHandler { const std::string& out_var_name = "") override; }; +class RequestCheckpointHandler final : public RequestHandler { + public: + explicit RequestCheckpointHandler(bool sync_mode) + : RequestHandler(sync_mode) {} + virtual ~RequestCheckpointHandler() {} + bool Handle(const std::string& varname, framework::Scope* scope, + framework::Variable* var, framework::Variable** outvar, + const std::string& out_var_name = "") override; +}; + } // namespace detail } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 4d12278799f66f2fb92b7580ba0c43e845aa4d3a..0804a266d0f1e15fe8ccccc37f3981805d3926e2 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -253,11 +253,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, request_get_handler_.reset(new detail::RequestGetHandler(sync_mode)); request_prefetch_handler_.reset( new detail::RequestPrefetchHandler(sync_mode)); + request_checkpoint_handler_.reset( + new detail::RequestCheckpointHandler(sync_mode)); rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get()); rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get()); rpc_service_->RegisterRPC(detail::kRequestPrefetch, request_prefetch_handler_.get()); + rpc_service_->RegisterRPC(detail::kRequestCheckpoint, + request_checkpoint_handler_.get()); auto *optimize_block = Attr(kOptimizeBlock); auto *program = optimize_block->Program(); @@ -300,6 +304,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, f(request_send_handler_.get()); f(request_get_handler_.get()); f(request_prefetch_handler_.get()); + f(request_checkpoint_handler_.get()); // start the server listening after all member initialized. server_thread_.reset(new std::thread(RunServer, rpc_service_)); @@ -344,6 +349,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault({}); AddAttr("Fanin", "How many clients send to this server.") .SetDefault(1); + AddAttr(kCheckpointBlockId, + "BolckID to run save checkpoint on pserer.") + .SetDefault(-1); } }; diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index 46c3a19e20b3f2dd970a672bb99f98e83d3e25bf..b00ad195e9e162a4c911af64dd19bc7cc0ef3775 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -32,6 +32,7 @@ namespace operators { constexpr char kOptimizeBlock[] = "OptimizeBlock"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id"; +constexpr char kCheckpointBlockId[] = "checkpint_block_id"; void RunServer(std::shared_ptr service); @@ -66,6 +67,7 @@ class ListenAndServOp : public framework::OperatorBase { mutable std::shared_ptr request_send_handler_; mutable std::shared_ptr request_get_handler_; mutable std::shared_ptr request_prefetch_handler_; + mutable std::shared_ptr request_checkpoint_handler_; mutable std::shared_ptr server_thread_; }; diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index e6d27e2dedd7668b93bd8ddc330a897d1c6fa732..410796eeb6cd1e13de2e2699f639033d8525f9ed 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { @@ -78,26 +79,37 @@ class SaveOp : public framework::OperatorBase { MkDirRecursively(DirName(filename).c_str()); - // FIXME(yuyang18): We save variable to local file now, but we should change - // it to save an output stream. - std::ofstream fout(filename); - PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", - filename); - auto iname = Input("X"); auto *var = scope.FindVar(iname); PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op", iname); - PADDLE_ENFORCE(var->IsType(), - "SaveOp only support LoDTensor, %s has wrong type", iname); + if (var->IsType()) { + SaveLodTensor(filename, place, var); + } else if (var->IsType()) { + SaveSelectedRows(filename, place, var); + } else { + PADDLE_ENFORCE( + false, + "SaveOp only support LoDTensor and SelectedRows, %s has wrong type", + iname); + } + } + SaveLodTensor(const string &filename, const platform::Place &place, + Variable *var) { auto &tensor = var->Get(); // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. + std::ofstream fout(filename); + PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", + filename); + auto in_dtype = framework::ToDataType(tensor.type()); auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; @@ -112,17 +124,35 @@ class SaveOp : public framework::OperatorBase { } else { framework::SerializeToStream(fout, tensor, dev_ctx); } + fout.close() + } + + SaveSelectedRows(const string &filename, const platform::Place &place, + Variable *var) { + auto &selectedRows = var->Get(); + + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. + std::ofstream fout(filename); + PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", + filename); + framework::SerializeToStream(fout, selectedRows, dev_ctx); + fout.close() } }; class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "(Tensor ) Input tensor to be saved"); + AddInput("X", "(Tensor ) Input LoDTensor and SelectedRows to be saved"); AddComment(R"DOC( Save operator -This operator will serialize and write a tensor variable to file on disk. +This operator will serialize and write a tensor/selected rows variable to file on disk. )DOC"); AddAttr("overwrite", "(boolean, default true)" diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 2480d4e76a1b5fd76b7dc8299c2f8fcae967145e..caad745b1fb7392fcb7164a8a3b03b166767db5f 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -522,6 +522,8 @@ class DistributeTranspiler: pserver_index, pserver_program, pre_block_idx, grad_to_block_id) prefetch_var_name_to_block_id = self._create_prefetch_block( pserver_index, pserver_program, table_opt_block) + checkpoint_block_id = self._create_checkpoint_save_block( + pserver_program, table_opt_block.idx) # NOTE: if has_distributed_lookup_table is False, then prefetch_block will # not be executed, so it's safe to use optimize_block to hold the place @@ -540,6 +542,7 @@ class DistributeTranspiler: if len(prefetch_var_name_to_block_id) > 0: attrs['prefetch_var_name_to_block_id'] \ = prefetch_var_name_to_block_id + attrs['checkpint_block_id'] = checkpoint_block_id # step5 append the listen_and_serv op pserver_program.global_block().append_op( @@ -824,6 +827,23 @@ class DistributeTranspiler: return table_opt_block + def _create_checkpoint_save_block(self, pserver_program, pre_block_idx): + """ + create a new block to handle save checkpoint. + """ + import os + + checkpoint_save_block = pserver_program.create_block(pre_block_idx) + checkpoint_save_block.append_op( + type='save', + inputs={'X': [self.table_name]}, + outputs={}, + attrs={ + 'file_path': os.path.join("/tmp/pserver_ckpt/", self.table_name) + }) + + return checkpoint_save_block.idx + def _create_vars_from_blocklist(self, program, block_list,