提交 41701969 编写于 作者: T tangwei12

[wip] ckpt m2 develop

上级 431491a2
......@@ -36,6 +36,7 @@ namespace detail {
constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet";
constexpr char kRequestPrefetch[] = "RequestPrefetch";
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
......
......@@ -66,6 +66,16 @@ class RequestPrefetchHandler final : public RequestHandler {
const std::string& out_var_name = "") override;
};
class RequestCheckpointHandler final : public RequestHandler {
public:
explicit RequestCheckpointHandler(bool sync_mode)
: RequestHandler(sync_mode) {}
virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
};
} // namespace detail
} // namespace operators
} // namespace paddle
......@@ -253,11 +253,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_get_handler_.reset(new detail::RequestGetHandler(sync_mode));
request_prefetch_handler_.reset(
new detail::RequestPrefetchHandler(sync_mode));
request_checkpoint_handler_.reset(
new detail::RequestCheckpointHandler(sync_mode));
rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestPrefetch,
request_prefetch_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestCheckpoint,
request_checkpoint_handler_.get());
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = optimize_block->Program();
......@@ -300,6 +304,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
f(request_send_handler_.get());
f(request_get_handler_.get());
f(request_prefetch_handler_.get());
f(request_checkpoint_handler_.get());
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
......@@ -344,6 +349,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault({});
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
AddAttr<int>(kCheckpointBlockId,
"BolckID to run save checkpoint on pserer.")
.SetDefault(-1);
}
};
......
......@@ -32,6 +32,7 @@ namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
constexpr char kCheckpointBlockId[] = "checkpint_block_id";
void RunServer(std::shared_ptr<detail::RPCServer> service);
......@@ -66,6 +67,7 @@ class ListenAndServOp : public framework::OperatorBase {
mutable std::shared_ptr<detail::RequestHandler> request_send_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_get_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_prefetch_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_checkpoint_handler_;
mutable std::shared_ptr<std::thread> server_thread_;
};
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
......@@ -78,26 +79,37 @@ class SaveOp : public framework::OperatorBase {
MkDirRecursively(DirName(filename).c_str());
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
auto iname = Input("X");
auto *var = scope.FindVar(iname);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op",
iname);
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>(),
"SaveOp only support LoDTensor, %s has wrong type", iname);
if (var->IsType<framework::LoDTensor>()) {
SaveLodTensor(filename, place, var);
} else if (var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(filename, place, var);
} else {
PADDLE_ENFORCE(
false,
"SaveOp only support LoDTensor and SelectedRows, %s has wrong type",
iname);
}
}
SaveLodTensor(const string &filename, const platform::Place &place,
Variable *var) {
auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
auto in_dtype = framework::ToDataType(tensor.type());
auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
......@@ -112,17 +124,35 @@ class SaveOp : public framework::OperatorBase {
} else {
framework::SerializeToStream(fout, tensor, dev_ctx);
}
fout.close()
}
SaveSelectedRows(const string &filename, const platform::Place &place,
Variable *var) {
auto &selectedRows = var->Get<framework::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
framework::SerializeToStream(fout, selectedRows, dev_ctx);
fout.close()
}
};
class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor ) Input tensor to be saved");
AddInput("X", "(Tensor ) Input LoDTensor and SelectedRows to be saved");
AddComment(R"DOC(
Save operator
This operator will serialize and write a tensor variable to file on disk.
This operator will serialize and write a tensor/selected rows variable to file on disk.
)DOC");
AddAttr<bool>("overwrite",
"(boolean, default true)"
......
......@@ -522,6 +522,8 @@ class DistributeTranspiler:
pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
prefetch_var_name_to_block_id = self._create_prefetch_block(
pserver_index, pserver_program, table_opt_block)
checkpoint_block_id = self._create_checkpoint_save_block(
pserver_program, table_opt_block.idx)
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place
......@@ -540,6 +542,7 @@ class DistributeTranspiler:
if len(prefetch_var_name_to_block_id) > 0:
attrs['prefetch_var_name_to_block_id'] \
= prefetch_var_name_to_block_id
attrs['checkpint_block_id'] = checkpoint_block_id
# step5 append the listen_and_serv op
pserver_program.global_block().append_op(
......@@ -824,6 +827,23 @@ class DistributeTranspiler:
return table_opt_block
def _create_checkpoint_save_block(self, pserver_program, pre_block_idx):
"""
create a new block to handle save checkpoint.
"""
import os
checkpoint_save_block = pserver_program.create_block(pre_block_idx)
checkpoint_save_block.append_op(
type='save',
inputs={'X': [self.table_name]},
outputs={},
attrs={
'file_path': os.path.join("/tmp/pserver_ckpt/", self.table_name)
})
return checkpoint_save_block.idx
def _create_vars_from_blocklist(self,
program,
block_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册