From 8b50ad80ff6934512d3959947ac1e71ea3fb9ea3 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 23 Jan 2019 15:13:22 +0800 Subject: [PATCH] checkpoint at distributed training (#14854) checkpoint for distributed training. --- .../operators/distributed/grpc/grpc_client.cc | 89 ++-- .../operators/distributed/grpc/grpc_client.h | 17 +- .../operators/distributed/grpc/grpc_server.cc | 59 ++- .../operators/distributed/grpc/grpc_service.h | 3 + .../operators/distributed/request_handler.h | 13 + .../distributed/request_handler_impl.cc | 30 +- .../distributed/request_handler_impl.h | 10 + .../fluid/operators/distributed/rpc_client.h | 7 + .../operators/distributed/send_recv.proto.in | 18 + .../distributed_ops/listen_and_serv_op.cc | 5 + .../distributed_ops/listen_and_serv_op.h | 3 +- .../operators/distributed_ops/recv_op.cc | 63 ++- paddle/fluid/platform/mkldnn_reuse.h | 4 +- python/paddle/fluid/framework.py | 15 +- python/paddle/fluid/io.py | 454 +++++++++++++----- .../fluid/tests/unittests/dist_save_load.py | 57 ++- .../fluid/tests/unittests/dist_simnet_bow.py | 13 +- .../fluid/tests/unittests/test_dist_base.py | 6 +- .../tests/unittests/test_dist_save_load.py | 73 ++- .../tests/unittests/test_dist_transpiler.py | 49 +- .../fluid/transpiler/distribute_transpiler.py | 414 ++++++++++++++-- 21 files changed, 1122 insertions(+), 280 deletions(-) diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.cc b/paddle/fluid/operators/distributed/grpc/grpc_client.cc index 7875c16c3cf..52310f8d04d 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.cc @@ -74,7 +74,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep, const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); SendProcessor* s = new SendProcessor(ch); - const std::string method = "SendRPC"; + const std::string method = kSendRPC; VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope)); s->Prepare(h, time_out); @@ -107,7 +107,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep, void ProcGetResponse(const VarHandle& var_h, const ::grpc::ByteBuffer& ret_msg) { - VLOG(100) << "ProcGetResponse"; + VLOG(4) << "ProcGetResponse"; framework::Variable* outvar = nullptr; // get response's trainer_id is not used int trainer_id; @@ -127,59 +127,74 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, + const std::string& out_varname, int64_t time_out) { - return _AsyncGetVar(ep, ctx, scope, var_name, + return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname, "/sendrecv.SendRecvService/GetVariable", time_out); } +VarHandlePtr GRPCClient::AsyncGetVarNoBarrier( + const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& var_name, + const std::string& out_varname, int64_t time_out) { + std::string var_name_no_barrier = + string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE); + + return _AsyncGetVar( + ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname, + "/sendrecv.SendRecvService/GetVariableNoBarrier", time_out); +} + VarHandlePtr GRPCClient::AsyncGetMonomerVariable( const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, int64_t time_out) { - return _AsyncGetVar(ep, ctx, scope, var_name, + return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name, "/sendrecv.SendRecvService/GetMonomerVariable", time_out); } -VarHandlePtr GRPCClient::_AsyncGetVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, - const std::string& rpc_path, - int64_t time_out) { +VarHandlePtr GRPCClient::_AsyncGetVar( + const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& method, + const std::string& var_name, const std::string& out_varname, + const std::string& rpc_path, int64_t time_out) { const platform::DeviceContext* p_ctx = &ctx; const std::string ep_val = ep; const std::string var_name_val = var_name; + const std::string out_varname_val = out_varname; const framework::Scope* p_scope = &scope; const auto ch = GetChannel(ep_val); GetProcessor* s = new GetProcessor(ch); - const std::string method = "GetRPC"; - VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope)); + + VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope)); s->Prepare(h, time_out); - framework::AsyncIO([var_name_val, s, method, p_ctx, h, rpc_path, this] { - // prepare input - sendrecv::VariableMessage req; - req.set_varname(var_name_val); - req.set_trainer_id(trainer_id_); - ::grpc::ByteBuffer buf; - RequestToByteBuffer(req, &buf); + framework::AsyncIO( + [var_name_val, out_varname_val, s, method, p_ctx, h, rpc_path, this] { + // prepare input + sendrecv::VariableMessage req; + req.set_varname(var_name_val); + req.set_out_varname(out_varname_val); + req.set_trainer_id(trainer_id_); + ::grpc::ByteBuffer buf; + RequestToByteBuffer(req, &buf); - VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; + VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; - // stub context - s->response_call_back_ = ProcGetResponse; + // stub context + s->response_call_back_ = ProcGetResponse; - platform::RecordRPCEvent record_event(method, p_ctx); + platform::RecordRPCEvent record_event(method, p_ctx); - auto call = - s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_); - call->StartCall(); - call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + auto call = + s->stub_g_.PrepareUnaryCall(s->context_.get(), rpc_path, buf, &cq_); + call->StartCall(); + call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - }); + if (UNLIKELY(platform::IsProfileEnabled())) { + h->Wait(); + } + }); req_count_++; @@ -202,7 +217,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, const auto ch = GetChannel(ep_val); GetProcessor* s = new GetProcessor(ch); - const std::string method = "PrefetchRPC"; + const std::string method = kPrefetchRPC; VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope)); s->Prepare(h, time_out); @@ -242,7 +257,7 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep, const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); - const std::string method = "BatchBarrierRPC"; + const std::string method = kBatchBarrierRPC; VarHandlePtr h( new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr)); s->Prepare(h, time_out); @@ -267,7 +282,7 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { const auto ch = GetChannel(ep); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); - const std::string method = "FetchBarrierRPC"; + const std::string method = kFetchBarrierRPC; VarHandlePtr h( new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr)); s->Prepare(h, time_out); @@ -293,7 +308,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep, int64_t time_out) { const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); - const std::string method = "SendMonomerFetchBarrierRPC"; + const std::string method = kSendMonomerFetchBarrierRPC; VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr)); s->Prepare(h, time_out); @@ -320,7 +335,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); - const std::string method = "SendCompleteRPC"; + const std::string method = kSendCompleteRPC; VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr)); s->Prepare(h, time_out); @@ -347,7 +362,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch); - const std::string method = "CheckPointNotifyRPC"; + const std::string method = kCheckPointNotifyRPC; VarHandlePtr h( new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr)); diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.h b/paddle/fluid/operators/distributed/grpc/grpc_client.h index fa77d212576..ce0d2152aa2 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.h @@ -186,8 +186,15 @@ class GRPCClient : public RPCClient { const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, + const std::string& out_varname, int64_t time_out = FLAGS_rpc_deadline) override; + VarHandlePtr AsyncGetVarNoBarrier( + const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& var_name, + const std::string& out_varname, + int64_t time_out = FLAGS_rpc_deadline) override; + VarHandlePtr AsyncGetMonomerVariable( const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, @@ -228,11 +235,11 @@ class GRPCClient : public RPCClient { void Proceed(); std::shared_ptr GetChannel(const std::string& ep); - VarHandlePtr _AsyncGetVar(const std::string& ep, - const platform::DeviceContext& ctx, - const framework::Scope& scope, - const std::string& var_name, const std::string& rpc, - int64_t time_out); + VarHandlePtr _AsyncGetVar( + const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& method, + const std::string& var_name, const std::string& out_varname, + const std::string& rpc_path, int64_t time_out = FLAGS_rpc_deadline); private: grpc::CompletionQueue cq_; diff --git a/paddle/fluid/operators/distributed/grpc/grpc_server.cc b/paddle/fluid/operators/distributed/grpc/grpc_server.cc index 08f777e279e..4a9c158cb0a 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_server.cc @@ -136,17 +136,65 @@ class RequestGet final : public RequestBase { void Process() override { // proc request. std::string varname = request_.varname(); + std::string out_varname = request_.out_varname(); int trainer_id = request_.trainer_id(); - VLOG(4) << "RequestGet " << varname; + + VLOG(4) << "RequestGet " << out_varname << " from " << varname; auto scope = request_handler_->scope(); - auto invar = scope->FindVar(varname); + framework::Variable* invar = nullptr; framework::Variable* outvar = nullptr; - request_handler_->Handle(varname, scope, invar, &outvar, trainer_id); + request_handler_->Handle(varname, scope, invar, &outvar, trainer_id, + out_varname); if (outvar) { - SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), + SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(), + &reply_); + } + Finish(reply_, &responder_); + } + + protected: + sendrecv::VariableMessage request_; + ::grpc::ByteBuffer reply_; + ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; +}; + +class RequestGetNoBarrier final : public RequestBase { + public: + explicit RequestGetNoBarrier(GrpcService::AsyncService* service, + ::grpc::ServerCompletionQueue* cq, + RequestHandler* request_handler, int req_id) + : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { + auto method_id = + static_cast(distributed::GrpcMethod::kGetVariableNoBarrier); + service_->RequestAsyncUnary( + method_id, &ctx_, &request_, &responder_, cq_, cq_, + reinterpret_cast(static_cast(req_id))); + } + + virtual ~RequestGetNoBarrier() {} + + std::string GetReqName() override { return request_.varname(); } + + void Process() override { + // proc request. + std::string varname = request_.varname(); + std::string out_varname = request_.out_varname(); + int trainer_id = request_.trainer_id(); + + VLOG(4) << "RequestGetNoBarrier " << out_varname << " from " << varname; + + auto scope = request_handler_->scope(); + framework::Variable* invar = nullptr; + framework::Variable* outvar = nullptr; + + request_handler_->Handle(varname, scope, invar, &outvar, trainer_id, + out_varname); + + if (outvar) { + SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(), &reply_); } Finish(reply_, &responder_); @@ -460,6 +508,9 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, b = new RequestSend(&service_, cq.get(), handler, req_id); } else if (rpc_name == kRequestGet) { b = new RequestGet(&service_, cq.get(), handler, req_id); + + } else if (rpc_name == kRequestGetNoBarrier) { + b = new RequestGetNoBarrier(&service_, cq.get(), handler, req_id); } else if (rpc_name == kRequestGetMonomerVariable) { b = new RequestGetMonomerVariable(&service_, cq.get(), handler, req_id, this); diff --git a/paddle/fluid/operators/distributed/grpc/grpc_service.h b/paddle/fluid/operators/distributed/grpc/grpc_service.h index 0b5c5151e63..2965fe4490b 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_service.h +++ b/paddle/fluid/operators/distributed/grpc/grpc_service.h @@ -81,6 +81,7 @@ enum class GrpcMethod { kGetVariable, kPrefetchVariable, kCheckpointNotify, + kGetVariableNoBarrier, kGetMonomerVariable, kGetMonomerBarrier, }; @@ -94,6 +95,8 @@ inline const char* GrpcMethodName(GrpcMethod id) { return "/sendrecv.SendRecvService/SendVariable"; case GrpcMethod::kGetVariable: return "/sendrecv.SendRecvService/GetVariable"; + case GrpcMethod::kGetVariableNoBarrier: + return "/sendrecv.SendRecvService/GetVariableNoBarrier"; case GrpcMethod::kGetMonomerVariable: return "/sendrecv.SendRecvService/GetMonomerVariable"; case GrpcMethod::kGetMonomerBarrier: diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 62b24f150b4..991158ac720 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -42,11 +42,24 @@ constexpr char kRequestGetMonomerBarrier[] = "RequestGetMonomerBarrier"; constexpr char kRequestPrefetch[] = "RequestPrefetch"; constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; +constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier"; + +constexpr char kSendRPC[] = "SendRPC"; +constexpr char kGetRPC[] = "GetRPC"; +constexpr char kGetNoBarrierRPC[] = "GetNoBarrierRPC"; +constexpr char kGetMonomerRPC[] = "GetMonomerRPC"; +constexpr char kPrefetchRPC[] = "PrefetchRPC"; +constexpr char kBatchBarrierRPC[] = "BatchBarrierRPC"; +constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC"; +constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC"; +constexpr char kSendCompleteRPC[] = "SendCompleteRPC"; +constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC"; #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV" +#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV" #define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" #define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 9722f8c96e9..913ae76b38d 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -23,6 +23,7 @@ #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/operators/distributed/rpc_server.h" +#include "paddle/fluid/string/piece.h" #include "paddle/fluid/string/printf.h" namespace paddle { @@ -81,7 +82,8 @@ bool RequestGetHandler::Handle(const std::string& varname, const int trainer_id, const std::string& out_var_name, const std::string& table_name) { - VLOG(4) << "RequestGetHandler:" << varname; + VLOG(4) << "RequestGetHandler:" << varname + << " out_var_name: " << out_var_name; if (sync_mode_) { if (varname == FETCH_BARRIER_MESSAGE) { @@ -112,6 +114,32 @@ bool RequestGetHandler::Handle(const std::string& varname, return true; } +bool RequestGetNoBarrierHandler::Handle(const std::string& varname, + framework::Scope* scope, + framework::Variable* invar, + framework::Variable** outvar, + const int trainer_id, + const std::string& out_var_name, + const std::string& table_name) { + VLOG(4) << "RequestGetNoBarrierHandler:" << varname + << " out_var_name: " << out_var_name; + + // get var from pserver immediately without barriers + string::Piece without_barrier_piece(WITHOUT_BARRIER_MESSAGE); + string::Piece var_name_piece = string::Piece(varname); + + if (string::Contains(var_name_piece, without_barrier_piece)) { + var_name_piece = string::TrimSuffix(var_name_piece, without_barrier_piece); + VLOG(4) << "Get var " << var_name_piece << " with " + << WITHOUT_BARRIER_MESSAGE; + *outvar = scope_->FindVar(var_name_piece.ToString()); + return true; + } else { + PADDLE_THROW("GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE); + } + return true; +} + bool RequestPrefetchHandler::Handle(const std::string& varname, framework::Scope* scope, framework::Variable* invar, diff --git a/paddle/fluid/operators/distributed/request_handler_impl.h b/paddle/fluid/operators/distributed/request_handler_impl.h index 5e0b25c5c2c..f3c1b24526b 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.h +++ b/paddle/fluid/operators/distributed/request_handler_impl.h @@ -67,6 +67,16 @@ class RequestGetHandler final : public RequestHandler { bool enable_dc_asgd_; }; +class RequestGetNoBarrierHandler final : public RequestHandler { + public: + RequestGetNoBarrierHandler() : RequestHandler(false) {} + virtual ~RequestGetNoBarrierHandler() {} + bool Handle(const std::string& varname, framework::Scope* scope, + framework::Variable* var, framework::Variable** outvar, + const int trainer_id, const std::string& out_var_name = "", + const std::string& table_name = "") override; +}; + static inline void BuildVar(const std::string& param_name, std::initializer_list arguments, paddle::framework::proto::OpDesc::Var* var) { diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index b668d869787..ea54e0c2951 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -43,8 +43,15 @@ class RPCClient { const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, + const std::string& out_varname, int64_t time_out = FLAGS_rpc_deadline) = 0; + virtual VarHandlePtr AsyncGetVarNoBarrier( + const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& var_name, + const std::string& out_varname, + int64_t time_out = FLAGS_rpc_deadline) = 0; + virtual VarHandlePtr AsyncGetMonomerVariable( const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& var_name, diff --git a/paddle/fluid/operators/distributed/send_recv.proto.in b/paddle/fluid/operators/distributed/send_recv.proto.in index b39eef04d8d..63036678843 100644 --- a/paddle/fluid/operators/distributed/send_recv.proto.in +++ b/paddle/fluid/operators/distributed/send_recv.proto.in @@ -17,8 +17,14 @@ package sendrecv; option cc_generic_services = @cc_generic_services@; service SendRecvService { + // For parameter server round-robin like hashing, do not split tensors. + // Send and recv only one tensor + // TODO(typhoonzero): add streaming API rpc SendVariable(VariableMessage) returns (VoidMessage) {} + // Argument VariableMessage for GetVariable should only contain varname. rpc GetVariable(VariableMessage) returns (VariableMessage) {} + rpc GetVariableNoBarrier(VariableMessage) returns (VariableMessage) {} + // pre-fetch variable by given variable name and Ids rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {} rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {} @@ -27,12 +33,17 @@ service SendRecvService { rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {} } +// It can be: LoDTensor态SelectedRows or NCCL_ID enum VarType { LOD_TENSOR = 0; SELECTED_ROWS = 1; NCCL_ID = 2; } +// VariableMessage is serialized paddle variable message. +// NOTICE(gongwb):don't modify this proto if you are not +// not familar with how we serialize in sendrecvop_utils.h +// and deserilize it in variable_response.h. message VariableMessage { enum Type { // Pod Types @@ -49,14 +60,21 @@ message VariableMessage { string varname = 1; // TODO(Yancey1989): reference framework::proto::VarDesc::VarType VarType type = 2; + // bool persistable is not needed for sending. + // tensor info: Type data_type = 3; repeated int64 dims = 4; + // lod details: int64 lod_level = 5; repeated LodData lod = 6; + // selected_rows height, aka. original dim0 int64 slr_height = 7; + // tensor data bytes serialized = 8; + // selected_rows data bytes rows = 9; + // Look up table block execution output variable name. string out_varname = 10; // If 1, the ps server will start profiling, the ps // server stops profiling and generates a profile to /tmp/profile_ps_* diff --git a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc index 629f364d712..53968831ea0 100644 --- a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc +++ b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc @@ -347,6 +347,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, new distributed::RequestPrefetchHandler(sync_mode)); request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler( sync_mode, checkpoint_block_id)); + request_get_no_barrier_handler_.reset( + new distributed::RequestGetNoBarrierHandler()); rpc_service_->RegisterRPC(distributed::kRequestSend, request_send_handler_.get(), @@ -359,6 +361,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, FLAGS_rpc_prefetch_thread_num); rpc_service_->RegisterRPC(distributed::kRequestCheckpoint, request_checkpoint_handler_.get()); + rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier, + request_get_no_barrier_handler_.get()); auto optimize_blocks = Attr>(kOptimizeBlocks); @@ -413,6 +417,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, f(request_get_handler_.get()); f(request_prefetch_handler_.get()); f(request_checkpoint_handler_.get()); + f(request_get_no_barrier_handler_.get()); // start the server listening after all member initialized. server_thread_.reset(new std::thread(RunServer, rpc_service_)); diff --git a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h index 9431978df83..f20442bad7c 100644 --- a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h +++ b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h @@ -55,7 +55,6 @@ class ListenAndServOp : public framework::OperatorBase { const framework::VariableNameMap& inputs, const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs); - virtual ~ListenAndServOp(); void RunSyncLoop(framework::Executor* executor, @@ -89,6 +88,8 @@ class ListenAndServOp : public framework::OperatorBase { mutable std::shared_ptr rpc_service_; mutable std::shared_ptr request_send_handler_; mutable std::shared_ptr request_get_handler_; + mutable std::shared_ptr + request_get_no_barrier_handler_; mutable std::shared_ptr request_prefetch_handler_; mutable std::shared_ptr diff --git a/paddle/fluid/operators/distributed_ops/recv_op.cc b/paddle/fluid/operators/distributed_ops/recv_op.cc index 48065437e38..120c65f2969 100644 --- a/paddle/fluid/operators/distributed_ops/recv_op.cc +++ b/paddle/fluid/operators/distributed_ops/recv_op.cc @@ -27,30 +27,50 @@ namespace operators { class RecvOp : public framework::OperatorBase { public: - RecvOp(const std::string& type, const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) + RecvOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} - void RunImpl(const framework::Scope& scope, - const platform::Place& place) const override { - auto outs = Outputs("Out"); + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { std::vector epmap = Attr>("epmap"); + std::vector varnames = + Attr>("varnames"); int sync_mode = Attr("sync_mode"); + auto outs = Outputs("Out"); + bool with_barrier = Attr("with_barrier"); - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& ctx = *pool.Get(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &ctx = *pool.Get(place); - distributed::RPCClient* rpc_client = + distributed::RPCClient *rpc_client = distributed::RPCClient::GetInstance( Attr("trainer_id")); - std::vector rets; - for (size_t i = 0; i < outs.size(); i++) { - VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; - rets.push_back(rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i])); - } - if (sync_mode) { + if (with_barrier) { + std::vector rets; + for (size_t i = 0; i < outs.size(); i++) { + std::string varname = varnames.size() == 0 ? outs[i] : varnames[i]; + VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with " + << varname << " and with AsyncGetVar"; + rets.push_back( + rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i])); + } + if (sync_mode) { + for (size_t i = 0; i < rets.size(); i++) { + PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); + } + } + } else { + std::vector rets; + for (size_t i = 0; i < outs.size(); i++) { + std::string varname = varnames.size() == 0 ? outs[i] : varnames[i]; + VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with " + << varname << " and with AsyncGetVarNoBarrier"; + rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope, + varname, outs[i])); + } for (size_t i = 0; i < rets.size(); i++) { PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); } @@ -79,12 +99,23 @@ This operator can get variables from server side. "(int, default 0)" "sync recv or async recv.") .SetDefault(0); + AddAttr("with_barrier", + "(bool, default True) if with_barrier=False, will use " + "AsyncGetVarNoBarrier get variable from pserver immediately") + .SetDefault(true); + AddAttr>( + "varnames", + "(string vector, default {}) " + "sometimes we need to put received var in another name " + "for example: we need var named 'moment_1@127.0.0.1:1001', " + "and it real name on parameter server is 'moment_1'. ") + .SetDefault({}); } }; class RecvOpShapeInference : public framework::InferShapeBase { public: - void operator()(framework::InferShapeContext* ctx) const override {} + void operator()(framework::InferShapeContext *ctx) const override {} }; } // namespace operators diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index faac6a12c66..269280d604a 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -365,7 +365,7 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { mem_fmt.ndims = axis.size(); for (unsigned int i = 0; i < nchw_tz.size(); ++i) { mem_fmt.dims[i] = nchw_tz[i]; // logical dimensions (nchw format, - // regardless physical layout) + // regardless physical layout) } mem_fmt.data_type = mkldnn_f32; mem_fmt.format = mkldnn_blocked; @@ -374,7 +374,7 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { for (int i = nchw_tz.size() - 1; i >= 0; --i) { mem_fmt.layout_desc.blocking.padding_dims[i] = nchw_tz[i]; // logical dimensions (nchw format, regardless physical - // layout) + // layout) mem_fmt.layout_desc.blocking.block_dims[i] = 1; mem_fmt.layout_desc.blocking.offset_padding_to_data[i] = 0; // no offset mem_fmt.layout_desc.blocking.strides[0][axis[i]] = total_stride; diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index fc5e471ae30..22f505854e2 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1696,12 +1696,20 @@ class Program(object): self._current_role = core.op_proto_and_checker_maker.OpRole.Forward self._op_role_var = [] - # for distribute + # for distribute training + # _is_distributed = True if under distributed training self._is_distributed = False + # _is_chief = True if the trainer is the first one, usually No.0 self._is_chief = False - self._slice_vars_and_attrs = [] + # _parameters_on_pservers records all the parameters distributed on parameter servers. + self._parameters_on_pservers = None + # _endpoints is a list about parameter servers ip:port, such as ["ip:port","ip:port"] self._endpoints = [] + # if current role is parameter server, the _ps_endpoint is its "ip:port" + self._ps_endpoint = None + # trainers_endpoints, it is used for distribution. self._trainers_endpoints = [] + # the distributed lookup table names self._distributed_lookup_table = None @property @@ -2232,8 +2240,9 @@ class Program(object): "Program") self._is_distributed = other._is_distributed self._is_chief = other._is_chief - self._slice_vars_and_attrs = other._slice_vars_and_attrs + self._parameters_on_pservers = other._parameters_on_pservers self._endpoints = other._endpoints + self._ps_endpoint = other._ps_endpoint self._distributed_lookup_table = other._distributed_lookup_table def _copy_data_info_from(self, other): diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index e74a87fc68d..6b1d4cc34f3 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -19,6 +19,7 @@ import errno import time import shutil import six +from functools import reduce from paddle.fluid.executor import Executor from paddle.fluid.evaluator import Evaluator @@ -183,8 +184,6 @@ def save_vars(executor, # NOTE: don't save the variable which type is RAW if each_var.type == core.VarDesc.VarType.RAW: continue - if each_var.name == main_program._distributed_lookup_table: - continue new_var = _clone_var_in_block_(save_block, each_var) if filename is None: save_block.append_op( @@ -206,16 +205,6 @@ def save_vars(executor, outputs={}, attrs={'file_path': os.path.join(dirname, filename)}) - # if there is lookup table, the trainer 0 will notify all pserver to save. - if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table: - lookup_table_filename = os.path.join(dirname, "__lookup_table__") - attrs = {} - attrs['epmap'] = main_program._endpoints - attrs['dir'] = lookup_table_filename - attrs['lookup_table'] = main_program._distributed_lookup_table - save_block.append_op( - type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs) - executor.run(save_program) @@ -267,6 +256,186 @@ def save_params(executor, dirname, main_program=None, filename=None): filename=filename) +def _save_distributed_persistables(executor, dirname, main_program): + """ + save_persistables for distributed training. + the method will do things listed below: + 1.save part of persistable variables on trainer. + 2.receive "remote prefetch variables" from parameter servers and merge them. + 3.save "distributed lookup table" on parameter servers. + 4.receive "optimizer variables" from parameter servers and merge them. + + Args: + executor(Executor): The executor to run for saving parameters. + dirname(str): The saving directory path. + main_program(Program): The program whose parameters will be + saved. the main_program must be the trainer_program + get after transpiler. + + Returns: + None + + Examples: + .. code-block:: python + + exe = fluid.Executor(fluid.CPUPlace()) + param_path = "./my_paddle_model" + t = distribute_transpiler.DistributeTranspiler() + t.transpile(...) + train_program = t.get_trainer_program() + _save_distributed_persistables(executor=exe, dirname=param_path, main_program=train_program) + """ + + def __save_remote_params(executor, dirname, remote_params_map): + """ + recive params on pserver through rpc. + if the params are be sliced, will concat them to one, then save it. + """ + if not remote_params_map: + return + + prog = Program() + block = prog.global_block() + + # recv optimize vars from pserver + for name, remote_params in remote_params_map.items(): + origin_var = None + is_slice = False + slice_vars = [0] * len(remote_params) + slice_var_names = [""] * len(remote_params) + endpoints = [""] * len(remote_params) + + for idx, optimizer in enumerate(remote_params): + origin = optimizer.origin + slice = optimizer.slice + is_slice = optimizer.is_slice + block_id = optimizer.block_id + endpoint = optimizer.endpoint + + if idx == 0: + origin_var = block.create_var( + name=origin.name, + type=origin.type, + shape=origin.shape, + dtype=origin.dtype, + persistable=True) + + slice_var = block.create_var( + name="{}.slice.{}".format(slice.name, idx), + type=slice.type, + shape=slice.shape, + dtype=slice.dtype, + persistable=True) + + index = block_id if is_slice else idx + slice_vars[index] = slice_var + slice_var_names[index] = slice.name + endpoints[index] = endpoint + + if is_slice: + block.append_op( + type='recv', + inputs={"X": []}, + outputs={"Out": slice_vars}, + attrs={ + "epmap": endpoints, + "with_barrier": False, + "varnames": slice_var_names, + "sync_mode": True + }) + block.append_op( + type='concat', + inputs={'X': slice_vars}, + outputs={'Out': origin_var}, + attrs={}) + else: + block.append_op( + type='recv', + inputs={"X": []}, + outputs={"Out": [origin_var]}, + attrs={ + "epmap": endpoints[:1], + "with_barrier": False, + "varnames": slice_var_names, + "sync_mode": True + }) + block.append_op( + type='save', + inputs={'X': [origin_var]}, + outputs={}, + attrs={'file_path': os.path.join(dirname, origin_var.name)}) + block.append_op(type='delete_var', inputs={'X': slice_vars}) + executor.run(prog) + + def __save_distributed_lookup_tables(executor, dirname, + distributed_lookup_table, endpoints): + """ + because the distributed lookup table may too huge to merge and save at one place, + it will be saved at parameter server independent respectively. + + the save directory is dirname/"__lookup_table__". + + """ + prog = Program() + block = prog.global_block() + + # if there is lookup table, the trainer 0 will notify all pserver to save. + lookup_table_filename = os.path.join(dirname, "__lookup_table__") + attrs = {} + attrs['epmap'] = endpoints + attrs['dir'] = lookup_table_filename + attrs['lookup_table'] = distributed_lookup_table + block.append_op( + type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs) + executor.run(prog) + + def __exclude_vars(exclude_var_names=[]): + def is_valid(var): + if var.name in exclude_var_names: + return False + if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ + var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ + var.desc.type() == core.VarDesc.VarType.READER: + return False + return var.persistable + + return is_valid + + if not isinstance(main_program, Program): + raise ValueError("'main_program' should be an instance of Program.") + + if not main_program._is_distributed: + raise ValueError( + "'_save_distributed_persistables' just be designed for distributed training." + ) + + remote_params_map = main_program._parameters_on_pservers.get_distributed_vars_by_vtypes( + ["Optimizer", "RemotePrefetch"], groupby=True) + + exclude_var_names = [] + if remote_params_map: + exclude_var_names.extend(remote_params_map.keys()) + + if main_program._distributed_lookup_table: + if isinstance(main_program._distributed_lookup_table, list): + exclude_var_names.extend(main_program._distributed_lookup_table) + else: + exclude_var_names.append(main_program._distributed_lookup_table) + + local_vars = list( + filter(__exclude_vars(exclude_var_names), main_program.list_vars())) + save_vars( + executor, main_program=main_program, dirname=dirname, vars=local_vars) + + if main_program._is_chief: + if remote_params_map: + __save_remote_params(executor, dirname, remote_params_map) + if main_program._distributed_lookup_table: + __save_distributed_lookup_tables( + executor, dirname, main_program._distributed_lookup_table, + main_program._endpoints) + + def save_persistables(executor, dirname, main_program=None, filename=None): """ This function filters out all variables with `persistable==True` from the @@ -301,13 +470,19 @@ def save_persistables(executor, dirname, main_program=None, filename=None): fluid.io.save_persistables(executor=exe, dirname=param_path, main_program=None) """ - save_vars( - executor, - dirname=dirname, - main_program=main_program, - vars=None, - predicate=is_persistable, - filename=filename) + + if main_program and main_program._is_distributed: + _save_distributed_persistables( + executor, dirname=dirname, main_program=main_program) + + else: + save_vars( + executor, + dirname=dirname, + main_program=main_program, + vars=None, + predicate=is_persistable, + filename=filename) def load_vars(executor, @@ -402,17 +577,11 @@ def load_vars(executor, if not isinstance(main_program, Program): raise TypeError("program should be as Program type or None") - load_slice_vars = [] - for each_var in main_program._slice_vars_and_attrs: - load_slice_vars.append(each_var[2].name) - load_var_map = {} for each_var in vars: assert isinstance(each_var, Variable) if each_var.type == core.VarDesc.VarType.RAW: continue - if each_var.name in load_slice_vars: - continue new_var = _clone_var_in_block_(load_block, each_var) if filename is None: load_block.append_op( @@ -435,10 +604,6 @@ def load_vars(executor, attrs={'file_path': os.path.join(dirname, filename)}) executor.run(load_prog) - # load slice vars on pserver, if have it. - _load_slice_up_vars(executor, dirname, - main_program._slice_vars_and_attrs) - def load_params(executor, dirname, main_program=None, filename=None): """ @@ -521,12 +686,134 @@ def load_persistables(executor, dirname, main_program=None, filename=None): fluid.io.load_persistables(executor=exe, dirname=param_path, main_program=None) """ - load_vars( - executor, - dirname=dirname, - main_program=main_program, - predicate=is_persistable, - filename=filename) + + if main_program and main_program._is_distributed: + _load_distributed_persistables( + executor, dirname=dirname, main_program=main_program) + else: + load_vars( + executor, + dirname=dirname, + main_program=main_program, + predicate=is_persistable, + filename=filename) + + +def _load_distributed_persistables(executor, dirname, main_program=None): + """ + customized load_persistables for distributed training. + it should be used on parameter server, + + Args: + executor(Executor): The executor to run for saving parameters. + dirname(str): The load directory path. + main_program(Program): The program whose parameters will be + loaded. the main_program must be the pserver_program + get after transpiler. + + Returns: + None + + Examples: + .. code-block:: python + + exe = fluid.Executor(fluid.CPUPlace()) + param_path = "./my_paddle_model" + t = distribute_transpiler.DistributeTranspiler() + t.transpile(...) + pserver_prog = t.get_pserver_program(...) + _load_distributed_persistables(executor=exe, dirname=param_path, main_program=pserver_prog) + """ + + def __is_distributed_part_var(varname): + trainer_idx = varname.find(".trainer_") + block_idx = varname.find(".block") + return trainer_idx or block_idx + + def __load_persistable_vars(executor, dirname, need_load_vars): + load_prog = Program() + load_block = load_prog.global_block() + need_delete_vars = [] + + for param in need_load_vars: + origin_var = param.origin + slice_var = param.slice + is_slice = param.is_slice + offset = param.offset + + if is_slice: + origin = load_block.create_var( + name="{}.load".format(origin_var.name), + type=origin_var.type, + shape=origin_var.shape, + dtype=origin_var.dtype, + persistable=True) + + load_block.append_op( + type='load', + inputs={}, + outputs={'Out': [origin]}, + attrs={ + 'file_path': os.path.join(dirname, origin_var.name) + }) + + slice = load_block.create_var( + name=slice_var.name, + type=slice_var.type, + shape=slice_var.shape, + dtype=slice_var.dtype, + persistable=True) + + dim1_flatten = reduce(lambda x, y: x * y, slice.shape[1:]) + start = int(offset / dim1_flatten) + end = int(offset / dim1_flatten + slice.shape[0]) + + load_block.append_op( + type="slice", + inputs={'Input': origin}, + outputs={'Out': slice}, + attrs={'axes': [0], + 'starts': [start], + 'ends': [end]}) + + need_delete_vars.append(origin) + else: + origin = load_block.create_var( + name="{}".format(origin_var.name), + type=origin_var.type, + shape=origin_var.shape, + dtype=origin_var.dtype, + persistable=True) + load_block.append_op( + type='load', + inputs={}, + outputs={'Out': [origin]}, + attrs={ + 'file_path': os.path.join(dirname, origin_var.name) + }) + + load_block.append_op( + type='delete_var', + inputs={'X': need_delete_vars}, ) + + executor.run(load_prog) + + if not isinstance(main_program, Program): + raise ValueError("'main_program' should be an instance of Program.") + + if not main_program._is_distributed: + raise ValueError( + "'_load_distributed_persistables' just be designed for distributed training." + ) + + if not main_program._ps_endpoint: + raise ValueError( + "'_load_distributed_persistables' need current_endpoint set in DistributeTranspiler.transpile" + ) + + need_load_vars = main_program._parameters_on_pservers.get_distributed_vars_by_ep( + main_program._ps_endpoint) + __load_persistable_vars(executor, dirname, need_load_vars) def prepend_feed_ops(inference_program, @@ -795,52 +1082,6 @@ def load_inference_model(dirname, return [program, feed_target_names, fetch_targets] -def _save_lookup_tables_by_notify(executor, dirname, lookup_table, - pserver_endpoints): - """ - This function will send checkpoint notify message from Trainer 0 - to all the pservers. - The checkpoint notify message contains lookup table name, - the absolute path on pserver to save lookup_table. - - Args: - executor(Executor): The executor to run for send checkpoint notify. - dirname(str): The folder where to save. - lookup_table(string): the lookup table name, when use distribute - lookup table, we can get lookup table name by DistributeTranspiler. - table_name - ps_endpoint_list(list): the parameter server ip:port list. - when use distribute lookup table, we can get ps_endpoint_list by - distribute arguments. - Return: - None - - Examples: - .. code-block:: python - - exe = fluid.Executor(fluid.CPUPlace()) - param_path = "./my_paddle_model" - table_name = "share_w" - ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"] - - _save_pserver_vars_by_notify(executor=exe, - dirname=param_path, lookup_table=table_name, - pserver_endpoints=ps_endpoints) - """ - - pserver_notify_program = Program() - pserver_notify_block = pserver_notify_program.global_block() - - attrs = {} - attrs['epmap'] = pserver_endpoints - attrs['dir'] = dirname - attrs['lookup_table'] = lookup_table - - pserver_notify_block.append_op( - type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs) - executor.run(pserver_notify_program) - - def _endpoints_replacement(program, endpoints): ENDPOINT_MAP = "epmap" for op in program.global_block().ops: @@ -911,54 +1152,3 @@ def get_parameter_value_by_name(name, executor, program=None): program = default_main_program() var = program.global_block().var(name) return get_parameter_value(var, executor) - - -def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs): - if not slice_vars_and_attrs: - return - - load_prog = Program() - load_block = load_prog.global_block() - need_delete_vars = [] - - for var_tuple in slice_vars_and_attrs: - orig_var = var_tuple[0] - start = var_tuple[1] - slice_var = var_tuple[2] - end = start + slice_var.shape[0] - - orig_var_name = orig_var.name - orig_var.name = "{}.origin".format(orig_var_name) - - clone_orig_var = load_block.create_var( - name=orig_var.name, - type=orig_var.type, - shape=orig_var.shape, - dtype=orig_var.dtype, - persistable=True) - - clone_slice_var = load_block.create_var( - name=slice_var.name, - type=slice_var.type, - shape=slice_var.shape, - dtype=slice_var.dtype, - persistable=True) - - load_block.append_op( - type='load', - inputs={}, - outputs={'Out': [clone_orig_var]}, - attrs={'file_path': os.path.join(dirname, orig_var_name)}) - load_block.append_op( - type="slice", - inputs={'Input': clone_orig_var}, - outputs={'Out': clone_slice_var}, - attrs={'axes': [0], - 'starts': [start], - 'ends': [end]}) - need_delete_vars.append(clone_orig_var) - - load_block.append_op( - type='delete_var', - inputs={'X': need_delete_vars}, ) - executor.run(load_prog) diff --git a/python/paddle/fluid/tests/unittests/dist_save_load.py b/python/paddle/fluid/tests/unittests/dist_save_load.py index faec5350424..f0f13a9d49c 100644 --- a/python/paddle/fluid/tests/unittests/dist_save_load.py +++ b/python/paddle/fluid/tests/unittests/dist_save_load.py @@ -80,7 +80,8 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2): # NOTE: pserver should not call memory optimize t = self.get_transpiler(args.trainer_id, fluid.default_main_program(), args.endpoints, - args.trainers, args.sync_mode) + args.trainers, args.sync_mode, False, + args.current_endpoint) pserver_prog = t.get_pserver_program(args.current_endpoint) startup_prog = t.get_startup_program(args.current_endpoint, pserver_prog) @@ -93,7 +94,8 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2): exe.run(startup_prog) if need_load and model_dir: - self._load_persistable_vars(exe, model_dir, startup_prog) + fluid.io.load_persistables(exe, model_dir, pserver_prog) + exe.run(pserver_prog) def run_trainer(self, args): @@ -158,19 +160,46 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2): need_save = bool(int(os.getenv("SAVE", "0"))) model_dir = os.getenv("MODEL_DIR", "") - - if need_save: - for _ in six.moves.xrange(RUN_STEP): - loss, = exe.run(fetch_list=[avg_cost.name], - feed=feeder.feed(get_data())) - if need_save and model_dir: - io.save_persistables(startup_exe, model_dir, trainer_prog) - - var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor()) - if six.PY2: - print(pickle.dumps(np.ravel(var).tolist())) + save_mode = os.getenv("SAVE_MODE", "") + + if save_mode == "LOCAL": + if need_save: + for _ in six.moves.xrange(RUN_STEP): + loss, = exe.run(fetch_list=[avg_cost.name], + feed=feeder.feed(get_data())) + if need_save and model_dir: + io.save_persistables(startup_exe, model_dir, trainer_prog) + + var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor( + )) + if six.PY2: + print(pickle.dumps(np.ravel(var).tolist())) + else: + sys.stdout.buffer.write(pickle.dumps(np.ravel(var).tolist())) + + elif save_mode == "DIST": + skip_steps = int(os.getenv("SKIP_STEPS")) + loss = None + if need_save: + for idx in six.moves.xrange(8): + loss, = exe.run(fetch_list=[avg_cost.name], + feed=feeder.feed(get_data())) + if need_save and model_dir and idx == skip_steps and args.trainer_id == 0: + io.save_persistables(startup_exe, model_dir, + trainer_prog) + else: + for idx in six.moves.xrange(8): + data = get_data() + if idx <= skip_steps: + continue + loss, = exe.run(fetch_list=[avg_cost.name], + feed=feeder.feed(data)) + if six.PY2: + print(pickle.dumps(loss.tolist())) + else: + sys.stdout.buffer.write(pickle.dumps(loss.tolist())) else: - sys.stdout.buffer.write(pickle.dumps(np.ravel(var).tolist())) + raise Exception("save_mode must be LOCAL or DIST") if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/dist_simnet_bow.py b/python/paddle/fluid/tests/unittests/dist_simnet_bow.py index fac5e037a46..09afae6114e 100644 --- a/python/paddle/fluid/tests/unittests/dist_simnet_bow.py +++ b/python/paddle/fluid/tests/unittests/dist_simnet_bow.py @@ -75,9 +75,13 @@ def get_loss(cos_q_pt, cos_q_nt): return avg_cost -def get_optimizer(): - # SGD optimizer - optimizer = fluid.optimizer.SGD(learning_rate=base_lr) +def get_optimizer(op="sgd"): + if op.upper() == "sgd".upper(): + optimizer = fluid.optimizer.SGD(learning_rate=base_lr) + elif op.upper() == "adam".upper(): + optimizer = fluid.optimizer.Adam(learning_rate=base_lr) + else: + optimizer = fluid.optimizer.SGD(learning_rate=base_lr) return optimizer @@ -237,7 +241,8 @@ class TestDistSimnetBow2x2(TestDistRunnerBase): inference_program = fluid.default_main_program().clone() # Optimization - opt = get_optimizer() + opt = os.getenv('OPTIMIZER', 'sgd') + opt = get_optimizer(opt) opt.minimize(avg_cost) # Reader diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 69a38618cde..e51ae1a944e 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -43,7 +43,8 @@ class TestDistRunnerBase(object): pserver_endpoints, trainers, sync_mode, - dc_asgd=False): + dc_asgd=False, + current_endpoint=None): # NOTE: import fluid until runtime, or else forking processes will cause error. config = fluid.DistributeTranspilerConfig() config.enable_dc_asgd = dc_asgd @@ -53,7 +54,8 @@ class TestDistRunnerBase(object): program=main_program, pservers=pserver_endpoints, trainers=trainers, - sync_mode=sync_mode) + sync_mode=sync_mode, + current_endpoint=current_endpoint) return t def run_pserver(self, args): diff --git a/python/paddle/fluid/tests/unittests/test_dist_save_load.py b/python/paddle/fluid/tests/unittests/test_dist_save_load.py index 4588ca7c17b..e795bc410ee 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_dist_save_load.py @@ -33,7 +33,6 @@ class TestDistSaveLoadDense2x2(TestDistBase): delta=1e-3, check_error_log=False, need_envs={}): - required_envs = { "PATH": os.getenv("PATH", ""), "PYTHONPATH": os.getenv("PYTHONPATH", ""), @@ -77,7 +76,77 @@ class TestDistSaveLoadDense2x2(TestDistBase): need_envs = { "IS_DISTRIBUTED": '0', "IS_SPARSE": '0', - 'IS_SELF_CONTAINED_LR': '1' + 'IS_SELF_CONTAINED_LR': '1', + 'SAVE_MODE': 'LOCAL', + } + self.check_with_place( + "dist_save_load.py", + delta=0, + check_error_log=False, + need_envs=need_envs) + + +class TestDistSaveLoadWithPServerStateDense2x2(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._enforce_place = "CPU" + + def check_with_place(self, + model_file, + delta=1e-3, + check_error_log=False, + need_envs={}): + required_envs = { + "PATH": os.getenv("PATH", ""), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "http_proxy": "" + } + + required_envs.update(need_envs) + + if check_error_log: + required_envs["GLOG_v"] = "3" + required_envs["GLOG_logtostderr"] = "1" + + model_dir = tempfile.mkdtemp() + + save_env = {} + save_env["SAVE_MODE"] = "DIST" + save_env["SAVE"] = "1" + save_env["MODEL_DIR"] = model_dir + save_env.update(required_envs) + + tr0_var_1, tr1_var_1 = self._run_cluster(model_file, save_env, + check_error_log) + + load_env = {} + load_env["LOAD"] = "1" + load_env["MODEL_DIR"] = model_dir + load_env.update(required_envs) + tr0_var_2, tr1_var_2 = self._run_cluster(model_file, load_env, + check_error_log) + + shutil.rmtree(model_dir) + + train0_1_np = np.array(tr0_var_1) + train1_1_np = np.array(tr1_var_1) + train0_2_np = np.array(tr0_var_2) + train1_2_np = np.array(tr1_var_2) + + self.assertAlmostEqual( + train0_1_np.all(), train0_2_np.all(), delta=delta) + self.assertAlmostEqual( + train1_1_np.all(), train1_2_np.all(), delta=delta) + + def test_dist(self): + need_envs = { + "IS_DISTRIBUTED": '0', + "IS_SPARSE": '0', + 'IS_SELF_CONTAINED_LR': '1', + 'SAVE_MODE': 'DIST', + 'OPTIMIZER': 'ADAM', + 'SKIP_STEPS': str(np.random.randint(2, 6)) } self.check_with_place( "dist_save_load.py", diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 3d1ce6b27c9..3566fed2152 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -741,21 +741,40 @@ class TestLoadSliceVar(TranspilerTest): pserver, _ = self.get_pserver(self.pserver1_ep) pserver2, _ = self.get_pserver(self.pserver2_ep) - self.assertTrue(pserver._slice_vars_and_attrs) - self.assertTrue(pserver2._slice_vars_and_attrs) - - for idx in six.moves.xrange(len(pserver._slice_vars_and_attrs)): - self.assertEqual(pserver._slice_vars_and_attrs[idx][0], - pserver2._slice_vars_and_attrs[idx][0]) - - total_numel = six.moves.reduce( - lambda x, y: x * y, pserver._slice_vars_and_attrs[idx][0].shape) - self.assertEqual( - total_numel, - six.moves.reduce(lambda x, y: x * y, - pserver._slice_vars_and_attrs[idx][2].shape) + - six.moves.reduce(lambda x, y: x * y, - pserver2._slice_vars_and_attrs[idx][2].shape)) + vars_ps1 = pserver._parameters_on_pservers.get_distributed_vars_by_ep( + self.pserver1_ep) + vars_ps2 = pserver._parameters_on_pservers.get_distributed_vars_by_ep( + self.pserver2_ep) + + self.assertTrue(vars_ps1) + self.assertTrue(vars_ps2) + + for idx in six.moves.xrange(len(vars_ps1)): + total_numel = 0 + ps1_numel, ps2_numel = 0, 0 + + ps1_var = vars_ps1[idx] + + if not ps1_var.is_slice: + total_numel = six.moves.reduce(lambda x, y: x * y, + vars_ps1[idx].origin.shape) + ps1_numel = six.moves.reduce(lambda x, y: x * y, + vars_ps1[idx].slice.shape) + else: + ps2_var = None + for var in vars_ps2: + if var.origin.name == ps1_var.origin.name: + ps2_var = var + break + + total_numel = six.moves.reduce(lambda x, y: x * y, + ps1_var.origin.shape) + ps1_numel = six.moves.reduce(lambda x, y: x * y, + ps1_var.slice.shape) + ps2_numel = six.moves.reduce(lambda x, y: x * y, + ps2_var.slice.shape) + + self.assertEqual(total_numel, ps1_numel + ps2_numel) class TestNCCL2Transpile(TranspilerTest): diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index ea5a4cf7cdb..c61cb54e1f2 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -39,7 +39,7 @@ from .ps_dispatcher import RoundRobin, PSDispatcher from .. import core, framework, unique_name from ..framework import Program, default_main_program, \ default_startup_program, Block, \ - Parameter, grad_var_name + Parameter, Variable, grad_var_name from .details import * from ..distribute_lookup_table import find_distributed_lookup_table from functools import reduce @@ -62,6 +62,260 @@ def log(*args): print(args) +class VarStruct(object): + """ + record part properties of a Variable in python. + """ + + def __init__(self, name, shape, dtype, type, lod_level, persistable): + self.name = name + self.shape = shape + self.dtype = dtype + self.type = type + self.lod_level = lod_level + self.persistable = persistable + + +class VarDistributed(object): + """ + a class to record the var distributed on parameter servers. + the class will record the relationship between origin var and slice var. + the slice var's properties, such as type/shape/offset/endpoint. + """ + + def __init__(self, + origin_var, + slice_var, + is_slice=None, + block_id=None, + offset=None, + vtype=None, + endpoint=None): + """ + Args: + origin_var(Variable|VarStruct): origin var properties + slice_var(Variable|VarStruct): slice var properties + is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard. + block_id(int|None): the number about the slice var. + offset(int|None): if the slice var is sliced, offset is the numel before the var. + vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch. + endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001" + """ + + if isinstance(origin_var, Variable): + self.origin = self.__create_var_struct(origin_var) + else: + self.origin = origin_var + + if isinstance(slice_var, Variable): + self.slice = self.__create_var_struct(slice_var) + else: + self.slice = slice_var + + if self.equal(self.origin, self.slice): + self.is_slice = False + self.block_id = 0 + self.offset = 0 + else: + self.is_slice = True + self.block_id = 0 + self.offset = 0 + + if is_slice is not None: + self.is_slice = is_slice + if block_id is not None: + self.block_id = block_id + if offset is not None: + self.offset = offset + + self.vtype = vtype + self.endpoint = endpoint + + @staticmethod + def __create_var_struct(var): + return VarStruct(var.name, var.shape, var.dtype, var.type, + var.lod_level, var.persistable) + + @staticmethod + def equal(var1, var2): + """ + the two var is equal or not. + Returns: + bool: equal will return True else False + """ + assert isinstance(var1, VarStruct) and isinstance(var2, VarStruct) + + return var1.name == var2.name and \ + var1.type == var2.type and \ + var1.shape == var2.shape and \ + var1.dtype == var2.dtype and \ + var1.lod_level == var2.lod_level and \ + var1.persistable == var2.persistable + + def __str__(self): + origin_var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})". \ + format(i="{", e="}", name=self.origin.name, type=self.origin.type, + shape=self.origin.shape, dtype=self.origin.dtype) + + slice_var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})" \ + ".slice({is_slice}).block({block_id}).offset({offset})". \ + format(i="{", e="}", name=self.slice.name, type=self.slice.type, + shape=self.slice.shape, dtype=self.slice.dtype, + is_slice=self.is_slice, block_id=self.block_id, offset=self.offset) + + return "var owned: {}, origin var: ( {} ), slice var: ( {} ), endpoint: {} ".format( + self.vtype, origin_var_str, slice_var_str, self.endpoint) + + +class VarsDistributed(object): + """ + a gather about VarDistributed with many methods to find distributed vars. + through the class, we can get overview about the distributed parameters on parameter servers. + this class may centralized and convenient for developer to manage and get variable's distribute. + other module can also use this to find variables such io.py. + """ + + def __init__(self): + self.distributed_vars = [] + + def add_distributed_var(self, + origin_var, + slice_var, + is_slice=None, + block_id=None, + offset=None, + vtype=None, + endpoint=None): + """ + add distributed var in this. + + Args: + origin_var(Variable|VarStruct): origin var properties + slice_var(Variable|VarStruct): slice var properties + is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard. + block_id(int|None): the number about the slice var. + offset(int|None): if the slice var is sliced, offset is the numel before the var. + vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch. + endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001" + Returns: + None + """ + self.distributed_vars.append( + VarDistributed(origin_var, slice_var, is_slice, block_id, offset, + vtype, endpoint)) + + def get_distributed_var_by_slice(self, var_name): + """ + get distributed var by conditions. + + Args: + var_name(str): slice var name, such as "w.traier0.block1" + Returns: + VarDistributed: distributed var. + """ + for dist_var in self.distributed_vars: + if dist_var.slice.name == var_name: + return dist_var + return None + + @staticmethod + def equal(var1, var2): + """ + the two var is equal or not. + Returns: + bool: equal will return True else False + """ + return var1.name == var2.name and \ + var1.type == var2.type and \ + var1.shape == var2.shape and \ + var1.dtype == var2.dtype and \ + var1.lod_level == var2.lod_level and \ + var1.persistable == var2.persistable + + def get_distributed_var_by_origin_and_ep(self, origin_var_name, endpoint): + """ + get distributed var by conditions. + + Args: + origin_var_name(str): + endpoint(str): the parameter endpoint, such as "127.0.0.1:1001" + Returns: + VarDistributed: distributed var. + """ + for dist_var in self.distributed_vars: + if dist_var.origin.name == origin_var_name and dist_var.endpoint == endpoint: + return dist_var + return None + + def get_distributed_vars_by_vtypes(self, vtypes, groupby=False): + """ + get distributed vars by conditions. + + Args: + vtype(str|None): distributed var's vtype, such as "Optimizer", "RemotePrefetch" + groupby(bool|False): group by origin var or not. + + Returns: + list: distributed var list. + dict: distributed var map when groupby=True + """ + vtype_vars = [] + for var in self.distributed_vars: + if var.vtype in vtypes: + vtype_vars.append(var) + if not groupby: + return vtype_vars + + params_map = {} + for var in vtype_vars: + origin_var_name = var.origin.name + + if origin_var_name in params_map.keys(): + optimizers = params_map.get(origin_var_name) + else: + optimizers = [] + optimizers.append(var) + params_map[origin_var_name] = optimizers + return params_map + + def get_distributed_vars_by_ep(self, endpoint, vtype=None): + """ + get distributed vars by conditions. + + Args: + endpoint(str): the parameter server endpoint, such as "127.0.0.1:2001" + vtype(str|None): distributed var's vtype, such as "Optimizer", "RemotePrefetch" + + Returns: + list: distributed var list. + """ + endpoint_vars = [] + for var in self.distributed_vars: + if var.endpoint == endpoint: + endpoint_vars.append(var) + if not vtype: + return endpoint_vars + + vtype_vars = [] + for var in endpoint_vars: + if var.vtype == vtype: + vtype_vars.append(var) + return vtype_vars + + def overview(self): + """ + get the overview string about all params on all parameter servers. + + Returns: + Str: overview string. + + """ + vars_str = [] + for var in self.distributed_vars: + vars_str.append(str(var)) + return "\n".join(vars_str) + + class VarBlock: def __init__(self, varname, offset, size): self.varname = varname @@ -223,16 +477,13 @@ class DistributeTranspiler(object): trainer_id, trainers, current_endpoint, - startup_program=None, - wait_port=True): + startup_program=None): if not startup_program: startup_program = default_startup_program() if trainer_id >= 0: worker_endpoints = trainers.split(",") # send NCCL_ID to others or recv from trainer 0 worker_endpoints.remove(current_endpoint) - if trainer_id == 0 and wait_port: - wait_server_ready(worker_endpoints) nccl_id_var = startup_program.global_block().create_var( name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW) @@ -313,13 +564,11 @@ class DistributeTranspiler(object): if self.config.mode == "nccl2": assert (isinstance(trainers, str)) - self.origin_program._trainers_endpoints = trainers.split(",") self._transpile_nccl2( trainer_id, trainers, current_endpoint, - startup_program=startup_program, - wait_port=self.config.wait_port) + startup_program=startup_program) return self.trainer_num = trainers @@ -327,6 +576,7 @@ class DistributeTranspiler(object): self.trainer_id = trainer_id pserver_endpoints = pservers.split(",") self.pserver_endpoints = pserver_endpoints + self.vars_overview = VarsDistributed() self.optimize_ops, self.params_grads = self._get_optimize_pass() ps_dispatcher = self.config.split_method(self.pserver_endpoints) @@ -347,6 +597,7 @@ class DistributeTranspiler(object): # add distributed attrs to program self.origin_program._is_distributed = True self.origin_program._endpoints = self.pserver_endpoints + self.origin_program._ps_endpoint = current_endpoint self.origin_program._is_chief = self.trainer_id == 0 self.origin_program._distributed_lookup_table = self.table_name if self.table_name else None @@ -454,6 +705,10 @@ class DistributeTranspiler(object): self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) + distributed_var = self.vars_overview.get_distributed_var_by_slice( + recv_vars[i].name) + distributed_var.endpoint = ep + # step4: Concat the parameters splits together after recv. all_recv_outputs = [] for param_varname, splited_var in six.iteritems(self.param_var_mapping): @@ -480,6 +735,12 @@ class DistributeTranspiler(object): recv_op_role_var_name = splited_trainer_grad[0].name if param_varname in self.sparse_param_to_height_sections: + + for table_name in table_names: + distributed_var = self.vars_overview.get_distributed_var_by_slice( + table_name) + distributed_var.vtype = "RemotePrefetch" + height_sections = self.sparse_param_to_height_sections[ param_varname] self._update_remote_sparse_update_op( @@ -532,6 +793,9 @@ class DistributeTranspiler(object): pserver_endpoints) self._split_table_grad_and_add_send_vars(program, pserver_endpoints) + self._get_distributed_optimizer_vars() + self.origin_program._parameters_on_pservers = self.vars_overview + def get_trainer_program(self, wait_port=True): """ Get transpiled trainer side program. @@ -541,6 +805,7 @@ class DistributeTranspiler(object): """ # remove optimize ops and add a send op to main_program # FIXME(typhoonzero): Also ops like clip_gradient, lrn_decay? + lr_ops = self._get_lr_ops() delete_ops(self.origin_program.global_block(), self.optimize_ops) delete_ops(self.origin_program.global_block(), lr_ops) @@ -665,9 +930,14 @@ class DistributeTranspiler(object): # NOTE: assume blocks of the same variable is not distributed # on the same pserver, only change param/grad varnames for # trainers to fetch. + sys.stderr.write( + "get_pserver_program() is deprecated, call get_pserver_programs() to get pserver main and startup in a single call.\n" + ) # step1 pserver_program = Program() pserver_program.random_seed = self.origin_program.random_seed + pserver_program._copy_dist_param_info_from(self.origin_program) + # step2: Create vars to receive vars at parameter servers. recv_inputs = [] for v in self.param_grad_ep_mapping[endpoint]["params"]: @@ -703,9 +973,6 @@ class DistributeTranspiler(object): else: recv_inputs.append(single_trainer_var) - self._slice_params_and_optimizes = self._get_slice_vars_and_attrs( - endpoint) - # step 3 # Create a union-find data structure from optimize ops, # If two ops are connected, we could add these two ops @@ -882,10 +1149,6 @@ class DistributeTranspiler(object): outputs={}, attrs=attrs) - # add distributed attrs - pserver_program._slice_vars_and_attrs = list( - self._slice_params_and_optimizes.values()) - pserver_program._sync_with_cpp() # save pserver program to generate pserver side startup relatively. self.pserver_program = pserver_program @@ -984,30 +1247,88 @@ class DistributeTranspiler(object): inputs={"X": startup_param_var}, outputs={"Out": startup_tmpvar}) - # add slice vars - s_prog._slice_vars_and_attrs = pserver_program._slice_vars_and_attrs - return s_prog - def _get_slice_vars_and_attrs(self, endpoint): - slice_vars_and_attrs = {} + # ====================== private transpiler functions ===================== + def _get_slice_var_info(self, slice_var): block_suffix = "block" - for param in self.param_grad_ep_mapping[endpoint]["params"]: - orig_var_name, block_name, _ = self._get_varname_parts(param.name) - if not block_name: - continue + block_idx = 0 + offset = 0 + is_slice = False - block_idx = int(block_name.split(block_suffix)[1]) - orig_var = self.origin_program.global_block().vars[orig_var_name] + orig_var_name, block_name, _ = self._get_varname_parts(slice_var.name) - skip_dim0 = 0 - slice_vars = self.param_var_mapping[orig_var_name] - for slice_var in slice_vars[:block_idx]: - skip_dim0 += slice_var.shape[0] - slice_vars_and_attrs[param.name] = [orig_var, skip_dim0, param] - return slice_vars_and_attrs + if not block_name: + return is_slice, block_idx, offset - # ====================== private transpiler functions ===================== + block_idx = int(block_name.split(block_suffix)[1]) + skip_dim0 = 0 + slice_vars = self.param_var_mapping[orig_var_name] + + orig_dim1_flatten = reduce(lambda x, y: x * y, slice_vars[0].shape[1:]) + + for slice_var in slice_vars[:block_idx]: + skip_dim0 += slice_var.shape[0] + + offset = skip_dim0 * orig_dim1_flatten + is_slice = True + return is_slice, block_idx, offset + + def _get_distributed_optimizer_vars(self): + def _get_distributed_optimizer_var(endpoint): + opt_op_on_pserver = [] + for _, op in enumerate(self.optimize_ops): + if self._is_optimizer_op(op) and self._is_opt_op_on_pserver( + endpoint, op): + opt_op_on_pserver.append(op) + + for opt_op in opt_op_on_pserver: + dist_var = None + for key in opt_op.input_names: + if key == "Param": + param_name = opt_op.input(key)[0] + dist_var = self.vars_overview.get_distributed_var_by_origin_and_ep( + param_name, endpoint) + break + for key in opt_op.input_names: + if key in ["Param", "Grad", "LearningRate"]: + continue + origin_var = self.origin_program.global_block().vars[ + opt_op.input(key)[0]] + # update accumulator variable shape + new_shape = self._get_optimizer_input_shape( + opt_op.type, key, origin_var.shape, + dist_var.slice.shape) + + if new_shape == dist_var.slice.shape: + splited_var = VarStruct( + name=origin_var.name, + shape=new_shape, + dtype=origin_var.dtype, + type=origin_var.type, + lod_level=origin_var.lod_level, + persistable=origin_var.persistable) + + self.vars_overview.add_distributed_var( + origin_var=origin_var, + slice_var=splited_var, + is_slice=dist_var.is_slice, + block_id=dist_var.block_id, + offset=dist_var.offset, + vtype="Optimizer", + endpoint=endpoint) + else: + self.vars_overview.add_distributed_var( + origin_var=origin_var, + slice_var=origin_var, + is_slice=False, + block_id=0, + offset=0, + vtype="Optimizer", + endpoint=endpoint) + + for ep in self.pserver_endpoints: + _get_distributed_optimizer_var(ep) def _update_dist_lookup_table_vars(self, param_list, grad_list, params_grads): @@ -1093,6 +1414,22 @@ class DistributeTranspiler(object): # origin_param_name -> [splited_param_vars] self.param_var_mapping = self._create_vars_from_blocklist( self.origin_program, param_blocks) + + for orig_name, splited_vars in self.param_var_mapping.items(): + orig_var = self.origin_program.global_block().var(orig_name) + + for splited_var in splited_vars: + is_slice, block_id, offset = self._get_slice_var_info( + splited_var) + + self.vars_overview.add_distributed_var( + origin_var=orig_var, + slice_var=splited_var, + block_id=block_id, + offset=offset, + is_slice=is_slice, + vtype="Param") + # origin_grad_name -> [splited_grad_vars] self.grad_var_mapping = self._create_vars_from_blocklist( self.origin_program, @@ -1729,13 +2066,6 @@ class DistributeTranspiler(object): shape=new_shape) new_inputs[key] = tmpvar - # var shape been changed - if new_shape != var.shape: - slice_var_args = self._slice_params_and_optimizes[ - param_var.name] - self._slice_params_and_optimizes[ - var.name] = [var, slice_var_args[1], tmpvar] - # change output's ParamOut variable outputs = self._get_output_map_from_op( self.origin_program.global_block().vars, opt_op) @@ -1763,8 +2093,8 @@ class DistributeTranspiler(object): # skip per trainer vars if g.name.find(".trainer_") == -1: # only param or grads have splited blocks - if self._orig_varname(g.name) in self.grad_name_to_param_name or\ - self._orig_varname(g.name) in self.param_name_to_grad_name: + if self._orig_varname(g.name) in self.grad_name_to_param_name or \ + self._orig_varname(g.name) in self.param_name_to_grad_name: grad_block = g break return grad_block -- GitLab