未验证 提交 8b50ad80 编写于 作者: T tangwei12 提交者: GitHub

checkpoint at distributed training (#14854)

checkpoint for distributed training.
上级 07dc5a15
...@@ -74,7 +74,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep, ...@@ -74,7 +74,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
SendProcessor* s = new SendProcessor(ch); SendProcessor* s = new SendProcessor(ch);
const std::string method = "SendRPC"; const std::string method = kSendRPC;
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope)); VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out); s->Prepare(h, time_out);
...@@ -107,7 +107,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep, ...@@ -107,7 +107,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
void ProcGetResponse(const VarHandle& var_h, void ProcGetResponse(const VarHandle& var_h,
const ::grpc::ByteBuffer& ret_msg) { const ::grpc::ByteBuffer& ret_msg) {
VLOG(100) << "ProcGetResponse"; VLOG(4) << "ProcGetResponse";
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
// get response's trainer_id is not used // get response's trainer_id is not used
int trainer_id; int trainer_id;
...@@ -127,59 +127,74 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep, ...@@ -127,59 +127,74 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
const std::string& out_varname,
int64_t time_out) { int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name, return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname,
"/sendrecv.SendRecvService/GetVariable", time_out); "/sendrecv.SendRecvService/GetVariable", time_out);
} }
VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_varname, int64_t time_out) {
std::string var_name_no_barrier =
string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE);
return _AsyncGetVar(
ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname,
"/sendrecv.SendRecvService/GetVariableNoBarrier", time_out);
}
VarHandlePtr GRPCClient::AsyncGetMonomerVariable( VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx, const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name, const framework::Scope& scope, const std::string& var_name,
int64_t time_out) { int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name, return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name,
"/sendrecv.SendRecvService/GetMonomerVariable", time_out); "/sendrecv.SendRecvService/GetMonomerVariable", time_out);
} }
VarHandlePtr GRPCClient::_AsyncGetVar(const std::string& ep, VarHandlePtr GRPCClient::_AsyncGetVar(
const platform::DeviceContext& ctx, const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope, const std::string& method,
const std::string& var_name, const std::string& var_name, const std::string& out_varname,
const std::string& rpc_path, const std::string& rpc_path, int64_t time_out) {
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx; const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
const std::string out_varname_val = out_varname;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch); GetProcessor* s = new GetProcessor(ch);
const std::string method = "GetRPC";
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope)); VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
s->Prepare(h, time_out); s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, s, method, p_ctx, h, rpc_path, this] { framework::AsyncIO(
// prepare input [var_name_val, out_varname_val, s, method, p_ctx, h, rpc_path, this] {
sendrecv::VariableMessage req; // prepare input
req.set_varname(var_name_val); sendrecv::VariableMessage req;
req.set_trainer_id(trainer_id_); req.set_varname(var_name_val);
::grpc::ByteBuffer buf; req.set_out_varname(out_varname_val);
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf); req.set_trainer_id(trainer_id_);
::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context // stub context
s->response_call_back_ = ProcGetResponse; s->response_call_back_ = ProcGetResponse;
platform::RecordRPCEvent record_event(method, p_ctx); platform::RecordRPCEvent record_event(method, p_ctx);
auto call = auto call =
s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_); s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) { if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait(); h->Wait();
} }
}); });
req_count_++; req_count_++;
...@@ -202,7 +217,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, ...@@ -202,7 +217,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch); GetProcessor* s = new GetProcessor(ch);
const std::string method = "PrefetchRPC"; const std::string method = kPrefetchRPC;
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope)); VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out); s->Prepare(h, time_out);
...@@ -242,7 +257,7 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep, ...@@ -242,7 +257,7 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = "BatchBarrierRPC"; const std::string method = kBatchBarrierRPC;
VarHandlePtr h( VarHandlePtr h(
new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr)); new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out); s->Prepare(h, time_out);
...@@ -267,7 +282,7 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep, ...@@ -267,7 +282,7 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
const std::string method = "FetchBarrierRPC"; const std::string method = kFetchBarrierRPC;
VarHandlePtr h( VarHandlePtr h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr)); new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out); s->Prepare(h, time_out);
...@@ -293,7 +308,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep, ...@@ -293,7 +308,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = "SendMonomerFetchBarrierRPC"; const std::string method = kSendMonomerFetchBarrierRPC;
VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr)); VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr));
s->Prepare(h, time_out); s->Prepare(h, time_out);
...@@ -320,7 +335,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, ...@@ -320,7 +335,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = "SendCompleteRPC"; const std::string method = kSendCompleteRPC;
VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr)); VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out); s->Prepare(h, time_out);
...@@ -347,7 +362,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, ...@@ -347,7 +362,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch); CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
const std::string method = "CheckPointNotifyRPC"; const std::string method = kCheckPointNotifyRPC;
VarHandlePtr h( VarHandlePtr h(
new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr)); new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
......
...@@ -186,8 +186,15 @@ class GRPCClient : public RPCClient { ...@@ -186,8 +186,15 @@ class GRPCClient : public RPCClient {
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetVarNoBarrier(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetMonomerVariable( VarHandlePtr AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx, const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name, const framework::Scope& scope, const std::string& var_name,
...@@ -228,11 +235,11 @@ class GRPCClient : public RPCClient { ...@@ -228,11 +235,11 @@ class GRPCClient : public RPCClient {
void Proceed(); void Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep); std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
VarHandlePtr _AsyncGetVar(const std::string& ep, VarHandlePtr _AsyncGetVar(
const platform::DeviceContext& ctx, const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope, const std::string& method,
const std::string& var_name, const std::string& rpc, const std::string& var_name, const std::string& out_varname,
int64_t time_out); const std::string& rpc_path, int64_t time_out = FLAGS_rpc_deadline);
private: private:
grpc::CompletionQueue cq_; grpc::CompletionQueue cq_;
......
...@@ -136,17 +136,65 @@ class RequestGet final : public RequestBase { ...@@ -136,17 +136,65 @@ class RequestGet final : public RequestBase {
void Process() override { void Process() override {
// proc request. // proc request.
std::string varname = request_.varname(); std::string varname = request_.varname();
std::string out_varname = request_.out_varname();
int trainer_id = request_.trainer_id(); int trainer_id = request_.trainer_id();
VLOG(4) << "RequestGet " << varname;
VLOG(4) << "RequestGet " << out_varname << " from " << varname;
auto scope = request_handler_->scope(); auto scope = request_handler_->scope();
auto invar = scope->FindVar(varname); framework::Variable* invar = nullptr;
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id); request_handler_->Handle(varname, scope, invar, &outvar, trainer_id,
out_varname);
if (outvar) { if (outvar) {
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(),
&reply_);
}
Finish(reply_, &responder_);
}
protected:
sendrecv::VariableMessage request_;
::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
};
class RequestGetNoBarrier final : public RequestBase {
public:
explicit RequestGetNoBarrier(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
auto method_id =
static_cast<int>(distributed::GrpcMethod::kGetVariableNoBarrier);
service_->RequestAsyncUnary(
method_id, &ctx_, &request_, &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestGetNoBarrier() {}
std::string GetReqName() override { return request_.varname(); }
void Process() override {
// proc request.
std::string varname = request_.varname();
std::string out_varname = request_.out_varname();
int trainer_id = request_.trainer_id();
VLOG(4) << "RequestGetNoBarrier " << out_varname << " from " << varname;
auto scope = request_handler_->scope();
framework::Variable* invar = nullptr;
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id,
out_varname);
if (outvar) {
SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(),
&reply_); &reply_);
} }
Finish(reply_, &responder_); Finish(reply_, &responder_);
...@@ -460,6 +508,9 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, ...@@ -460,6 +508,9 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b = new RequestSend(&service_, cq.get(), handler, req_id); b = new RequestSend(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestGet) { } else if (rpc_name == kRequestGet) {
b = new RequestGet(&service_, cq.get(), handler, req_id); b = new RequestGet(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestGetNoBarrier) {
b = new RequestGetNoBarrier(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestGetMonomerVariable) { } else if (rpc_name == kRequestGetMonomerVariable) {
b = new RequestGetMonomerVariable(&service_, cq.get(), handler, req_id, b = new RequestGetMonomerVariable(&service_, cq.get(), handler, req_id,
this); this);
......
...@@ -81,6 +81,7 @@ enum class GrpcMethod { ...@@ -81,6 +81,7 @@ enum class GrpcMethod {
kGetVariable, kGetVariable,
kPrefetchVariable, kPrefetchVariable,
kCheckpointNotify, kCheckpointNotify,
kGetVariableNoBarrier,
kGetMonomerVariable, kGetMonomerVariable,
kGetMonomerBarrier, kGetMonomerBarrier,
}; };
...@@ -94,6 +95,8 @@ inline const char* GrpcMethodName(GrpcMethod id) { ...@@ -94,6 +95,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return "/sendrecv.SendRecvService/SendVariable"; return "/sendrecv.SendRecvService/SendVariable";
case GrpcMethod::kGetVariable: case GrpcMethod::kGetVariable:
return "/sendrecv.SendRecvService/GetVariable"; return "/sendrecv.SendRecvService/GetVariable";
case GrpcMethod::kGetVariableNoBarrier:
return "/sendrecv.SendRecvService/GetVariableNoBarrier";
case GrpcMethod::kGetMonomerVariable: case GrpcMethod::kGetMonomerVariable:
return "/sendrecv.SendRecvService/GetMonomerVariable"; return "/sendrecv.SendRecvService/GetMonomerVariable";
case GrpcMethod::kGetMonomerBarrier: case GrpcMethod::kGetMonomerBarrier:
......
...@@ -42,11 +42,24 @@ constexpr char kRequestGetMonomerBarrier[] = "RequestGetMonomerBarrier"; ...@@ -42,11 +42,24 @@ constexpr char kRequestGetMonomerBarrier[] = "RequestGetMonomerBarrier";
constexpr char kRequestPrefetch[] = "RequestPrefetch"; constexpr char kRequestPrefetch[] = "RequestPrefetch";
constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier";
constexpr char kSendRPC[] = "SendRPC";
constexpr char kGetRPC[] = "GetRPC";
constexpr char kGetNoBarrierRPC[] = "GetNoBarrierRPC";
constexpr char kGetMonomerRPC[] = "GetMonomerRPC";
constexpr char kPrefetchRPC[] = "PrefetchRPC";
constexpr char kBatchBarrierRPC[] = "BatchBarrierRPC";
constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC";
constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC";
constexpr char kSendCompleteRPC[] = "SendCompleteRPC";
constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" #define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" #define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/piece.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
namespace paddle { namespace paddle {
...@@ -81,7 +82,8 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -81,7 +82,8 @@ bool RequestGetHandler::Handle(const std::string& varname,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string& out_var_name,
const std::string& table_name) { const std::string& table_name) {
VLOG(4) << "RequestGetHandler:" << varname; VLOG(4) << "RequestGetHandler:" << varname
<< " out_var_name: " << out_var_name;
if (sync_mode_) { if (sync_mode_) {
if (varname == FETCH_BARRIER_MESSAGE) { if (varname == FETCH_BARRIER_MESSAGE) {
...@@ -112,6 +114,32 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -112,6 +114,32 @@ bool RequestGetHandler::Handle(const std::string& varname,
return true; return true;
} }
bool RequestGetNoBarrierHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestGetNoBarrierHandler:" << varname
<< " out_var_name: " << out_var_name;
// get var from pserver immediately without barriers
string::Piece without_barrier_piece(WITHOUT_BARRIER_MESSAGE);
string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, without_barrier_piece)) {
var_name_piece = string::TrimSuffix(var_name_piece, without_barrier_piece);
VLOG(4) << "Get var " << var_name_piece << " with "
<< WITHOUT_BARRIER_MESSAGE;
*outvar = scope_->FindVar(var_name_piece.ToString());
return true;
} else {
PADDLE_THROW("GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE);
}
return true;
}
bool RequestPrefetchHandler::Handle(const std::string& varname, bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
......
...@@ -67,6 +67,16 @@ class RequestGetHandler final : public RequestHandler { ...@@ -67,6 +67,16 @@ class RequestGetHandler final : public RequestHandler {
bool enable_dc_asgd_; bool enable_dc_asgd_;
}; };
class RequestGetNoBarrierHandler final : public RequestHandler {
public:
RequestGetNoBarrierHandler() : RequestHandler(false) {}
virtual ~RequestGetNoBarrierHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
};
static inline void BuildVar(const std::string& param_name, static inline void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments, std::initializer_list<const char*> arguments,
paddle::framework::proto::OpDesc::Var* var) { paddle::framework::proto::OpDesc::Var* var) {
......
...@@ -43,8 +43,15 @@ class RPCClient { ...@@ -43,8 +43,15 @@ class RPCClient {
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncGetVarNoBarrier(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncGetMonomerVariable( virtual VarHandlePtr AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx, const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name, const framework::Scope& scope, const std::string& var_name,
......
...@@ -17,8 +17,14 @@ package sendrecv; ...@@ -17,8 +17,14 @@ package sendrecv;
option cc_generic_services = @cc_generic_services@; option cc_generic_services = @cc_generic_services@;
service SendRecvService { service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor
// TODO(typhoonzero): add streaming API
rpc SendVariable(VariableMessage) returns (VoidMessage) {} rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {} rpc GetVariable(VariableMessage) returns (VariableMessage) {}
rpc GetVariableNoBarrier(VariableMessage) returns (VariableMessage) {}
// pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {} rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {} rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
...@@ -27,12 +33,17 @@ service SendRecvService { ...@@ -27,12 +33,17 @@ service SendRecvService {
rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {} rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {}
} }
// It can be: LoDTensorSelectedRows or NCCL_ID
enum VarType { enum VarType {
LOD_TENSOR = 0; LOD_TENSOR = 0;
SELECTED_ROWS = 1; SELECTED_ROWS = 1;
NCCL_ID = 2; NCCL_ID = 2;
} }
// VariableMessage is serialized paddle variable message.
// NOTICE(gongwb):don't modify this proto if you are not
// not familar with how we serialize in sendrecvop_utils.h
// and deserilize it in variable_response.h.
message VariableMessage { message VariableMessage {
enum Type { enum Type {
// Pod Types // Pod Types
...@@ -49,14 +60,21 @@ message VariableMessage { ...@@ -49,14 +60,21 @@ message VariableMessage {
string varname = 1; string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType // TODO(Yancey1989): reference framework::proto::VarDesc::VarType
VarType type = 2; VarType type = 2;
// bool persistable is not needed for sending.
// tensor info:
Type data_type = 3; Type data_type = 3;
repeated int64 dims = 4; repeated int64 dims = 4;
// lod details:
int64 lod_level = 5; int64 lod_level = 5;
repeated LodData lod = 6; repeated LodData lod = 6;
// selected_rows height, aka. original dim0
int64 slr_height = 7; int64 slr_height = 7;
// tensor data
bytes serialized = 8; bytes serialized = 8;
// selected_rows data
bytes rows = 9; bytes rows = 9;
// Look up table block execution output variable name.
string out_varname = 10; string out_varname = 10;
// If 1, the ps server will start profiling, the ps // If 1, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_* // server stops profiling and generates a profile to /tmp/profile_ps_*
......
...@@ -347,6 +347,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -347,6 +347,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
new distributed::RequestPrefetchHandler(sync_mode)); new distributed::RequestPrefetchHandler(sync_mode));
request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler( request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler(
sync_mode, checkpoint_block_id)); sync_mode, checkpoint_block_id));
request_get_no_barrier_handler_.reset(
new distributed::RequestGetNoBarrierHandler());
rpc_service_->RegisterRPC(distributed::kRequestSend, rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(), request_send_handler_.get(),
...@@ -359,6 +361,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -359,6 +361,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
FLAGS_rpc_prefetch_thread_num); FLAGS_rpc_prefetch_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestCheckpoint, rpc_service_->RegisterRPC(distributed::kRequestCheckpoint,
request_checkpoint_handler_.get()); request_checkpoint_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier,
request_get_no_barrier_handler_.get());
auto optimize_blocks = auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks); Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
...@@ -413,6 +417,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -413,6 +417,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
f(request_get_handler_.get()); f(request_get_handler_.get());
f(request_prefetch_handler_.get()); f(request_prefetch_handler_.get());
f(request_checkpoint_handler_.get()); f(request_checkpoint_handler_.get());
f(request_get_no_barrier_handler_.get());
// start the server listening after all member initialized. // start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_)); server_thread_.reset(new std::thread(RunServer, rpc_service_));
......
...@@ -55,7 +55,6 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -55,7 +55,6 @@ class ListenAndServOp : public framework::OperatorBase {
const framework::VariableNameMap& inputs, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs, const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs); const framework::AttributeMap& attrs);
virtual ~ListenAndServOp(); virtual ~ListenAndServOp();
void RunSyncLoop(framework::Executor* executor, void RunSyncLoop(framework::Executor* executor,
...@@ -89,6 +88,8 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -89,6 +88,8 @@ class ListenAndServOp : public framework::OperatorBase {
mutable std::shared_ptr<distributed::RPCServer> rpc_service_; mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_; mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
mutable std::shared_ptr<distributed::RequestHandler> request_get_handler_; mutable std::shared_ptr<distributed::RequestHandler> request_get_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_get_no_barrier_handler_;
mutable std::shared_ptr<distributed::RequestHandler> mutable std::shared_ptr<distributed::RequestHandler>
request_prefetch_handler_; request_prefetch_handler_;
mutable std::shared_ptr<distributed::RequestHandler> mutable std::shared_ptr<distributed::RequestHandler>
......
...@@ -27,30 +27,50 @@ namespace operators { ...@@ -27,30 +27,50 @@ namespace operators {
class RecvOp : public framework::OperatorBase { class RecvOp : public framework::OperatorBase {
public: public:
RecvOp(const std::string& type, const framework::VariableNameMap& inputs, RecvOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap& outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap& attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope &scope,
const platform::Place& place) const override { const platform::Place &place) const override {
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> varnames =
Attr<std::vector<std::string>>("varnames");
int sync_mode = Attr<int>("sync_mode"); int sync_mode = Attr<int>("sync_mode");
auto outs = Outputs("Out");
bool with_barrier = Attr<bool>("with_barrier");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto &ctx = *pool.Get(place);
distributed::RPCClient* rpc_client = distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>( distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id")); Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets; if (with_barrier) {
for (size_t i = 0; i < outs.size(); i++) { std::vector<distributed::VarHandlePtr> rets;
VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; for (size_t i = 0; i < outs.size(); i++) {
rets.push_back(rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i])); std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
} VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
if (sync_mode) { << varname << " and with AsyncGetVar";
rets.push_back(
rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i]));
}
if (sync_mode) {
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
} else {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
<< varname << " and with AsyncGetVarNoBarrier";
rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope,
varname, outs[i]));
}
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
} }
...@@ -79,12 +99,23 @@ This operator can get variables from server side. ...@@ -79,12 +99,23 @@ This operator can get variables from server side.
"(int, default 0)" "(int, default 0)"
"sync recv or async recv.") "sync recv or async recv.")
.SetDefault(0); .SetDefault(0);
AddAttr<bool>("with_barrier",
"(bool, default True) if with_barrier=False, will use "
"AsyncGetVarNoBarrier get variable from pserver immediately")
.SetDefault(true);
AddAttr<std::vector<std::string>>(
"varnames",
"(string vector, default {}) "
"sometimes we need to put received var in another name "
"for example: we need var named 'moment_1@127.0.0.1:1001', "
"and it real name on parameter server is 'moment_1'. ")
.SetDefault({});
} }
}; };
class RecvOpShapeInference : public framework::InferShapeBase { class RecvOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override {} void operator()(framework::InferShapeContext *ctx) const override {}
}; };
} // namespace operators } // namespace operators
......
...@@ -365,7 +365,7 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -365,7 +365,7 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
mem_fmt.ndims = axis.size(); mem_fmt.ndims = axis.size();
for (unsigned int i = 0; i < nchw_tz.size(); ++i) { for (unsigned int i = 0; i < nchw_tz.size(); ++i) {
mem_fmt.dims[i] = nchw_tz[i]; // logical dimensions (nchw format, mem_fmt.dims[i] = nchw_tz[i]; // logical dimensions (nchw format,
// regardless physical layout) // regardless physical layout)
} }
mem_fmt.data_type = mkldnn_f32; mem_fmt.data_type = mkldnn_f32;
mem_fmt.format = mkldnn_blocked; mem_fmt.format = mkldnn_blocked;
...@@ -374,7 +374,7 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -374,7 +374,7 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
for (int i = nchw_tz.size() - 1; i >= 0; --i) { for (int i = nchw_tz.size() - 1; i >= 0; --i) {
mem_fmt.layout_desc.blocking.padding_dims[i] = mem_fmt.layout_desc.blocking.padding_dims[i] =
nchw_tz[i]; // logical dimensions (nchw format, regardless physical nchw_tz[i]; // logical dimensions (nchw format, regardless physical
// layout) // layout)
mem_fmt.layout_desc.blocking.block_dims[i] = 1; mem_fmt.layout_desc.blocking.block_dims[i] = 1;
mem_fmt.layout_desc.blocking.offset_padding_to_data[i] = 0; // no offset mem_fmt.layout_desc.blocking.offset_padding_to_data[i] = 0; // no offset
mem_fmt.layout_desc.blocking.strides[0][axis[i]] = total_stride; mem_fmt.layout_desc.blocking.strides[0][axis[i]] = total_stride;
......
...@@ -1696,12 +1696,20 @@ class Program(object): ...@@ -1696,12 +1696,20 @@ class Program(object):
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
self._op_role_var = [] self._op_role_var = []
# for distribute # for distribute training
# _is_distributed = True if under distributed training
self._is_distributed = False self._is_distributed = False
# _is_chief = True if the trainer is the first one, usually No.0
self._is_chief = False self._is_chief = False
self._slice_vars_and_attrs = [] # _parameters_on_pservers records all the parameters distributed on parameter servers.
self._parameters_on_pservers = None
# _endpoints is a list about parameter servers ip:port, such as ["ip:port","ip:port"]
self._endpoints = [] self._endpoints = []
# if current role is parameter server, the _ps_endpoint is its "ip:port"
self._ps_endpoint = None
# trainers_endpoints, it is used for distribution.
self._trainers_endpoints = [] self._trainers_endpoints = []
# the distributed lookup table names
self._distributed_lookup_table = None self._distributed_lookup_table = None
@property @property
...@@ -2232,8 +2240,9 @@ class Program(object): ...@@ -2232,8 +2240,9 @@ class Program(object):
"Program") "Program")
self._is_distributed = other._is_distributed self._is_distributed = other._is_distributed
self._is_chief = other._is_chief self._is_chief = other._is_chief
self._slice_vars_and_attrs = other._slice_vars_and_attrs self._parameters_on_pservers = other._parameters_on_pservers
self._endpoints = other._endpoints self._endpoints = other._endpoints
self._ps_endpoint = other._ps_endpoint
self._distributed_lookup_table = other._distributed_lookup_table self._distributed_lookup_table = other._distributed_lookup_table
def _copy_data_info_from(self, other): def _copy_data_info_from(self, other):
......
...@@ -19,6 +19,7 @@ import errno ...@@ -19,6 +19,7 @@ import errno
import time import time
import shutil import shutil
import six import six
from functools import reduce
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.evaluator import Evaluator from paddle.fluid.evaluator import Evaluator
...@@ -183,8 +184,6 @@ def save_vars(executor, ...@@ -183,8 +184,6 @@ def save_vars(executor,
# NOTE: don't save the variable which type is RAW # NOTE: don't save the variable which type is RAW
if each_var.type == core.VarDesc.VarType.RAW: if each_var.type == core.VarDesc.VarType.RAW:
continue continue
if each_var.name == main_program._distributed_lookup_table:
continue
new_var = _clone_var_in_block_(save_block, each_var) new_var = _clone_var_in_block_(save_block, each_var)
if filename is None: if filename is None:
save_block.append_op( save_block.append_op(
...@@ -206,16 +205,6 @@ def save_vars(executor, ...@@ -206,16 +205,6 @@ def save_vars(executor,
outputs={}, outputs={},
attrs={'file_path': os.path.join(dirname, filename)}) attrs={'file_path': os.path.join(dirname, filename)})
# if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
attrs = {}
attrs['epmap'] = main_program._endpoints
attrs['dir'] = lookup_table_filename
attrs['lookup_table'] = main_program._distributed_lookup_table
save_block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(save_program) executor.run(save_program)
...@@ -267,6 +256,186 @@ def save_params(executor, dirname, main_program=None, filename=None): ...@@ -267,6 +256,186 @@ def save_params(executor, dirname, main_program=None, filename=None):
filename=filename) filename=filename)
def _save_distributed_persistables(executor, dirname, main_program):
"""
save_persistables for distributed training.
the method will do things listed below:
1.save part of persistable variables on trainer.
2.receive "remote prefetch variables" from parameter servers and merge them.
3.save "distributed lookup table" on parameter servers.
4.receive "optimizer variables" from parameter servers and merge them.
Args:
executor(Executor): The executor to run for saving parameters.
dirname(str): The saving directory path.
main_program(Program): The program whose parameters will be
saved. the main_program must be the trainer_program
get after transpiler.
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
t = distribute_transpiler.DistributeTranspiler()
t.transpile(...)
train_program = t.get_trainer_program()
_save_distributed_persistables(executor=exe, dirname=param_path, main_program=train_program)
"""
def __save_remote_params(executor, dirname, remote_params_map):
"""
recive params on pserver through rpc.
if the params are be sliced, will concat them to one, then save it.
"""
if not remote_params_map:
return
prog = Program()
block = prog.global_block()
# recv optimize vars from pserver
for name, remote_params in remote_params_map.items():
origin_var = None
is_slice = False
slice_vars = [0] * len(remote_params)
slice_var_names = [""] * len(remote_params)
endpoints = [""] * len(remote_params)
for idx, optimizer in enumerate(remote_params):
origin = optimizer.origin
slice = optimizer.slice
is_slice = optimizer.is_slice
block_id = optimizer.block_id
endpoint = optimizer.endpoint
if idx == 0:
origin_var = block.create_var(
name=origin.name,
type=origin.type,
shape=origin.shape,
dtype=origin.dtype,
persistable=True)
slice_var = block.create_var(
name="{}.slice.{}".format(slice.name, idx),
type=slice.type,
shape=slice.shape,
dtype=slice.dtype,
persistable=True)
index = block_id if is_slice else idx
slice_vars[index] = slice_var
slice_var_names[index] = slice.name
endpoints[index] = endpoint
if is_slice:
block.append_op(
type='recv',
inputs={"X": []},
outputs={"Out": slice_vars},
attrs={
"epmap": endpoints,
"with_barrier": False,
"varnames": slice_var_names,
"sync_mode": True
})
block.append_op(
type='concat',
inputs={'X': slice_vars},
outputs={'Out': origin_var},
attrs={})
else:
block.append_op(
type='recv',
inputs={"X": []},
outputs={"Out": [origin_var]},
attrs={
"epmap": endpoints[:1],
"with_barrier": False,
"varnames": slice_var_names,
"sync_mode": True
})
block.append_op(
type='save',
inputs={'X': [origin_var]},
outputs={},
attrs={'file_path': os.path.join(dirname, origin_var.name)})
block.append_op(type='delete_var', inputs={'X': slice_vars})
executor.run(prog)
def __save_distributed_lookup_tables(executor, dirname,
distributed_lookup_table, endpoints):
"""
because the distributed lookup table may too huge to merge and save at one place,
it will be saved at parameter server independent respectively.
the save directory is dirname/"__lookup_table__".
"""
prog = Program()
block = prog.global_block()
# if there is lookup table, the trainer 0 will notify all pserver to save.
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
attrs = {}
attrs['epmap'] = endpoints
attrs['dir'] = lookup_table_filename
attrs['lookup_table'] = distributed_lookup_table
block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(prog)
def __exclude_vars(exclude_var_names=[]):
def is_valid(var):
if var.name in exclude_var_names:
return False
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.READER:
return False
return var.persistable
return is_valid
if not isinstance(main_program, Program):
raise ValueError("'main_program' should be an instance of Program.")
if not main_program._is_distributed:
raise ValueError(
"'_save_distributed_persistables' just be designed for distributed training."
)
remote_params_map = main_program._parameters_on_pservers.get_distributed_vars_by_vtypes(
["Optimizer", "RemotePrefetch"], groupby=True)
exclude_var_names = []
if remote_params_map:
exclude_var_names.extend(remote_params_map.keys())
if main_program._distributed_lookup_table:
if isinstance(main_program._distributed_lookup_table, list):
exclude_var_names.extend(main_program._distributed_lookup_table)
else:
exclude_var_names.append(main_program._distributed_lookup_table)
local_vars = list(
filter(__exclude_vars(exclude_var_names), main_program.list_vars()))
save_vars(
executor, main_program=main_program, dirname=dirname, vars=local_vars)
if main_program._is_chief:
if remote_params_map:
__save_remote_params(executor, dirname, remote_params_map)
if main_program._distributed_lookup_table:
__save_distributed_lookup_tables(
executor, dirname, main_program._distributed_lookup_table,
main_program._endpoints)
def save_persistables(executor, dirname, main_program=None, filename=None): def save_persistables(executor, dirname, main_program=None, filename=None):
""" """
This function filters out all variables with `persistable==True` from the This function filters out all variables with `persistable==True` from the
...@@ -301,13 +470,19 @@ def save_persistables(executor, dirname, main_program=None, filename=None): ...@@ -301,13 +470,19 @@ def save_persistables(executor, dirname, main_program=None, filename=None):
fluid.io.save_persistables(executor=exe, dirname=param_path, fluid.io.save_persistables(executor=exe, dirname=param_path,
main_program=None) main_program=None)
""" """
save_vars(
executor, if main_program and main_program._is_distributed:
dirname=dirname, _save_distributed_persistables(
main_program=main_program, executor, dirname=dirname, main_program=main_program)
vars=None,
predicate=is_persistable, else:
filename=filename) save_vars(
executor,
dirname=dirname,
main_program=main_program,
vars=None,
predicate=is_persistable,
filename=filename)
def load_vars(executor, def load_vars(executor,
...@@ -402,17 +577,11 @@ def load_vars(executor, ...@@ -402,17 +577,11 @@ def load_vars(executor,
if not isinstance(main_program, Program): if not isinstance(main_program, Program):
raise TypeError("program should be as Program type or None") raise TypeError("program should be as Program type or None")
load_slice_vars = []
for each_var in main_program._slice_vars_and_attrs:
load_slice_vars.append(each_var[2].name)
load_var_map = {} load_var_map = {}
for each_var in vars: for each_var in vars:
assert isinstance(each_var, Variable) assert isinstance(each_var, Variable)
if each_var.type == core.VarDesc.VarType.RAW: if each_var.type == core.VarDesc.VarType.RAW:
continue continue
if each_var.name in load_slice_vars:
continue
new_var = _clone_var_in_block_(load_block, each_var) new_var = _clone_var_in_block_(load_block, each_var)
if filename is None: if filename is None:
load_block.append_op( load_block.append_op(
...@@ -435,10 +604,6 @@ def load_vars(executor, ...@@ -435,10 +604,6 @@ def load_vars(executor,
attrs={'file_path': os.path.join(dirname, filename)}) attrs={'file_path': os.path.join(dirname, filename)})
executor.run(load_prog) executor.run(load_prog)
# load slice vars on pserver, if have it.
_load_slice_up_vars(executor, dirname,
main_program._slice_vars_and_attrs)
def load_params(executor, dirname, main_program=None, filename=None): def load_params(executor, dirname, main_program=None, filename=None):
""" """
...@@ -521,12 +686,134 @@ def load_persistables(executor, dirname, main_program=None, filename=None): ...@@ -521,12 +686,134 @@ def load_persistables(executor, dirname, main_program=None, filename=None):
fluid.io.load_persistables(executor=exe, dirname=param_path, fluid.io.load_persistables(executor=exe, dirname=param_path,
main_program=None) main_program=None)
""" """
load_vars(
executor, if main_program and main_program._is_distributed:
dirname=dirname, _load_distributed_persistables(
main_program=main_program, executor, dirname=dirname, main_program=main_program)
predicate=is_persistable, else:
filename=filename) load_vars(
executor,
dirname=dirname,
main_program=main_program,
predicate=is_persistable,
filename=filename)
def _load_distributed_persistables(executor, dirname, main_program=None):
"""
customized load_persistables for distributed training.
it should be used on parameter server,
Args:
executor(Executor): The executor to run for saving parameters.
dirname(str): The load directory path.
main_program(Program): The program whose parameters will be
loaded. the main_program must be the pserver_program
get after transpiler.
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
t = distribute_transpiler.DistributeTranspiler()
t.transpile(...)
pserver_prog = t.get_pserver_program(...)
_load_distributed_persistables(executor=exe, dirname=param_path, main_program=pserver_prog)
"""
def __is_distributed_part_var(varname):
trainer_idx = varname.find(".trainer_")
block_idx = varname.find(".block")
return trainer_idx or block_idx
def __load_persistable_vars(executor, dirname, need_load_vars):
load_prog = Program()
load_block = load_prog.global_block()
need_delete_vars = []
for param in need_load_vars:
origin_var = param.origin
slice_var = param.slice
is_slice = param.is_slice
offset = param.offset
if is_slice:
origin = load_block.create_var(
name="{}.load".format(origin_var.name),
type=origin_var.type,
shape=origin_var.shape,
dtype=origin_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [origin]},
attrs={
'file_path': os.path.join(dirname, origin_var.name)
})
slice = load_block.create_var(
name=slice_var.name,
type=slice_var.type,
shape=slice_var.shape,
dtype=slice_var.dtype,
persistable=True)
dim1_flatten = reduce(lambda x, y: x * y, slice.shape[1:])
start = int(offset / dim1_flatten)
end = int(offset / dim1_flatten + slice.shape[0])
load_block.append_op(
type="slice",
inputs={'Input': origin},
outputs={'Out': slice},
attrs={'axes': [0],
'starts': [start],
'ends': [end]})
need_delete_vars.append(origin)
else:
origin = load_block.create_var(
name="{}".format(origin_var.name),
type=origin_var.type,
shape=origin_var.shape,
dtype=origin_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [origin]},
attrs={
'file_path': os.path.join(dirname, origin_var.name)
})
load_block.append_op(
type='delete_var',
inputs={'X': need_delete_vars}, )
executor.run(load_prog)
if not isinstance(main_program, Program):
raise ValueError("'main_program' should be an instance of Program.")
if not main_program._is_distributed:
raise ValueError(
"'_load_distributed_persistables' just be designed for distributed training."
)
if not main_program._ps_endpoint:
raise ValueError(
"'_load_distributed_persistables' need current_endpoint set in DistributeTranspiler.transpile"
)
need_load_vars = main_program._parameters_on_pservers.get_distributed_vars_by_ep(
main_program._ps_endpoint)
__load_persistable_vars(executor, dirname, need_load_vars)
def prepend_feed_ops(inference_program, def prepend_feed_ops(inference_program,
...@@ -795,52 +1082,6 @@ def load_inference_model(dirname, ...@@ -795,52 +1082,6 @@ def load_inference_model(dirname,
return [program, feed_target_names, fetch_targets] return [program, feed_target_names, fetch_targets]
def _save_lookup_tables_by_notify(executor, dirname, lookup_table,
pserver_endpoints):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
_save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
pserver_endpoints=ps_endpoints)
"""
pserver_notify_program = Program()
pserver_notify_block = pserver_notify_program.global_block()
attrs = {}
attrs['epmap'] = pserver_endpoints
attrs['dir'] = dirname
attrs['lookup_table'] = lookup_table
pserver_notify_block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(pserver_notify_program)
def _endpoints_replacement(program, endpoints): def _endpoints_replacement(program, endpoints):
ENDPOINT_MAP = "epmap" ENDPOINT_MAP = "epmap"
for op in program.global_block().ops: for op in program.global_block().ops:
...@@ -911,54 +1152,3 @@ def get_parameter_value_by_name(name, executor, program=None): ...@@ -911,54 +1152,3 @@ def get_parameter_value_by_name(name, executor, program=None):
program = default_main_program() program = default_main_program()
var = program.global_block().var(name) var = program.global_block().var(name)
return get_parameter_value(var, executor) return get_parameter_value(var, executor)
def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs):
if not slice_vars_and_attrs:
return
load_prog = Program()
load_block = load_prog.global_block()
need_delete_vars = []
for var_tuple in slice_vars_and_attrs:
orig_var = var_tuple[0]
start = var_tuple[1]
slice_var = var_tuple[2]
end = start + slice_var.shape[0]
orig_var_name = orig_var.name
orig_var.name = "{}.origin".format(orig_var_name)
clone_orig_var = load_block.create_var(
name=orig_var.name,
type=orig_var.type,
shape=orig_var.shape,
dtype=orig_var.dtype,
persistable=True)
clone_slice_var = load_block.create_var(
name=slice_var.name,
type=slice_var.type,
shape=slice_var.shape,
dtype=slice_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [clone_orig_var]},
attrs={'file_path': os.path.join(dirname, orig_var_name)})
load_block.append_op(
type="slice",
inputs={'Input': clone_orig_var},
outputs={'Out': clone_slice_var},
attrs={'axes': [0],
'starts': [start],
'ends': [end]})
need_delete_vars.append(clone_orig_var)
load_block.append_op(
type='delete_var',
inputs={'X': need_delete_vars}, )
executor.run(load_prog)
...@@ -80,7 +80,8 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2): ...@@ -80,7 +80,8 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
# NOTE: pserver should not call memory optimize # NOTE: pserver should not call memory optimize
t = self.get_transpiler(args.trainer_id, t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(), args.endpoints, fluid.default_main_program(), args.endpoints,
args.trainers, args.sync_mode) args.trainers, args.sync_mode, False,
args.current_endpoint)
pserver_prog = t.get_pserver_program(args.current_endpoint) pserver_prog = t.get_pserver_program(args.current_endpoint)
startup_prog = t.get_startup_program(args.current_endpoint, startup_prog = t.get_startup_program(args.current_endpoint,
pserver_prog) pserver_prog)
...@@ -93,7 +94,8 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2): ...@@ -93,7 +94,8 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
exe.run(startup_prog) exe.run(startup_prog)
if need_load and model_dir: if need_load and model_dir:
self._load_persistable_vars(exe, model_dir, startup_prog) fluid.io.load_persistables(exe, model_dir, pserver_prog)
exe.run(pserver_prog) exe.run(pserver_prog)
def run_trainer(self, args): def run_trainer(self, args):
...@@ -158,19 +160,46 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2): ...@@ -158,19 +160,46 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
need_save = bool(int(os.getenv("SAVE", "0"))) need_save = bool(int(os.getenv("SAVE", "0")))
model_dir = os.getenv("MODEL_DIR", "") model_dir = os.getenv("MODEL_DIR", "")
save_mode = os.getenv("SAVE_MODE", "")
if need_save:
for _ in six.moves.xrange(RUN_STEP): if save_mode == "LOCAL":
loss, = exe.run(fetch_list=[avg_cost.name], if need_save:
feed=feeder.feed(get_data())) for _ in six.moves.xrange(RUN_STEP):
if need_save and model_dir: loss, = exe.run(fetch_list=[avg_cost.name],
io.save_persistables(startup_exe, model_dir, trainer_prog) feed=feeder.feed(get_data()))
if need_save and model_dir:
var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor()) io.save_persistables(startup_exe, model_dir, trainer_prog)
if six.PY2:
print(pickle.dumps(np.ravel(var).tolist())) var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor(
))
if six.PY2:
print(pickle.dumps(np.ravel(var).tolist()))
else:
sys.stdout.buffer.write(pickle.dumps(np.ravel(var).tolist()))
elif save_mode == "DIST":
skip_steps = int(os.getenv("SKIP_STEPS"))
loss = None
if need_save:
for idx in six.moves.xrange(8):
loss, = exe.run(fetch_list=[avg_cost.name],
feed=feeder.feed(get_data()))
if need_save and model_dir and idx == skip_steps and args.trainer_id == 0:
io.save_persistables(startup_exe, model_dir,
trainer_prog)
else:
for idx in six.moves.xrange(8):
data = get_data()
if idx <= skip_steps:
continue
loss, = exe.run(fetch_list=[avg_cost.name],
feed=feeder.feed(data))
if six.PY2:
print(pickle.dumps(loss.tolist()))
else:
sys.stdout.buffer.write(pickle.dumps(loss.tolist()))
else: else:
sys.stdout.buffer.write(pickle.dumps(np.ravel(var).tolist())) raise Exception("save_mode must be LOCAL or DIST")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -75,9 +75,13 @@ def get_loss(cos_q_pt, cos_q_nt): ...@@ -75,9 +75,13 @@ def get_loss(cos_q_pt, cos_q_nt):
return avg_cost return avg_cost
def get_optimizer(): def get_optimizer(op="sgd"):
# SGD optimizer if op.upper() == "sgd".upper():
optimizer = fluid.optimizer.SGD(learning_rate=base_lr) optimizer = fluid.optimizer.SGD(learning_rate=base_lr)
elif op.upper() == "adam".upper():
optimizer = fluid.optimizer.Adam(learning_rate=base_lr)
else:
optimizer = fluid.optimizer.SGD(learning_rate=base_lr)
return optimizer return optimizer
...@@ -237,7 +241,8 @@ class TestDistSimnetBow2x2(TestDistRunnerBase): ...@@ -237,7 +241,8 @@ class TestDistSimnetBow2x2(TestDistRunnerBase):
inference_program = fluid.default_main_program().clone() inference_program = fluid.default_main_program().clone()
# Optimization # Optimization
opt = get_optimizer() opt = os.getenv('OPTIMIZER', 'sgd')
opt = get_optimizer(opt)
opt.minimize(avg_cost) opt.minimize(avg_cost)
# Reader # Reader
......
...@@ -43,7 +43,8 @@ class TestDistRunnerBase(object): ...@@ -43,7 +43,8 @@ class TestDistRunnerBase(object):
pserver_endpoints, pserver_endpoints,
trainers, trainers,
sync_mode, sync_mode,
dc_asgd=False): dc_asgd=False,
current_endpoint=None):
# NOTE: import fluid until runtime, or else forking processes will cause error. # NOTE: import fluid until runtime, or else forking processes will cause error.
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.enable_dc_asgd = dc_asgd config.enable_dc_asgd = dc_asgd
...@@ -53,7 +54,8 @@ class TestDistRunnerBase(object): ...@@ -53,7 +54,8 @@ class TestDistRunnerBase(object):
program=main_program, program=main_program,
pservers=pserver_endpoints, pservers=pserver_endpoints,
trainers=trainers, trainers=trainers,
sync_mode=sync_mode) sync_mode=sync_mode,
current_endpoint=current_endpoint)
return t return t
def run_pserver(self, args): def run_pserver(self, args):
......
...@@ -33,7 +33,6 @@ class TestDistSaveLoadDense2x2(TestDistBase): ...@@ -33,7 +33,6 @@ class TestDistSaveLoadDense2x2(TestDistBase):
delta=1e-3, delta=1e-3,
check_error_log=False, check_error_log=False,
need_envs={}): need_envs={}):
required_envs = { required_envs = {
"PATH": os.getenv("PATH", ""), "PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""), "PYTHONPATH": os.getenv("PYTHONPATH", ""),
...@@ -77,7 +76,77 @@ class TestDistSaveLoadDense2x2(TestDistBase): ...@@ -77,7 +76,77 @@ class TestDistSaveLoadDense2x2(TestDistBase):
need_envs = { need_envs = {
"IS_DISTRIBUTED": '0', "IS_DISTRIBUTED": '0',
"IS_SPARSE": '0', "IS_SPARSE": '0',
'IS_SELF_CONTAINED_LR': '1' 'IS_SELF_CONTAINED_LR': '1',
'SAVE_MODE': 'LOCAL',
}
self.check_with_place(
"dist_save_load.py",
delta=0,
check_error_log=False,
need_envs=need_envs)
class TestDistSaveLoadWithPServerStateDense2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"http_proxy": ""
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
model_dir = tempfile.mkdtemp()
save_env = {}
save_env["SAVE_MODE"] = "DIST"
save_env["SAVE"] = "1"
save_env["MODEL_DIR"] = model_dir
save_env.update(required_envs)
tr0_var_1, tr1_var_1 = self._run_cluster(model_file, save_env,
check_error_log)
load_env = {}
load_env["LOAD"] = "1"
load_env["MODEL_DIR"] = model_dir
load_env.update(required_envs)
tr0_var_2, tr1_var_2 = self._run_cluster(model_file, load_env,
check_error_log)
shutil.rmtree(model_dir)
train0_1_np = np.array(tr0_var_1)
train1_1_np = np.array(tr1_var_1)
train0_2_np = np.array(tr0_var_2)
train1_2_np = np.array(tr1_var_2)
self.assertAlmostEqual(
train0_1_np.all(), train0_2_np.all(), delta=delta)
self.assertAlmostEqual(
train1_1_np.all(), train1_2_np.all(), delta=delta)
def test_dist(self):
need_envs = {
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '0',
'IS_SELF_CONTAINED_LR': '1',
'SAVE_MODE': 'DIST',
'OPTIMIZER': 'ADAM',
'SKIP_STEPS': str(np.random.randint(2, 6))
} }
self.check_with_place( self.check_with_place(
"dist_save_load.py", "dist_save_load.py",
......
...@@ -741,21 +741,40 @@ class TestLoadSliceVar(TranspilerTest): ...@@ -741,21 +741,40 @@ class TestLoadSliceVar(TranspilerTest):
pserver, _ = self.get_pserver(self.pserver1_ep) pserver, _ = self.get_pserver(self.pserver1_ep)
pserver2, _ = self.get_pserver(self.pserver2_ep) pserver2, _ = self.get_pserver(self.pserver2_ep)
self.assertTrue(pserver._slice_vars_and_attrs) vars_ps1 = pserver._parameters_on_pservers.get_distributed_vars_by_ep(
self.assertTrue(pserver2._slice_vars_and_attrs) self.pserver1_ep)
vars_ps2 = pserver._parameters_on_pservers.get_distributed_vars_by_ep(
for idx in six.moves.xrange(len(pserver._slice_vars_and_attrs)): self.pserver2_ep)
self.assertEqual(pserver._slice_vars_and_attrs[idx][0],
pserver2._slice_vars_and_attrs[idx][0]) self.assertTrue(vars_ps1)
self.assertTrue(vars_ps2)
total_numel = six.moves.reduce(
lambda x, y: x * y, pserver._slice_vars_and_attrs[idx][0].shape) for idx in six.moves.xrange(len(vars_ps1)):
self.assertEqual( total_numel = 0
total_numel, ps1_numel, ps2_numel = 0, 0
six.moves.reduce(lambda x, y: x * y,
pserver._slice_vars_and_attrs[idx][2].shape) + ps1_var = vars_ps1[idx]
six.moves.reduce(lambda x, y: x * y,
pserver2._slice_vars_and_attrs[idx][2].shape)) if not ps1_var.is_slice:
total_numel = six.moves.reduce(lambda x, y: x * y,
vars_ps1[idx].origin.shape)
ps1_numel = six.moves.reduce(lambda x, y: x * y,
vars_ps1[idx].slice.shape)
else:
ps2_var = None
for var in vars_ps2:
if var.origin.name == ps1_var.origin.name:
ps2_var = var
break
total_numel = six.moves.reduce(lambda x, y: x * y,
ps1_var.origin.shape)
ps1_numel = six.moves.reduce(lambda x, y: x * y,
ps1_var.slice.shape)
ps2_numel = six.moves.reduce(lambda x, y: x * y,
ps2_var.slice.shape)
self.assertEqual(total_numel, ps1_numel + ps2_numel)
class TestNCCL2Transpile(TranspilerTest): class TestNCCL2Transpile(TranspilerTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册