From 9421cf4605b570bb991bb0148d6f6e023c2eaf8e Mon Sep 17 00:00:00 2001 From: willzhang4a58 Date: Mon, 9 Oct 2017 17:04:41 +0800 Subject: [PATCH] make control module Former-commit-id: c259b77a7d36d7e796c219844c561a2c65311748 --- examples/mnist/train/resource.prototxt | 5 +- .../core/comm_network/ctrl_comm_network.cpp | 106 ------------- .../epoll/epoll_data_comm_network.cpp | 44 ++++-- .../epoll/epoll_data_comm_network.h | 7 +- .../rpc => control}/control.proto | 50 ++++-- .../{comm_network/rpc => control}/ctrl_call.h | 8 +- oneflow/core/control/ctrl_client.cpp | 146 ++++++++++++++++++ .../ctrl_client.h} | 59 +++---- .../rpc => control}/ctrl_server.cpp | 107 +++++++------ .../rpc => control}/ctrl_server.h | 24 +-- .../rpc => control}/ctrl_service.cpp | 2 +- .../rpc => control}/ctrl_service.h | 17 +- oneflow/core/job/resource.proto | 5 +- oneflow/core/job/runtime.cpp | 15 +- oneflow/core/job/runtime_context.cpp | 5 +- oneflow/core/job/runtime_context.h | 6 +- oneflow/core/job/scheduler.cpp | 37 +++-- oneflow/core/persistence/snapshot.h | 2 +- scripts/run_scheduler.sh | 2 +- 19 files changed, 378 insertions(+), 269 deletions(-) delete mode 100644 oneflow/core/comm_network/ctrl_comm_network.cpp rename oneflow/core/{comm_network/rpc => control}/control.proto (53%) rename oneflow/core/{comm_network/rpc => control}/ctrl_call.h (88%) create mode 100644 oneflow/core/control/ctrl_client.cpp rename oneflow/core/{comm_network/ctrl_comm_network.h => control/ctrl_client.h} (56%) rename oneflow/core/{comm_network/rpc => control}/ctrl_server.cpp (76%) rename oneflow/core/{comm_network/rpc => control}/ctrl_server.h (61%) rename oneflow/core/{comm_network/rpc => control}/ctrl_service.cpp (96%) rename oneflow/core/{comm_network/rpc => control}/ctrl_service.h (80%) diff --git a/examples/mnist/train/resource.prototxt b/examples/mnist/train/resource.prototxt index cd04070415..2a87610725 100644 --- a/examples/mnist/train/resource.prototxt +++ b/examples/mnist/train/resource.prototxt @@ -1,15 +1,14 @@ machine { addr: "192.168.1.11" + port: 9000 name: "192.168.1.11" } machine { addr: "192.168.1.13" + port: 9000 name: "192.168.1.13" } device_num_per_machine: 4 device_type: kCPU - -port_min: 9000 -port_max: 9100 diff --git a/oneflow/core/comm_network/ctrl_comm_network.cpp b/oneflow/core/comm_network/ctrl_comm_network.cpp deleted file mode 100644 index 383bb50992..0000000000 --- a/oneflow/core/comm_network/ctrl_comm_network.cpp +++ /dev/null @@ -1,106 +0,0 @@ -#include "oneflow/core/comm_network/ctrl_comm_network.h" -#include "oneflow/core/job/runtime_context.h" - -namespace oneflow { - -namespace { - -const int32_t max_retry_num = 60; -const int64_t sleep_seconds = 10; - -} // namespace - -CtrlCommNet::CtrlCommNet(uint16_t port) { - ctrl_server_.reset(new CtrlServer(RuntimeCtx::Singleton()->GetThisAddr() + ":" - + std::to_string(port))); - stubs_.reserve(JobDesc::Singleton()->TotalMachineNum()); - for (int64_t i = 0; i < JobDesc::Singleton()->TotalMachineNum(); ++i) { - stubs_.push_back(CtrlService::NewStub(RuntimeCtx::Singleton()->GetAddr(i) - + ":" + std::to_string(port))); - } - int32_t retry_idx = 0; - for (; retry_idx < max_retry_num; ++retry_idx) { - grpc::ClientContext client_ctx; - AddWorkerRequest request; - request.set_worker_addr(RuntimeCtx::Singleton()->GetThisAddr()); - AddWorkerResponse response; - grpc::Status st = - GetMasterStub()->AddWorker(&client_ctx, request, &response); - if (st.error_code() == grpc::StatusCode::OK) { - LOG(INFO) << "AddWorker Successful at " << retry_idx << " times"; - break; - } else if (st.error_code() == grpc::StatusCode::UNAVAILABLE) { - LOG(INFO) << "AddWorker Failed at " << retry_idx << " times"; - std::this_thread::sleep_for(std::chrono::seconds(sleep_seconds)); - continue; - } else { - LOG(FATAL) << st.error_message(); - } - } - CHECK_LT(retry_idx, max_retry_num); -} - -void CtrlCommNet::Barrier(const std::string& barrier_name) { - Barrier(barrier_name, JobDesc::Singleton()->TotalMachineNum()); -} - -void CtrlCommNet::Barrier(const std::string& barrier_name, - int32_t barrier_num) { - grpc::ClientContext client_ctx; - BarrierRequest request; - request.set_name(barrier_name); - request.set_num(barrier_num); - BarrierResponse response; - GetMasterStub()->Barrier(&client_ctx, request, &response); -} - -TryLockResult CtrlCommNet::TryLock(const std::string& name) { - if (done_names_.find(name) != done_names_.end()) { - return TryLockResult::kDone; - } - grpc::ClientContext client_ctx; - TryLockRequest request; - request.set_name(name); - TryLockResponse response; - GetResponsibleStub(name)->TryLock(&client_ctx, request, &response); - if (response.result() == TryLockResult::kDone) { - CHECK(done_names_.insert(name).second); - } - return response.result(); -} - -void CtrlCommNet::NotifyDone(const std::string& name) { - grpc::ClientContext client_ctx; - NotifyDoneRequest request; - request.set_name(name); - NotifyDoneResponse response; - GetResponsibleStub(name)->NotifyDone(&client_ctx, request, &response); -} - -void CtrlCommNet::WaitUntilDone(const std::string& name) { - grpc::ClientContext client_ctx; - WaitUntilDoneRequest request; - request.set_name(name); - WaitUntilDoneResponse response; - GetResponsibleStub(name)->WaitUntilDone(&client_ctx, request, &response); -} - -void CtrlCommNet::PublishPlan(const Plan* plan) { - ctrl_server_->PublishPlan(plan); -} - -void CtrlCommNet::FetchPlan(Plan* plan) { - grpc::ClientContext client_ctx; - FetchPlanRequest request; - FetchPlanResponse response; - GetMasterStub()->FetchPlan(&client_ctx, request, &response); - *plan = response.plan(); -} - -CtrlService::Stub* CtrlCommNet::GetResponsibleStub(const std::string& key) { - int64_t machine_id = - (std::hash{}(key)) % JobDesc::Singleton()->TotalMachineNum(); - return stubs_[machine_id].get(); -} - -} // namespace oneflow diff --git a/oneflow/core/comm_network/epoll/epoll_data_comm_network.cpp b/oneflow/core/comm_network/epoll/epoll_data_comm_network.cpp index 19c06c97dd..bb1f03805c 100644 --- a/oneflow/core/comm_network/epoll/epoll_data_comm_network.cpp +++ b/oneflow/core/comm_network/epoll/epoll_data_comm_network.cpp @@ -1,4 +1,5 @@ #include "oneflow/core/comm_network/epoll/epoll_data_comm_network.h" +#include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/job/runtime_context.h" #ifdef PLATFORM_POSIX @@ -36,8 +37,8 @@ EpollDataCommNet::~EpollDataCommNet() { for (auto& pair : sockfd2helper_) { delete pair.second; } } -void EpollDataCommNet::Init(uint16_t port) { - DataCommNet::Singleton()->set_comm_network_ptr(new EpollDataCommNet(port)); +void EpollDataCommNet::Init() { + DataCommNet::Singleton()->set_comm_network_ptr(new EpollDataCommNet()); } const void* EpollDataCommNet::RegisterMemory(void* mem_ptr, size_t byte_size) { @@ -99,18 +100,18 @@ void EpollDataCommNet::SendSocketMsg(int64_t dst_machine_id, GetSocketHelper(dst_machine_id)->AsyncWrite(msg); } -EpollDataCommNet::EpollDataCommNet(uint16_t port) { +EpollDataCommNet::EpollDataCommNet() { mem_descs_.clear(); unregister_mem_descs_cnt_ = 0; pollers_.resize(JobDesc::Singleton()->CommNetIOWorkerNum(), nullptr); for (size_t i = 0; i < pollers_.size(); ++i) { pollers_[i] = new IOEventPoller; } - InitSockets(port); + InitSockets(); for (IOEventPoller* poller : pollers_) { poller->Start(); } } -void EpollDataCommNet::InitSockets(uint16_t port) { +void EpollDataCommNet::InitSockets() { int64_t this_machine_id = RuntimeCtx::Singleton()->this_machine_id(); int64_t total_machine_num = JobDesc::Singleton()->TotalMachineNum(); machine_id2sockfd_.assign(total_machine_num, -1); @@ -122,22 +123,31 @@ void EpollDataCommNet::InitSockets(uint16_t port) { return new SocketHelper(sockfd, poller); }; // listen - sockaddr_in this_sockaddr = GetSockAddr(this_machine_id, port); int listen_sockfd = socket(AF_INET, SOCK_STREAM, 0); - PCHECK(bind(listen_sockfd, reinterpret_cast(&this_sockaddr), - sizeof(this_sockaddr)) - == 0); - PCHECK(listen(listen_sockfd, total_machine_num) == 0); + uint16_t this_listen_port = 1024; + uint16_t listen_port_max = std::numeric_limits::max(); + for (; this_listen_port < listen_port_max; ++this_listen_port) { + sockaddr_in this_sockaddr = GetSockAddr(this_machine_id, this_listen_port); + int bind_result = + bind(listen_sockfd, reinterpret_cast(&this_sockaddr), + sizeof(this_sockaddr)); + if (bind_result == 0) { + PCHECK(listen(listen_sockfd, total_machine_num) == 0); + CtrlClient::Singleton()->PushPort(this_listen_port); + break; + } else { + PCHECK(errno == EACCES || errno == EADDRINUSE); + } + } + CHECK_LT(this_listen_port, listen_port_max); // connect FOR_RANGE(int64_t, peer_machine_id, this_machine_id + 1, total_machine_num) { - sockaddr_in peer_sockaddr = GetSockAddr(peer_machine_id, port); + uint16_t peer_port = CtrlClient::Singleton()->PullPort(peer_machine_id); + sockaddr_in peer_sockaddr = GetSockAddr(peer_machine_id, peer_port); int sockfd = socket(AF_INET, SOCK_STREAM, 0); - int rc = -1; - while (rc == -1) { - connect(sockfd, reinterpret_cast(&peer_sockaddr), - sizeof(peer_sockaddr)); - } - PCHECK(rc == 0); + PCHECK(connect(sockfd, reinterpret_cast(&peer_sockaddr), + sizeof(peer_sockaddr)) + == 0); CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second); machine_id2sockfd_[peer_machine_id] = sockfd; } diff --git a/oneflow/core/comm_network/epoll/epoll_data_comm_network.h b/oneflow/core/comm_network/epoll/epoll_data_comm_network.h index c8c32b87e0..e4dbb07b70 100644 --- a/oneflow/core/comm_network/epoll/epoll_data_comm_network.h +++ b/oneflow/core/comm_network/epoll/epoll_data_comm_network.h @@ -12,14 +12,13 @@ namespace oneflow { class EpollDataCommNet final : public DataCommNet { public: OF_DISALLOW_COPY_AND_MOVE(EpollDataCommNet); - EpollDataCommNet() = delete; ~EpollDataCommNet(); static EpollDataCommNet* Singleton() { return static_cast(DataCommNet::Singleton()); } - static void Init(uint16_t port); + static void Init(); const void* RegisterMemory(void* mem_ptr, size_t byte_size) override; void UnRegisterMemory(const void* token) override; @@ -33,8 +32,8 @@ class EpollDataCommNet final : public DataCommNet { void SendSocketMsg(int64_t dst_machine_id, const SocketMsg& msg); private: - EpollDataCommNet(uint16_t port); - void InitSockets(uint16_t port); + EpollDataCommNet(); + void InitSockets(); SocketHelper* GetSocketHelper(int64_t machine_id); // Memory Desc diff --git a/oneflow/core/comm_network/rpc/control.proto b/oneflow/core/control/control.proto similarity index 53% rename from oneflow/core/comm_network/rpc/control.proto rename to oneflow/core/control/control.proto index 8e9c4c9494..4a2d0cbaf3 100644 --- a/oneflow/core/comm_network/rpc/control.proto +++ b/oneflow/core/control/control.proto @@ -3,19 +3,18 @@ package oneflow; import "oneflow/core/job/plan.proto"; -message BarrierRequest { - string name = 1; - int32 num = 2; +message LoadServerRequest { } -message BarrierResponse { +message LoadServerResponse { } -message AddWorkerRequest { - string worker_addr = 1; +message BarrierRequest { + string name = 1; + int32 num = 2; } -message AddWorkerResponse { +message BarrierResponse { } enum TryLockResult { @@ -46,9 +45,42 @@ message WaitUntilDoneRequest { message WaitUntilDoneResponse { } -message FetchPlanRequest { +message PushPlanRequest { + Plan plan = 1; +} + +message PushPlanResponse { +} + +message ClearPlanRequest { +} + +message ClearPlanResponse { } -message FetchPlanResponse { +message PullPlanRequest { +} + +message PullPlanResponse { Plan plan = 1; } + +message PushPortRequest { + int32 port = 1; +} + +message PushPortResponse { +} + +message ClearPortRequest { +} + +message ClearPortResponse { +} + +message PullPortRequest { +} + +message PullPortResponse { + int32 port = 1; +} diff --git a/oneflow/core/comm_network/rpc/ctrl_call.h b/oneflow/core/control/ctrl_call.h similarity index 88% rename from oneflow/core/comm_network/rpc/ctrl_call.h rename to oneflow/core/control/ctrl_call.h index af26919861..c769d58c20 100644 --- a/oneflow/core/comm_network/rpc/ctrl_call.h +++ b/oneflow/core/control/ctrl_call.h @@ -1,7 +1,7 @@ -#ifndef ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_CALL_H_ -#define ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_CALL_H_ +#ifndef ONEFLOW_CORE_CONTROL_CTRL_CALL_H_ +#define ONEFLOW_CORE_CONTROL_CTRL_CALL_H_ -#include "oneflow/core/comm_network/rpc/ctrl_service.h" +#include "oneflow/core/control/ctrl_service.h" namespace oneflow { @@ -70,4 +70,4 @@ class CtrlCall final : public CtrlCallIf { } // namespace oneflow -#endif // ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_CALL_H_ +#endif // ONEFLOW_CORE_CONTROL_CTRL_CALL_H_ diff --git a/oneflow/core/control/ctrl_client.cpp b/oneflow/core/control/ctrl_client.cpp new file mode 100644 index 0000000000..0ad7d3bd8a --- /dev/null +++ b/oneflow/core/control/ctrl_client.cpp @@ -0,0 +1,146 @@ +#include "oneflow/core/control/ctrl_client.h" +#include "oneflow/core/job/runtime_context.h" + +namespace oneflow { + +namespace { + +const int32_t max_retry_num = 60; +const int64_t sleep_seconds = 10; + +} // namespace + +void CtrlClient::Barrier(const std::string& barrier_name) { + Barrier(barrier_name, JobDesc::Singleton()->TotalMachineNum()); +} + +void CtrlClient::Barrier(const std::string& barrier_name, int32_t barrier_num) { + grpc::ClientContext client_ctx; + BarrierRequest request; + request.set_name(barrier_name); + request.set_num(barrier_num); + BarrierResponse response; + GetMasterStub()->Barrier(&client_ctx, request, &response); +} + +TryLockResult CtrlClient::TryLock(const std::string& name) { + if (done_names_.find(name) != done_names_.end()) { + return TryLockResult::kDone; + } + grpc::ClientContext client_ctx; + TryLockRequest request; + request.set_name(name); + TryLockResponse response; + GetResponsibleStub(name)->TryLock(&client_ctx, request, &response); + if (response.result() == TryLockResult::kDone) { + CHECK(done_names_.insert(name).second); + } + return response.result(); +} + +void CtrlClient::NotifyDone(const std::string& name) { + grpc::ClientContext client_ctx; + NotifyDoneRequest request; + request.set_name(name); + NotifyDoneResponse response; + GetResponsibleStub(name)->NotifyDone(&client_ctx, request, &response); +} + +void CtrlClient::WaitUntilDone(const std::string& name) { + grpc::ClientContext client_ctx; + WaitUntilDoneRequest request; + request.set_name(name); + WaitUntilDoneResponse response; + GetResponsibleStub(name)->WaitUntilDone(&client_ctx, request, &response); +} + +void CtrlClient::PushPlan(const Plan& plan) { + grpc::ClientContext client_ctx; + PushPlanRequest request; + *(request.mutable_plan()) = plan; + PushPlanResponse response; + GetMasterStub()->PushPlan(&client_ctx, request, &response); +} + +void CtrlClient::ClearPlan() { + grpc::ClientContext client_ctx; + ClearPlanRequest request; + ClearPlanResponse response; + GetMasterStub()->ClearPlan(&client_ctx, request, &response); +} + +void CtrlClient::PullPlan(Plan* plan) { + grpc::ClientContext client_ctx; + PullPlanRequest request; + PullPlanResponse response; + GetMasterStub()->PullPlan(&client_ctx, request, &response); + *plan = response.plan(); +} + +void CtrlClient::PushPort(uint16_t port) { + grpc::ClientContext client_ctx; + PushPortRequest request; + request.set_port(port); + PushPortResponse response; + GetThisStub()->PushPort(&client_ctx, request, &response); +} + +void CtrlClient::ClearPort() { + grpc::ClientContext client_ctx; + ClearPortRequest request; + ClearPortResponse response; + GetThisStub()->ClearPort(&client_ctx, request, &response); +} + +uint16_t CtrlClient::PullPort(uint64_t machine_id) { + grpc::ClientContext client_ctx; + PullPortRequest request; + PullPortResponse response; + stubs_[machine_id]->PullPort(&client_ctx, request, &response); + return response.port(); +} + +CtrlClient::CtrlClient() { + stubs_.reserve(JobDesc::Singleton()->TotalMachineNum()); + for (int64_t i = 0; i < JobDesc::Singleton()->TotalMachineNum(); ++i) { + std::string addr = RuntimeCtx::Singleton()->GetCtrlAddr(i); + stubs_.push_back(CtrlService::NewStub(addr)); + LoadServer(addr, stubs_[i].get()); + } +} + +void CtrlClient::LoadServer(const std::string& server_addr, + CtrlService::Stub* stub) { + int32_t retry_idx = 0; + for (; retry_idx < max_retry_num; ++retry_idx) { + grpc::ClientContext client_ctx; + LoadServerRequest request; + LoadServerResponse response; + grpc::Status st = stub->LoadServer(&client_ctx, request, &response); + if (st.error_code() == grpc::StatusCode::OK) { + LOG(INFO) << "LoadServer " << server_addr << " Successful at " + << retry_idx << " times"; + break; + } else if (st.error_code() == grpc::StatusCode::UNAVAILABLE) { + LOG(INFO) << "LoadServer " << server_addr << " Failed at " << retry_idx + << " times"; + std::this_thread::sleep_for(std::chrono::seconds(sleep_seconds)); + continue; + } else { + LOG(FATAL) << st.error_message(); + } + } + CHECK_LT(retry_idx, max_retry_num); +} + +CtrlService::Stub* CtrlClient::GetThisStub() { + return stubs_[RuntimeCtx::Singleton()->this_machine_id()].get(); +} + +CtrlService::Stub* CtrlClient::GetResponsibleStub(const std::string& key) { + int64_t machine_id = + (std::hash{}(key)) % JobDesc::Singleton()->TotalMachineNum(); + return stubs_[machine_id].get(); +} + +} // namespace oneflow diff --git a/oneflow/core/comm_network/ctrl_comm_network.h b/oneflow/core/control/ctrl_client.h similarity index 56% rename from oneflow/core/comm_network/ctrl_comm_network.h rename to oneflow/core/control/ctrl_client.h index 0e69abfebc..8d7c13fde3 100644 --- a/oneflow/core/comm_network/ctrl_comm_network.h +++ b/oneflow/core/control/ctrl_client.h @@ -1,18 +1,17 @@ -#ifndef ONEFLOW_CORE_COMM_NETWORK_CTRL_COMM_NETWORK_H_ -#define ONEFLOW_CORE_COMM_NETWORK_CTRL_COMM_NETWORK_H_ +#ifndef ONEFLOW_CORE_CONTROL_CTRL_CLIENT_H_ +#define ONEFLOW_CORE_CONTROL_CTRL_CLIENT_H_ #include "oneflow/core/actor/actor_message.h" -#include "oneflow/core/comm_network/rpc/ctrl_server.h" +#include "oneflow/core/control/ctrl_service.h" namespace oneflow { -class CtrlCommNet final { +class CtrlClient final { public: - OF_DISALLOW_COPY_AND_MOVE(CtrlCommNet); - CtrlCommNet() = delete; - ~CtrlCommNet() = default; + OF_DISALLOW_COPY_AND_MOVE(CtrlClient); + ~CtrlClient() = default; - OF_SINGLETON(CtrlCommNet); + OF_SINGLETON(CtrlClient); void Barrier(const std::string& barrier_name); void Barrier(const std::string& barrier_name, int32_t barrier_num); @@ -21,37 +20,43 @@ class CtrlCommNet final { void NotifyDone(const std::string& name); void WaitUntilDone(const std::string& name); - void PublishPlan(const Plan* plan); - void FetchPlan(Plan* plan); + void PushPlan(const Plan& plan); + void ClearPlan(); + void PullPlan(Plan* plan); + + void PushPort(uint16_t port); + void ClearPort(); + uint16_t PullPort(uint64_t machine_id); private: - CtrlCommNet(uint16_t port); + CtrlClient(); + void LoadServer(const std::string& server_addr, CtrlService::Stub* stub); CtrlService::Stub* GetMasterStub() { return stubs_[0].get(); } + CtrlService::Stub* GetThisStub(); CtrlService::Stub* GetResponsibleStub(const std::string& key); - std::unique_ptr ctrl_server_; std::vector> stubs_; HashSet done_names_; }; #define FILE_LINE_STR __FILE__ ":" OF_PP_STRINGIZE(__LINE__) -#define OF_BARRIER() CtrlCommNet::Singleton()->Barrier(FILE_LINE_STR) - -#define OF_CALL_ONCE(name, ...) \ - do { \ - TryLockResult lock_ret = CtrlCommNet::Singleton()->TryLock(name); \ - if (lock_ret == TryLockResult::kLocked) { \ - __VA_ARGS__; \ - CtrlCommNet::Singleton()->NotifyDone(name); \ - } else if (lock_ret == TryLockResult::kDone) { \ - } else if (lock_ret == TryLockResult::kDoing) { \ - CtrlCommNet::Singleton()->WaitUntilDone(name); \ - } else { \ - UNEXPECTED_RUN(); \ - } \ +#define OF_BARRIER() CtrlClient::Singleton()->Barrier(FILE_LINE_STR) + +#define OF_CALL_ONCE(name, ...) \ + do { \ + TryLockResult lock_ret = CtrlClient::Singleton()->TryLock(name); \ + if (lock_ret == TryLockResult::kLocked) { \ + __VA_ARGS__; \ + CtrlClient::Singleton()->NotifyDone(name); \ + } else if (lock_ret == TryLockResult::kDone) { \ + } else if (lock_ret == TryLockResult::kDoing) { \ + CtrlClient::Singleton()->WaitUntilDone(name); \ + } else { \ + UNEXPECTED_RUN(); \ + } \ } while (0) } // namespace oneflow -#endif // ONEFLOW_CORE_COMM_NETWORK_CTRL_COMM_NETWORK_H_ +#endif // ONEFLOW_CORE_CONTROL_CTRL_CLIENT_H_ diff --git a/oneflow/core/comm_network/rpc/ctrl_server.cpp b/oneflow/core/control/ctrl_server.cpp similarity index 76% rename from oneflow/core/comm_network/rpc/ctrl_server.cpp rename to oneflow/core/control/ctrl_server.cpp index 2215462cae..77a9aca9cd 100644 --- a/oneflow/core/comm_network/rpc/ctrl_server.cpp +++ b/oneflow/core/control/ctrl_server.cpp @@ -1,20 +1,7 @@ -#include "oneflow/core/comm_network/rpc/ctrl_server.h" -#include "grpc++/alarm.h" -#include "oneflow/core/job/runtime_context.h" +#include "oneflow/core/control/ctrl_server.h" namespace oneflow { -#define ENQUEUE_REQUEST(method) \ - do { \ - auto call = new CtrlCall(); \ - call->set_request_handler( \ - std::bind(&CtrlServer::method##Handler, this, call)); \ - grpc_service_->RequestAsyncUnary( \ - static_cast(CtrlMethod::k##method), call->mut_server_ctx(), \ - call->mut_request(), call->mut_responder(), cq_.get(), cq_.get(), \ - call); \ - } while (0); - CtrlServer::~CtrlServer() { grpc::Alarm alarm(cq_.get(), gpr_now(GPR_CLOCK_MONOTONIC), nullptr); loop_thread_.join(); @@ -31,25 +18,21 @@ CtrlServer::CtrlServer(const std::string& server_addr) { cq_ = server_builder.AddCompletionQueue(); grpc_server_ = server_builder.BuildAndStart(); LOG(INFO) << "CtrlServer listening on " << server_addr; - added_worker_calls_.clear(); plan_ = nullptr; - pending_plan_calls_.clear(); + port_ = -1; loop_thread_ = std::thread(&CtrlServer::HandleRpcs, this); } -void CtrlServer::PublishPlan(const Plan* plan) { - std::unique_lock lck(plan_mtx_); - plan_ = plan; - if (plan_) { - for (auto call : pending_plan_calls_) { - *(call->mut_response()->mutable_plan()) = *plan; - call->SendResponse(); - } - pending_plan_calls_.clear(); - } else { - CHECK(pending_plan_calls_.empty()); - } -} +#define ENQUEUE_REQUEST(method) \ + do { \ + auto call = new CtrlCall(); \ + call->set_request_handler( \ + std::bind(&CtrlServer::method##Handler, this, call)); \ + grpc_service_->RequestAsyncUnary( \ + static_cast(CtrlMethod::k##method), call->mut_server_ctx(), \ + call->mut_request(), call->mut_responder(), cq_.get(), cq_.get(), \ + call); \ + } while (0); void CtrlServer::HandleRpcs() { OF_PP_FOR_EACH_TUPLE(ENQUEUE_REQUEST, CTRL_METHOD_SEQ); @@ -68,18 +51,10 @@ void CtrlServer::HandleRpcs() { } } -void CtrlServer::AddWorkerHandler( - CtrlCall* call) { - CHECK(RuntimeCtx::Singleton()->IsThisMachineMaster()); - added_worker_calls_.push_back(call); - LOG(INFO) << "Added Worker " << call->request().worker_addr(); - if (added_worker_calls_.size() == JobDesc::Singleton()->TotalMachineNum()) { - for (CtrlCallIf* pending_call : added_worker_calls_) { - pending_call->SendResponse(); - } - added_worker_calls_.clear(); - } - ENQUEUE_REQUEST(AddWorker); +void CtrlServer::LoadServerHandler( + CtrlCall* call) { + call->SendResponse(); + ENQUEUE_REQUEST(LoadServer); } void CtrlServer::BarrierHandler( @@ -153,16 +128,58 @@ void CtrlServer::WaitUntilDoneHandler( ENQUEUE_REQUEST(WaitUntilDone); } -void CtrlServer::FetchPlanHandler( - CtrlCall* call) { - std::unique_lock lck(plan_mtx_); +void CtrlServer::PushPlanHandler( + CtrlCall* call) { + plan_.reset(new Plan(call->request().plan())); + for (auto call : pending_plan_calls_) { + *(call->mut_response()->mutable_plan()) = *plan_; + call->SendResponse(); + } + ENQUEUE_REQUEST(PushPlan); +} + +void CtrlServer::ClearPlanHandler( + CtrlCall* call) { + plan_.reset(); + ENQUEUE_REQUEST(ClearPlan); +} + +void CtrlServer::PullPlanHandler( + CtrlCall* call) { if (plan_) { *(call->mut_response()->mutable_plan()) = *plan_; call->SendResponse(); } else { pending_plan_calls_.push_back(call); } - ENQUEUE_REQUEST(FetchPlan); + ENQUEUE_REQUEST(PullPlan); +} + +void CtrlServer::PushPortHandler( + CtrlCall* call) { + port_ = call->request().port(); + for (auto call : pending_port_calls_) { + call->mut_response()->set_port(port_); + call->SendResponse(); + } + ENQUEUE_REQUEST(PushPort); +} + +void CtrlServer::ClearPortHandler( + CtrlCall* call) { + port_ = -1; + ENQUEUE_REQUEST(ClearPort); +} + +void CtrlServer::PullPortHandler( + CtrlCall* call) { + if (port_ != -1) { + call->mut_response()->set_port(port_); + call->SendResponse(); + } else { + pending_port_calls_.push_back(call); + } + ENQUEUE_REQUEST(PullPort); } } // namespace oneflow diff --git a/oneflow/core/comm_network/rpc/ctrl_server.h b/oneflow/core/control/ctrl_server.h similarity index 61% rename from oneflow/core/comm_network/rpc/ctrl_server.h rename to oneflow/core/control/ctrl_server.h index 0a1735d74b..cdbacc1c7c 100644 --- a/oneflow/core/comm_network/rpc/ctrl_server.h +++ b/oneflow/core/control/ctrl_server.h @@ -1,8 +1,9 @@ -#ifndef ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_SERVER_H_ -#define ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_SERVER_H_ +#ifndef ONEFLOW_CORE_CONTROL_CTRL_SERVER_H_ +#define ONEFLOW_CORE_CONTROL_CTRL_SERVER_H_ +#include "grpc++/alarm.h" #include "grpc++/server_builder.h" -#include "oneflow/core/comm_network/rpc/ctrl_call.h" +#include "oneflow/core/control/ctrl_call.h" namespace oneflow { @@ -13,7 +14,6 @@ class CtrlServer final { ~CtrlServer(); CtrlServer(const std::string& server_addr); - void PublishPlan(const Plan* plan); private: void HandleRpcs(); @@ -29,19 +29,19 @@ class CtrlServer final { std::unique_ptr cq_; std::unique_ptr grpc_server_; std::thread loop_thread_; - // AddWorker - std::list added_worker_calls_; // Barrier HashMap, int32_t>> barrier_calls_; // TryLock, NotifyDone, WaitUntilDone - HashMap name2lock_status_; // TODO: erase outdated item - // FetchPlan - std::mutex plan_mtx_; - const Plan* plan_; - std::list*> pending_plan_calls_; + HashMap name2lock_status_; + // PushPlan, PullPlan + std::unique_ptr plan_; + std::list*> pending_plan_calls_; + // PushPort, ClearPort, PullPort + int32_t port_; + std::list*> pending_port_calls_; }; } // namespace oneflow -#endif // ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_SERVER_H_ +#endif // ONEFLOW_CORE_CONTROL_CTRL_SERVER_H_ diff --git a/oneflow/core/comm_network/rpc/ctrl_service.cpp b/oneflow/core/control/ctrl_service.cpp similarity index 96% rename from oneflow/core/comm_network/rpc/ctrl_service.cpp rename to oneflow/core/control/ctrl_service.cpp index 181429b638..321a565aa6 100644 --- a/oneflow/core/comm_network/rpc/ctrl_service.cpp +++ b/oneflow/core/control/ctrl_service.cpp @@ -1,4 +1,4 @@ -#include "oneflow/core/comm_network/rpc/ctrl_service.h" +#include "oneflow/core/control/ctrl_service.h" #include "grpc++/impl/codegen/client_unary_call.h" namespace oneflow { diff --git a/oneflow/core/comm_network/rpc/ctrl_service.h b/oneflow/core/control/ctrl_service.h similarity index 80% rename from oneflow/core/comm_network/rpc/ctrl_service.h rename to oneflow/core/control/ctrl_service.h index c77791bfda..7f2ec342a6 100644 --- a/oneflow/core/comm_network/rpc/ctrl_service.h +++ b/oneflow/core/control/ctrl_service.h @@ -1,5 +1,5 @@ -#ifndef ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_SERVICE_H_ -#define ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_SERVICE_H_ +#ifndef ONEFLOW_CORE_CONTROL_CTRL_SERVICE_H_ +#define ONEFLOW_CORE_CONTROL_CTRL_SERVICE_H_ #include "grpc++/grpc++.h" #include "grpc++/impl/codegen/async_stream.h" @@ -10,19 +10,24 @@ #include "grpc++/impl/codegen/status.h" #include "grpc++/impl/codegen/stub_options.h" #include "grpc++/impl/codegen/sync_stream.h" -#include "oneflow/core/comm_network/rpc/control.pb.h" #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/util.h" +#include "oneflow/core/control/control.pb.h" namespace oneflow { #define CTRL_METHOD_SEQ \ - OF_PP_MAKE_TUPLE_SEQ(AddWorker) \ + OF_PP_MAKE_TUPLE_SEQ(LoadServer) \ OF_PP_MAKE_TUPLE_SEQ(Barrier) \ OF_PP_MAKE_TUPLE_SEQ(TryLock) \ OF_PP_MAKE_TUPLE_SEQ(NotifyDone) \ OF_PP_MAKE_TUPLE_SEQ(WaitUntilDone) \ - OF_PP_MAKE_TUPLE_SEQ(FetchPlan) + OF_PP_MAKE_TUPLE_SEQ(PushPlan) \ + OF_PP_MAKE_TUPLE_SEQ(ClearPlan) \ + OF_PP_MAKE_TUPLE_SEQ(PullPlan) \ + OF_PP_MAKE_TUPLE_SEQ(PushPort) \ + OF_PP_MAKE_TUPLE_SEQ(ClearPort) \ + OF_PP_MAKE_TUPLE_SEQ(PullPort) enum class CtrlMethod { #define MAKE_ENTRY(method) k##method, @@ -66,4 +71,4 @@ class CtrlService final { } // namespace oneflow -#endif // ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_SERVICE_H_ +#endif // ONEFLOW_CORE_CONTROL_CTRL_SERVICE_H_ diff --git a/oneflow/core/job/resource.proto b/oneflow/core/job/resource.proto index 7e9d52185b..47600eb726 100644 --- a/oneflow/core/job/resource.proto +++ b/oneflow/core/job/resource.proto @@ -3,7 +3,8 @@ package oneflow; message Machine { string addr = 1; // domain name or ip - string name = 4; + int32 port = 2; + string name = 3; } enum DeviceType { @@ -15,8 +16,6 @@ message Resource { repeated Machine machine = 1; int64 device_num_per_machine = 2; DeviceType device_type = 3; - int32 port_min = 4; - int32 port_max = 5; } // If one machine named "machine_xxx" and device_num_per_machine = 4 diff --git a/oneflow/core/job/runtime.cpp b/oneflow/core/job/runtime.cpp index 40ac4ba4c1..a01ef301f2 100644 --- a/oneflow/core/job/runtime.cpp +++ b/oneflow/core/job/runtime.cpp @@ -1,15 +1,11 @@ #include #include "oneflow/core/comm_network/epoll/epoll_data_comm_network.h" +#include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/runtime_context.h" #include "oneflow/core/kernel/kernel_manager.h" #include "oneflow/core/thread/thread_manager.h" -DEFINE_string(plan_filepath, "", ""); -DEFINE_string(this_machine_name, "", ""); -DEFINE_int32(ctrl_port, -1, ""); -DEFINE_int32(data_port, -1, ""); - namespace oneflow { class Runtime final { @@ -82,10 +78,10 @@ void Runtime::NewAllSingleton(const Plan& plan, JobDesc::NewSingleton(plan.job_desc()); IDMgr::NewSingleton(); RuntimeCtx::NewSingleton(this_machine_name); - CtrlCommNet::NewSingleton(FLAGS_ctrl_port); + CtrlClient::NewSingleton(); KernelMgr::NewSingleton(plan); #ifdef PLATFORM_POSIX - EpollDataCommNet::Init(FLAGS_data_port); + EpollDataCommNet::Init(); #endif SnapshotMgr::NewSingleton(plan); RegstMgr::NewSingleton(); @@ -100,7 +96,7 @@ void Runtime::DeleteAllSingleton() { SnapshotMgr::DeleteSingleton(); delete DataCommNet::Singleton(); KernelMgr::DeleteSingleton(); - CtrlCommNet::DeleteSingleton(); + CtrlClient::DeleteSingleton(); RuntimeCtx::DeleteSingleton(); IDMgr::DeleteSingleton(); JobDesc::DeleteSingleton(); @@ -122,6 +118,9 @@ void Runtime::SendCmdMsg(const std::vector& tasks, } // namespace oneflow +DEFINE_string(plan_filepath, "", ""); +DEFINE_string(this_machine_name, "", ""); + int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); google::ParseCommandLineFlags(&argc, &argv, true); diff --git a/oneflow/core/job/runtime_context.cpp b/oneflow/core/job/runtime_context.cpp index 3139320038..6562c6d616 100644 --- a/oneflow/core/job/runtime_context.cpp +++ b/oneflow/core/job/runtime_context.cpp @@ -2,8 +2,9 @@ namespace oneflow { -std::string RuntimeCtx::GetAddr(int64_t machine_id) const { - return JobDesc::Singleton()->resource().machine(machine_id).addr(); +std::string RuntimeCtx::GetCtrlAddr(int64_t machine_id) const { + const Machine& mchn = JobDesc::Singleton()->resource().machine(machine_id); + return mchn.addr() + ":" + std::to_string(mchn.port()); } PersistentInStream* RuntimeCtx::GetDataInStream(const std::string& name) { diff --git a/oneflow/core/job/runtime_context.h b/oneflow/core/job/runtime_context.h index d4a3b53f23..9b4fbe5371 100644 --- a/oneflow/core/job/runtime_context.h +++ b/oneflow/core/job/runtime_context.h @@ -17,9 +17,9 @@ class RuntimeCtx final { int64_t this_machine_id() const { return this_machine_id_; } bool IsThisMachineMaster() const { return this_machine_id_ == 0; } - std::string GetThisAddr() const { return GetAddr(this_machine_id_); } - std::string GetMasterCtrlAddr() const { return GetAddr(0); } - std::string GetAddr(int64_t machine_id) const; + std::string GetThisCtrlAddr() const { return GetCtrlAddr(this_machine_id_); } + std::string GetMasterCtrlAddr() const { return GetCtrlAddr(0); } + std::string GetCtrlAddr(int64_t machine_id) const; ThreadSafeCounter& mut_model_init_cnt() { return model_init_cnt_; } diff --git a/oneflow/core/job/scheduler.cpp b/oneflow/core/job/scheduler.cpp index 149c59f639..9e364dd089 100644 --- a/oneflow/core/job/scheduler.cpp +++ b/oneflow/core/job/scheduler.cpp @@ -1,7 +1,8 @@ #include #include -#include "oneflow/core/comm_network/ctrl_comm_network.h" #include "oneflow/core/common/str_util.h" +#include "oneflow/core/control/ctrl_client.h" +#include "oneflow/core/control/ctrl_server.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/job/runtime_context.h" @@ -21,13 +22,12 @@ class Scheduler final { private: Scheduler() = default; - uint16_t GetNextPort(); void NewAllSingleton(const std::string& job_conf_filepath, const std::string& this_machine_name, char** env); - std::string GetEnvPrefix(); + void DeleteAllSingleton(); void SystemCall(const std::string& cmd); - uint16_t next_port_; + std::unique_ptr ctrl_server_; std::string env_prefix_; }; @@ -44,13 +44,13 @@ void Scheduler::Process(const std::string& job_conf_filepath, << "-plan_filepath=" << naive_plan_filepath; SystemCall(compile_cmd.str()); ParseProtoFromTextFile(naive_plan_filepath, plan.get()); - CtrlCommNet::Singleton()->PublishPlan(plan.get()); + CtrlClient::Singleton()->PushPlan(*plan); } else { - CtrlCommNet::Singleton()->FetchPlan(plan.get()); + CtrlClient::Singleton()->PullPlan(plan.get()); } OF_BARRIER(); if (RuntimeCtx::Singleton()->IsThisMachineMaster()) { - CtrlCommNet::Singleton()->PublishPlan(nullptr); + CtrlClient::Singleton()->ClearPlan(); } else { PrintProtoToTextFile(*plan, naive_plan_filepath); } @@ -58,15 +58,9 @@ void Scheduler::Process(const std::string& job_conf_filepath, std::stringstream runtime_cmd; runtime_cmd << "./runtime " << "-plan_filepath=" << naive_plan_filepath << " " - << "-this_machine_name=" << this_machine_name << " " - << "-ctrl_port=" << GetNextPort() << " " - << "-data_port=" << GetNextPort(); + << "-this_machine_name=" << this_machine_name; SystemCall(runtime_cmd.str()); -} - -uint16_t Scheduler::GetNextPort() { - CHECK_LE(next_port_, JobDesc::Singleton()->resource().port_max()); - return next_port_++; + DeleteAllSingleton(); } void Scheduler::NewAllSingleton(const std::string& job_conf_filepath, @@ -75,10 +69,11 @@ void Scheduler::NewAllSingleton(const std::string& job_conf_filepath, oneflow::JobConf job_conf; oneflow::ParseProtoFromTextFile(job_conf_filepath, &job_conf); JobDesc::NewSingleton(job_conf); - next_port_ = JobDesc::Singleton()->resource().port_min(); IDMgr::NewSingleton(); RuntimeCtx::NewSingleton(this_machine_name); - CtrlCommNet::NewSingleton(GetNextPort()); + ctrl_server_.reset( + new CtrlServer(RuntimeCtx::Singleton()->GetThisCtrlAddr())); + CtrlClient::NewSingleton(); env_prefix_ = ""; std::stringstream ss; while (*env) { @@ -88,6 +83,14 @@ void Scheduler::NewAllSingleton(const std::string& job_conf_filepath, env_prefix_ = ss.str(); } +void Scheduler::DeleteAllSingleton() { + CtrlClient::DeleteSingleton(); + ctrl_server_.reset(); + RuntimeCtx::DeleteSingleton(); + IDMgr::DeleteSingleton(); + JobDesc::DeleteSingleton(); +} + void Scheduler::SystemCall(const std::string& cmd) { LOG(INFO) << "SystemCall: [" << cmd << "]"; CHECK_EQ(std::system(cmd.c_str()), 0); diff --git a/oneflow/core/persistence/snapshot.h b/oneflow/core/persistence/snapshot.h index c91f0db07b..171b9928fe 100644 --- a/oneflow/core/persistence/snapshot.h +++ b/oneflow/core/persistence/snapshot.h @@ -1,7 +1,7 @@ #ifndef ONEFLOW_CORE_PERSISTENCE_SNAPSHOT_H_ #define ONEFLOW_CORE_PERSISTENCE_SNAPSHOT_H_ -#include "oneflow/core/comm_network/ctrl_comm_network.h" +#include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/persistence/normal_persistent_in_stream.h" #include "oneflow/core/persistence/persistent_out_stream.h" diff --git a/scripts/run_scheduler.sh b/scripts/run_scheduler.sh index 38d3a34ca5..141931e425 100755 --- a/scripts/run_scheduler.sh +++ b/scripts/run_scheduler.sh @@ -8,7 +8,7 @@ SCHEDULER_CMD='GLOG_logtostderr=0 GLOG_log_dir=./log GLOG_v=0 GLOG_logbuflevel=- set +e for host in "${hosts[@]}" do - ssh $USER@$host "/usr/sbin/fuser -k 9000/tcp 9001/tcp 9002/tcp 9003/tcp 9004/tcp" + ssh $USER@$host "/usr/sbin/fuser -k 9000/tcp" done set -e -- GitLab