From 7f2aa2db3c69cb9ebb8bae9e19280e75f964e1d0 Mon Sep 17 00:00:00 2001 From: Chengmo Date: Sun, 30 Aug 2020 17:44:39 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90paddle.fleet=E3=80=91Support=20Heter?= =?UTF-8?q?=20Parameter=20Server=20(#25998)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Support Heter Parameter Server --- .../operators/distributed/CMakeLists.txt | 2 +- .../operators/distributed/grpc/grpc_client.cc | 82 ++ .../operators/distributed/grpc/grpc_client.h | 32 + .../operators/distributed/grpc/grpc_serde.cc | 16 +- .../operators/distributed/grpc/grpc_serde.h | 5 + .../operators/distributed/grpc/grpc_server.cc | 48 + .../operators/distributed/grpc/grpc_service.h | 6 +- .../operators/distributed/request_handler.h | 2 + .../distributed/request_handler_impl.cc | 16 + .../distributed/request_handler_impl.h | 11 + .../fluid/operators/distributed/rpc_client.h | 6 + .../operators/distributed/rpc_server_test.cc | 109 ++- .../operators/distributed/send_recv.proto.in | 2 +- .../operators/distributed/variable_response.h | 7 + .../distributed_ops/listen_and_serv_op.cc | 9 +- .../distributed_ops/listen_and_serv_op.h | 2 + .../distributed_ops/send_and_recv_op.cc | 98 ++ .../distributed/fleet/base/fleet_base.py | 15 +- .../distributed/fleet/base/role_maker.py | 131 ++- .../fleet/meta_optimizers/__init__.py | 4 +- ...py => parameter_server_graph_optimizer.py} | 9 +- ...mizer.py => parameter_server_optimizer.py} | 28 +- .../fleet/runtime/parameter_server_runtime.py | 24 +- .../fluid/incubate/fleet/base/fleet_base.py | 2 +- .../fluid/incubate/fleet/base/role_maker.py | 1 - .../distribute_transpiler/__init__.py | 8 +- .../parameter_server/ir/heter_trainer_pass.py | 100 ++ .../fleet/parameter_server/ir/pserver_pass.py | 6 +- .../fleet/parameter_server/ir/public.py | 67 +- .../fleet/parameter_server/ir/trainer_pass.py | 880 +++++++++++++++++- .../tests/unittests/ctr_dataset_reader.py | 58 +- .../tests/unittests/dist_fleet_heter_ctr.py | 220 +++++ .../unittests/test_dist_fleet_heter_base.py | 388 ++++++++ .../unittests/test_dist_fleet_heter_ctr.py | 56 ++ .../test_dist_fleet_heter_program.py | 139 +++ 35 files changed, 2506 insertions(+), 83 deletions(-) create mode 100644 paddle/fluid/operators/distributed_ops/send_and_recv_op.cc rename python/paddle/distributed/fleet/meta_optimizers/{async_graph_execution_optimizer.py => parameter_server_graph_optimizer.py} (88%) rename python/paddle/distributed/fleet/meta_optimizers/{async_optimizer.py => parameter_server_optimizer.py} (82%) create mode 100644 python/paddle/fluid/incubate/fleet/parameter_server/ir/heter_trainer_pass.py create mode 100644 python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py create mode 100644 python/paddle/fluid/tests/unittests/test_dist_fleet_heter_base.py create mode 100644 python/paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py create mode 100644 python/paddle/fluid/tests/unittests/test_dist_fleet_heter_program.py diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index a033611f478..e584e025088 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -56,7 +56,7 @@ endif() cc_test(rpc_server_test SRCS rpc_server_test.cc - DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op) + DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op scale_op) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.cc b/paddle/fluid/operators/distributed/grpc/grpc_client.cc index edbe945cd72..0983b4a406e 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.cc @@ -132,6 +132,15 @@ void ProcGetResponse(const VarHandle& var_h, &trainer_id); } +void ProcGetRecvResponse(const VarHandle& var_h, + const ::grpc::ByteBuffer& ret_msg) { + VLOG(4) << "ProcGetRecvResponse"; + framework::Variable* outvar = nullptr; + int trainer_id; + DeserializeRecvFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar, + &trainer_id); +} + template void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) { ::grpc::Slice slice(proto.ByteSizeLong()); @@ -482,6 +491,79 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify( return h; } +VarHandlePtr GRPCClient::AsyncSendAndRecv(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& send_var_name, + const std::string& recv_var_name, + const std::string& table_name, + int64_t time_out) { + const platform::DeviceContext* p_ctx = &ctx; + const std::string ep_val = ep; + const std::string send_var_name_val = send_var_name; + const std::string recv_var_name_val = recv_var_name; + const std::string table_name_val = table_name; + const framework::Scope* p_scope = &scope; + const auto ch = GetChannel(ep_val); + const std::string method = kSendAndRecvRPC; + VLOG(4) << "GRPCClient::SendAndRecv Begin ,Send_var_name: " + << send_var_name_val << " Recv_var_name: " << recv_var_name_val; + int retry_times_ = 0; + + while (true) { + SendAndRecvProcessor* s = new SendAndRecvProcessor(ch); + VarHandlePtr h( + new VarHandle(ep, method, send_var_name_val, p_ctx, p_scope)); + VarHandlePtr h_recv( + new VarHandle(ep, method, recv_var_name_val, p_ctx, p_scope)); + s->Prepare(h, time_out); + s->RecvPrepare(h_recv); + + framework::AsyncIO([send_var_name_val, recv_var_name_val, table_name_val, + p_scope, p_ctx, s, method, h, this] { + auto* send_var = p_scope->FindVar(send_var_name_val); + send_var->GetMutable()->set_lod({}); + ::grpc::ByteBuffer buf; + VLOG(4) << "SerializeToByteBuffer: send_var_name_val: " + << send_var_name_val + << " recv_var_name_val: " << recv_var_name_val; + SerializeToByteBuffer(send_var_name_val, send_var, *p_ctx, &buf, + recv_var_name_val, trainer_id_, table_name_val); + + VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; + + // stub context + s->response_call_back_ = ProcGetRecvResponse; + + platform::RecordRPCEvent record_event(method); + + auto call = s->stub_g_.PrepareUnaryCall( + s->context_.get(), "/sendrecv.SendRecvService/SendAndRecvVariable", + buf, &cq_); + call->StartCall(); + call->Finish(&s->reply_, &s->status_, reinterpret_cast(s)); + + if (UNLIKELY(platform::IsProfileEnabled())) { + h->Wait(); + } + }); + req_count_++; + + if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) { + h->Wait(); + if (h->should_retry) { + VLOG(3) << "rpc call failed, retry times " << retry_times_; + retry_times_++; + std::random_device rd; + std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5)); + continue; + } + } + + return h; + } +} + bool GRPCClient::Wait() { std::unique_lock lk(sync_mutex_); sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); }); diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.h b/paddle/fluid/operators/distributed/grpc/grpc_client.h index bd9f25567dc..6b6249540c6 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.h @@ -53,6 +53,8 @@ namespace distributed { void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); +void ProcGetRecvResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); + class BaseProcessor { public: BaseProcessor() { context_ = nullptr; } @@ -131,6 +133,28 @@ class GetProcessor : public BaseProcessor { RequestGetCallBack response_call_back_ = ProcGetResponse; }; +class SendAndRecvProcessor : public BaseProcessor { + public: + explicit SendAndRecvProcessor(std::shared_ptr ch) + : BaseProcessor(), stub_g_(ch) {} + + virtual ~SendAndRecvProcessor() {} + + void ProcessImpl() override { + if (response_call_back_) { + response_call_back_(*var_h_recv_.get(), reply_); + var_h_recv_->Finish(true); + } + } + + void RecvPrepare(VarHandlePtr h_recv) { var_h_recv_ = h_recv; } + + ::grpc::ByteBuffer reply_; + ::grpc::GenericStub stub_g_; + RequestGetCallBack response_call_back_ = ProcGetResponse; + VarHandlePtr var_h_recv_; +}; + class BatchBarrierProcessor : public BaseProcessor { public: explicit BatchBarrierProcessor(std::shared_ptr ch) @@ -231,6 +255,14 @@ class GRPCClient : public RPCClient { const framework::Scope& scope, const std::string& var_name, int64_t time_out = FLAGS_rpc_deadline) override; + VarHandlePtr AsyncSendAndRecv(const std::string& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& send_var_name, + const std::string& recv_var_name, + const std::string& table_name = "", + int64_t time_out = FLAGS_rpc_deadline) override; + VarHandlePtr AsyncSendComplete( const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; diff --git a/paddle/fluid/operators/distributed/grpc/grpc_serde.cc b/paddle/fluid/operators/distributed/grpc/grpc_serde.cc index bb9719eaad0..eddd89cf20c 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_serde.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_serde.cc @@ -76,7 +76,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, PADDLE_THROW("Serialize does not support type: %s", typeid(var->Type()).name()); } - std::string header; request.AppendToString(&header); auto buffer = std::unique_ptr(new char[1024]); @@ -101,7 +100,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, } #endif PADDLE_ENFORCE_NOT_NULL(payload); - e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload->memory_size()); if (payload->memory_size() >= std::numeric_limits::max()) { @@ -140,7 +138,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ::grpc::Slice::STEAL_REF); num_slices = 4; } - ::grpc::ByteBuffer tmp(&slices[0], num_slices); msg->Swap(&tmp); } @@ -156,6 +153,19 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, *trainer_id = resp.GetTrainerId(); } +void DeserializeRecvFromByteBuffer(const ::grpc::ByteBuffer& msg, + const platform::DeviceContext& ctx, + const framework::Scope* scope, + framework::Variable** var, int* trainer_id) { + platform::RecordRPCEvent record_event("deserial"); + operators::distributed::GRPCVariableResponse resp(scope, &ctx); + PADDLE_ENFORCE_EQ( + resp.Parse(msg), 0, + platform::errors::InvalidArgument("parse bytebuffer to tensor error!")); + *var = resp.GetRecvVar(); + *trainer_id = resp.GetTrainerId(); +} + } // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc/grpc_serde.h b/paddle/fluid/operators/distributed/grpc/grpc_serde.h index c9a57beb3a6..30e6907656e 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_serde.h +++ b/paddle/fluid/operators/distributed/grpc/grpc_serde.h @@ -47,6 +47,11 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, const framework::Scope* scope, framework::Variable** var, int* trainer_id); +void DeserializeRecvFromByteBuffer(const ::grpc::ByteBuffer& msg, + const platform::DeviceContext& ctx, + const framework::Scope* scope, + framework::Variable** var, int* trainer_id); + } // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/distributed/grpc/grpc_server.cc b/paddle/fluid/operators/distributed/grpc/grpc_server.cc index e7effcc1805..5c0232a50a9 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_server.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_server.cc @@ -28,6 +28,7 @@ DECLARE_int32(rpc_retry_bind_port); namespace paddle { namespace operators { namespace distributed { + enum CallStatus { PROCESS = 0, FINISH }; // reference: @@ -433,6 +434,51 @@ class RequestNotify final : public RequestBase { ServerAsyncResponseWriter responder_; }; +class RequestSendAndRecv final : public RequestBase { + public: + explicit RequestSendAndRecv(GrpcService::AsyncService* service, + ::grpc::ServerCompletionQueue* cq, + RequestHandler* request_handler, int req_id) + : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { + request_.reset(new GRPCVariableResponse( + request_handler->scope(), request_handler->dev_ctx(), + request_handler->distributed_mode())); + + int method_id = + static_cast(distributed::GrpcMethod::kRequestSendAndRecv); + + service_->RequestAsyncUnary( + method_id, &ctx_, request_.get(), &responder_, cq_, cq_, + reinterpret_cast(static_cast(req_id))); + } + + virtual ~RequestSendAndRecv() {} + std::string GetReqName() override { return request_->Varname(); } + + void Process() override { + std::string in_var_name = request_->Varname(); + std::string out_var_name = request_->OutVarname(); + std::string table_name = request_->TableName(); + int trainer_id = request_->GetTrainerId(); + + VLOG(4) << "RequestSendAndRecv, in_var_name: " << in_var_name + << " out_var_name: " << out_var_name << " trainer: " << trainer_id; + auto scope = request_->GetMutableLocalScope(); + auto invar = scope->FindVar(in_var_name); + framework::Variable* outvar = nullptr; + request_handler_->Handle(in_var_name, scope, invar, &outvar, trainer_id, + out_var_name, table_name); + SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(), + &reply_); + Finish(reply_, &responder_); + } + + protected: + std::shared_ptr request_; + ::grpc::ByteBuffer reply_; + ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; +}; + void AsyncGRPCServer::WaitServerReady() { VLOG(4) << "AsyncGRPCServer is waiting server ready"; std::unique_lock lock(this->mutex_ready_); @@ -586,6 +632,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, b = new RequestCheckpointNotify(service_.get(), cq.get(), handler, req_id); } else if (rpc_name == kRequestNotify) { b = new RequestNotify(service_.get(), cq.get(), handler, req_id); + } else if (rpc_name == kRequestSendAndRecv) { + b = new RequestSendAndRecv(service_.get(), cq.get(), handler, req_id); } else { PADDLE_ENFORCE(false, "not supported rpc"); } diff --git a/paddle/fluid/operators/distributed/grpc/grpc_service.h b/paddle/fluid/operators/distributed/grpc/grpc_service.h index 45152293896..95b6810ec61 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_service.h +++ b/paddle/fluid/operators/distributed/grpc/grpc_service.h @@ -85,10 +85,12 @@ enum class GrpcMethod { kGetMonomerVariable, kGetMonomerBarrier, kRequestNotify, + kRequestSendAndRecv, + // when you add new handler, change kGrpcNumMethods at the same time! }; static const int kGrpcNumMethods = - static_cast(GrpcMethod::kRequestNotify) + 1; + static_cast(GrpcMethod::kRequestSendAndRecv) + 1; inline const char* GrpcMethodName(GrpcMethod id) { switch (id) { @@ -108,6 +110,8 @@ inline const char* GrpcMethodName(GrpcMethod id) { return "/sendrecv.SendRecvService/CheckpointNotify"; case GrpcMethod::kRequestNotify: return "/sendrecv.SendRecvService/DistributeNotify"; + case GrpcMethod::kRequestSendAndRecv: + return "/sendrecv.SendRecvService/SendAndRecvVariable"; } // Shouldn't be reached. diff --git a/paddle/fluid/operators/distributed/request_handler.h b/paddle/fluid/operators/distributed/request_handler.h index 59531c0ec78..44359af1b1b 100644 --- a/paddle/fluid/operators/distributed/request_handler.h +++ b/paddle/fluid/operators/distributed/request_handler.h @@ -46,6 +46,7 @@ constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier"; constexpr char kRequestNotify[] = "RequestNotify"; +constexpr char kRequestSendAndRecv[] = "RequestSendAndRecv"; constexpr char kSendRPC[] = "SendRPC"; constexpr char kGetRPC[] = "GetRPC"; @@ -57,6 +58,7 @@ constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC"; constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC"; constexpr char kSendCompleteRPC[] = "SendCompleteRPC"; constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC"; +constexpr char kSendAndRecvRPC[] = "SendAndRecvRPC"; constexpr int64_t kPrefetchTimeout = 60000; #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index e99b0ed4072..761a4edc523 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -325,6 +325,22 @@ bool RequestNotifyHandler::Handle(const std::string &varname, return true; } +bool RequestSendAndRecvHandler::Handle(const std::string &varname, + framework::Scope *Scope, + framework::Variable *var, + framework::Variable **outvar, + const int trainer_id, + const std::string &out_var_name, + const std::string &table_name) { + VLOG(3) << "SendAndRecvHandle: " << varname + << " out_var_name: " << out_var_name + << " , trainer_id: " << trainer_id; + + executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), Scope); + *outvar = Scope->FindVar(out_var_name); + return true; +} + } // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/distributed/request_handler_impl.h b/paddle/fluid/operators/distributed/request_handler_impl.h index f22a133c2d5..42621724e68 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.h +++ b/paddle/fluid/operators/distributed/request_handler_impl.h @@ -176,6 +176,17 @@ class RequestNotifyHandler final : public RequestHandler { std::unordered_map decay_counters; }; +class RequestSendAndRecvHandler final : public RequestHandler { + public: + explicit RequestSendAndRecvHandler(int distributed_mode) + : RequestHandler(distributed_mode) {} + virtual ~RequestSendAndRecvHandler() {} + bool Handle(const std::string& varname, framework::Scope* Scope, + framework::Variable* var, framework::Variable** outvar, + const int trainer_id, const std::string& out_var_name = "", + const std::string& table_name = "") override; +}; + } // namespace distributed } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/distributed/rpc_client.h b/paddle/fluid/operators/distributed/rpc_client.h index 62313222775..69a5e327431 100644 --- a/paddle/fluid/operators/distributed/rpc_client.h +++ b/paddle/fluid/operators/distributed/rpc_client.h @@ -85,6 +85,12 @@ class RPCClient { const framework::Scope& scope, const std::string& var_name, int64_t time_out = FLAGS_rpc_deadline) = 0; + virtual VarHandlePtr AsyncSendAndRecv( + const std::string& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& send_var_name, + const std::string& recv_var_name, const std::string& table_name = "", + int64_t time_out = FLAGS_rpc_deadline) = 0; + virtual VarHandlePtr AsyncSendComplete( const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0; diff --git a/paddle/fluid/operators/distributed/rpc_server_test.cc b/paddle/fluid/operators/distributed/rpc_server_test.cc index 67e11120b80..5ce7ac85269 100644 --- a/paddle/fluid/operators/distributed/rpc_server_test.cc +++ b/paddle/fluid/operators/distributed/rpc_server_test.cc @@ -35,27 +35,24 @@ namespace platform = paddle::platform; namespace distributed = paddle::operators::distributed; USE_NO_KERNEL_OP(lookup_sparse_table_read); +USE_OP(scale); std::unique_ptr g_rpc_service; std::unique_ptr g_req_handler; -framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { +framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) { auto root_block = program->MutableBlock(0); auto* block = program->AppendBlock(*root_block); - framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}}); - framework::VariableNameMap output({{"Output", {"out"}}}); - auto op = block->AppendOp(); - op->SetType("lookup_sparse_table_read"); - op->SetInput("W", {"w"}); - op->SetInput("Ids", {"ids"}); - op->SetOutput("Out", {"out"}); - op->SetAttr("tablename", {"w"}); - op->SetAttr("value_names", {"Param"}); - - auto& out = *root_block->Var("out"); + framework::OpDesc* op = block->AppendOp(); + op->SetType("scale"); + op->SetInput("X", {"x"}); + op->SetOutput("Out", {"res"}); + op->SetAttr("scale", 0.5f); + + auto& out = *root_block->Var("res"); out.SetType(framework::proto::VarType::LOD_TENSOR); - out.SetShape({10, 10}); + out.SetShape({1, 10}); return block; } @@ -69,6 +66,12 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { auto ids_var = scope->Var("ids"); ids_var->GetMutable(); + + auto x_var = scope->Var("x"); + x_var->GetMutable(); + + auto res_var = scope->Var("res"); + res_var->GetMutable(); } void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, @@ -78,6 +81,11 @@ void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, int64_t* ids_ptr = ids_var->mutable_data(framework::DDim({rows_numel, 1}), *place); for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2; + + auto x_var = scope->Var("x")->GetMutable(); + float* x_ptr = + x_var->mutable_data(framework::DDim({1, rows_numel}), *place); + for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0; } void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, @@ -124,6 +132,38 @@ void StartServer(const std::string& rpc_name) { server_thread.join(); } +void StartSendAndRecvServer(const std::string& rpc_name) { + framework::ProgramDesc program; + framework::Scope scope; + platform::CPUPlace place; + framework::Executor exe(place); + platform::CPUDeviceContext ctx(place); + auto block = AppendSendAndRecvBlock(&program); + std::string in_var_name("x"); + std::vector prefetch_block_ids{block->ID()}; + auto prepared = exe.Prepare(program, prefetch_block_ids); + InitTensorsOnServer(&scope, &place, 10); + + std::unordered_map> + grad_to_prepared_ctx; + grad_to_prepared_ctx[in_var_name] = prepared[0]; + + g_req_handler->SetProgram(&program); + g_req_handler->SetGradToPreparedCtx(&grad_to_prepared_ctx); + g_req_handler->SetDevCtx(&ctx); + g_req_handler->SetScope(&scope); + g_req_handler->SetExecutor(&exe); + + g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get()); + g_req_handler->SetRPCServer(g_rpc_service.get()); + + std::thread server_thread( + std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get())); + + server_thread.join(); +} + TEST(COMPLETE, CPU) { setenv("http_proxy", "", 1); setenv("https_proxy", "", 1); @@ -147,3 +187,46 @@ TEST(COMPLETE, CPU) { g_rpc_service.reset(nullptr); g_req_handler.reset(nullptr); } + +TEST(SENDANDRECV, CPU) { + setenv("http_proxy", "", 1); + setenv("https_proxy", "", 1); + g_req_handler.reset(new distributed::RequestSendAndRecvHandler( + distributed::DistributedMode::kAsync)); + g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); + distributed::RPCClient* client = + distributed::RPCClient::GetInstance(0); + PADDLE_ENFORCE_NE(client, nullptr, + platform::errors::InvalidArgument( + "Client Start Fail, Check Your Code & Env")); + std::thread server_thread(StartSendAndRecvServer, + distributed::kRequestSendAndRecv); + g_rpc_service->WaitServerReady(); + int port = g_rpc_service->GetSelectedPort(); + std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); + + framework::Scope scope; + platform::CPUPlace place; + platform::CPUDeviceContext ctx(place); + + // create var on local scope + int64_t rows_numel = 10; + InitTensorsOnClient(&scope, &place, rows_numel); + std::string in_var_name("x"); + std::string out_var_name("res"); + + client->AsyncSendAndRecv(ep, ctx, scope, in_var_name, out_var_name); + client->Wait(); + auto var = scope.Var(out_var_name); + auto value = var->GetMutable(); + auto ptr = value->mutable_data(place); + + for (int64_t i = 0; i < rows_numel; ++i) { + EXPECT_EQ(ptr[i], 0.5); + } + g_rpc_service->ShutDown(); + server_thread.join(); + LOG(INFO) << "begin reset"; + g_rpc_service.reset(nullptr); + g_req_handler.reset(nullptr); +} diff --git a/paddle/fluid/operators/distributed/send_recv.proto.in b/paddle/fluid/operators/distributed/send_recv.proto.in index 0337b72181c..a333642bd16 100644 --- a/paddle/fluid/operators/distributed/send_recv.proto.in +++ b/paddle/fluid/operators/distributed/send_recv.proto.in @@ -29,7 +29,7 @@ service SendRecvService { rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {} rpc DistributeNotify(VariableMessage) returns (VoidMessage) {} - + rpc SendAndRecvVariable(VariableMessage) returns (VariableMessage) {} rpc GetMonomerVariable(VariableMessage) returns (VariableMessage) {} rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {} } diff --git a/paddle/fluid/operators/distributed/variable_response.h b/paddle/fluid/operators/distributed/variable_response.h index 3cabcd22cd5..d979cd8a881 100644 --- a/paddle/fluid/operators/distributed/variable_response.h +++ b/paddle/fluid/operators/distributed/variable_response.h @@ -96,6 +96,13 @@ class VariableResponse { return scope_->FindVar(meta_.varname()); } + framework::Variable* GetRecvVar() { + if (create_scope_) { + return local_scope_->Var(meta_.out_varname()); + } + return scope_->FindVar(meta_.out_varname()); + } + int GetTrainerId() { return static_cast(meta_.trainer_id()); } protected: diff --git a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc index 5869407be5a..5e1e408eb2c 100644 --- a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc +++ b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc @@ -268,7 +268,6 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, size_t num_blocks = program->Size(); PADDLE_ENFORCE_GE(num_blocks, 2, "server program should have at least 2 blocks"); - std::vector block_list; for (size_t blkid = 1; blkid < num_blocks; ++blkid) { block_list.push_back(blkid); @@ -295,6 +294,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, request_send_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); + request_send_and_recv_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); while (true) { if (rpc_service_->IsExit()) { @@ -394,6 +394,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, new distributed::RequestGetNoBarrierHandler()); request_notify_handler_.reset( new distributed::RequestNotifyHandler(distributed_mode, fan_in)); + request_send_and_recv_handler_.reset( + new distributed::RequestSendAndRecvHandler(distributed_mode)); rpc_service_->RegisterRPC(distributed::kRequestSend, request_send_handler_.get(), rpc_send_thread_num); @@ -408,6 +410,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, request_get_no_barrier_handler_.get()); rpc_service_->RegisterRPC(distributed::kRequestNotify, request_notify_handler_.get(), rpc_send_thread_num); + rpc_service_->RegisterRPC(distributed::kRequestSendAndRecv, + request_send_and_recv_handler_.get(), + rpc_get_thread_num); auto optimize_blocks = Attr>(kOptimizeBlocks); @@ -416,6 +421,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, "optimize blocks is less than 1. Optimize blocks " "should be 1 at least on the pserver side.")); auto *program = optimize_blocks[0]->Program(); + framework::Executor executor(dev_place); std::shared_ptr ckpt_pre_context = nullptr; @@ -488,6 +494,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, f(request_checkpoint_handler_.get()); f(request_get_no_barrier_handler_.get()); f(request_notify_handler_.get()); + f(request_send_and_recv_handler_.get()); // register SIGINT(from ctrl+C) and SIGTERM(from kill) signal handlers signal(SIGINT, SignalHandler::StopAndExit); diff --git a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h index 369743dfb23..b41e4e87722 100644 --- a/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h +++ b/paddle/fluid/operators/distributed_ops/listen_and_serv_op.h @@ -99,6 +99,8 @@ class ListenAndServOp : public framework::OperatorBase { mutable std::shared_ptr request_checkpoint_handler_; mutable std::shared_ptr request_notify_handler_; + mutable std::shared_ptr + request_send_and_recv_handler_; mutable std::shared_ptr server_thread_; mutable std::vector sparse_vars_; diff --git a/paddle/fluid/operators/distributed_ops/send_and_recv_op.cc b/paddle/fluid/operators/distributed_ops/send_and_recv_op.cc new file mode 100644 index 00000000000..00cdbe70ca4 --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/send_and_recv_op.cc @@ -0,0 +1,98 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include // NOLINT +#include + +#include "paddle/fluid/framework/blocking_queue.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/distributed/communicator.h" +#include "paddle/fluid/operators/distributed/communicator_common.h" +#include "paddle/fluid/operators/distributed/distributed.h" +#include "paddle/fluid/operators/distributed/parameter_send.h" +#include "paddle/fluid/operators/distributed_ops/send_recv_util.h" +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace operators { + +template +class SendAndRecvKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& scope = ctx.scope(); + const auto& place = ctx.GetPlace(); + auto send_var_name = ctx.Attr("send_var_name"); + auto recv_var_name = ctx.Attr("recv_var_name"); + auto epmap = ctx.Attr("endpoint"); + auto trainer_id = ctx.Attr("trainer_id"); + + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& context = *pool.Get(place); + + distributed::RPCClient* rpc_client = + distributed::RPCClient::GetInstance(trainer_id); + VLOG(3) << "SendAndRecvOp Send_var_name: " << send_var_name + << " Recv_var_name: " << recv_var_name; + distributed::VarHandlePtr rets = rpc_client->AsyncSendAndRecv( + epmap, context, scope, send_var_name, recv_var_name); + rets->Wait(); + } +}; + +class SendAndRecvOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override {} + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(data_type, platform::CPUPlace()); + } +}; + +class SendAndRecvOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "Tensor Input variable to be sent").AsDuplicable(); + AddOutput("Out", "Tensor Output varibale to be recv").AsDuplicable(); + AddAttr("send_var_name", "Send Tensor's name") + .SetDefault(std::string("")); + AddAttr("recv_var_name", "Recv Tensor's name") + .SetDefault(std::string("")); + AddAttr("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); + AddAttr("endpoint", "Server endpoint") + .SetDefault({"127.0.0.1:6164"}); + AddComment(R"DOC( + SendAndRecv operator + This operator will send variables to listen_and_serve op at the parameter server. + And recv variable from parameter server of send variable's scope. + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(send_and_recv, ops::SendAndRecvOp, ops::SendAndRecvOpMaker); + +REGISTER_OP_CPU_KERNEL( + send_and_recv, + ops::SendAndRecvKernel) diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index eb2cb19eaec..f4a16d0de17 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -85,7 +85,7 @@ class Fleet(object): This function is responsible for the distributed architecture what you want to run your code behind,such as Transpiler, Collective in PaddleCloudRoleMaker or UserDefinedRoleMaker - + """ if isinstance(role_maker, RoleMakerBase): self._role_maker = role_maker @@ -112,7 +112,7 @@ class Fleet(object): Returns: bool: True if this is the first node of worker, False if not. - + """ return self._role_maker.is_first_worker() @@ -200,7 +200,8 @@ class Fleet(object): bool: True if this is a node of server, False if not. """ - return self._role_maker.is_server() + return self._role_maker.is_server( + ) or self._role_maker._is_heter_worker() @property def util(self): @@ -372,10 +373,10 @@ class Fleet(object): can_not_apply_optimizer_list.append(opt) # combine recalled meta optimizers to be a valid meta optimizer meta_optimizer, graph_optimizer = \ - self.strategy_compiler.generate_optimizer( - loss, self._role_maker, self.user_defined_optimizer, - self.user_defined_strategy, valid_optimizer_list, - valid_graph_optimizer_list) + self.strategy_compiler.generate_optimizer( + loss, self._role_maker, self.user_defined_optimizer, + self.user_defined_strategy, valid_optimizer_list, + valid_graph_optimizer_list) valid_strategy = self.strategy_compiler._get_valid_strategy( self.user_defined_strategy, can_not_apply_optimizer_list) diff --git a/python/paddle/distributed/fleet/base/role_maker.py b/python/paddle/distributed/fleet/base/role_maker.py index 3d159a63122..25f2d0dd3f4 100644 --- a/python/paddle/distributed/fleet/base/role_maker.py +++ b/python/paddle/distributed/fleet/base/role_maker.py @@ -14,6 +14,7 @@ """Defination of Role Makers.""" import os import numpy as np +import warnings from multiprocessing import Process, Manager import paddle.fluid as fluid @@ -23,6 +24,7 @@ import paddle.fluid as fluid class Role: WORKER = 1 SERVER = 2 + HETER_WORKER = 3 class RoleMakerBase(object): @@ -40,6 +42,11 @@ class RoleMakerBase(object): self._role = None self._current_id = -1 + # for heter parameter server mode + self._heter_trainer_endpoints = [] + self._heter_trainer_device = "CPU" + self._is_heter_parameter_server_mode = False + self._node_type = None self._node_type_comm = None self._all_comm = None @@ -163,12 +170,58 @@ class RoleMakerBase(object): """ print("warning: RoleMakerBase does not have barrier worker.") + def _is_heter_worker(self): + """ + Return is_heter_worker() of current process + """ + warnings.warn("RoleMakerBase does not have function: _is_heter_worker.") + return False + + def _heter_worker_num(self): + """ + Get current total heter-worker number. + + Returns: + int: heter_worker number + """ + warnings.warn( + "RoleMakerBase does not have function: _heter_worker_num.") + return 0 + + def _get_heter_worker_endpoints(self): + """ + Returns: + string: all heter_trainers'endpoints + """ + assert self._heter_trainer_endpoints != [] + return self._heter_trainer_endpoints + + def _get_heter_worker_endpoint(self): + """ + Returns: + int: corresponding heter_trainer's endpoint + + e.g: if we have 4 cpu-trainer(default), 2 gpu-trainer(heter) + then No.0 and No.2 cpu-trainer will work with No.0 gpu-trainer + and No.1 and No.3 cpu-trainer will work with No.1 gpu-trainerr + """ + assert self._heter_trainer_endpoints != [] + return self._heter_trainer_endpoints[(self._current_id + 1) % + self._heter_worker_num()] + + def _get_heter_worker_device(self): + """ + Returns: + string: heter_trainer's device of current node, e.g: CPU/GPU/XPU + """ + return self._heter_trainer_device.upper() + class PaddleCloudRoleMaker(RoleMakerBase): def __init__(self, is_collective=False, **kwargs): super(PaddleCloudRoleMaker, self).__init__() self._is_collective = is_collective - self._init_gloo = False #default no init gloo + self._init_gloo = False # default no init gloo self._kwargs = kwargs self._role_is_generated = False @@ -278,10 +331,7 @@ class PaddleCloudRoleMaker(RoleMakerBase): """ get index of current node """ - if self.is_server(): - return self.server_index() - elif self.is_worker(): - return self.worker_index() + return self._current_id def worker_num(self): """ @@ -323,6 +373,22 @@ class PaddleCloudRoleMaker(RoleMakerBase): self.generate_role() return self._server_endpoints + def _heter_worker_num(self): + """ + get heter worker nums + """ + if not self._role_is_generated: + self.generate_role() + return self._heter_trainers_num + + def _is_heter_worker(self): + """ + whether current process is heter worker + """ + if not self._role_is_generated: + self.generate_role() + return self._role == Role.HETER_WORKER + def _get_rank(self): """ get current rank in all workers and pservers @@ -342,17 +408,47 @@ class PaddleCloudRoleMaker(RoleMakerBase): def _ps_env(self): try: # Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set - # format: string(ip:port), eg. 127.0.0.1:6001 - self._server_endpoints = os.environ[ - "PADDLE_PSERVERS_IP_PORT_LIST"].split(",") + # format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002 + self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST", + "").split(",") + assert self._server_endpoints != "" self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", "").split(",") + assert self._server_endpoints != "" trainers_num = int(os.environ["PADDLE_TRAINERS_NUM"]) training_role = os.environ["TRAINING_ROLE"] - if training_role not in ["TRAINER", "PSERVER"]: - raise ValueError("TRAINING_ROLE must be PSERVER or TRAINER") + if training_role not in ["TRAINER", "PSERVER", "HETER_TRAINER"]: + raise ValueError( + "TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER, but get {}, please check your environment.". + format(training_role)) + + # For heter parameter server env setting + heter_trainer_eplist = os.getenv( + "PADDLE_HETER_TRAINER_IP_PORT_LIST", None) + heter_trainer_device = os.getenv("PADDLE_HETER_TRAINER_DEVICE", + None) + if heter_trainer_eplist and heter_trainer_device: + try: + heter_trainer_eplist = os.environ[ + "PADDLE_HETER_TRAINER_IP_PORT_LIST"].split(",") + except: + raise ValueError( + "Can not Find PADDLE_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ." + ) + + self._is_heter_parameter_server_mode = True + heter_trainers_num = len(heter_trainer_eplist) + current_node_device = heter_trainer_device.upper() + if current_node_device not in ["CPU", "GPU", "XPU"]: + raise ValueError( + "Heter Trainer doesn't support {} device now, please use CPU / GPU / XPU(KunLun)". + format(heter_trainer_device)) + self._heter_trainer_device = current_node_device + else: + self._is_heter_parameter_server_mode = False + heter_trainers_num = 0 if training_role == "TRAINER": role = Role.WORKER @@ -365,17 +461,26 @@ class PaddleCloudRoleMaker(RoleMakerBase): ip = os.environ["POD_IP"] self._cur_endpoint = ip + ":" + port current_id = self._server_endpoints.index(self._cur_endpoint) + elif training_role == "HETER_TRAINER": + role = Role.HETER_WORKER + cur_ip = os.environ["POD_IP"] + cur_port = os.environ["PADDLE_PORT"] + curr_endpoint = ":".join([cur_ip, cur_port]) + current_id = heter_trainer_eplist.index(curr_endpoint) else: - raise ValueError("TRAINING_ROLE must be PSERVER or TRAINER") - except ValueError as ve: + raise ValueError( + "TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER") + except ValueError as e: raise ValueError( - "something wrong with PaddleCloud, please check environment") + "Something wrong with PaddleCloud, please check environment") self._trainers_num = trainers_num self._role = role self._current_id = current_id self._node_num = len( set([x.split(':')[0] for x in self._worker_endpoints])) + self._heter_trainers_num = heter_trainers_num + self._heter_trainer_endpoints = heter_trainer_eplist def _collective_env(self): self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) diff --git a/python/paddle/distributed/fleet/meta_optimizers/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/__init__.py index 78b2b8117b9..d98b2ef3e2a 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/__init__.py @@ -15,10 +15,10 @@ from .amp_optimizer import AMPOptimizer from .recompute_optimizer import RecomputeOptimizer from .gradient_merge_optimizer import GradientMergeOptimizer from .graph_execution_optimizer import GraphExecutionOptimizer -from .async_optimizer import AsyncMetaOptimizer +from .parameter_server_optimizer import ParameterServerOptimizer from .pipeline_optimizer import PipelineOptimizer from .localsgd_optimizer import LocalSGDOptimizer from .lars_optimizer import LarsOptimizer -from .async_graph_execution_optimizer import AsyncGraphExecutionOptimizer +from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer from .dgc_optimizer import DGCOptimizer from .lamb_optimizer import LambOptimizer diff --git a/python/paddle/distributed/fleet/meta_optimizers/async_graph_execution_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py similarity index 88% rename from python/paddle/distributed/fleet/meta_optimizers/async_graph_execution_optimizer.py rename to python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py index c0dee220aaf..878ed7422d7 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/async_graph_execution_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py @@ -13,12 +13,12 @@ from paddle import fluid from paddle.fluid import compiler -from .async_optimizer import AsyncMetaOptimizer +from .parameter_server_optimizer import ParameterServerOptimizer -class AsyncGraphExecutionOptimizer(AsyncMetaOptimizer): +class ParameterServerGraphOptimizer(ParameterServerOptimizer): def __init__(self, optimizer): - super(AsyncGraphExecutionOptimizer, self).__init__(optimizer) + super(ParameterServerGraphOptimizer, self).__init__(optimizer) self.inner_opt = optimizer # we do not allow meta optimizer to be inner optimizer currently self.meta_optimizers_white_list = [] @@ -31,6 +31,9 @@ class AsyncGraphExecutionOptimizer(AsyncMetaOptimizer): if self.role_maker.is_server(): return False + if self.role_maker._is_heter_parameter_server_mode: + return False + return True def _disable_strategy(self, dist_strategy): diff --git a/python/paddle/distributed/fleet/meta_optimizers/async_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py similarity index 82% rename from python/paddle/distributed/fleet/meta_optimizers/async_optimizer.py rename to python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py index b6543549728..ecb198bedf9 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/async_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py @@ -15,9 +15,9 @@ from paddle import fluid from .meta_optimizer_base import MetaOptimizerBase -class AsyncMetaOptimizer(MetaOptimizerBase): +class ParameterServerOptimizer(MetaOptimizerBase): def __init__(self, optimizer): - super(AsyncMetaOptimizer, self).__init__(optimizer) + super(ParameterServerOptimizer, self).__init__(optimizer) self.inner_opt = optimizer # we do not allow meta optimizer to be inner optimizer currently self.meta_optimizers_white_list = [] @@ -68,6 +68,21 @@ class AsyncMetaOptimizer(MetaOptimizerBase): _startup = worker.init_from_server_pass(_startup, compiled_config) _startup = worker.delet_extra_optimizes_pass(_startup, compiled_config) + + # for heter program + if self.role_maker._is_heter_parameter_server_mode: + from paddle.fluid.incubate.fleet.parameter_server.ir import heter_trainer_pass as heter_worker + if self.role_maker._is_heter_worker(): + # for heter worker + _main = heter_worker.split_heter_worker_ops_pass( + _main, compiled_config) + else: + # for default worker + _main = heter_worker.split_trainer_ops_pass(_main, + compiled_config) + # for startup change + _startup = heter_worker.delete_startup_useless_ops_var_pass( + _startup, _main, compiled_config) else: _main = worker.append_send_ops_pass(_main, compiled_config) _startup = _startup @@ -129,9 +144,12 @@ class AsyncMetaOptimizer(MetaOptimizerBase): _origin_startup_program, strategy, self.role_maker) - main_program, startup_program = \ - self._build_trainer_programs(compiled_config) if self.role_maker.is_worker() \ - else self._build_pserver_programs(compiled_config) + if self.role_maker.is_worker() or self.role_maker._is_heter_worker(): + main_program, startup_program = self._build_trainer_programs( + compiled_config) + elif self.role_maker.is_server(): + main_program, startup_program = self._build_pserver_programs( + compiled_config) loss.block.program = main_program fluid.framework.switch_startup_program(startup_program) diff --git a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py index c731ed08893..1741f10ccb1 100644 --- a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py +++ b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py @@ -196,6 +196,18 @@ class ParameterServerRuntime(RuntimeBase): else: warnings.warn("communicator has been initialized, skip") + def _get_executor(self): + if self.role_maker._is_heter_worker(): + if self.role_maker._get_heter_worker_device() == "GPU": + gpu_id = int(os.getenv("FLAGS_selected_gpus", "0")) + executor = Executor(fluid.CUDAPlace(gpu_id)) + else: + raise ValueError("Not Support Device {}".format( + self.role_maker._get_heter_worker_device())) + else: + executor = fluid.Executor(fluid.CPUPlace()) + return executor + def _init_server(self, *args, **kwargs): if len(args) > 1: raise ValueError("init server can only accept 1 args: `dirname`") @@ -204,9 +216,15 @@ class ParameterServerRuntime(RuntimeBase): else: model_dirname = None - executor = fluid.Executor(fluid.CPUPlace()) + if self.role_maker._is_heter_worker(): + self._init_worker() + + executor = self._get_executor() executor.run(fluid.default_startup_program()) + if self.role_maker._is_heter_worker(): + return + if not model_dirname: return @@ -237,12 +255,12 @@ class ParameterServerRuntime(RuntimeBase): # self._load_sparse_params(dirname=model_dir, varnames=distribtued_varnames) def _run_server(self): - executor = fluid.Executor(fluid.CPUPlace()) + executor = self._get_executor() executor.run(fluid.default_main_program()) def _stop_worker(self): self._communicator.stop() - executor = fluid.Executor(fluid.CPUPlace()) + executor = self._get_executor() executor.close() def _get_optimizer_status(self, op, param_name): diff --git a/python/paddle/fluid/incubate/fleet/base/fleet_base.py b/python/paddle/fluid/incubate/fleet/base/fleet_base.py index f885e51ef7f..40cc2d2dd4e 100644 --- a/python/paddle/fluid/incubate/fleet/base/fleet_base.py +++ b/python/paddle/fluid/incubate/fleet/base/fleet_base.py @@ -145,7 +145,7 @@ class Fleet(object): Returns: bool: True if this is a node of server, - False if not. + False if not """ return self._role_maker.is_server() diff --git a/python/paddle/fluid/incubate/fleet/base/role_maker.py b/python/paddle/fluid/incubate/fleet/base/role_maker.py index 7f8db694d36..be27a7c5214 100644 --- a/python/paddle/fluid/incubate/fleet/base/role_maker.py +++ b/python/paddle/fluid/incubate/fleet/base/role_maker.py @@ -343,7 +343,6 @@ class MPISymetricRoleMaker(MPIRoleMaker): def get_pserver_endpoints(self): """ get pserver endpoints - Returns: endpoints(list): pserver endpoints """ diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py index 1a7a82fbfac..236cb458be4 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py @@ -467,7 +467,7 @@ class FleetTranspiler(Fleet): opts = public._get_optimize_ops(self._origin_main_program) for op in opts: if "Param" in op.input_names and \ - "LearningRate" in op.input_names and op.input("Param")[0] == param_name: + "LearningRate" in op.input_names and op.input("Param")[0] == param_name: return op def _save_dense_params(self, executor, dirname, context, main_program): @@ -700,8 +700,8 @@ if you would like to save all variables in a return False if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ - var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ - var.desc.type() == core.VarDesc.VarType.READER: + var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ + var.desc.type() == core.VarDesc.VarType.READER: return False return var.persistable @@ -846,4 +846,4 @@ class ParameterServerOptimizer(DistributedOptimizer): fleet.compiled_config = compiled_config fleet.main_program, fleet.startup_program = \ self._build_trainer_programs(compiled_config) if fleet.is_worker() \ - else self._build_pserver_programs(compiled_config) + else self._build_pserver_programs(compiled_config) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/heter_trainer_pass.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/heter_trainer_pass.py new file mode 100644 index 00000000000..e8668e39bd4 --- /dev/null +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/heter_trainer_pass.py @@ -0,0 +1,100 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import warnings + +import paddle.fluid.core as core +import paddle.fluid.framework as framework + +from paddle.fluid.transpiler.details.program_utils import delete_ops +from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import find_heter_ops +from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import create_heter_program +from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import create_trainer_program +from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import find_block_joints +from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import find_op_input_output +from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import get_vars_name_in_block + + +def split_heter_worker_ops_pass(program, config): + """ + split heter worker program from origin-program + 1. find heter op (located on different device) + 2. find input&output of every heter-block + 3. create heter worker program, add listen&serv op + """ + default_deveice = "cpu" + program, heter_ops, _, program_block_ops = find_heter_ops(program, + default_deveice) + if len(heter_ops) == 0: + warnings.warn( + "Currently running in Heter Parameter Server mode, but no OP running on heterogeneous devices, Please check your code." + ) + return program + + current_device = "gpu" + if current_device not in heter_ops: + raise ValueError("Op which run on device {} not exist.".format( + current_device)) + + block_vars_detail = find_block_joints(program, program_block_ops, heter_ops) + heter_program = framework.Program() + create_heter_program(program, config, heter_program, heter_ops, + block_vars_detail, current_device) + return heter_program + + +def split_trainer_ops_pass(program, config): + """ + split cpu-trainer program from origin-program + 1. find heter op (located on different device) + 2. find input&output of every heter-block + 3. create cpu-trainer program, add send&recv op + """ + # Todo: support user define default_device (MrChengmo) + default_deveice = "cpu" + program, heter_ops, _, program_block_ops = find_heter_ops(program, + default_deveice) + block_vars_detail = find_block_joints(program, program_block_ops, heter_ops) + create_trainer_program(program, config, heter_ops, block_vars_detail) + return program + + +def delete_startup_useless_ops_var_pass(startup_program, main_program, config): + """ + delete variable which not used in current main_program + """ + # find all op and its var + vars_in_main_program = get_vars_name_in_block(main_program.global_block()) + + block_nums = startup_program.num_blocks + for block_index in range(1, block_nums): + current_block = startup_program.block(block_index) + # delete useless op + need_delete_op = [] + for op in current_block.ops: + inputs, outputs = find_op_input_output(startup_program, + current_block, op) + inputs += outputs + # Todo: delete some concat op + if list(set(inputs) & set(vars_in_main_program)) == None: + need_delete_op.append(op) + delete_ops(current_block, need_delete_op) + + # delete useless var + for var in current_block.vars: + if var.name not in vars_in_main_program: + startup_program._remove_var(var.name) + + return startup_program diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/pserver_pass.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/pserver_pass.py index 765c18283b4..05deff10a2e 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/pserver_pass.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/pserver_pass.py @@ -37,7 +37,7 @@ LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched def _is_optimizer_op(op): if "Param" in op.input_names and \ - "LearningRate" in op.input_names: + "LearningRate" in op.input_names: return True return False @@ -49,7 +49,7 @@ def _same_or_split_var(p_name, var_name): def _get_optimizer_input_shape(op_type, varkey, orig_shape, param_shape): """ Returns the shape for optimizer inputs that need to be reshaped when - Param and Grad is split to multiple servers. + Param and Grad is split to multiple servers. """ # HACK(typhoonzero) : Should use functions of corresponding optimizer in # optimizer.py to get the shape, do not bind this in the transpiler. @@ -542,7 +542,7 @@ def add_optimizer_pass(program, config): for _, op in enumerate(optimize_ops): # optimizer is connected to itself if op.attr(OP_ROLE_VAR_ATTR_NAME)[0] == optimize_target_param_name and \ - op not in global_ops: + op not in global_ops: __append_optimize_op__(op, per_opt_block, grad_to_block_id, merged_var, lr_ops) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py index f9889997d9e..378c8fc23d7 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py @@ -12,33 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copyright(c) 2020 PaddlePaddle Authors.All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0(the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http: // www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import print_function from functools import reduce import collections import math import os +import warnings import six +import paddle.fluid as fluid from paddle.fluid import core from paddle.fluid.core import CommContext +import paddle.fluid.framework as framework from paddle.fluid.incubate.fleet.parameter_server.mode import DistributedMode from paddle.fluid.incubate.fleet.parameter_server.ir import vars_metatools from paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher import RoundRobin, PSDispatcher +from paddle.fluid.transpiler.details.program_utils import delete_ops OP_NAME_SCOPE = "op_namescope" CLIP_OP_NAME_SCOPE = "@CLIP" @@ -58,8 +48,8 @@ def _get_lr_ops(program): for index, op in enumerate(program.global_block().ops): role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME)) if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \ - role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \ - int(OPT_OP_ROLE_ATTR_VALUE): + role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \ + int(OPT_OP_ROLE_ATTR_VALUE): lr_ops.append(op) return lr_ops @@ -122,9 +112,20 @@ class MergedVariable: self.offsets = offsets +def Singleton(cls): + _instance = {} + + def _singleton(*args, **kargs): + if cls not in _instance: + _instance[cls] = cls(*args, **kargs) + return _instance[cls] + + return _singleton + + +@Singleton class CompileTimeStrategy(object): def __init__(self, main_program, startup_program, strategy, role_maker): - self.min_block_size = 8192 self.origin_main_program = main_program @@ -177,6 +178,12 @@ class CompileTimeStrategy(object): def get_ps_endpoints(self): return self.role_maker.get_pserver_endpoints() + def get_heter_worker_endpoints(self): + return self.role_maker._get_heter_worker_endpoints() + + def get_heter_worker_endpoint(self): + return self.role_maker._get_heter_worker_endpoint() + def get_origin_programs(self): return self.origin_main_program, self.origin_startup_program @@ -810,6 +817,30 @@ class CompileTimeStrategy(object): return sparse_param_grads, dense_param_grads + def remove_var_pair_by_grad(self, var_name): + + for index, pair in enumerate(self.merged_variables_pairs): + var = pair[0] + var_grad = pair[1] + if var_grad.merged_var.name == var_name: + del self.merged_variables_pairs[index] + + for index, pair in enumerate(self.merged_dense_pairs): + var = pair[0] + var_grad = pair[1] + if var_grad.merged_var.name == var_name: + del self.merged_dense_pairs[index] + return + + for index, pair in enumerate(self.merged_sparse_pairs): + var = pair[0] + var_grad = pair[1] + if var_grad.merged_var.name == var_name: + del self.merged_sparse_pairs[index] + return + + print("Not find {} in self.merge_pairs".format(var_name)) + def _is_opt_role_op(op): # NOTE : depend on oprole to find out whether this op is for @@ -817,7 +848,7 @@ def _is_opt_role_op(op): op_maker = core.op_proto_and_checker_maker optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize if op_maker.kOpRoleAttrName() in op.attr_names and \ - int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role): + int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role): return True return False diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py index 912eee0df0a..201b3863a4b 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py @@ -13,7 +13,13 @@ # limitations under the License. from __future__ import print_function +import six +import collections +import warnings +import math +from functools import reduce +import paddle.fluid as fluid import paddle.fluid.core as core import paddle.fluid.framework as framework @@ -34,6 +40,10 @@ LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() +DEVICE_LIST = ["cpu", "gpu", "xpu"] +COMMUNICATE_OPS_TYPE = ["send", "recv", "fetch_barrier", "send_barrier"] +DEFAULT_DEVICE = 'cpu' + def delete_optimizer_pass(program, config): def _delete_optimizer_op_and_vars(_program, optimize_ops): @@ -250,7 +260,7 @@ def fake_init_ops_pass(program, config): return list(set(dist_varnames + sparse_varnames)) def _fake_init_sparsetable(sparse_table_names): - #delete table init op + # delete table init op for table_name in sparse_table_names: table_var = program.global_block().vars[table_name] table_param_init_op = [] @@ -307,3 +317,871 @@ def delet_extra_optimizes_pass(program, config): program.global_block()._remove_var(var) return program + + +def find_heter_ops(program, default_device="cpu"): + if default_device not in DEVICE_LIST: + raise ValueError("Given device {} is not in device list {}".format( + default_device, DEVICE_LIST)) + + def _is_heter_op(op, current_heter_device, default_device="cpu"): + heter_devices = list(DEVICE_LIST) + heter_devices.remove(default_device) + op_device = op.attr("op_device") + op_type = op.type + if op_device in heter_devices: + return True + elif op_type in COMMUNICATE_OPS_TYPE and current_heter_device != default_device: + # for distributed communciate ops: send & recv & barrier etc. + # Todo: need update this method + op._set_attr('op_device', current_heter_device) + return True + elif op_device == None or op_device == default_device: + op._set_attr('op_device', default_device) + return False + return False + + def _is_same_device(op, pre_device, default_device="cpu"): + op_device = op.attr("op_device") + if op_device == pre_device: + return True + if pre_device == default_device: + return True + return False + + def _append_heter_op(op, current_heter_block_ops, heter_ops): + op_device = op.attr("op_device") + if op_device not in heter_ops: + heter_ops[op_device] = {} + current_heter_block_ops.append(op) + + origin_porgram = program.clone() + block = program.global_block() + + program_block_ops = [] + default_ops = {default_device: {}} + heter_ops = {} + block_index = 0 + # heter_ops: {"gpu": {1:[op1, op2, ...], 2:[op1, op2, ...] }; "xpu": {3:[op1, op2, ...], 4:[op1, op2, ...] }} + + current_heter_block_ops = [] + current_default_block_ops = [] + current_heter_device = default_device + is_heter = False + for op in block.ops: + if _is_heter_op(op, current_heter_device, default_device): + # for gpu/xpu-op + is_heter = True + + # for cpu-op block append + if len(current_default_block_ops) > 1: + default_ops[default_device][ + block_index] = current_default_block_ops + program_block_ops.append(current_default_block_ops) + current_default_block_ops = [] + block_index += 1 + + if _is_same_device(op, current_heter_device, default_device): + # for gpu-op, gpu-op -> gpu-op,... + current_heter_device = op.attr("op_device") + _append_heter_op(op, current_heter_block_ops, heter_ops) + else: + # for gpu-op -> xpu-op, ... + op_device = current_heter_block_ops[0].attr("op_device") + heter_ops[op_device][block_index] = current_heter_block_ops + program_block_ops.append(current_heter_block_ops) + block_index += 1 + current_heter_block_ops = [] + current_heter_device = op.attr("op_device") + _append_heter_op(op, current_heter_block_ops, heter_ops) + + elif is_heter: + # for gpu/xpu-op -> cpu-op + op_device = current_heter_block_ops[0].attr("op_device") + heter_ops[op_device][block_index] = current_heter_block_ops + program_block_ops.append(current_heter_block_ops) + block_index += 1 + current_heter_block_ops = [] + current_heter_device = default_device + is_heter = False + current_default_block_ops.append(op) + else: + # for cpu-op + current_default_block_ops.append(op) + + if current_default_block_ops != []: + default_ops[default_device][block_index] = current_default_block_ops + program_block_ops.append(current_default_block_ops) + + if current_heter_block_ops != []: + op_device = current_heter_block_ops[0].attr("op_device") + heter_ops[op_device][block_index] = current_heter_block_ops + program_block_ops.append(current_heter_block_ops) + + if len(heter_ops) == 0: + warnings.warn( + "No heterogeneous OP was found in your program , " + " please using fluid.device_guard() to run OPs on different device.") + + total_heter_ops = 0 + heter_blocks = 0 + for device in heter_ops.keys(): + heter_block_dict = heter_ops[device] + heter_blocks += len(heter_block_dict) + for _, heter_block in heter_block_dict.items(): + total_heter_ops += len(heter_block) + print( + "There are {} OPs in your main_program, and contains {} heter-OPs which is made up of {} heter-blocks.". + format(len(block.ops), total_heter_ops, heter_blocks)) + return origin_porgram, heter_ops, default_ops, program_block_ops + + +def create_heter_program(program, config, heter_program, heter_ops, + block_var_detail, current_device): + # add heter op + optimizer_block = [] + grad_to_block_id = [] + send_grad_var_list = [] + + pre_block_idx = heter_program.num_blocks - 1 + for index, heter_block_ops in heter_ops[current_device].items(): + heter_block = heter_program._create_block(pre_block_idx) + optimizer_block.append(heter_block) + for _, op in enumerate(heter_block_ops): + block_append_op(heter_program, program, heter_block, op) + + # add relate variables + inputs = _get_input_map_from_op(program.global_block().vars, op) + add_vars_by_op_map(inputs, heter_program) + + outputs = _get_output_map_from_op(program.global_block().vars, op) + add_vars_by_op_map(outputs, heter_program) + + entrance_vars = block_var_detail[index]["entrance"] + add_vars_by_var_list(entrance_vars, program, heter_program) + exit_vars = block_var_detail[index]["exit"] + add_vars_by_var_list(exit_vars, program, heter_program) + + comm_info = get_communicate_var_info(program, index, entrance_vars, + exit_vars) + + grad_to_block_id.append(comm_info["block_input_var_name"] + ":" + str( + heter_block.idx)) + + # create slice op + first_op_index = 0 + + get_type_var_name = comm_info["input_var_reshape_name"][0].split( + ".input_reshape@Heter")[0] + get_type_var = heter_program.global_block().vars[get_type_var_name] + + insert_recv_slice_op( + heter_program, heter_block, first_op_index, + comm_info["block_input_var_name"], + (-1, sum(comm_info["input_var_reshape_dim"])), get_type_var.dtype, + get_type_var.type, comm_info["input_var_reshape_name"], [ + (-1, comm_info["input_var_reshape_dim"][i]) + for i in range(len(comm_info["input_var_reshape_dim"])) + ]) + first_op_index += len(comm_info["input_var_reshape_dim"]) + # create reshape op + for i in range(len(comm_info["input_var_reshape_name"])): + var_name = entrance_vars[i] + insert_reshape_op( + heter_program, + heter_block, + first_op_index, + comm_info["input_var_reshape_name"][i], + var_name, ) + first_op_index += 1 + + first_op_index = len(heter_block.ops) + + # create send reshape op + for i in range(len(exit_vars)): + insert_reshape_op(heter_program, heter_block, first_op_index, + exit_vars[i], + comm_info["output_var_reshape_name"][i], + [-1, comm_info["output_var_reshape_dim"][i]]) + first_op_index += 1 + + # create send concat op + insert_send_concat_op(heter_program, heter_block, first_op_index, + comm_info["output_var_reshape_name"], + comm_info["block_output_var_name"], + [-1, sum(comm_info["output_var_reshape_dim"])]) + check_op_device(heter_block, current_device) + send_grad_var_list = send_grad_var_list + add_heter_send_op( + program, heter_program, heter_block, block_var_detail[index]) + + # add step conter + send_input_vars = [] + dummy_output = [] + trainer_id = config.get_role_id() + pserver_endpoints = config.get_ps_endpoints() + optimizer_block[-1].append_op( + type="send", + inputs={"X": send_input_vars}, + outputs={"Out": dummy_output}, + attrs={ + "send_varnames": [STEP_COUNTER], + "merge_add": True, + "use_send_handler": False, + "endpoints": pserver_endpoints + }) + + # add info in listen&serv + attrs = { + "grad_to_block_id": grad_to_block_id, + "sparse_grad_to_param": None, + "lr_decay_block_id": None, + "dense_optimize_blocks": None, + "sparse_optimize_blocks": None, + "optimize_blocks": optimizer_block, + + # runtime attribute + "endpoint": config.get_heter_worker_endpoint(), + "pserver_id": config.get_role_id(), + "Fanin": config.get_trainers(), + "distributed_mode": config.get_distributed_mode(), + "rpc_get_thread_num": 12, + "rpc_send_thread_num": 12, + "rpc_prefetch_thread_num": 12 + } + + # append the listen_and_serv op + heter_program.global_block().append_op( + type="listen_and_serv", inputs={'X': []}, outputs={}, attrs=attrs) + + check_heter_compile_time_strategy(program, config, send_grad_var_list) + + +def check_heter_compile_time_strategy(program, config, send_grad_var_list): + origin_grad_var_list = [] + for _, var_grad in config.merged_variables_pairs: + origin_grad_var_list.append(var_grad.merged_var.name) + + origin_grad_var_list = list(set(origin_grad_var_list)) + send_grad_var_list = list(set(send_grad_var_list)) + useless_grad_var_list = list( + set(origin_grad_var_list) - set(send_grad_var_list)) + + for useless_grad_var in useless_grad_var_list: + config.remove_var_pair_by_grad(useless_grad_var) + + +def create_trainer_program(program, config, heter_ops, block_var_detail): + for device in heter_ops.keys(): + for heter_block_index in sorted(heter_ops[device]): + replace_ops_by_communicate_op(program, config, heter_block_index, + heter_ops[device][heter_block_index], + block_var_detail) + remove_trainer_send_op(program, config, heter_block_index, + block_var_detail) + deleter_trainer_useless_var(program) + check_op_device(program.global_block(), DEFAULT_DEVICE) + + +def replace_ops_by_communicate_op(program, config, heter_block_index, ops_list, + block_var_detail): + all_op = program.global_block().ops + start_op = ops_list[0] + first_op_idx = -1 + for op in all_op: + if is_same_op(op, start_op): + first_op_idx = all_op.index(op) + break + assert first_op_idx != -1 + delete_same_ops(program.global_block(), ops_list) + + mode = config.get_distributed_mode() + heter_worker_endpoint = config.get_heter_worker_endpoint() + entrance_var = block_var_detail[heter_block_index]["entrance"] + exit_var = block_var_detail[heter_block_index]["exit"] + + default_device_comm_info = get_communicate_var_info( + program, heter_block_index - 1, + block_var_detail[heter_block_index - 1]["entrance"], + block_var_detail[heter_block_index - 1]["exit"]) + comm_info = get_communicate_var_info(program, heter_block_index, + entrance_var, exit_var) + + # create reshape op + for i in range(len(entrance_var)): + insert_reshape_op( + program, + program.global_block(), first_op_idx, entrance_var[i], + default_device_comm_info["output_var_reshape_name"][i], + [-1, default_device_comm_info["output_var_reshape_dim"][i]]) + first_op_idx += 1 + + # create concat op + insert_send_concat_op( + program, + program.global_block(), first_op_idx, + default_device_comm_info["output_var_reshape_name"], + default_device_comm_info["block_output_var_name"], + [-1, sum(default_device_comm_info["output_var_reshape_dim"])]) + first_op_idx += 1 + + # create send op + send_input_vars = [ + program.global_block().vars[default_device_comm_info[ + "block_output_var_name"]] + ] + + get_type_var_name = comm_info["output_var_reshape_name"][0].split( + ".output_reshape@Heter")[0] + get_type_var = program.global_block().vars[get_type_var_name] + + program.global_block().create_var( + name=comm_info["block_output_var_name"], + shape=(-1, sum(comm_info["output_var_reshape_dim"])), + dtype=get_type_var.dtype, + type=get_type_var.type) + + recv_vars = [ + program.global_block().vars[comm_info["block_output_var_name"]] + ] + + program.global_block()._insert_op( + index=first_op_idx, + type="send_and_recv", + inputs={"X": send_input_vars}, + outputs={"Out": recv_vars}, + attrs={ + "send_var_name": default_device_comm_info["block_output_var_name"], + "recv_var_name": comm_info["block_output_var_name"], + "endpoint": heter_worker_endpoint, + "trainer_id": config.get_role_id(), + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) + first_op_idx += 1 + + # recv + # create slice op + insert_recv_slice_op( + program, + program.global_block(), first_op_idx, + comm_info["block_output_var_name"], + (-1, sum(comm_info["output_var_reshape_dim"])), get_type_var.dtype, + get_type_var.type, comm_info["output_var_reshape_name"], [ + (-1, comm_info["output_var_reshape_dim"][i]) + for i in range(len(comm_info["output_var_reshape_dim"])) + ]) + + first_op_idx += len(comm_info["output_var_reshape_dim"]) + + # create reshape op + for i in range(len(comm_info["output_var_reshape_name"])): + var_name = comm_info["output_var_reshape_name"][i].split( + ".output_reshape@Heter")[0] + insert_reshape_op( + program, + program.global_block(), + first_op_idx, + comm_info["output_var_reshape_name"][i], + var_name, ) + first_op_idx += 1 + + +def remove_trainer_send_op(program, config, heter_block_index, + block_var_detaile): + # if trainer do FF->BP->SEND, it has follow vars: var, var@GRAD + # if trainer only do SEND, it has one var: var@GRAD + # Delete Send op ,if trainer doesn't has pair var (var<->var@GRAD) + persistables = block_var_detaile[heter_block_index]["persistables"] + need_remove_send_op = [] + need_remove_grad_var = [] + for op in find_send_op(program): + input_list, _ = find_op_input_output(program, + program.global_block(), op) + for var_name in input_list: + origin_var_name = var_name.split("@GRAD")[0] + if origin_var_name in persistables: + need_remove_send_op.append(op) + need_remove_grad_var.append(var_name) + need_remove_send_op = list(set(need_remove_send_op)) + delete_ops(program.global_block(), need_remove_send_op) + for grad_var_name in need_remove_grad_var: + config.remove_var_pair_by_grad(grad_var_name) + + +def add_heter_send_op(program, heter_program, block, block_var_detail): + def _get_send_op_dict(): + send_op_dict = {} + send_op_list = find_send_op(program) + for op in send_op_list: + input_list, _ = find_op_input_output(program, + program.global_block(), op) + for var in input_list: + send_op_dict[var] = op + return send_op_dict + + send_grad_var_list = [] + send_op_dict = _get_send_op_dict() + for persistable_var in block_var_detail["persistables"]: + # check var_name == var@GRAD + if "@GRAD" not in persistable_var: + continue + if "GRAD" != persistable_var.split("@")[-1]: + continue + if persistable_var not in send_op_dict: + continue + block_append_op(program, heter_program, block, + send_op_dict[persistable_var]) + send_grad_var_list.append(persistable_var) + return send_grad_var_list + + +def find_send_op(program): + send_op_list = [] + for op in program.global_block().ops: + if op.type == "send": + send_op_list.append(op) + return send_op_list + + +def get_communicate_var_info(program, block_index, entrance_var_list, + exit_var_list): + input_var_reshape_dim = [] + input_var_reshape_name = [] + block_input_var_name = "joint_{}_{}@Heter".format(block_index - 1, + block_index) + output_var_reshape_dim = [] + output_var_reshape_name = [] + block_output_var_name = "joint_{}_{}@Heter".format(block_index, + block_index + 1) + entrance_var_list.sort() + exit_var_list.sort() + # input + # Heter_SERVER_BLOCK_index@JOINT_VAR -> slice -> var@Heter_SERVER_BLOCK@INPUT_RESHAPE_VAR -> reshape -> var + for name in entrance_var_list: + var = program.global_block().vars[name] + shape = var.shape + if len(shape) < 2 or shape[0] != -1: + raise ValueError( + "Variable {} not support heter training. its shape is {}". + format(name, shape)) + recv_var_dim = -1 * reduce(lambda x, y: x * y, shape) + input_var_reshape_dim.append(recv_var_dim) + input_var_reshape_name.append("{}.input_reshape@Heter".format(name)) + + # output + # var -> reshape -> var@Heter_SERVER_BLOCK@INPUT_RESHAPE_VAR -> concat -> Heter_SERVER_BLOCK_index@JOINT_VAR + for var_name in exit_var_list: + var = program.global_block().vars[var_name] + shape = var.shape + if len(shape) < 2 or shape[0] != -1: + raise ValueError( + "Variable {} not support heter training. its shape is {}". + format(var_name, shape)) + send_reshape_dim = -1 * reduce(lambda x, y: x * y, shape) + output_var_reshape_dim.append(send_reshape_dim) + output_var_reshape_name.append("{}.output_reshape@Heter".format( + var_name)) + + info = { + "input_var_reshape_dim": input_var_reshape_dim, + "input_var_reshape_name": input_var_reshape_name, + "block_input_var_name": block_input_var_name, + "output_var_reshape_dim": output_var_reshape_dim, + "output_var_reshape_name": output_var_reshape_name, + "block_output_var_name": block_output_var_name + } + + return info + + +def find_block_joints(program, program_block_ops_list, heter_ops): + block_var_detail = find_entrance_exit_private(program, + program_block_ops_list) + block_var_detail = entrance_exit_check(program, program_block_ops_list, + block_var_detail, heter_ops) + block_var_detail = delete_block_useless_exit( + program, program_block_ops_list, block_var_detail) + return block_var_detail + + +def find_entrance_exit_private(program, program_block_ops_list): + block_var_detail = [] + persistables = [] + for index, block_op_list in enumerate(program_block_ops_list): + block_input, block_output = find_ops_list_input_output(program, + block_op_list) + persistables = screen_persistables( + program, block_input) + screen_persistables(program, block_output) + # find entrance & exit + block_private_vars = list(set(block_input) & set(block_output)) + block_entrance = list(set(block_input) - set(block_private_vars)) + block_exit = list(set(block_output) - set(block_private_vars)) + detail = { + "entrance": block_entrance, + "exit": block_exit, + "private": block_private_vars, + "persistables": persistables + } + block_var_detail.append(detail) + return block_var_detail + + +def entrance_exit_check(program, program_block_ops_list, block_var_detail, + heter_ops): + for index in range(len(block_var_detail) - 1, -1, -1): + if index - 1 < 0: + break + previous_block_exit = block_var_detail[index - 1]["exit"] + previous_block_exit.sort() + current_block_entrance = block_var_detail[index]["entrance"] + current_block_entrance.sort() + if previous_block_exit == current_block_entrance: + continue + exist_vars = list( + set(previous_block_exit) & set(current_block_entrance)) + need_add_vars = list(set(current_block_entrance) - set(exist_vars)) + need_add_vars = find_need_var_from_previous_block( + need_add_vars, block_var_detail, index, heter_ops) + + previous_block_private = block_var_detail[index - 1]["private"] + previous_block_entrance = block_var_detail[index - 1]["entrance"] + for var in need_add_vars: + if var not in previous_block_private and var not in previous_block_entrance: + previous_block_entrance.append(var) + previous_block_exit.append(var) + return block_var_detail + + +def find_need_var_from_previous_block(need_add_vars, block_var_detail, + current_index, heter_ops): + # create index_device_map + index_device_map = {} + for index in range(len(block_var_detail)): + index_device_map[index] = DEFAULT_DEVICE + for device in heter_ops: + for index in heter_ops[device].keys(): + index_device_map[index] = device + + pre_index = current_index - 1 + need_ignore_var = [] + + # if need_add_var in current device, no need communicate + for var in need_add_vars: + while (pre_index >= 0): + previous_block_private = block_var_detail[pre_index]["private"] + previous_block_exit = block_var_detail[pre_index]["exit"] + previous_block_entrance = block_var_detail[pre_index]["entrance"] + total_var = previous_block_private + previous_block_exit + previous_block_entrance + if var in total_var: + if index_device_map[current_index] == index_device_map[ + pre_index] and index_device_map[ + current_index] == DEFAULT_DEVICE: + need_ignore_var.append(var) + break + pre_index -= 1 + + need_add_vars = list(set(need_add_vars).difference(set(need_ignore_var))) + return need_add_vars + + +def delete_block_useless_exit(program, program_block_ops_list, + block_var_detail): + for index in range(len(block_var_detail)): + if index == len(block_var_detail) - 1: + break + current_block_exit = block_var_detail[index]["exit"] + next_block_entrance = block_var_detail[index + 1]["entrance"] + need_delete_var = [] + for var in current_block_exit: + if var not in next_block_entrance: + need_delete_var.append(var) + + for var in need_delete_var: + current_block_exit.remove(var) + + return block_var_detail + + +def check_op_device(block, device): + for op in block.ops: + op._set_attr('op_device', device) + + +def screen_persistables(program, var_list): + need_remove = [] + for var_name in var_list: + if "@GRAD" in var_name: + origin_var_name = var_name.split("@GRAD")[0] + var = program.global_block().vars[origin_var_name] + else: + var = program.global_block().vars[var_name] + + if fluid.io.is_persistable(var): + need_remove.append(var_name) + + for var_name in need_remove: + var_list.remove(var_name) + return need_remove + + +def insert_reshape_op(program, + block, + index, + var_name, + new_var_name, + new_var_shape=None): + input_var = program.global_block().vars[var_name] + + if new_var_name not in program.global_block().vars: + out = program.global_block().create_var( + name=new_var_name, + shape=new_var_shape, + dtype=input_var.dtype, + type=input_var.type) + else: + out = program.global_block().vars[new_var_name] + new_var_shape = out.shape + + x_shape = program.global_block().create_var( + name="{}.xshape@Heter".format(var_name), dtype=input_var.dtype) + block._insert_op( + index=index, + type="reshape2", + inputs={"X": input_var}, + attrs={'shape': new_var_shape}, + outputs={"Out": out, + "XShape": x_shape}) + + +def insert_send_concat_op(program, block, index, var_name_list, new_var_name, + new_var_shape): + input_var_list = [ + program.global_block().vars[var_name] for var_name in var_name_list + ] + + out = program.global_block().create_var( + name=new_var_name, + shape=new_var_shape, + dtype=input_var_list[0].dtype, + type=input_var_list[0].type) + + block._insert_op( + index=index, + type='concat', + inputs={"X": input_var_list}, + outputs={'Out': [out]}, + attrs={'axis': -1, + 'use_stack': False}) + + +def insert_recv_slice_op(program, block, index, var_name, var_shape, dtype, + type, new_var_name_list, new_var_shape_list): + + if var_name not in program.global_block().vars: + input_var = program.global_block().create_var( + name=var_name, shape=var_shape, dtype=dtype, type=type) + else: + input_var = program.global_block().vars[var_name] + + out_list = [] + for i in range(len(new_var_name_list)): + if new_var_name_list[i] not in program.global_block().vars: + out = program.global_block().create_var( + name=new_var_name_list[i], + shape=new_var_shape_list[i], + dtype=input_var.dtype, + type=input_var.type) + else: + out = program.global_block().vars[new_var_name_list[i]] + out_list.append(out) + + start_index = 0 + end_index = 0 + for i in range(len(new_var_name_list)): + starts = [] + ends = [] + attrs = {'axes': [1]} + end_index += new_var_shape_list[i][1] + starts.append(start_index) + ends.append(end_index) + attrs['starts'] = starts + attrs['ends'] = ends + + block._insert_op( + index=index, + type='slice', + inputs={'Input': input_var}, + attrs=attrs, + outputs={'Out': out_list[i]}) + start_index = end_index + index += 1 + + +def deleter_trainer_useless_var(program): + porgram_useful_var_list = [] + for op in program.global_block().ops: + input_var_list, output_var_list = find_op_input_output( + program, program.global_block(), op) + op_var_list = list(set(input_var_list).union(set(output_var_list))) + porgram_useful_var_list = list( + set(porgram_useful_var_list).union(set(op_var_list))) + + program_useless_var_list = list( + set(get_vars_name_in_block(program.global_block())).difference( + set(porgram_useful_var_list))) + for var in program_useless_var_list: + program.global_block()._remove_var(var) + return program_useless_var_list + + +def block_append_op(program, origin_program, block, op): + inputs = _get_input_map_from_op(origin_program.global_block().vars, op) + for key, varlist in six.iteritems(inputs): + if not isinstance(varlist, list): + varlist = [varlist] + for var in varlist: + if var.name not in program.global_block().vars: + program.global_block()._clone_variable(var) + + outputs = _get_output_map_from_op(origin_program.global_block().vars, op) + for key, varlist in six.iteritems(outputs): + if not isinstance(varlist, list): + varlist = [varlist] + for var in varlist: + if var.name not in program.global_block().vars: + program.global_block()._clone_variable(var) + + if "_grad" not in op.type: + # for forward op + return block.append_op( + type=op.type, inputs=inputs, outputs=outputs, attrs=op.all_attrs()) + else: + # for grad op + op_desc = op.desc + op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() + backward = core.op_proto_and_checker_maker.OpRole.Backward + device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() + + # append grad op + new_op_desc = block.desc.append_op() + new_op_desc.copy_from(op_desc) + new_op_desc._set_attr(op_role_attr_name, backward) + + # set device gard + if op.desc.has_attr(device_attr_name): + op_device = op_desc.attr(device_attr_name) + new_op_desc._set_attr(device_attr_name, op_device) + block._sync_with_cpp() + + +def add_vars_by_op_map(var_map, program): + for key, varlist in six.iteritems(var_map): + if not isinstance(varlist, list): + varlist = [varlist] + for i in range(len(varlist)): + var = varlist[i] + if var.name not in program.global_block().vars: + program.global_block()._clone_variable(var) + + +def add_vars_by_var_list(var_name_list, origin_program, program): + for var_name in var_name_list: + if var_name not in program.global_block().vars: + var = origin_program.global_block().vars[var_name] + program.global_block()._clone_variable(var) + + +def get_varlist_from_op_map(var_map): + var_list = [] + for key, varlist in six.iteritems(var_map): + if not isinstance(varlist, list): + varlist = [varlist] + for i in range(len(varlist)): + var = varlist[i] + var_list.append(var.name) + return var_list + + +def find_ops_list_input_output(program, ops_list): + input_var_list = [] + output_var_list = [] + for op in ops_list: + inputs = _get_input_map_from_op(program.global_block().vars, op) + input_var_list += get_varlist_from_op_map(inputs) + outputs = _get_output_map_from_op(program.global_block().vars, op) + output_var_list += get_varlist_from_op_map(outputs) + + input_var_list = list(set(input_var_list)) + output_var_list = list(set(output_var_list)) + return input_var_list, output_var_list + + +def find_op_input_output(program, block, op): + input_var_list = [] + output_var_list = [] + inputs = _get_input_map_from_op(block.vars, op) + input_var_list += get_varlist_from_op_map(inputs) + outputs = _get_output_map_from_op(block.vars, op) + output_var_list += get_varlist_from_op_map(outputs) + input_var_list = list(set(input_var_list)) + output_var_list = list(set(output_var_list)) + return input_var_list, output_var_list + + +def get_vars_name_in_block(block): + vars_list = block.vars.keys() + vars_name_list = [var_name for var_name in vars_list] + return vars_name_list + + +def is_same_op(op1, op2): + if str(op1) != str(op2): + return False + return True + + +def _get_input_map_from_op(varmap, op): + """Returns a dict from op input name to the vars in varmap.""" + iomap = collections.OrderedDict() + for key in op.input_names: + vars = [] + for varname in op.input(key): + if varname == "@EMPTY@": + continue + if "lod_tensor_blocking_queue" in varname: + continue + vars.append(varmap[varname]) + if len(vars) == 1: + iomap[key] = vars[0] + else: + iomap[key] = vars + return iomap + + +def _get_output_map_from_op(varmap, op): + """Returns a dict from op output name to the vars in varmap.""" + iomap = collections.OrderedDict() + for key in op.output_names: + vars = [] + for varname in op.output(key): + if varname == "@EMPTY@": + continue + if "lod_tensor_blocking_queue" in varname: + continue + vars.append(varmap[varname]) + if len(vars) == 1: + iomap[key] = vars[0] + else: + iomap[key] = vars + return iomap + + +def delete_same_ops(block, ops): + for op in ops: + try: + for origin_op in block.ops: + if is_same_op(origin_op, op): + idx = list(block.ops).index(origin_op) + block._remove_op(idx) + break + except Exception as e: + print(e) diff --git a/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py b/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py index fe7513ae842..863c001f226 100644 --- a/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py +++ b/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py @@ -17,8 +17,9 @@ from __future__ import print_function import os import logging import tarfile - +import tempfile import random +import warnings import paddle import paddle.fluid.incubate.data_generator as data_generator @@ -57,7 +58,7 @@ def load_dnn_input_record(sent): def load_lr_input_record(sent): res = [] for _ in [x.split(':') for x in sent.split()]: - res.append(int(_[0])) + res.append(int(_[0]) % 10000) return res @@ -120,9 +121,62 @@ def prepare_data(): lr_input_dim = res[1] logger.info('dnn input dim: %d' % dnn_input_dim) logger.info('lr input dim: %d' % lr_input_dim) + return dnn_input_dim, lr_input_dim, train_file_path +def gen_fake_line(dnn_data_num=7, + dnn_data_range=1e5, + lr_data_num=5, + lr_data_range=1e5): + line = "" + + # for deep data + for index in range(dnn_data_num): + data = str(random.randint(0, dnn_data_range - 1)) + if index < dnn_data_num - 1: + data += " " + line += data + line += "\t" + + # for wide data + for index in range(lr_data_num): + data = str(random.randint(0, lr_data_range - 1)) + ":" + str(1) + if index < lr_data_num - 1: + data += " " + line += data + line += "\t" + + # for label + line += str(random.randint(0, 1)) + line += "\n" + return line + + +def prepare_fake_data(file_nums=8, file_lines=1000): + """ + Create fake data with same type as avazu_ctr_data + """ + file_dir = tempfile.mkdtemp() + warnings.warn("Fake data write in {}".format(file_dir)) + for file_index in range(file_nums): + with open( + os.path.join(file_dir, + "ctr_train_data_part_{}".format(file_index)), + 'w+') as fin: + file_str = "" + for line_index in range(file_lines): + file_str += gen_fake_line() + fin.write(file_str) + warnings.warn("Write done ctr_train_data_part_{}".format( + file_index)) + + file_list = [os.path.join(file_dir, x) for x in os.listdir(file_dir)] + assert len(file_list) == file_nums + + return file_list + + if __name__ == "__main__": pairwise_reader = DatasetCtrReader() pairwise_reader.run_from_stdin() diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py b/python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py new file mode 100644 index 00000000000..0de898d6dde --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py @@ -0,0 +1,220 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Distribute CTR model for test fleet api +""" + +from __future__ import print_function + +import shutil +import tempfile +import time + +import paddle +import paddle.fluid as fluid +import os +import numpy as np + +import ctr_dataset_reader +from test_dist_fleet_heter_base import runtime_main, FleetDistHeterRunnerBase +from dist_fleet_ctr import TestDistCTR2x2, fake_ctr_reader +from paddle.distributed.fleet.base.util_factory import fleet_util + +# Fix seed for test +fluid.default_startup_program().random_seed = 1 +fluid.default_main_program().random_seed = 1 + + +class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): + """ + For test CTR model, using Fleet api + """ + + def net(self, args, batch_size=4, lr=0.01): + """ + network definition + + Args: + batch_size(int): the size of mini-batch for training + lr(float): learning rate of training + Returns: + avg_cost: LoDTensor of cost. + """ + dnn_input_dim, lr_input_dim = int(1e5), int(1e5) + + dnn_data = fluid.layers.data( + name="dnn_data", + shape=[-1, 1], + dtype="int64", + lod_level=1, + append_batch_size=False) + lr_data = fluid.layers.data( + name="lr_data", + shape=[-1, 1], + dtype="int64", + lod_level=1, + append_batch_size=False) + label = fluid.layers.data( + name="click", + shape=[-1, 1], + dtype="float32", + lod_level=0, + append_batch_size=False) + + datas = [dnn_data, lr_data, label] + + if args.reader == "pyreader": + self.reader = fluid.io.PyReader( + feed_list=datas, + capacity=64, + iterable=False, + use_double_buffer=False) + + # build dnn model + dnn_layer_dims = [128, 64, 32, 1] + dnn_embedding = fluid.layers.embedding( + is_distributed=False, + input=dnn_data, + size=[dnn_input_dim, dnn_layer_dims[0]], + param_attr=fluid.ParamAttr( + name="deep_embedding", + initializer=fluid.initializer.Constant(value=0.01)), + is_sparse=True) + dnn_pool = fluid.layers.sequence_pool( + input=dnn_embedding, pool_type="sum") + dnn_out = dnn_pool + + # build lr model + lr_embbding = fluid.layers.embedding( + is_distributed=False, + input=lr_data, + size=[lr_input_dim, 1], + param_attr=fluid.ParamAttr( + name="wide_embedding", + initializer=fluid.initializer.Constant(value=0.01)), + is_sparse=True) + lr_pool = fluid.layers.sequence_pool(input=lr_embbding, pool_type="sum") + + with fluid.device_guard("gpu"): + for i, dim in enumerate(dnn_layer_dims[1:]): + fc = fluid.layers.fc( + input=dnn_out, + size=dim, + act="relu", + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01)), + name='dnn-fc-%d' % i) + dnn_out = fc + + merge_layer = fluid.layers.concat(input=[dnn_out, lr_pool], axis=1) + label = fluid.layers.cast(label, dtype="int64") + predict = fluid.layers.fc(input=merge_layer, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + fluid.layers.Print(avg_cost, message="avg_cost") + + self.feeds = datas + self.train_file_path = ["fake1", "fake2"] + self.avg_cost = avg_cost + self.predict = predict + + return avg_cost + + def check_model_right(self, dirname): + model_filename = os.path.join(dirname, "__model__") + + with open(model_filename, "rb") as f: + program_desc_str = f.read() + + program = fluid.Program.parse_from_string(program_desc_str) + with open(os.path.join(dirname, "__model__.proto"), "w") as wn: + wn.write(str(program)) + + def do_pyreader_training(self, fleet): + """ + do training using dataset, using fetch handler to catch variable + Args: + fleet(Fleet api): the fleet object of Parameter Server, define distribute training role + """ + + exe = fluid.Executor(fluid.CPUPlace()) + fleet.init_worker() + exe.run(fluid.default_startup_program()) + batch_size = 4 + train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size) + self.reader.decorate_sample_list_generator(train_reader) + + for epoch_id in range(1): + self.reader.start() + try: + pass_start = time.time() + while True: + exe.run(program=fluid.default_main_program()) + + pass_time = time.time() - pass_start + except fluid.core.EOFException: + self.reader.reset() + + fleet.stop_worker() + + def do_dataset_training(self, fleet): + train_file_list = ctr_dataset_reader.prepare_fake_data() + + exe = fluid.Executor(fluid.CPUPlace()) + + fleet.init_worker() + exe.run(fluid.default_startup_program()) + + thread_num = 1 + batch_size = 128 + filelist = fleet_util.get_file_shard(train_file_list) + print("filelist: {}".format(filelist)) + + # config dataset + dataset = paddle.distributed.fleet.DatasetFactory().create_dataset() + dataset.set_batch_size(batch_size) + dataset.set_use_var(self.feeds) + pipe_command = 'python ctr_dataset_reader.py' + dataset.set_pipe_command(pipe_command) + + dataset.set_filelist(filelist) + dataset.set_thread(thread_num) + + for epoch_id in range(1): + pass_start = time.time() + dataset.set_filelist(filelist) + exe.train_from_dataset( + program=fluid.default_main_program(), + dataset=dataset, + fetch_list=[self.avg_cost], + fetch_info=["cost"], + print_period=2, + debug=int(os.getenv("Debug", "0"))) + pass_time = time.time() - pass_start + print("do_dataset_training done. using time {}".format(pass_time)) + if os.getenv("SAVE_MODEL") == "1": + model_dir = tempfile.mkdtemp() + fleet.save_inference_model(exe, model_dir, + [feed.name for feed in self.feeds], + self.avg_cost) + self.check_model_right(model_dir) + shutil.rmtree(model_dir) + + fleet.stop_worker() + print("do_dataset_training stop worker.") + + +if __name__ == "__main__": + runtime_main(TestHeterPsCTR2x2) diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_base.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_base.py new file mode 100644 index 00000000000..4d744c8299f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_base.py @@ -0,0 +1,388 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +""" + high level unit test for distribute fleet. +""" + +import os +import sys +import subprocess + +import six +import shutil +import numpy as np +import argparse +from contextlib import closing +import socket +import time +import tempfile +import unittest + +import paddle +import paddle.fluid as fluid +import paddle.distributed.fleet.base.role_maker as role_maker +from paddle.distributed.fleet.base.util_factory import fleet_util +from paddle.distributed.fleet import fleet + +__all__ = ['FleetDistHeterRunnerBase', 'TestFleetHeterBase', 'runtime_main'] + +RUN_STEP = 5 +LEARNING_RATE = 0.01 +DIST_UT_PORT = 0 + + +class FleetDistHeterRunnerBase(object): + """ + run_pserver,run_trainer : after init role, using transpiler split program + net : implment by child class, the network of model + do training : exe run program + """ + + def build_role(self, args): + environs = {} + environs["PADDLE_PSERVERS_IP_PORT_LIST"] = args.endpoints + environs["PADDLE_TRAINER_ENDPOINTS"] = args.trainer_endpoints + environs[ + "PADDLE_HETER_TRAINER_IP_PORT_LIST"] = args.heter_trainer_endpoints + environs["PADDLE_HETER_TRAINER_DEVICE"] = args.heter_trainer_device + environs["TRAINING_ROLE"] = args.role.upper() + environs["PADDLE_TRAINERS_NUM"] = args.trainers + environs["PADDLE_TRAINER_ID"] = args.current_id + if args.role.upper() == "PSERVER": + environs["POD_IP"] = args.endpoints.split(",")[int( + args.current_id)].split(":")[0] + environs["PADDLE_PORT"] = args.endpoints.split(",")[int( + args.current_id)].split(":")[1] + elif args.role.upper() == "HETER_TRAINER": + environs["POD_IP"] = args.heter_trainer_endpoints.split(",")[int( + args.current_id)].split(":")[0] + environs["PADDLE_PORT"] = args.heter_trainer_endpoints.split(",")[ + int(args.current_id)].split(":")[1] + environs["FLAGS_selected_gpus"] = args.current_id + + for k, v in environs.items(): + os.environ[k] = str(v) + + self.role = role_maker.PaddleCloudRoleMaker() + return self.role + + def build_strategy(self, args): + self.strategy = paddle.distributed.fleet.DistributedStrategy() + self.strategy.a_sync = True + + return self.strategy + + def build_optimizer(self, avg_cost, strategy): + optimizer = fluid.optimizer.SGD(LEARNING_RATE) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + def run_pserver(self, args): + fleet.init_server() + fleet.run_server() + + def run_dataset_trainer(self, args): + out = self.do_dataset_training(fleet) + + def run_pyreader_trainer(self, args): + out = self.do_pyreader_training(fleet) + + def net(self, args, batch_size=4, lr=0.01): + raise NotImplementedError( + "get_model should be implemented by child classes.") + + def do_dataset_training(self, fleet): + raise NotImplementedError( + "do_dataset_training should be implemented by child classes.") + + def do_pyreader_training(self, fleet): + raise NotImplementedError( + "do_pyreader_training should be implemented by child classes.") + + +class TestFleetHeterBase(unittest.TestCase): + """ + start_pserver,start_trainer : add start cmd to test + run_cluster : using multi process to test distribute program + """ + + def _setup_config(self): + raise NotImplementedError("tests should have _setup_config implemented") + + def tearDown(self): + t = time.time() - self.startTime + print('%s: %.3f' % (self.__class__.__name__, t)) + + def setUp(self): + self.startTime = time.time() + + self._mode = "async" + self._reader = "pyreader" + self._trainers = 2 + self._pservers = 2 + self._port_set = set() + + self._heter_device = "gpu" + + global DIST_UT_PORT + if DIST_UT_PORT == 0 and os.getenv("PADDLE_DIST_UT_PORT"): + DIST_UT_PORT = int(os.getenv("PADDLE_DIST_UT_PORT")) + + if DIST_UT_PORT: + print("set begin_port:", DIST_UT_PORT) + self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( + DIST_UT_PORT, DIST_UT_PORT + 1) + self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( + DIST_UT_PORT + 2, DIST_UT_PORT + 3) + self._heter_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( + DIST_UT_PORT + 4, DIST_UT_PORT + 5) + DIST_UT_PORT += 6 + else: + self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( + self._find_free_port(), self._find_free_port()) + self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( + self._find_free_port(), self._find_free_port()) + self._heter_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( + self._find_free_port(), self._find_free_port()) + + self._python_interp = sys.executable + self._geo_sgd_need_push_nums = 5 + self._grad_clip_mode = 0 + self._setup_config() + + def _find_free_port(self): + def __free_port(): + with closing(socket.socket(socket.AF_INET, + socket.SOCK_STREAM)) as s: + s.bind(('', 0)) + return s.getsockname()[1] + + while True: + port = __free_port() + if port not in self._port_set: + self._port_set.add(port) + return port + + def _start_pserver(self, cmd, required_envs): + ps0_cmd, ps1_cmd = cmd.format(0), cmd.format(1) + + ps0_pipe = open(tempfile.gettempdir() + "/ps0_err.log", "wb+") + ps1_pipe = open(tempfile.gettempdir() + "/ps1_err.log", "wb+") + + ps0_proc = subprocess.Popen( + ps0_cmd.strip().split(" "), + stdout=subprocess.PIPE, + stderr=ps0_pipe, + env=required_envs) + ps1_proc = subprocess.Popen( + ps1_cmd.strip().split(" "), + stdout=subprocess.PIPE, + stderr=ps1_pipe, + env=required_envs) + return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe + + def _start_trainer(self, cmd, required_envs): + tr0_cmd, tr1_cmd = cmd.format(0), cmd.format(1) + + tr0_pipe = open(tempfile.gettempdir() + "/tr0_err.log", "wb+") + tr1_pipe = open(tempfile.gettempdir() + "/tr1_err.log", "wb+") + + tr0_out = open(tempfile.gettempdir() + "/tr0_out.log", "wb+") + tr1_out = open(tempfile.gettempdir() + "/tr1_out.log", "wb+") + + tr0_proc = subprocess.Popen( + tr0_cmd.strip().split(" "), + stdout=tr0_out, + stderr=tr0_pipe, + env=required_envs) + tr1_proc = subprocess.Popen( + tr1_cmd.strip().split(" "), + stdout=tr1_out, + stderr=tr1_pipe, + env=required_envs) + + return tr0_proc, tr1_proc, tr0_pipe, tr1_pipe + + def _start_heter_trainer(self, cmd, required_envs): + heter0_cmd, heter1_cmd = cmd.format(0), cmd.format(1) + + heter0_pipe = open(tempfile.gettempdir() + "/heter0_err.log", "wb+") + heter1_pipe = open(tempfile.gettempdir() + "/heter1_err.log", "wb+") + heter0_out = open(tempfile.gettempdir() + "/heter0_out.log", "wb+") + heter1_out = open(tempfile.gettempdir() + "/heter1_out.log", "wb+") + + heter0_proc = subprocess.Popen( + heter0_cmd.strip().split(" "), + stdout=heter0_out, + stderr=heter0_pipe, + env=required_envs) + heter1_proc = subprocess.Popen( + heter1_cmd.strip().split(" "), + stdout=heter1_out, + stderr=heter1_pipe, + env=required_envs) + + return heter0_proc, heter1_proc, heter0_pipe, heter1_pipe + + def _run_cluster(self, model, envs): + env = {'GRAD_CLIP': str(self._grad_clip_mode)} + python_path = self._python_interp + gloo_path = tempfile.mkdtemp() + + if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': + envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '') + python_path += " -m coverage run --branch -p" + env.update(envs) + + tr_cmd = "{0} {1} --role trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}".format( + python_path, model, self._ps_endpoints, self._tr_endpoints, + self._trainers, self._mode, self._geo_sgd_need_push_nums, + self._reader, gloo_path, self._heter_endpoints, self._heter_device) + + ps_cmd = "{0} {1} --role pserver --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}".format( + python_path, model, self._ps_endpoints, self._tr_endpoints, + self._trainers, self._mode, self._geo_sgd_need_push_nums, + self._reader, gloo_path, self._heter_endpoints, self._heter_device) + + heter_cmd = "{0} {1} --role heter_trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}".format( + python_path, model, self._ps_endpoints, self._tr_endpoints, + self._trainers, self._mode, self._geo_sgd_need_push_nums, + self._reader, gloo_path, self._heter_endpoints, self._heter_device) + + # Run dist train to compare with local results + ps0, ps1, ps0_pipe, ps1_pipe = self._start_pserver(ps_cmd, env) + tr0, tr1, tr0_pipe, tr1_pipe = self._start_trainer(tr_cmd, env) + heter0, heter1, heter0_pipe, heter1_pipe = self._start_heter_trainer( + heter_cmd, env) + + # Wait until trainer process terminate + while True: + stat0 = tr0.poll() + time.sleep(0.1) + if stat0 is not None: + break + + while True: + stat1 = tr1.poll() + time.sleep(0.1) + if stat1 is not None: + break + + tr0_out, tr0_err = tr0.communicate() + tr1_out, tr1_err = tr1.communicate() + print("tr end communicate") + + tr0_ret = tr0.returncode + tr1_ret = tr0.returncode + print("tr get returncode: {}".format(tr0_ret)) + if tr0_ret != 0: + print( + "========================Error tr0_err begin===========================" + ) + os.system("cat {}".format(tempfile.gettempdir() + "/tr0_err.log")) + print( + "========================Error tr0_err end===========================" + ) + + if tr1_ret != 0: + print( + "========================Error tr1_err begin===========================" + ) + os.system("cat {}".format(tempfile.gettempdir() + "/tr1_err.log")) + print( + "========================Error tr1_err end===========================" + ) + + self.assertEqual(tr0_ret, 0, "something wrong in tr0, please check") + self.assertEqual(tr1_ret, 0, "something wrong in tr1, please check") + + # close trainer file + tr0_pipe.close() + tr1_pipe.close() + ps0_pipe.close() + ps1_pipe.close() + heter0_pipe.close() + heter1_pipe.close() + + ps0.terminate() + ps1.terminate() + heter0.terminate() + heter1.terminate() + + shutil.rmtree(gloo_path) + return 0, 0 + + def check_with_place(self, + model_file, + delta=1e-3, + check_error_log=False, + need_envs={}): + required_envs = { + "PATH": os.getenv("PATH", ""), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "FLAGS_rpc_deadline": "5000", # 5sec to fail fast + "http_proxy": "" + } + + required_envs.update(need_envs) + + if check_error_log: + required_envs["GLOG_v"] = "3" + required_envs["GLOG_logtostderr"] = "1" + + tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs) + + +def runtime_main(test_class): + parser = argparse.ArgumentParser(description='Run Fleet test.') + parser.add_argument( + '--role', + type=str, + required=True, + choices=['pserver', 'trainer', 'heter_trainer']) + parser.add_argument('--endpoints', type=str, required=False, default="") + parser.add_argument( + '--trainer_endpoints', type=str, required=False, default="") + parser.add_argument( + '--heter_trainer_endpoints', type=str, required=False, default="") + parser.add_argument( + '--heter_trainer_device', type=str, required=False, default="gpu") + parser.add_argument('--gloo_path', type=str, required=False, default="") + parser.add_argument('--current_id', type=int, required=False, default=0) + parser.add_argument('--trainers', type=int, required=False, default=1) + parser.add_argument('--mode', type=str, required=False, default='async') + parser.add_argument( + '--geo_sgd_need_push_nums', type=int, required=False, default=2) + parser.add_argument('--reader', type=str, required=False, default='dataset') + args = parser.parse_args() + + model = test_class() + role = model.build_role(args) + fleet.init(role) + strategy = model.build_strategy(args) + avg_cost = model.net(args) + model.build_optimizer(avg_cost, strategy) + fleet_util._set_strategy(strategy) + fleet_util._set_role_maker(role) + + if args.role == "pserver" or args.role == "heter_trainer": + model.run_pserver(args) + else: + if args.reader == "dataset": + model.run_dataset_trainer(args) + else: + model.run_pyreader_trainer(args) diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py new file mode 100644 index 00000000000..c3ffd50dc8d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py @@ -0,0 +1,56 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import unittest +import tempfile +from test_dist_fleet_heter_base import TestFleetHeterBase + + +class TestDistHeterDatasetAsync2x2(TestFleetHeterBase): + def _setup_config(self): + self._mode = "async" + self._reader = "dataset" + + def check_with_place(self, + model_file, + delta=1e-3, + check_error_log=False, + need_envs={}): + required_envs = { + "PATH": os.getenv("PATH", ""), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "FLAGS_rpc_deadline": "5000", # 5sec to fail fast + "http_proxy": "", + "CPU_NUM": "1" + } + + required_envs.update(need_envs) + + if check_error_log: + required_envs["GLOG_v"] = "4" + required_envs["GLOG_logtostderr"] = "1" + + tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs) + + def test_dist_train(self): + self.check_with_place( + "dist_fleet_heter_ctr.py", delta=1e-5, check_error_log=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_program.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_program.py new file mode 100644 index 00000000000..33690396612 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_program.py @@ -0,0 +1,139 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import paddle +import os +import math +import paddle.fluid as fluid +import paddle.distributed.fleet.base.role_maker as role_maker +from paddle.distributed.fleet.base.util_factory import fleet_util +from paddle.distributed.fleet import fleet + + +class TestDistFleetHeterProgram(unittest.TestCase): + def build_role(self): + environs = {} + environs[ + "PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36012,127.0.0.1:36013" + environs["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36014,127.0.0.1:36015" + environs[ + "PADDLE_HETER_TRAINER_IP_PORT_LIST"] = "127.0.0.1:36016,127.0.0.1:36017" + environs["PADDLE_HETER_TRAINER_DEVICE"] = "gpu" + environs["TRAINING_ROLE"] = "HETER_TRAINER" + environs["PADDLE_TRAINERS_NUM"] = 2 + environs["PADDLE_TRAINER_ID"] = 0 + environs["POD_IP"] = "127.0.0.1" + environs["PADDLE_PORT"] = "36016" + environs["FLAGS_selected_gpus"] = 0 + + for k, v in environs.items(): + os.environ[k] = str(v) + + self.role = role_maker.PaddleCloudRoleMaker() + return self.role + + def build_strategy(self): + self.strategy = paddle.distributed.fleet.DistributedStrategy() + self.strategy.a_sync = True + return self.strategy + + def build_input(self): + dense_input = fluid.layers.data( + name="dense_input", shape=[10], dtype="float32") + + sparse_input_ids = [ + fluid.layers.data( + name="C" + str(i), shape=[1], lod_level=1, dtype="int64") + for i in range(1, 27) + ] + + label = fluid.layers.data(name="label", shape=[1], dtype="float32") + + inputs = [dense_input] + sparse_input_ids + [label] + return inputs + + def build_net(self, inputs): + def embedding_layer(input): + return fluid.layers.embedding( + input=input, + is_sparse=True, + size=[100001, 10], + param_attr=fluid.ParamAttr( + name="SparseFeatFactors", + initializer=fluid.initializer.Uniform()), ) + + sparse_embed_seq = list(map(embedding_layer, inputs[1:-1])) + + concated = fluid.layers.concat(sparse_embed_seq + inputs[0:1], axis=1) + + with fluid.device_guard("gpu"): + fc1 = fluid.layers.fc( + input=concated, + size=400, + act="relu", + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(concated.shape[1]))), + name="fc1") + + with fluid.device_guard("cpu"): + fc2 = fluid.layers.fc(input=fc1, + size=400, + act="relu", + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(fc1.shape[1]))), + name="fc2") + + with fluid.device_guard("gpu"): + fc3 = fluid.layers.fc(input=fc2, + size=400, + act="relu", + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(fc2.shape[1]))), + name="fc3") + + with fluid.device_guard("cpu"): + predict = fluid.layers.fc( + input=fc3, + size=2, + act="softmax", + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal( + scale=1 / math.sqrt(fc3.shape[1]))), ) + + with fluid.device_guard("gpu"): + labels = fluid.layers.cast(inputs[-1], dtype="int64") + cost = fluid.layers.cross_entropy(input=predict, label=labels) + avg_cost = fluid.layers.reduce_sum(cost) + + return avg_cost + + def build_optimizer(self, avg_cost, strategy): + optimizer = fluid.optimizer.SGD(1e-2) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + def test(self): + role = self.build_role() + fleet.init(role) + strategy = self.build_strategy() + inputs = self.build_input() + avg_cost = self.build_net(inputs) + self.build_optimizer(avg_cost, strategy) + + +if __name__ == "__main__": + unittest.main() -- GitLab