提交 9ee5d499 编写于 作者: W willzhang4a58

barrier

上级 5e5b38aa
......@@ -2,6 +2,8 @@ syntax = "proto3";
package oneflow;
message BarrierRequest {
string name = 1;
int32 num = 2;
}
message BarrierResponse {
......
......@@ -40,4 +40,24 @@ void CtrlCommNet::Init() {
CHECK_LT(retry_idx, max_retry_num);
}
void CtrlCommNet::Barrier(const std::string& barrier_name) {
Barrier(barrier_name, JobDesc::Singleton()->TotalMachineNum());
}
void CtrlCommNet::Barrier(const std::string& barrier_name,
int32_t barrier_num) {
grpc::ClientContext client_ctx;
BarrierRequest request;
request.set_name(barrier_name);
request.set_num(barrier_num);
BarrierResponse response;
GetMasterStub()->Barrier(&client_ctx, request, &response);
}
CtrlService::Stub* CtrlCommNet::GetResponsibleStub(const std::string& key) {
int64_t machine_id =
(std::hash<std::string>{}(key)) % JobDesc::Singleton()->TotalMachineNum();
return stubs_[machine_id].get();
}
} // namespace oneflow
......@@ -15,7 +15,8 @@ class CtrlCommNet final {
void Init();
void Barrier(const std::string& barrier_name) {}
void Barrier(const std::string& barrier_name);
void Barrier(const std::string& barrier_name, int32_t barrier_num);
// 0 : locked
// 1 : done
......@@ -36,6 +37,7 @@ class CtrlCommNet final {
private:
CtrlCommNet() = default;
CtrlService::Stub* GetMasterStub() { return stubs_[0].get(); }
CtrlService::Stub* GetResponsibleStub(const std::string& key);
std::unique_ptr<CtrlServer> ctrl_server_;
std::vector<std::unique_ptr<CtrlService::Stub>> stubs_;
......
......@@ -18,9 +18,8 @@ namespace oneflow {
CtrlServer::~CtrlServer() {
grpc::Alarm alarm(cq_.get(), gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
loop_thread_.join();
grpc_server_.reset();
cq_.reset();
grpc_service_.reset();
grpc_server_->Shutdown();
cq_->Shutdown();
}
CtrlServer::CtrlServer(const std::string& server_addr) {
......@@ -48,7 +47,6 @@ void CtrlServer::HandleRpcs() {
if (call) {
call->Process();
} else {
cq_->Shutdown();
break;
}
}
......@@ -61,14 +59,36 @@ void CtrlServer::AddWorkerHandler(CtrlCallIf* call) {
auto addworker_call = static_cast<AddWorkerCtrlCall*>(call);
LOG(INFO) << "Add Worker " << addworker_call->request().worker_ctrl_addr();
if (added_worker_calls_.size() == JobDesc::Singleton()->TotalMachineNum()) {
for (CtrlCallIf* added_call : added_worker_calls_) {
added_call->SendResponse();
for (CtrlCallIf* pending_call : added_worker_calls_) {
pending_call->SendResponse();
}
added_worker_calls_.clear();
}
ENQUEUE_REQUEST(AddWorker);
}
void CtrlServer::BarrierHandler(CtrlCallIf* call) { TODO(); }
void CtrlServer::BarrierHandler(CtrlCallIf* call) {
using BarrierCtrlCall = CtrlCall<BarrierRequest, BarrierResponse>;
auto barrier_call = static_cast<BarrierCtrlCall*>(call);
const std::string& barrier_name = barrier_call->request().name();
int32_t barrier_num = barrier_call->request().num();
auto barrier_call_it = barrier_calls_.find(barrier_name);
if (barrier_call_it == barrier_calls_.end()) {
barrier_call_it =
barrier_calls_
.emplace(barrier_name,
std::make_pair(std::list<CtrlCallIf*>{}, barrier_num))
.first;
}
CHECK_EQ(barrier_num, barrier_call_it->second.second);
barrier_call_it->second.first.push_back(call);
if (barrier_call_it->second.first.size() == barrier_call_it->second.second) {
for (CtrlCallIf* pending_call : barrier_call_it->second.first) {
pending_call->SendResponse();
}
barrier_calls_.erase(barrier_call_it);
}
ENQUEUE_REQUEST(Barrier);
}
} // namespace oneflow
......@@ -27,8 +27,10 @@ class CtrlServer final {
std::unique_ptr<CtrlService::AsyncService> grpc_service_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
std::unique_ptr<grpc::Server> grpc_server_;
std::list<CtrlCallIf*> added_worker_calls_;
std::thread loop_thread_;
std::list<CtrlCallIf*> added_worker_calls_;
HashMap<std::string, std::pair<std::list<CtrlCallIf*>, int32_t>>
barrier_calls_;
};
} // namespace oneflow
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册