提交 9421cf46 编写于 作者: W willzhang4a58

make control module


Former-commit-id: c259b77a
上级 6d431ed0
machine {
addr: "192.168.1.11"
port: 9000
name: "192.168.1.11"
}
machine {
addr: "192.168.1.13"
port: 9000
name: "192.168.1.13"
}
device_num_per_machine: 4
device_type: kCPU
port_min: 9000
port_max: 9100
#include "oneflow/core/comm_network/epoll/epoll_data_comm_network.h"
#include "oneflow/core/control/ctrl_client.h"
#include "oneflow/core/job/runtime_context.h"
#ifdef PLATFORM_POSIX
......@@ -36,8 +37,8 @@ EpollDataCommNet::~EpollDataCommNet() {
for (auto& pair : sockfd2helper_) { delete pair.second; }
}
void EpollDataCommNet::Init(uint16_t port) {
DataCommNet::Singleton()->set_comm_network_ptr(new EpollDataCommNet(port));
void EpollDataCommNet::Init() {
DataCommNet::Singleton()->set_comm_network_ptr(new EpollDataCommNet());
}
const void* EpollDataCommNet::RegisterMemory(void* mem_ptr, size_t byte_size) {
......@@ -99,18 +100,18 @@ void EpollDataCommNet::SendSocketMsg(int64_t dst_machine_id,
GetSocketHelper(dst_machine_id)->AsyncWrite(msg);
}
EpollDataCommNet::EpollDataCommNet(uint16_t port) {
EpollDataCommNet::EpollDataCommNet() {
mem_descs_.clear();
unregister_mem_descs_cnt_ = 0;
pollers_.resize(JobDesc::Singleton()->CommNetIOWorkerNum(), nullptr);
for (size_t i = 0; i < pollers_.size(); ++i) {
pollers_[i] = new IOEventPoller;
}
InitSockets(port);
InitSockets();
for (IOEventPoller* poller : pollers_) { poller->Start(); }
}
void EpollDataCommNet::InitSockets(uint16_t port) {
void EpollDataCommNet::InitSockets() {
int64_t this_machine_id = RuntimeCtx::Singleton()->this_machine_id();
int64_t total_machine_num = JobDesc::Singleton()->TotalMachineNum();
machine_id2sockfd_.assign(total_machine_num, -1);
......@@ -122,22 +123,31 @@ void EpollDataCommNet::InitSockets(uint16_t port) {
return new SocketHelper(sockfd, poller);
};
// listen
sockaddr_in this_sockaddr = GetSockAddr(this_machine_id, port);
int listen_sockfd = socket(AF_INET, SOCK_STREAM, 0);
PCHECK(bind(listen_sockfd, reinterpret_cast<sockaddr*>(&this_sockaddr),
sizeof(this_sockaddr))
== 0);
PCHECK(listen(listen_sockfd, total_machine_num) == 0);
uint16_t this_listen_port = 1024;
uint16_t listen_port_max = std::numeric_limits<uint16_t>::max();
for (; this_listen_port < listen_port_max; ++this_listen_port) {
sockaddr_in this_sockaddr = GetSockAddr(this_machine_id, this_listen_port);
int bind_result =
bind(listen_sockfd, reinterpret_cast<sockaddr*>(&this_sockaddr),
sizeof(this_sockaddr));
if (bind_result == 0) {
PCHECK(listen(listen_sockfd, total_machine_num) == 0);
CtrlClient::Singleton()->PushPort(this_listen_port);
break;
} else {
PCHECK(errno == EACCES || errno == EADDRINUSE);
}
}
CHECK_LT(this_listen_port, listen_port_max);
// connect
FOR_RANGE(int64_t, peer_machine_id, this_machine_id + 1, total_machine_num) {
sockaddr_in peer_sockaddr = GetSockAddr(peer_machine_id, port);
uint16_t peer_port = CtrlClient::Singleton()->PullPort(peer_machine_id);
sockaddr_in peer_sockaddr = GetSockAddr(peer_machine_id, peer_port);
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
int rc = -1;
while (rc == -1) {
connect(sockfd, reinterpret_cast<sockaddr*>(&peer_sockaddr),
sizeof(peer_sockaddr));
}
PCHECK(rc == 0);
PCHECK(connect(sockfd, reinterpret_cast<sockaddr*>(&peer_sockaddr),
sizeof(peer_sockaddr))
== 0);
CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second);
machine_id2sockfd_[peer_machine_id] = sockfd;
}
......
......@@ -12,14 +12,13 @@ namespace oneflow {
class EpollDataCommNet final : public DataCommNet {
public:
OF_DISALLOW_COPY_AND_MOVE(EpollDataCommNet);
EpollDataCommNet() = delete;
~EpollDataCommNet();
static EpollDataCommNet* Singleton() {
return static_cast<EpollDataCommNet*>(DataCommNet::Singleton());
}
static void Init(uint16_t port);
static void Init();
const void* RegisterMemory(void* mem_ptr, size_t byte_size) override;
void UnRegisterMemory(const void* token) override;
......@@ -33,8 +32,8 @@ class EpollDataCommNet final : public DataCommNet {
void SendSocketMsg(int64_t dst_machine_id, const SocketMsg& msg);
private:
EpollDataCommNet(uint16_t port);
void InitSockets(uint16_t port);
EpollDataCommNet();
void InitSockets();
SocketHelper* GetSocketHelper(int64_t machine_id);
// Memory Desc
......
......@@ -3,19 +3,18 @@ package oneflow;
import "oneflow/core/job/plan.proto";
message BarrierRequest {
string name = 1;
int32 num = 2;
message LoadServerRequest {
}
message BarrierResponse {
message LoadServerResponse {
}
message AddWorkerRequest {
string worker_addr = 1;
message BarrierRequest {
string name = 1;
int32 num = 2;
}
message AddWorkerResponse {
message BarrierResponse {
}
enum TryLockResult {
......@@ -46,9 +45,42 @@ message WaitUntilDoneRequest {
message WaitUntilDoneResponse {
}
message FetchPlanRequest {
message PushPlanRequest {
Plan plan = 1;
}
message PushPlanResponse {
}
message ClearPlanRequest {
}
message ClearPlanResponse {
}
message FetchPlanResponse {
message PullPlanRequest {
}
message PullPlanResponse {
Plan plan = 1;
}
message PushPortRequest {
int32 port = 1;
}
message PushPortResponse {
}
message ClearPortRequest {
}
message ClearPortResponse {
}
message PullPortRequest {
}
message PullPortResponse {
int32 port = 1;
}
#ifndef ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_CALL_H_
#define ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_CALL_H_
#ifndef ONEFLOW_CORE_CONTROL_CTRL_CALL_H_
#define ONEFLOW_CORE_CONTROL_CTRL_CALL_H_
#include "oneflow/core/comm_network/rpc/ctrl_service.h"
#include "oneflow/core/control/ctrl_service.h"
namespace oneflow {
......@@ -70,4 +70,4 @@ class CtrlCall final : public CtrlCallIf {
} // namespace oneflow
#endif // ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_CALL_H_
#endif // ONEFLOW_CORE_CONTROL_CTRL_CALL_H_
#include "oneflow/core/comm_network/ctrl_comm_network.h"
#include "oneflow/core/control/ctrl_client.h"
#include "oneflow/core/job/runtime_context.h"
namespace oneflow {
......@@ -10,42 +10,11 @@ const int64_t sleep_seconds = 10;
} // namespace
CtrlCommNet::CtrlCommNet(uint16_t port) {
ctrl_server_.reset(new CtrlServer(RuntimeCtx::Singleton()->GetThisAddr() + ":"
+ std::to_string(port)));
stubs_.reserve(JobDesc::Singleton()->TotalMachineNum());
for (int64_t i = 0; i < JobDesc::Singleton()->TotalMachineNum(); ++i) {
stubs_.push_back(CtrlService::NewStub(RuntimeCtx::Singleton()->GetAddr(i)
+ ":" + std::to_string(port)));
}
int32_t retry_idx = 0;
for (; retry_idx < max_retry_num; ++retry_idx) {
grpc::ClientContext client_ctx;
AddWorkerRequest request;
request.set_worker_addr(RuntimeCtx::Singleton()->GetThisAddr());
AddWorkerResponse response;
grpc::Status st =
GetMasterStub()->AddWorker(&client_ctx, request, &response);
if (st.error_code() == grpc::StatusCode::OK) {
LOG(INFO) << "AddWorker Successful at " << retry_idx << " times";
break;
} else if (st.error_code() == grpc::StatusCode::UNAVAILABLE) {
LOG(INFO) << "AddWorker Failed at " << retry_idx << " times";
std::this_thread::sleep_for(std::chrono::seconds(sleep_seconds));
continue;
} else {
LOG(FATAL) << st.error_message();
}
}
CHECK_LT(retry_idx, max_retry_num);
}
void CtrlCommNet::Barrier(const std::string& barrier_name) {
void CtrlClient::Barrier(const std::string& barrier_name) {
Barrier(barrier_name, JobDesc::Singleton()->TotalMachineNum());
}
void CtrlCommNet::Barrier(const std::string& barrier_name,
int32_t barrier_num) {
void CtrlClient::Barrier(const std::string& barrier_name, int32_t barrier_num) {
grpc::ClientContext client_ctx;
BarrierRequest request;
request.set_name(barrier_name);
......@@ -54,7 +23,7 @@ void CtrlCommNet::Barrier(const std::string& barrier_name,
GetMasterStub()->Barrier(&client_ctx, request, &response);
}
TryLockResult CtrlCommNet::TryLock(const std::string& name) {
TryLockResult CtrlClient::TryLock(const std::string& name) {
if (done_names_.find(name) != done_names_.end()) {
return TryLockResult::kDone;
}
......@@ -69,7 +38,7 @@ TryLockResult CtrlCommNet::TryLock(const std::string& name) {
return response.result();
}
void CtrlCommNet::NotifyDone(const std::string& name) {
void CtrlClient::NotifyDone(const std::string& name) {
grpc::ClientContext client_ctx;
NotifyDoneRequest request;
request.set_name(name);
......@@ -77,7 +46,7 @@ void CtrlCommNet::NotifyDone(const std::string& name) {
GetResponsibleStub(name)->NotifyDone(&client_ctx, request, &response);
}
void CtrlCommNet::WaitUntilDone(const std::string& name) {
void CtrlClient::WaitUntilDone(const std::string& name) {
grpc::ClientContext client_ctx;
WaitUntilDoneRequest request;
request.set_name(name);
......@@ -85,19 +54,90 @@ void CtrlCommNet::WaitUntilDone(const std::string& name) {
GetResponsibleStub(name)->WaitUntilDone(&client_ctx, request, &response);
}
void CtrlCommNet::PublishPlan(const Plan* plan) {
ctrl_server_->PublishPlan(plan);
void CtrlClient::PushPlan(const Plan& plan) {
grpc::ClientContext client_ctx;
PushPlanRequest request;
*(request.mutable_plan()) = plan;
PushPlanResponse response;
GetMasterStub()->PushPlan(&client_ctx, request, &response);
}
void CtrlCommNet::FetchPlan(Plan* plan) {
void CtrlClient::ClearPlan() {
grpc::ClientContext client_ctx;
FetchPlanRequest request;
FetchPlanResponse response;
GetMasterStub()->FetchPlan(&client_ctx, request, &response);
ClearPlanRequest request;
ClearPlanResponse response;
GetMasterStub()->ClearPlan(&client_ctx, request, &response);
}
void CtrlClient::PullPlan(Plan* plan) {
grpc::ClientContext client_ctx;
PullPlanRequest request;
PullPlanResponse response;
GetMasterStub()->PullPlan(&client_ctx, request, &response);
*plan = response.plan();
}
CtrlService::Stub* CtrlCommNet::GetResponsibleStub(const std::string& key) {
void CtrlClient::PushPort(uint16_t port) {
grpc::ClientContext client_ctx;
PushPortRequest request;
request.set_port(port);
PushPortResponse response;
GetThisStub()->PushPort(&client_ctx, request, &response);
}
void CtrlClient::ClearPort() {
grpc::ClientContext client_ctx;
ClearPortRequest request;
ClearPortResponse response;
GetThisStub()->ClearPort(&client_ctx, request, &response);
}
uint16_t CtrlClient::PullPort(uint64_t machine_id) {
grpc::ClientContext client_ctx;
PullPortRequest request;
PullPortResponse response;
stubs_[machine_id]->PullPort(&client_ctx, request, &response);
return response.port();
}
CtrlClient::CtrlClient() {
stubs_.reserve(JobDesc::Singleton()->TotalMachineNum());
for (int64_t i = 0; i < JobDesc::Singleton()->TotalMachineNum(); ++i) {
std::string addr = RuntimeCtx::Singleton()->GetCtrlAddr(i);
stubs_.push_back(CtrlService::NewStub(addr));
LoadServer(addr, stubs_[i].get());
}
}
void CtrlClient::LoadServer(const std::string& server_addr,
CtrlService::Stub* stub) {
int32_t retry_idx = 0;
for (; retry_idx < max_retry_num; ++retry_idx) {
grpc::ClientContext client_ctx;
LoadServerRequest request;
LoadServerResponse response;
grpc::Status st = stub->LoadServer(&client_ctx, request, &response);
if (st.error_code() == grpc::StatusCode::OK) {
LOG(INFO) << "LoadServer " << server_addr << " Successful at "
<< retry_idx << " times";
break;
} else if (st.error_code() == grpc::StatusCode::UNAVAILABLE) {
LOG(INFO) << "LoadServer " << server_addr << " Failed at " << retry_idx
<< " times";
std::this_thread::sleep_for(std::chrono::seconds(sleep_seconds));
continue;
} else {
LOG(FATAL) << st.error_message();
}
}
CHECK_LT(retry_idx, max_retry_num);
}
CtrlService::Stub* CtrlClient::GetThisStub() {
return stubs_[RuntimeCtx::Singleton()->this_machine_id()].get();
}
CtrlService::Stub* CtrlClient::GetResponsibleStub(const std::string& key) {
int64_t machine_id =
(std::hash<std::string>{}(key)) % JobDesc::Singleton()->TotalMachineNum();
return stubs_[machine_id].get();
......
#ifndef ONEFLOW_CORE_COMM_NETWORK_CTRL_COMM_NETWORK_H_
#define ONEFLOW_CORE_COMM_NETWORK_CTRL_COMM_NETWORK_H_
#ifndef ONEFLOW_CORE_CONTROL_CTRL_CLIENT_H_
#define ONEFLOW_CORE_CONTROL_CTRL_CLIENT_H_
#include "oneflow/core/actor/actor_message.h"
#include "oneflow/core/comm_network/rpc/ctrl_server.h"
#include "oneflow/core/control/ctrl_service.h"
namespace oneflow {
class CtrlCommNet final {
class CtrlClient final {
public:
OF_DISALLOW_COPY_AND_MOVE(CtrlCommNet);
CtrlCommNet() = delete;
~CtrlCommNet() = default;
OF_DISALLOW_COPY_AND_MOVE(CtrlClient);
~CtrlClient() = default;
OF_SINGLETON(CtrlCommNet);
OF_SINGLETON(CtrlClient);
void Barrier(const std::string& barrier_name);
void Barrier(const std::string& barrier_name, int32_t barrier_num);
......@@ -21,37 +20,43 @@ class CtrlCommNet final {
void NotifyDone(const std::string& name);
void WaitUntilDone(const std::string& name);
void PublishPlan(const Plan* plan);
void FetchPlan(Plan* plan);
void PushPlan(const Plan& plan);
void ClearPlan();
void PullPlan(Plan* plan);
void PushPort(uint16_t port);
void ClearPort();
uint16_t PullPort(uint64_t machine_id);
private:
CtrlCommNet(uint16_t port);
CtrlClient();
void LoadServer(const std::string& server_addr, CtrlService::Stub* stub);
CtrlService::Stub* GetMasterStub() { return stubs_[0].get(); }
CtrlService::Stub* GetThisStub();
CtrlService::Stub* GetResponsibleStub(const std::string& key);
std::unique_ptr<CtrlServer> ctrl_server_;
std::vector<std::unique_ptr<CtrlService::Stub>> stubs_;
HashSet<std::string> done_names_;
};
#define FILE_LINE_STR __FILE__ ":" OF_PP_STRINGIZE(__LINE__)
#define OF_BARRIER() CtrlCommNet::Singleton()->Barrier(FILE_LINE_STR)
#define OF_CALL_ONCE(name, ...) \
do { \
TryLockResult lock_ret = CtrlCommNet::Singleton()->TryLock(name); \
if (lock_ret == TryLockResult::kLocked) { \
__VA_ARGS__; \
CtrlCommNet::Singleton()->NotifyDone(name); \
} else if (lock_ret == TryLockResult::kDone) { \
} else if (lock_ret == TryLockResult::kDoing) { \
CtrlCommNet::Singleton()->WaitUntilDone(name); \
} else { \
UNEXPECTED_RUN(); \
} \
#define OF_BARRIER() CtrlClient::Singleton()->Barrier(FILE_LINE_STR)
#define OF_CALL_ONCE(name, ...) \
do { \
TryLockResult lock_ret = CtrlClient::Singleton()->TryLock(name); \
if (lock_ret == TryLockResult::kLocked) { \
__VA_ARGS__; \
CtrlClient::Singleton()->NotifyDone(name); \
} else if (lock_ret == TryLockResult::kDone) { \
} else if (lock_ret == TryLockResult::kDoing) { \
CtrlClient::Singleton()->WaitUntilDone(name); \
} else { \
UNEXPECTED_RUN(); \
} \
} while (0)
} // namespace oneflow
#endif // ONEFLOW_CORE_COMM_NETWORK_CTRL_COMM_NETWORK_H_
#endif // ONEFLOW_CORE_CONTROL_CTRL_CLIENT_H_
#include "oneflow/core/comm_network/rpc/ctrl_server.h"
#include "grpc++/alarm.h"
#include "oneflow/core/job/runtime_context.h"
#include "oneflow/core/control/ctrl_server.h"
namespace oneflow {
#define ENQUEUE_REQUEST(method) \
do { \
auto call = new CtrlCall<method##Request, method##Response>(); \
call->set_request_handler( \
std::bind(&CtrlServer::method##Handler, this, call)); \
grpc_service_->RequestAsyncUnary( \
static_cast<int32_t>(CtrlMethod::k##method), call->mut_server_ctx(), \
call->mut_request(), call->mut_responder(), cq_.get(), cq_.get(), \
call); \
} while (0);
CtrlServer::~CtrlServer() {
grpc::Alarm alarm(cq_.get(), gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
loop_thread_.join();
......@@ -31,25 +18,21 @@ CtrlServer::CtrlServer(const std::string& server_addr) {
cq_ = server_builder.AddCompletionQueue();
grpc_server_ = server_builder.BuildAndStart();
LOG(INFO) << "CtrlServer listening on " << server_addr;
added_worker_calls_.clear();
plan_ = nullptr;
pending_plan_calls_.clear();
port_ = -1;
loop_thread_ = std::thread(&CtrlServer::HandleRpcs, this);
}
void CtrlServer::PublishPlan(const Plan* plan) {
std::unique_lock<std::mutex> lck(plan_mtx_);
plan_ = plan;
if (plan_) {
for (auto call : pending_plan_calls_) {
*(call->mut_response()->mutable_plan()) = *plan;
call->SendResponse();
}
pending_plan_calls_.clear();
} else {
CHECK(pending_plan_calls_.empty());
}
}
#define ENQUEUE_REQUEST(method) \
do { \
auto call = new CtrlCall<method##Request, method##Response>(); \
call->set_request_handler( \
std::bind(&CtrlServer::method##Handler, this, call)); \
grpc_service_->RequestAsyncUnary( \
static_cast<int32_t>(CtrlMethod::k##method), call->mut_server_ctx(), \
call->mut_request(), call->mut_responder(), cq_.get(), cq_.get(), \
call); \
} while (0);
void CtrlServer::HandleRpcs() {
OF_PP_FOR_EACH_TUPLE(ENQUEUE_REQUEST, CTRL_METHOD_SEQ);
......@@ -68,18 +51,10 @@ void CtrlServer::HandleRpcs() {
}
}
void CtrlServer::AddWorkerHandler(
CtrlCall<AddWorkerRequest, AddWorkerResponse>* call) {
CHECK(RuntimeCtx::Singleton()->IsThisMachineMaster());
added_worker_calls_.push_back(call);
LOG(INFO) << "Added Worker " << call->request().worker_addr();
if (added_worker_calls_.size() == JobDesc::Singleton()->TotalMachineNum()) {
for (CtrlCallIf* pending_call : added_worker_calls_) {
pending_call->SendResponse();
}
added_worker_calls_.clear();
}
ENQUEUE_REQUEST(AddWorker);
void CtrlServer::LoadServerHandler(
CtrlCall<LoadServerRequest, LoadServerResponse>* call) {
call->SendResponse();
ENQUEUE_REQUEST(LoadServer);
}
void CtrlServer::BarrierHandler(
......@@ -153,16 +128,58 @@ void CtrlServer::WaitUntilDoneHandler(
ENQUEUE_REQUEST(WaitUntilDone);
}
void CtrlServer::FetchPlanHandler(
CtrlCall<FetchPlanRequest, FetchPlanResponse>* call) {
std::unique_lock<std::mutex> lck(plan_mtx_);
void CtrlServer::PushPlanHandler(
CtrlCall<PushPlanRequest, PushPlanResponse>* call) {
plan_.reset(new Plan(call->request().plan()));
for (auto call : pending_plan_calls_) {
*(call->mut_response()->mutable_plan()) = *plan_;
call->SendResponse();
}
ENQUEUE_REQUEST(PushPlan);
}
void CtrlServer::ClearPlanHandler(
CtrlCall<ClearPlanRequest, ClearPlanResponse>* call) {
plan_.reset();
ENQUEUE_REQUEST(ClearPlan);
}
void CtrlServer::PullPlanHandler(
CtrlCall<PullPlanRequest, PullPlanResponse>* call) {
if (plan_) {
*(call->mut_response()->mutable_plan()) = *plan_;
call->SendResponse();
} else {
pending_plan_calls_.push_back(call);
}
ENQUEUE_REQUEST(FetchPlan);
ENQUEUE_REQUEST(PullPlan);
}
void CtrlServer::PushPortHandler(
CtrlCall<PushPortRequest, PushPortResponse>* call) {
port_ = call->request().port();
for (auto call : pending_port_calls_) {
call->mut_response()->set_port(port_);
call->SendResponse();
}
ENQUEUE_REQUEST(PushPort);
}
void CtrlServer::ClearPortHandler(
CtrlCall<ClearPortRequest, ClearPortResponse>* call) {
port_ = -1;
ENQUEUE_REQUEST(ClearPort);
}
void CtrlServer::PullPortHandler(
CtrlCall<PullPortRequest, PullPortResponse>* call) {
if (port_ != -1) {
call->mut_response()->set_port(port_);
call->SendResponse();
} else {
pending_port_calls_.push_back(call);
}
ENQUEUE_REQUEST(PullPort);
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_SERVER_H_
#define ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_SERVER_H_
#ifndef ONEFLOW_CORE_CONTROL_CTRL_SERVER_H_
#define ONEFLOW_CORE_CONTROL_CTRL_SERVER_H_
#include "grpc++/alarm.h"
#include "grpc++/server_builder.h"
#include "oneflow/core/comm_network/rpc/ctrl_call.h"
#include "oneflow/core/control/ctrl_call.h"
namespace oneflow {
......@@ -13,7 +14,6 @@ class CtrlServer final {
~CtrlServer();
CtrlServer(const std::string& server_addr);
void PublishPlan(const Plan* plan);
private:
void HandleRpcs();
......@@ -29,19 +29,19 @@ class CtrlServer final {
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
std::unique_ptr<grpc::Server> grpc_server_;
std::thread loop_thread_;
// AddWorker
std::list<CtrlCallIf*> added_worker_calls_;
// Barrier
HashMap<std::string, std::pair<std::list<CtrlCallIf*>, int32_t>>
barrier_calls_;
// TryLock, NotifyDone, WaitUntilDone
HashMap<std::string, void*> name2lock_status_; // TODO: erase outdated item
// FetchPlan
std::mutex plan_mtx_;
const Plan* plan_;
std::list<CtrlCall<FetchPlanRequest, FetchPlanResponse>*> pending_plan_calls_;
HashMap<std::string, void*> name2lock_status_;
// PushPlan, PullPlan
std::unique_ptr<Plan> plan_;
std::list<CtrlCall<PullPlanRequest, PullPlanResponse>*> pending_plan_calls_;
// PushPort, ClearPort, PullPort
int32_t port_;
std::list<CtrlCall<PullPortRequest, PullPortResponse>*> pending_port_calls_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_SERVER_H_
#endif // ONEFLOW_CORE_CONTROL_CTRL_SERVER_H_
#include "oneflow/core/comm_network/rpc/ctrl_service.h"
#include "oneflow/core/control/ctrl_service.h"
#include "grpc++/impl/codegen/client_unary_call.h"
namespace oneflow {
......
#ifndef ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_SERVICE_H_
#define ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_SERVICE_H_
#ifndef ONEFLOW_CORE_CONTROL_CTRL_SERVICE_H_
#define ONEFLOW_CORE_CONTROL_CTRL_SERVICE_H_
#include "grpc++/grpc++.h"
#include "grpc++/impl/codegen/async_stream.h"
......@@ -10,19 +10,24 @@
#include "grpc++/impl/codegen/status.h"
#include "grpc++/impl/codegen/stub_options.h"
#include "grpc++/impl/codegen/sync_stream.h"
#include "oneflow/core/comm_network/rpc/control.pb.h"
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/control/control.pb.h"
namespace oneflow {
#define CTRL_METHOD_SEQ \
OF_PP_MAKE_TUPLE_SEQ(AddWorker) \
OF_PP_MAKE_TUPLE_SEQ(LoadServer) \
OF_PP_MAKE_TUPLE_SEQ(Barrier) \
OF_PP_MAKE_TUPLE_SEQ(TryLock) \
OF_PP_MAKE_TUPLE_SEQ(NotifyDone) \
OF_PP_MAKE_TUPLE_SEQ(WaitUntilDone) \
OF_PP_MAKE_TUPLE_SEQ(FetchPlan)
OF_PP_MAKE_TUPLE_SEQ(PushPlan) \
OF_PP_MAKE_TUPLE_SEQ(ClearPlan) \
OF_PP_MAKE_TUPLE_SEQ(PullPlan) \
OF_PP_MAKE_TUPLE_SEQ(PushPort) \
OF_PP_MAKE_TUPLE_SEQ(ClearPort) \
OF_PP_MAKE_TUPLE_SEQ(PullPort)
enum class CtrlMethod {
#define MAKE_ENTRY(method) k##method,
......@@ -66,4 +71,4 @@ class CtrlService final {
} // namespace oneflow
#endif // ONEFLOW_CORE_COMM_NETWORK_RPC_CTRL_SERVICE_H_
#endif // ONEFLOW_CORE_CONTROL_CTRL_SERVICE_H_
......@@ -3,7 +3,8 @@ package oneflow;
message Machine {
string addr = 1; // domain name or ip
string name = 4;
int32 port = 2;
string name = 3;
}
enum DeviceType {
......@@ -15,8 +16,6 @@ message Resource {
repeated Machine machine = 1;
int64 device_num_per_machine = 2;
DeviceType device_type = 3;
int32 port_min = 4;
int32 port_max = 5;
}
// If one machine named "machine_xxx" and device_num_per_machine = 4
......
#include <gflags/gflags.h>
#include "oneflow/core/comm_network/epoll/epoll_data_comm_network.h"
#include "oneflow/core/control/ctrl_client.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/runtime_context.h"
#include "oneflow/core/kernel/kernel_manager.h"
#include "oneflow/core/thread/thread_manager.h"
DEFINE_string(plan_filepath, "", "");
DEFINE_string(this_machine_name, "", "");
DEFINE_int32(ctrl_port, -1, "");
DEFINE_int32(data_port, -1, "");
namespace oneflow {
class Runtime final {
......@@ -82,10 +78,10 @@ void Runtime::NewAllSingleton(const Plan& plan,
JobDesc::NewSingleton(plan.job_desc());
IDMgr::NewSingleton();
RuntimeCtx::NewSingleton(this_machine_name);
CtrlCommNet::NewSingleton(FLAGS_ctrl_port);
CtrlClient::NewSingleton();
KernelMgr::NewSingleton(plan);
#ifdef PLATFORM_POSIX
EpollDataCommNet::Init(FLAGS_data_port);
EpollDataCommNet::Init();
#endif
SnapshotMgr::NewSingleton(plan);
RegstMgr::NewSingleton();
......@@ -100,7 +96,7 @@ void Runtime::DeleteAllSingleton() {
SnapshotMgr::DeleteSingleton();
delete DataCommNet::Singleton();
KernelMgr::DeleteSingleton();
CtrlCommNet::DeleteSingleton();
CtrlClient::DeleteSingleton();
RuntimeCtx::DeleteSingleton();
IDMgr::DeleteSingleton();
JobDesc::DeleteSingleton();
......@@ -122,6 +118,9 @@ void Runtime::SendCmdMsg(const std::vector<const TaskProto*>& tasks,
} // namespace oneflow
DEFINE_string(plan_filepath, "", "");
DEFINE_string(this_machine_name, "", "");
int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]);
google::ParseCommandLineFlags(&argc, &argv, true);
......
......@@ -2,8 +2,9 @@
namespace oneflow {
std::string RuntimeCtx::GetAddr(int64_t machine_id) const {
return JobDesc::Singleton()->resource().machine(machine_id).addr();
std::string RuntimeCtx::GetCtrlAddr(int64_t machine_id) const {
const Machine& mchn = JobDesc::Singleton()->resource().machine(machine_id);
return mchn.addr() + ":" + std::to_string(mchn.port());
}
PersistentInStream* RuntimeCtx::GetDataInStream(const std::string& name) {
......
......@@ -17,9 +17,9 @@ class RuntimeCtx final {
int64_t this_machine_id() const { return this_machine_id_; }
bool IsThisMachineMaster() const { return this_machine_id_ == 0; }
std::string GetThisAddr() const { return GetAddr(this_machine_id_); }
std::string GetMasterCtrlAddr() const { return GetAddr(0); }
std::string GetAddr(int64_t machine_id) const;
std::string GetThisCtrlAddr() const { return GetCtrlAddr(this_machine_id_); }
std::string GetMasterCtrlAddr() const { return GetCtrlAddr(0); }
std::string GetCtrlAddr(int64_t machine_id) const;
ThreadSafeCounter& mut_model_init_cnt() { return model_init_cnt_; }
......
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "oneflow/core/comm_network/ctrl_comm_network.h"
#include "oneflow/core/common/str_util.h"
#include "oneflow/core/control/ctrl_client.h"
#include "oneflow/core/control/ctrl_server.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/plan.pb.h"
#include "oneflow/core/job/runtime_context.h"
......@@ -21,13 +22,12 @@ class Scheduler final {
private:
Scheduler() = default;
uint16_t GetNextPort();
void NewAllSingleton(const std::string& job_conf_filepath,
const std::string& this_machine_name, char** env);
std::string GetEnvPrefix();
void DeleteAllSingleton();
void SystemCall(const std::string& cmd);
uint16_t next_port_;
std::unique_ptr<CtrlServer> ctrl_server_;
std::string env_prefix_;
};
......@@ -44,13 +44,13 @@ void Scheduler::Process(const std::string& job_conf_filepath,
<< "-plan_filepath=" << naive_plan_filepath;
SystemCall(compile_cmd.str());
ParseProtoFromTextFile(naive_plan_filepath, plan.get());
CtrlCommNet::Singleton()->PublishPlan(plan.get());
CtrlClient::Singleton()->PushPlan(*plan);
} else {
CtrlCommNet::Singleton()->FetchPlan(plan.get());
CtrlClient::Singleton()->PullPlan(plan.get());
}
OF_BARRIER();
if (RuntimeCtx::Singleton()->IsThisMachineMaster()) {
CtrlCommNet::Singleton()->PublishPlan(nullptr);
CtrlClient::Singleton()->ClearPlan();
} else {
PrintProtoToTextFile(*plan, naive_plan_filepath);
}
......@@ -58,15 +58,9 @@ void Scheduler::Process(const std::string& job_conf_filepath,
std::stringstream runtime_cmd;
runtime_cmd << "./runtime "
<< "-plan_filepath=" << naive_plan_filepath << " "
<< "-this_machine_name=" << this_machine_name << " "
<< "-ctrl_port=" << GetNextPort() << " "
<< "-data_port=" << GetNextPort();
<< "-this_machine_name=" << this_machine_name;
SystemCall(runtime_cmd.str());
}
uint16_t Scheduler::GetNextPort() {
CHECK_LE(next_port_, JobDesc::Singleton()->resource().port_max());
return next_port_++;
DeleteAllSingleton();
}
void Scheduler::NewAllSingleton(const std::string& job_conf_filepath,
......@@ -75,10 +69,11 @@ void Scheduler::NewAllSingleton(const std::string& job_conf_filepath,
oneflow::JobConf job_conf;
oneflow::ParseProtoFromTextFile(job_conf_filepath, &job_conf);
JobDesc::NewSingleton(job_conf);
next_port_ = JobDesc::Singleton()->resource().port_min();
IDMgr::NewSingleton();
RuntimeCtx::NewSingleton(this_machine_name);
CtrlCommNet::NewSingleton(GetNextPort());
ctrl_server_.reset(
new CtrlServer(RuntimeCtx::Singleton()->GetThisCtrlAddr()));
CtrlClient::NewSingleton();
env_prefix_ = "";
std::stringstream ss;
while (*env) {
......@@ -88,6 +83,14 @@ void Scheduler::NewAllSingleton(const std::string& job_conf_filepath,
env_prefix_ = ss.str();
}
void Scheduler::DeleteAllSingleton() {
CtrlClient::DeleteSingleton();
ctrl_server_.reset();
RuntimeCtx::DeleteSingleton();
IDMgr::DeleteSingleton();
JobDesc::DeleteSingleton();
}
void Scheduler::SystemCall(const std::string& cmd) {
LOG(INFO) << "SystemCall: [" << cmd << "]";
CHECK_EQ(std::system(cmd.c_str()), 0);
......
#ifndef ONEFLOW_CORE_PERSISTENCE_SNAPSHOT_H_
#define ONEFLOW_CORE_PERSISTENCE_SNAPSHOT_H_
#include "oneflow/core/comm_network/ctrl_comm_network.h"
#include "oneflow/core/control/ctrl_client.h"
#include "oneflow/core/persistence/normal_persistent_in_stream.h"
#include "oneflow/core/persistence/persistent_out_stream.h"
......
......@@ -8,7 +8,7 @@ SCHEDULER_CMD='GLOG_logtostderr=0 GLOG_log_dir=./log GLOG_v=0 GLOG_logbuflevel=-
set +e
for host in "${hosts[@]}"
do
ssh $USER@$host "/usr/sbin/fuser -k 9000/tcp 9001/tcp 9002/tcp 9003/tcp 9004/tcp"
ssh $USER@$host "/usr/sbin/fuser -k 9000/tcp"
done
set -e
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册