From 41701969a9e73fe85bbcbf99265cb84ecf512f4d Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 13 Jun 2018 22:34:00 +0800 Subject: [PATCH] [wip] ckpt m2 develop --- .../fluid/operators/detail/request_handler.h | 1 + .../operators/detail/request_handler_impl.h | 10 ++++ paddle/fluid/operators/listen_and_serv_op.cc | 8 +++ paddle/fluid/operators/listen_and_serv_op.h | 2 + paddle/fluid/operators/save_op.cc | 50 +++++++++++++++---- .../fluid/transpiler/distribute_transpiler.py | 20 ++++++++ 6 files changed, 81 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/detail/request_handler.h b/paddle/fluid/operators/detail/request_handler.h index a2d08747d59..cb480accb4e 100644 --- a/paddle/fluid/operators/detail/request_handler.h +++ b/paddle/fluid/operators/detail/request_handler.h @@ -36,6 +36,7 @@ namespace detail { constexpr char kRequestSend[] = "RequestSend"; constexpr char kRequestGet[] = "RequestGet"; constexpr char kRequestPrefetch[] = "RequestPrefetch"; +constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" diff --git a/paddle/fluid/operators/detail/request_handler_impl.h b/paddle/fluid/operators/detail/request_handler_impl.h index 3f77c09a959..643eae4d314 100644 --- a/paddle/fluid/operators/detail/request_handler_impl.h +++ b/paddle/fluid/operators/detail/request_handler_impl.h @@ -66,6 +66,16 @@ class RequestPrefetchHandler final : public RequestHandler { const std::string& out_var_name = "") override; }; +class RequestCheckpointHandler final : public RequestHandler { + public: + explicit RequestCheckpointHandler(bool sync_mode) + : RequestHandler(sync_mode) {} + virtual ~RequestCheckpointHandler() {} + bool Handle(const std::string& varname, framework::Scope* scope, + framework::Variable* var, framework::Variable** outvar, + const std::string& out_var_name = "") override; +}; + } // namespace detail } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 4d12278799f..0804a266d0f 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -253,11 +253,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, request_get_handler_.reset(new detail::RequestGetHandler(sync_mode)); request_prefetch_handler_.reset( new detail::RequestPrefetchHandler(sync_mode)); + request_checkpoint_handler_.reset( + new detail::RequestCheckpointHandler(sync_mode)); rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get()); rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get()); rpc_service_->RegisterRPC(detail::kRequestPrefetch, request_prefetch_handler_.get()); + rpc_service_->RegisterRPC(detail::kRequestCheckpoint, + request_checkpoint_handler_.get()); auto *optimize_block = Attr(kOptimizeBlock); auto *program = optimize_block->Program(); @@ -300,6 +304,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, f(request_send_handler_.get()); f(request_get_handler_.get()); f(request_prefetch_handler_.get()); + f(request_checkpoint_handler_.get()); // start the server listening after all member initialized. server_thread_.reset(new std::thread(RunServer, rpc_service_)); @@ -344,6 +349,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault({}); AddAttr("Fanin", "How many clients send to this server.") .SetDefault(1); + AddAttr(kCheckpointBlockId, + "BolckID to run save checkpoint on pserer.") + .SetDefault(-1); } }; diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index 46c3a19e20b..b00ad195e9e 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -32,6 +32,7 @@ namespace operators { constexpr char kOptimizeBlock[] = "OptimizeBlock"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id"; +constexpr char kCheckpointBlockId[] = "checkpint_block_id"; void RunServer(std::shared_ptr service); @@ -66,6 +67,7 @@ class ListenAndServOp : public framework::OperatorBase { mutable std::shared_ptr request_send_handler_; mutable std::shared_ptr request_get_handler_; mutable std::shared_ptr request_prefetch_handler_; + mutable std::shared_ptr request_checkpoint_handler_; mutable std::shared_ptr server_thread_; }; diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index e6d27e2dedd..410796eeb6c 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { @@ -78,26 +79,37 @@ class SaveOp : public framework::OperatorBase { MkDirRecursively(DirName(filename).c_str()); - // FIXME(yuyang18): We save variable to local file now, but we should change - // it to save an output stream. - std::ofstream fout(filename); - PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", - filename); - auto iname = Input("X"); auto *var = scope.FindVar(iname); PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op", iname); - PADDLE_ENFORCE(var->IsType(), - "SaveOp only support LoDTensor, %s has wrong type", iname); + if (var->IsType()) { + SaveLodTensor(filename, place, var); + } else if (var->IsType()) { + SaveSelectedRows(filename, place, var); + } else { + PADDLE_ENFORCE( + false, + "SaveOp only support LoDTensor and SelectedRows, %s has wrong type", + iname); + } + } + SaveLodTensor(const string &filename, const platform::Place &place, + Variable *var) { auto &tensor = var->Get(); // get device context from pool platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. + std::ofstream fout(filename); + PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", + filename); + auto in_dtype = framework::ToDataType(tensor.type()); auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; @@ -112,17 +124,35 @@ class SaveOp : public framework::OperatorBase { } else { framework::SerializeToStream(fout, tensor, dev_ctx); } + fout.close() + } + + SaveSelectedRows(const string &filename, const platform::Place &place, + Variable *var) { + auto &selectedRows = var->Get(); + + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + + // FIXME(yuyang18): We save variable to local file now, but we should change + // it to save an output stream. + std::ofstream fout(filename); + PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", + filename); + framework::SerializeToStream(fout, selectedRows, dev_ctx); + fout.close() } }; class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "(Tensor ) Input tensor to be saved"); + AddInput("X", "(Tensor ) Input LoDTensor and SelectedRows to be saved"); AddComment(R"DOC( Save operator -This operator will serialize and write a tensor variable to file on disk. +This operator will serialize and write a tensor/selected rows variable to file on disk. )DOC"); AddAttr("overwrite", "(boolean, default true)" diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 2480d4e76a1..caad745b1fb 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -522,6 +522,8 @@ class DistributeTranspiler: pserver_index, pserver_program, pre_block_idx, grad_to_block_id) prefetch_var_name_to_block_id = self._create_prefetch_block( pserver_index, pserver_program, table_opt_block) + checkpoint_block_id = self._create_checkpoint_save_block( + pserver_program, table_opt_block.idx) # NOTE: if has_distributed_lookup_table is False, then prefetch_block will # not be executed, so it's safe to use optimize_block to hold the place @@ -540,6 +542,7 @@ class DistributeTranspiler: if len(prefetch_var_name_to_block_id) > 0: attrs['prefetch_var_name_to_block_id'] \ = prefetch_var_name_to_block_id + attrs['checkpint_block_id'] = checkpoint_block_id # step5 append the listen_and_serv op pserver_program.global_block().append_op( @@ -824,6 +827,23 @@ class DistributeTranspiler: return table_opt_block + def _create_checkpoint_save_block(self, pserver_program, pre_block_idx): + """ + create a new block to handle save checkpoint. + """ + import os + + checkpoint_save_block = pserver_program.create_block(pre_block_idx) + checkpoint_save_block.append_op( + type='save', + inputs={'X': [self.table_name]}, + outputs={}, + attrs={ + 'file_path': os.path.join("/tmp/pserver_ckpt/", self.table_name) + }) + + return checkpoint_save_block.idx + def _create_vars_from_blocklist(self, program, block_list, -- GitLab