提交 e684575f 编写于 作者: T tangwei12

checkpoint feature optimized

上级 2229db52
......@@ -42,10 +42,11 @@ class CheckpointNotifyOp : public framework::OperatorBase {
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < epmap.size(); i++) {
VLOG(3) << "checkpoint notify sending " << dir << " to " << epmap[i];
auto serial_looku_table =
auto lookup_table_save_dir =
string::Sprintf("%s/%s_%d", dir, lookup_table_name, i);
rpc_client->AsyncCheckpointNotify(epmap[i], serial_looku_table);
rpc_client->AsyncCheckpointNotify(epmap[i], lookup_table_save_dir);
VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name
<< " and dir:" << dir << " to " << epmap[i];
}
rpc_client->Wait();
}
......@@ -64,10 +65,10 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::string>("lookup_table",
"(string, default '') the lookup table name");
AddComment(R"DOC(
Prefetch operator
CheckpointNotify operator
This operator will send Ids variables to listen_and_serve op at
the parameter server and fetch result back.
This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at
the parameter server.
)DOC");
}
};
......
......@@ -25,3 +25,7 @@
#define RPCSERVER_T distributed::AsyncBRPCServer
#define RPCCLIENT_T distributed::BRPCClient
#endif
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "lookup_table_path";
......@@ -194,7 +194,7 @@ class RequestCheckpointNotify final : public RequestBase {
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new VariableResponse(request_handler->scope(),
request_handler->dev_ctx(), true));
request_handler->dev_ctx()));
int method_id =
static_cast<int>(distributed::GrpcMethod::kCheckpointNotify);
service_->RequestAsyncUnary(
......@@ -212,13 +212,10 @@ class RequestCheckpointNotify final : public RequestBase {
std::string checkpoint_notify = request_->Varname();
std::string checkpoint_dir = request_->OutVarname();
framework::Variable* invar = nullptr;
framework::Variable* outvar = nullptr;
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir;
request_handler_->Handle(checkpoint_notify, scope, invar, &outvar,
request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
checkpoint_dir);
Finish(reply_, &responder_);
}
......@@ -320,7 +317,7 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
return;
}
LOG(INFO) << "TryToRegisterNewOne on RPC NAME: " << rpc_name
VLOG(4) << "TryToRegisterNewOne on RPC NAME: " << rpc_name
<< " REQ ID: " << req_id;
auto& reqs = rpc_reqs_[rpc_name];
......
......@@ -20,6 +20,7 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/printf.h"
......@@ -129,10 +130,10 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke.");
auto* lt_var = scope->FindVar("loopup_table_path")->GetMutable<std::string>();
auto* lt_var = scope->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
lt_var->clear();
lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update loopup_table_path to: "
VLOG(4) << "RequestCheckpointHandler update var lookup_table_path to: "
<< out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope);
return true;
......
......@@ -247,11 +247,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint");
int checkpoint_notify_id = Attr<int>(kCheckpointBlockId);
int checkpoint_notify_block_id = Attr<int>(kCheckpointBlockId);
LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint
<< ", CheckpointNotify Id: " << checkpoint_notify_id;
<< ", CheckpointNotify Id: " << checkpoint_notify_block_id;
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
......@@ -260,7 +260,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_prefetch_handler_.reset(
new distributed::RequestPrefetchHandler(sync_mode));
request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler(
sync_mode, checkpoint_notify_id));
sync_mode, checkpoint_notify_block_id));
rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get());
......@@ -276,8 +276,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework::Executor executor(dev_place);
std::shared_ptr<framework::ExecutorPrepareContext> ckpt_pre_context = nullptr;
if (checkpoint_notify_id != -1) {
auto ctx = executor.Prepare(*program, checkpoint_notify_id);
if (checkpoint_notify_block_id != -1) {
auto ctx = executor.Prepare(*program, checkpoint_notify_block_id);
ckpt_pre_context = std::move(ctx);
}
......@@ -334,7 +334,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
SavePort();
if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list,
checkpoint_notify_id);
checkpoint_notify_block_id);
} else {
RunAsyncLoop(&executor, program);
}
......
......@@ -101,7 +101,7 @@ class LoadOp : public framework::OperatorBase {
class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("Out", "The tensor need to be loaded");
AddOutput("Out", "The LoDTensor / SelectedRows need to be loaded");
AddAttr<bool>(
"load_as_fp16",
"If true, the tensor will be first loaded and then "
......@@ -112,7 +112,9 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
R"(Variable will be loaded from "file_path")")
.AddCustomChecker(
[](const std::string &path) { return !path.empty(); });
AddComment("Load operator will load a tensor variable from disk file.");
AddComment(
"Load operator will load a LoDTensor / SelectedRows variable from disk "
"file.");
}
};
} // namespace operators
......
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
......@@ -131,11 +132,10 @@ class SaveOp : public framework::OperatorBase {
void SaveSelectedRows(const framework::Scope &scope,
const platform::Place &place,
framework::Variable *var) const {
auto *lt_var =
scope.FindVar("loopup_table_path")->GetMutable<std::string>();
auto *lt_var = scope.FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
PADDLE_ENFORCE(
lt_var != nullptr,
"Can not find variable loopup_table_path for SaveSelectedRows");
"Can not find variable lookup_table_path for SaveSelectedRows");
std::string filename = lt_var->data();
VLOG(4) << "SaveSelectedRows get File name: " << filename;
......@@ -162,7 +162,7 @@ class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Save operator
This operator will serialize and write a tensor/selected rows variable to file on disk.
This operator will serialize and write LoDTensor / SelectedRows variable to file on disk.
)DOC");
AddAttr<bool>("overwrite",
"(boolean, default true)"
......@@ -186,7 +186,7 @@ class SaveOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("loopup_table_path").front();
auto out_var_name = op_desc.Output(LOOKUP_TABLE_PATH).front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
......
......@@ -1042,6 +1042,7 @@ def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
main_program(Program): Find the variable named table_name in main_program
pserver_id(int): the serial number in pserver_endpoints list
table_name(str): lookup table name
Returns:
None
......
......@@ -914,7 +914,7 @@ class DistributeTranspiler(object):
import os
pserver_program.global_block().create_var(
name="loopup_table_path",
name="lookup_table_path",
persistable=True,
type=core.VarDesc.VarType.RAW)
......@@ -923,7 +923,10 @@ class DistributeTranspiler(object):
type='save',
inputs={'X': [self.table_name]},
outputs={},
attrs={'file_path': self.table_name})
attrs={
'file_path':
"this 'file_path' do not be used in save lookup table variable"
})
return checkpoint_save_block.idx
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册