提交 41701969 编写于 作者: T tangwei12

[wip] ckpt m2 develop

上级 431491a2
...@@ -36,6 +36,7 @@ namespace detail { ...@@ -36,6 +36,7 @@ namespace detail {
constexpr char kRequestSend[] = "RequestSend"; constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet"; constexpr char kRequestGet[] = "RequestGet";
constexpr char kRequestPrefetch[] = "RequestPrefetch"; constexpr char kRequestPrefetch[] = "RequestPrefetch";
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
......
...@@ -66,6 +66,16 @@ class RequestPrefetchHandler final : public RequestHandler { ...@@ -66,6 +66,16 @@ class RequestPrefetchHandler final : public RequestHandler {
const std::string& out_var_name = "") override; const std::string& out_var_name = "") override;
}; };
class RequestCheckpointHandler final : public RequestHandler {
public:
explicit RequestCheckpointHandler(bool sync_mode)
: RequestHandler(sync_mode) {}
virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
};
} // namespace detail } // namespace detail
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -253,11 +253,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -253,11 +253,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_get_handler_.reset(new detail::RequestGetHandler(sync_mode)); request_get_handler_.reset(new detail::RequestGetHandler(sync_mode));
request_prefetch_handler_.reset( request_prefetch_handler_.reset(
new detail::RequestPrefetchHandler(sync_mode)); new detail::RequestPrefetchHandler(sync_mode));
request_checkpoint_handler_.reset(
new detail::RequestCheckpointHandler(sync_mode));
rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get()); rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get()); rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestPrefetch, rpc_service_->RegisterRPC(detail::kRequestPrefetch,
request_prefetch_handler_.get()); request_prefetch_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestCheckpoint,
request_checkpoint_handler_.get());
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = optimize_block->Program(); auto *program = optimize_block->Program();
...@@ -300,6 +304,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -300,6 +304,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
f(request_send_handler_.get()); f(request_send_handler_.get());
f(request_get_handler_.get()); f(request_get_handler_.get());
f(request_prefetch_handler_.get()); f(request_prefetch_handler_.get());
f(request_checkpoint_handler_.get());
// start the server listening after all member initialized. // start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_)); server_thread_.reset(new std::thread(RunServer, rpc_service_));
...@@ -344,6 +349,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -344,6 +349,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault({}); .SetDefault({});
AddAttr<int>("Fanin", "How many clients send to this server.") AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1); .SetDefault(1);
AddAttr<int>(kCheckpointBlockId,
"BolckID to run save checkpoint on pserer.")
.SetDefault(-1);
} }
}; };
......
...@@ -32,6 +32,7 @@ namespace operators { ...@@ -32,6 +32,7 @@ namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock"; constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
constexpr char kCheckpointBlockId[] = "checkpint_block_id";
void RunServer(std::shared_ptr<detail::RPCServer> service); void RunServer(std::shared_ptr<detail::RPCServer> service);
...@@ -66,6 +67,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -66,6 +67,7 @@ class ListenAndServOp : public framework::OperatorBase {
mutable std::shared_ptr<detail::RequestHandler> request_send_handler_; mutable std::shared_ptr<detail::RequestHandler> request_send_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_get_handler_; mutable std::shared_ptr<detail::RequestHandler> request_get_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_prefetch_handler_; mutable std::shared_ptr<detail::RequestHandler> request_prefetch_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_checkpoint_handler_;
mutable std::shared_ptr<std::thread> server_thread_; mutable std::shared_ptr<std::thread> server_thread_;
}; };
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
...@@ -78,26 +79,37 @@ class SaveOp : public framework::OperatorBase { ...@@ -78,26 +79,37 @@ class SaveOp : public framework::OperatorBase {
MkDirRecursively(DirName(filename).c_str()); MkDirRecursively(DirName(filename).c_str());
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
auto iname = Input("X"); auto iname = Input("X");
auto *var = scope.FindVar(iname); auto *var = scope.FindVar(iname);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op", PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op",
iname); iname);
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>(), if (var->IsType<framework::LoDTensor>()) {
"SaveOp only support LoDTensor, %s has wrong type", iname); SaveLodTensor(filename, place, var);
} else if (var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(filename, place, var);
} else {
PADDLE_ENFORCE(
false,
"SaveOp only support LoDTensor and SelectedRows, %s has wrong type",
iname);
}
}
SaveLodTensor(const string &filename, const platform::Place &place,
Variable *var) {
auto &tensor = var->Get<framework::LoDTensor>(); auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
auto in_dtype = framework::ToDataType(tensor.type()); auto in_dtype = framework::ToDataType(tensor.type());
auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
...@@ -112,17 +124,35 @@ class SaveOp : public framework::OperatorBase { ...@@ -112,17 +124,35 @@ class SaveOp : public framework::OperatorBase {
} else { } else {
framework::SerializeToStream(fout, tensor, dev_ctx); framework::SerializeToStream(fout, tensor, dev_ctx);
} }
fout.close()
}
SaveSelectedRows(const string &filename, const platform::Place &place,
Variable *var) {
auto &selectedRows = var->Get<framework::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
framework::SerializeToStream(fout, selectedRows, dev_ctx);
fout.close()
} }
}; };
class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor ) Input tensor to be saved"); AddInput("X", "(Tensor ) Input LoDTensor and SelectedRows to be saved");
AddComment(R"DOC( AddComment(R"DOC(
Save operator Save operator
This operator will serialize and write a tensor variable to file on disk. This operator will serialize and write a tensor/selected rows variable to file on disk.
)DOC"); )DOC");
AddAttr<bool>("overwrite", AddAttr<bool>("overwrite",
"(boolean, default true)" "(boolean, default true)"
......
...@@ -522,6 +522,8 @@ class DistributeTranspiler: ...@@ -522,6 +522,8 @@ class DistributeTranspiler:
pserver_index, pserver_program, pre_block_idx, grad_to_block_id) pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
prefetch_var_name_to_block_id = self._create_prefetch_block( prefetch_var_name_to_block_id = self._create_prefetch_block(
pserver_index, pserver_program, table_opt_block) pserver_index, pserver_program, table_opt_block)
checkpoint_block_id = self._create_checkpoint_save_block(
pserver_program, table_opt_block.idx)
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will # NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place # not be executed, so it's safe to use optimize_block to hold the place
...@@ -540,6 +542,7 @@ class DistributeTranspiler: ...@@ -540,6 +542,7 @@ class DistributeTranspiler:
if len(prefetch_var_name_to_block_id) > 0: if len(prefetch_var_name_to_block_id) > 0:
attrs['prefetch_var_name_to_block_id'] \ attrs['prefetch_var_name_to_block_id'] \
= prefetch_var_name_to_block_id = prefetch_var_name_to_block_id
attrs['checkpint_block_id'] = checkpoint_block_id
# step5 append the listen_and_serv op # step5 append the listen_and_serv op
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
...@@ -824,6 +827,23 @@ class DistributeTranspiler: ...@@ -824,6 +827,23 @@ class DistributeTranspiler:
return table_opt_block return table_opt_block
def _create_checkpoint_save_block(self, pserver_program, pre_block_idx):
"""
create a new block to handle save checkpoint.
"""
import os
checkpoint_save_block = pserver_program.create_block(pre_block_idx)
checkpoint_save_block.append_op(
type='save',
inputs={'X': [self.table_name]},
outputs={},
attrs={
'file_path': os.path.join("/tmp/pserver_ckpt/", self.table_name)
})
return checkpoint_save_block.idx
def _create_vars_from_blocklist(self, def _create_vars_from_blocklist(self,
program, program,
block_list, block_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册