未验证 提交 aebb9f5b 编写于 作者: qq_22305325's avatar qq_22305325 提交者: GitHub

Refactor ctrl server and client (#4234)

* refactor CtrlClient

* RpcServer

* use Rpc

* Update rpc_client.cpp

add comment
Co-authored-by: Nlixinqi <lixinqi0703106@163.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 363b8de9
......@@ -14,175 +14,13 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/control/ctrl_client.h"
#include "oneflow/core/job/machine_context.h"
#include "oneflow/core/job/env_desc.h"
namespace oneflow {
namespace {
const int32_t max_retry_num = 60;
const int64_t sleep_seconds = 10;
#define GRPC_CHECK(x) CHECK_EQ(x.error_code(), grpc::StatusCode::OK)
template<CtrlMethod ctrl_method>
class ClientCall final {
public:
OF_DISALLOW_COPY_AND_MOVE(ClientCall);
ClientCall() = default;
~ClientCall() = default;
CtrlRequest<ctrl_method>* mut_request() { return &request_; }
const CtrlResponse<ctrl_method>& response() const { return response_; }
void operator()(CtrlService::Stub* stub) {
grpc::ClientContext client_ctx;
GRPC_CHECK(stub->CallMethod<ctrl_method>(&client_ctx, request_, &response_));
}
private:
CtrlRequest<ctrl_method> request_;
CtrlResponse<ctrl_method> response_;
};
} // namespace
CtrlClient::~CtrlClient() {
{
std::unique_lock<std::mutex> lck(need_heartbeat_thread_stop_mtx_);
need_heartbeat_thread_stop_ = true;
}
heartbeat_thread_.join();
}
void CtrlClient::Barrier(const std::string& barrier_name) {
Barrier(barrier_name, Global<EnvDesc>::Get()->TotalMachineNum());
}
void CtrlClient::Barrier(const std::string& barrier_name, int32_t barrier_num) {
ClientCall<CtrlMethod::kBarrier> call;
call.mut_request()->set_name(barrier_name);
call.mut_request()->set_num(barrier_num);
call(GetMasterStub());
}
TryLockResult CtrlClient::TryLock(const std::string& name) {
{
std::unique_lock<std::mutex> lck(done_names_mtx_);
if (done_names_.find(name) != done_names_.end()) { return TryLockResult::kDone; }
}
ClientCall<CtrlMethod::kTryLock> call;
call.mut_request()->set_name(name);
call(GetResponsibleStub(name));
if (call.response().result() == TryLockResult::kDone) {
std::unique_lock<std::mutex> lck(done_names_mtx_);
done_names_.insert(name);
}
return call.response().result();
}
void CtrlClient::NotifyDone(const std::string& name) {
ClientCall<CtrlMethod::kNotifyDone> call;
call.mut_request()->set_name(name);
call(GetResponsibleStub(name));
}
void CtrlClient::WaitUntilDone(const std::string& name) {
ClientCall<CtrlMethod::kWaitUntilDone> call;
call.mut_request()->set_name(name);
call(GetResponsibleStub(name));
}
void CtrlClient::PushKV(const std::string& k, std::function<void(std::string*)> VSetter) {
ClientCall<CtrlMethod::kPushKV> call;
call.mut_request()->set_key(k);
VSetter(call.mut_request()->mutable_val());
call(GetResponsibleStub(k));
}
void CtrlClient::PushMasterKV(const std::string& k, std::function<void(std::string*)> VSetter) {
ClientCall<CtrlMethod::kPushKV> call;
call.mut_request()->set_key(k);
VSetter(call.mut_request()->mutable_val());
call(GetMasterStub());
}
void CtrlClient::PushKV(const std::string& k, const std::string& v) {
PushKV(k, [&](std::string* o) { *o = v; });
}
void CtrlClient::PushKV(const std::string& k, const PbMessage& msg) {
PushKV(k, [&](std::string* o) { msg.SerializeToString(o); });
}
void CtrlClient::PushMasterKV(const std::string& k, const PbMessage& msg) {
PushMasterKV(k, [&](std::string* o) { msg.SerializeToString(o); });
}
void CtrlClient::ClearKV(const std::string& k) {
ClientCall<CtrlMethod::kClearKV> call;
call.mut_request()->set_key(k);
call(GetResponsibleStub(k));
}
void CtrlClient::ClearMasterKV(const std::string& k) {
ClientCall<CtrlMethod::kClearKV> call;
call.mut_request()->set_key(k);
call(GetMasterStub());
}
void CtrlClient::PullKV(const std::string& k, std::function<void(const std::string&)> VGetter) {
ClientCall<CtrlMethod::kPullKV> call;
call.mut_request()->set_key(k);
call(GetResponsibleStub(k));
VGetter(call.response().val());
}
void CtrlClient::PullMasterKV(const std::string& k,
std::function<void(const std::string&)> VGetter) {
ClientCall<CtrlMethod::kPullKV> call;
call.mut_request()->set_key(k);
call(GetMasterStub());
VGetter(call.response().val());
}
void CtrlClient::PullKV(const std::string& k, std::string* v) {
PullKV(k, [&](const std::string& i) { *v = i; });
}
void CtrlClient::PullKV(const std::string& k, PbMessage* msg) {
PullKV(k, [&](const std::string& i) { msg->ParseFromString(i); });
}
void CtrlClient::PullMasterKV(const std::string& k, PbMessage* msg) {
PullMasterKV(k, [&](const std::string& i) { msg->ParseFromString(i); });
}
void CtrlClient::PushActEvent(const ActEvent& act_event) {
ClientCall<CtrlMethod::kPushActEvent> call;
*(call.mut_request()->mutable_act_event()) = act_event;
call(GetMasterStub());
}
void CtrlClient::Clear() {
ClientCall<CtrlMethod::kClear> call;
call(GetThisStub());
std::unique_lock<std::mutex> lck(done_names_mtx_);
done_names_.clear();
}
int32_t CtrlClient::IncreaseCount(const std::string& k, int32_t v) {
ClientCall<CtrlMethod::kIncreaseCount> call;
call.mut_request()->set_key(k);
call.mut_request()->set_val(v);
call(GetResponsibleStub(k));
return call.response().val();
}
void CtrlClient::EraseCount(const std::string& k) {
ClientCall<CtrlMethod::kEraseCount> call;
call.mut_request()->set_key(k);
call(GetResponsibleStub(k));
}
CtrlClient::CtrlClient() {
......@@ -219,36 +57,4 @@ CtrlClient::CtrlClient() {
});
}
void CtrlClient::LoadServer(const std::string& server_addr, CtrlService::Stub* stub) {
int32_t retry_idx = 0;
for (; retry_idx < max_retry_num; ++retry_idx) {
grpc::ClientContext client_ctx;
LoadServerRequest request;
request.set_addr(server_addr);
LoadServerResponse response;
grpc::Status st = stub->CallMethod<CtrlMethod::kLoadServer>(&client_ctx, request, &response);
if (st.error_code() == grpc::StatusCode::OK) {
LOG(INFO) << "LoadServer " << server_addr << " Successful at " << retry_idx << " times";
break;
} else if (st.error_code() == grpc::StatusCode::UNAVAILABLE) {
LOG(INFO) << "LoadServer " << server_addr << " Failed at " << retry_idx << " times"
<< " error_code " << st.error_code() << " error_message " << st.error_message();
std::this_thread::sleep_for(std::chrono::seconds(sleep_seconds));
continue;
} else {
LOG(FATAL) << st.error_message();
}
}
CHECK_LT(retry_idx, max_retry_num);
}
CtrlService::Stub* CtrlClient::GetThisStub() {
return stubs_[Global<MachineCtx>::Get()->this_machine_id()].get();
}
CtrlService::Stub* CtrlClient::GetResponsibleStub(const std::string& key) {
int64_t machine_id = (std::hash<std::string>{}(key)) % Global<EnvDesc>::Get()->TotalMachineNum();
return stubs_[machine_id].get();
}
} // namespace oneflow
......@@ -16,72 +16,18 @@ limitations under the License.
#ifndef ONEFLOW_CORE_CONTROL_CTRL_CLIENT_H_
#define ONEFLOW_CORE_CONTROL_CTRL_CLIENT_H_
#include "oneflow/core/actor/actor_message.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/control/ctrl_service.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/control/rpc_client.h"
namespace oneflow {
class CtrlClient final {
class CtrlClient final : public RpcClient {
public:
OF_DISALLOW_COPY_AND_MOVE(CtrlClient);
~CtrlClient();
void Barrier(const std::string& barrier_name);
void Barrier(const std::string& barrier_name, int32_t barrier_num);
TryLockResult TryLock(const std::string& name);
void NotifyDone(const std::string& name);
void WaitUntilDone(const std::string& name);
void PushKV(const std::string& k, std::function<void(std::string*)> VSetter);
void PushKV(const std::string& k, const std::string& v);
void PushKV(const std::string& k, const PbMessage& msg);
void PushMasterKV(const std::string& k, const PbMessage& msg);
template<typename T>
typename std::enable_if<std::is_arithmetic<T>::value>::type PushKVT(const std::string& k, T v) {
PushKV(k, std::to_string(v));
}
void ClearKV(const std::string& k);
void ClearMasterKV(const std::string& k);
void PullKV(const std::string& k, std::function<void(const std::string&)> VGetter);
void PullKV(const std::string& k, std::string* v);
void PullKV(const std::string& k, PbMessage* msg);
void PullMasterKV(const std::string& k, PbMessage* msg);
template<typename T>
typename std::enable_if<std::is_arithmetic<T>::value>::type PullKVT(const std::string& k, T* v) {
std::string v_str;
PullKV(k, &v_str);
*v = oneflow_cast<T>(v_str);
}
void PushActEvent(const ActEvent&);
void Clear();
int32_t IncreaseCount(const std::string& k, int32_t v);
int32_t IncreaseCount(const std::string& k) { return IncreaseCount(k, 1); }
void EraseCount(const std::string& k);
~CtrlClient() override = default;
private:
friend class Global<CtrlClient>;
CtrlClient();
void LoadServer(const std::string& server_addr, CtrlService::Stub* stub);
void PushMasterKV(const std::string& k, std::function<void(std::string*)> VSetter);
void PullMasterKV(const std::string& k, std::function<void(const std::string&)> VGetter);
CtrlService::Stub* GetMasterStub() { return stubs_[0].get(); }
CtrlService::Stub* GetThisStub();
CtrlService::Stub* GetResponsibleStub(const std::string& key);
std::vector<std::unique_ptr<CtrlService::Stub>> stubs_;
std::mutex done_names_mtx_;
HashSet<std::string> done_names_;
bool need_heartbeat_thread_stop_;
std::mutex need_heartbeat_thread_stop_mtx_;
std::thread heartbeat_thread_;
};
#define FILE_LINE_STR __FILE__ ":" OF_PP_STRINGIZE(__LINE__)
......
......@@ -21,23 +21,7 @@ limitations under the License.
namespace oneflow {
namespace {
int ExtractPortFromAddr(const std::string& addr) {
size_t pos = addr.find(':');
return oneflow_cast<int>(addr.substr(pos + 1));
}
} // namespace
CtrlServer::~CtrlServer() {
// NOTE(chengcheng): This enqueues a special event (with a null tag) that causes
// the completion queue to be shut down on the polling thread.
grpc::Alarm alarm(cq_.get(), gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
loop_thread_.join();
}
CtrlServer::CtrlServer() : is_first_connect_(true), this_machine_addr_("") {
CtrlServer::CtrlServer() : RpcServer(), is_first_connect_(true), this_machine_addr_("") {
Init();
int port = Global<EnvDesc>::Get()->ctrl_port();
grpc::ServerBuilder server_builder;
......@@ -55,183 +39,15 @@ CtrlServer::CtrlServer() : is_first_connect_(true), this_machine_addr_("") {
loop_thread_ = std::thread(&CtrlServer::HandleRpcs, this);
}
void CtrlServer::HandleRpcs() {
EnqueueRequests();
void* tag = nullptr;
bool ok = false;
// NOTE(chengcheng): The is_shutdown bool flag make sure that 'ok = false' occurs ONLY after
// cq_->Shutdown() for security check.
bool is_shutdown = false;
// NOTE(chengcheng): The final end is that cq_->Next() get false and cq_ is empty with no item.
while (cq_->Next(&tag, &ok)) {
auto call = static_cast<CtrlCallIf*>(tag);
if (!ok) {
// NOTE(chengcheng): After call grpc_server_->Shutdown() and cq_->Shutdown(),
// there will trigger some cancel tag items on each RPC. And cq_->Next() can get these tag
// with ok = false. Then delete the tag with CtrlCallIf pointer for recovery.
CHECK(is_shutdown);
CHECK(call);
delete call;
continue;
}
if (call) {
call->Process();
} else {
// NOTE(chengcheng): A null `call` indicates that this is the shutdown alarm.
CHECK(!is_shutdown);
is_shutdown = true;
grpc_server_->Shutdown();
cq_->Shutdown();
// NOTE(chengcheng): You CANNOT use code 'break;' in this block because that
// there still be items in the cq_.
// 'break;'
}
void CtrlServer::OnLoadServer(CtrlCall<CtrlMethod::kLoadServer>* call) {
if (this->is_first_connect_) {
this->this_machine_addr_ = call->request().addr();
this->is_first_connect_ = false;
} else {
CHECK_EQ(call->request().addr(), this->this_machine_addr_);
}
}
void CtrlServer::Init() {
Add([this](CtrlCall<CtrlMethod::kLoadServer>* call) {
if (this->is_first_connect_) {
this->this_machine_addr_ = call->request().addr();
this->is_first_connect_ = false;
} else {
CHECK_EQ(call->request().addr(), this->this_machine_addr_);
}
call->SendResponse();
EnqueueRequest<CtrlMethod::kLoadServer>();
});
Add([this](CtrlCall<CtrlMethod::kBarrier>* call) {
const std::string& barrier_name = call->request().name();
int32_t barrier_num = call->request().num();
auto barrier_call_it = barrier_calls_.find(barrier_name);
if (barrier_call_it == barrier_calls_.end()) {
barrier_call_it =
barrier_calls_
.emplace(barrier_name, std::make_pair(std::list<CtrlCallIf*>{}, barrier_num))
.first;
}
CHECK_EQ(barrier_num, barrier_call_it->second.second);
barrier_call_it->second.first.push_back(call);
if (barrier_call_it->second.first.size() == barrier_call_it->second.second) {
for (CtrlCallIf* pending_call : barrier_call_it->second.first) {
pending_call->SendResponse();
}
barrier_calls_.erase(barrier_call_it);
}
EnqueueRequest<CtrlMethod::kBarrier>();
});
Add([this](CtrlCall<CtrlMethod::kTryLock>* call) {
const std::string& lock_name = call->request().name();
auto name2lock_status_it = name2lock_status_.find(lock_name);
if (name2lock_status_it == name2lock_status_.end()) {
call->mut_response()->set_result(TryLockResult::kLocked);
auto waiting_until_done_calls = new std::list<CtrlCallIf*>;
CHECK(name2lock_status_.emplace(lock_name, waiting_until_done_calls).second);
} else {
if (name2lock_status_it->second) {
call->mut_response()->set_result(TryLockResult::kDoing);
} else {
call->mut_response()->set_result(TryLockResult::kDone);
}
}
call->SendResponse();
EnqueueRequest<CtrlMethod::kTryLock>();
});
Add([this](CtrlCall<CtrlMethod::kNotifyDone>* call) {
const std::string& lock_name = call->request().name();
auto name2lock_status_it = name2lock_status_.find(lock_name);
auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(name2lock_status_it->second);
for (CtrlCallIf* waiting_call : *waiting_calls) { waiting_call->SendResponse(); }
delete waiting_calls;
name2lock_status_it->second = nullptr;
call->SendResponse();
EnqueueRequest<CtrlMethod::kNotifyDone>();
});
Add([this](CtrlCall<CtrlMethod::kWaitUntilDone>* call) {
const std::string& lock_name = call->request().name();
void* lock_status = name2lock_status_.at(lock_name);
if (lock_status) {
auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(lock_status);
waiting_calls->push_back(call);
} else {
call->SendResponse();
}
EnqueueRequest<CtrlMethod::kWaitUntilDone>();
});
Add([this](CtrlCall<CtrlMethod::kPushKV>* call) {
const std::string& k = call->request().key();
const std::string& v = call->request().val();
CHECK(kv_.emplace(k, v).second);
auto pending_kv_calls_it = pending_kv_calls_.find(k);
if (pending_kv_calls_it != pending_kv_calls_.end()) {
for (auto pending_call : pending_kv_calls_it->second) {
pending_call->mut_response()->set_val(v);
pending_call->SendResponse();
}
pending_kv_calls_.erase(pending_kv_calls_it);
}
call->SendResponse();
EnqueueRequest<CtrlMethod::kPushKV>();
});
Add([this](CtrlCall<CtrlMethod::kClearKV>* call) {
const std::string& k = call->request().key();
CHECK_EQ(kv_.erase(k), 1);
CHECK(pending_kv_calls_.find(k) == pending_kv_calls_.end());
call->SendResponse();
EnqueueRequest<CtrlMethod::kClearKV>();
});
Add([this](CtrlCall<CtrlMethod::kPullKV>* call) {
const std::string& k = call->request().key();
auto kv_it = kv_.find(k);
if (kv_it != kv_.end()) {
call->mut_response()->set_val(kv_it->second);
call->SendResponse();
} else {
pending_kv_calls_[k].push_back(call);
}
EnqueueRequest<CtrlMethod::kPullKV>();
});
Add([this](CtrlCall<CtrlMethod::kPushActEvent>* call) {
ActEvent act_event = call->request().act_event();
call->SendResponse();
Global<ActEventLogger>::Get()->PrintActEventToLogDir(act_event);
EnqueueRequest<CtrlMethod::kPushActEvent>();
});
Add([this](CtrlCall<CtrlMethod::kClear>* call) {
name2lock_status_.clear();
kv_.clear();
CHECK(pending_kv_calls_.empty()) << "size(): " << pending_kv_calls_.size()
<< ", begin()->key: " << pending_kv_calls_.begin()->first;
call->SendResponse();
EnqueueRequest<CtrlMethod::kClear>();
});
Add([this](CtrlCall<CtrlMethod::kIncreaseCount>* call) {
int32_t& count = count_[call->request().key()];
count += call->request().val();
call->mut_response()->set_val(count);
call->SendResponse();
EnqueueRequest<CtrlMethod::kIncreaseCount>();
});
Add([this](CtrlCall<CtrlMethod::kEraseCount>* call) {
CHECK_EQ(count_.erase(call->request().key()), 1);
call->SendResponse();
EnqueueRequest<CtrlMethod::kEraseCount>();
});
call->SendResponse();
EnqueueRequest<CtrlMethod::kLoadServer>();
}
} // namespace oneflow
......@@ -16,83 +16,20 @@ limitations under the License.
#ifndef ONEFLOW_CORE_CONTROL_CTRL_SERVER_H_
#define ONEFLOW_CORE_CONTROL_CTRL_SERVER_H_
#include <grpc++/alarm.h>
#include <grpc++/server_builder.h>
#include "oneflow/core/control/ctrl_call.h"
#include "oneflow/core/common/function_traits.h"
#include "oneflow/core/control/rpc_server.h"
namespace oneflow {
namespace {
template<size_t... Idx>
static std::tuple<std::function<void(CtrlCall<(CtrlMethod)Idx>*)>...> GetHandlerTuple(
std::index_sequence<Idx...>) {
return {};
}
} // namespace
class CtrlServer final {
class CtrlServer final : public RpcServer {
public:
OF_DISALLOW_COPY_AND_MOVE(CtrlServer);
~CtrlServer();
~CtrlServer() override {}
CtrlServer();
const std::string& this_machine_addr() { return this_machine_addr_; }
private:
void HandleRpcs();
void Init();
void EnqueueRequests() {
for_each_i(handlers_, helper{this}, std::make_index_sequence<kCtrlMethodNum>{});
}
template<CtrlMethod kMethod>
void EnqueueRequest() {
constexpr const size_t I = (size_t)kMethod;
auto handler = std::get<I>(handlers_);
auto call = new CtrlCall<(CtrlMethod)I>();
call->set_request_handler(std::bind(handler, call));
grpc_service_->RequestAsyncUnary(I, call->mut_server_ctx(), call->mut_request(),
call->mut_responder(), cq_.get(), cq_.get(), call);
}
template<typename F>
void Add(F f) {
using args_type = typename function_traits<F>::args_type;
using arg_type =
typename std::remove_pointer<typename std::tuple_element<0, args_type>::type>::type;
std::get<arg_type::value>(handlers_) = std::move(f);
}
struct helper {
helper(CtrlServer* s) : s_(s) {}
template<typename T, typename V>
void operator()(const T& t, V) {
s_->EnqueueRequest<(CtrlMethod)V::value>();
}
CtrlServer* s_;
};
using HandlerTuple = decltype(GetHandlerTuple(std::make_index_sequence<kCtrlMethodNum>{}));
HandlerTuple handlers_;
std::unique_ptr<CtrlService::AsyncService> grpc_service_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
std::unique_ptr<grpc::Server> grpc_server_;
std::thread loop_thread_;
// Barrier
HashMap<std::string, std::pair<std::list<CtrlCallIf*>, int32_t>> barrier_calls_;
// TryLock, NotifyDone, WaitUntilDone
HashMap<std::string, void*> name2lock_status_;
// PushKV, ClearKV, PullKV
HashMap<std::string, std::string> kv_;
HashMap<std::string, std::list<CtrlCall<CtrlMethod::kPullKV>*>> pending_kv_calls_;
// IncreaseCount, EraseCount
HashMap<std::string, int32_t> count_;
void OnLoadServer(CtrlCall<CtrlMethod::kLoadServer>* call) override;
bool is_first_connect_;
std::string this_machine_addr_;
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/control/rpc_client.h"
#include "oneflow/core/job/machine_context.h"
#include "oneflow/core/job/env_desc.h"
namespace oneflow {
namespace {
const int32_t max_retry_num = 60;
const int64_t sleep_seconds = 10;
#define GRPC_CHECK(x) CHECK_EQ(x.error_code(), grpc::StatusCode::OK)
template<CtrlMethod ctrl_method>
class ClientCall final {
public:
OF_DISALLOW_COPY_AND_MOVE(ClientCall);
ClientCall() = default;
~ClientCall() = default;
CtrlRequest<ctrl_method>* mut_request() { return &request_; }
const CtrlResponse<ctrl_method>& response() const { return response_; }
void operator()(CtrlService::Stub* stub) {
grpc::ClientContext client_ctx;
GRPC_CHECK(stub->CallMethod<ctrl_method>(&client_ctx, request_, &response_));
}
private:
CtrlRequest<ctrl_method> request_;
CtrlResponse<ctrl_method> response_;
};
} // namespace
RpcClient::~RpcClient() {
{
std::unique_lock<std::mutex> lck(need_heartbeat_thread_stop_mtx_);
need_heartbeat_thread_stop_ = true;
}
heartbeat_thread_.join();
}
void RpcClient::Barrier(const std::string& barrier_name) {
// TODO(hanbinbin): depend world_size of Global<CtrlConf>
Barrier(barrier_name, Global<EnvDesc>::Get()->TotalMachineNum());
}
void RpcClient::Barrier(const std::string& barrier_name, int32_t barrier_num) {
ClientCall<CtrlMethod::kBarrier> call;
call.mut_request()->set_name(barrier_name);
call.mut_request()->set_num(barrier_num);
call(GetMasterStub());
}
TryLockResult RpcClient::TryLock(const std::string& name) {
{
std::unique_lock<std::mutex> lck(done_names_mtx_);
if (done_names_.find(name) != done_names_.end()) { return TryLockResult::kDone; }
}
ClientCall<CtrlMethod::kTryLock> call;
call.mut_request()->set_name(name);
call(GetResponsibleStub(name));
if (call.response().result() == TryLockResult::kDone) {
std::unique_lock<std::mutex> lck(done_names_mtx_);
done_names_.insert(name);
}
return call.response().result();
}
void RpcClient::NotifyDone(const std::string& name) {
ClientCall<CtrlMethod::kNotifyDone> call;
call.mut_request()->set_name(name);
call(GetResponsibleStub(name));
}
void RpcClient::WaitUntilDone(const std::string& name) {
ClientCall<CtrlMethod::kWaitUntilDone> call;
call.mut_request()->set_name(name);
call(GetResponsibleStub(name));
}
void RpcClient::PushKV(const std::string& k, std::function<void(std::string*)> VSetter) {
ClientCall<CtrlMethod::kPushKV> call;
call.mut_request()->set_key(k);
VSetter(call.mut_request()->mutable_val());
call(GetResponsibleStub(k));
}
void RpcClient::PushMasterKV(const std::string& k, std::function<void(std::string*)> VSetter) {
ClientCall<CtrlMethod::kPushKV> call;
call.mut_request()->set_key(k);
VSetter(call.mut_request()->mutable_val());
call(GetMasterStub());
}
void RpcClient::PushKV(const std::string& k, const std::string& v) {
PushKV(k, [&](std::string* o) { *o = v; });
}
void RpcClient::PushKV(const std::string& k, const PbMessage& msg) {
PushKV(k, [&](std::string* o) { msg.SerializeToString(o); });
}
void RpcClient::PushMasterKV(const std::string& k, const PbMessage& msg) {
PushMasterKV(k, [&](std::string* o) { msg.SerializeToString(o); });
}
void RpcClient::ClearKV(const std::string& k) {
ClientCall<CtrlMethod::kClearKV> call;
call.mut_request()->set_key(k);
call(GetResponsibleStub(k));
}
void RpcClient::ClearMasterKV(const std::string& k) {
ClientCall<CtrlMethod::kClearKV> call;
call.mut_request()->set_key(k);
call(GetMasterStub());
}
void RpcClient::PullKV(const std::string& k, std::function<void(const std::string&)> VGetter) {
ClientCall<CtrlMethod::kPullKV> call;
call.mut_request()->set_key(k);
call(GetResponsibleStub(k));
VGetter(call.response().val());
}
void RpcClient::PullMasterKV(const std::string& k,
std::function<void(const std::string&)> VGetter) {
ClientCall<CtrlMethod::kPullKV> call;
call.mut_request()->set_key(k);
call(GetMasterStub());
VGetter(call.response().val());
}
void RpcClient::PullKV(const std::string& k, std::string* v) {
PullKV(k, [&](const std::string& i) { *v = i; });
}
void RpcClient::PullKV(const std::string& k, PbMessage* msg) {
PullKV(k, [&](const std::string& i) { msg->ParseFromString(i); });
}
void RpcClient::PullMasterKV(const std::string& k, PbMessage* msg) {
PullMasterKV(k, [&](const std::string& i) { msg->ParseFromString(i); });
}
void RpcClient::PushActEvent(const ActEvent& act_event) {
ClientCall<CtrlMethod::kPushActEvent> call;
*(call.mut_request()->mutable_act_event()) = act_event;
call(GetMasterStub());
}
void RpcClient::Clear() {
ClientCall<CtrlMethod::kClear> call;
call(GetThisStub());
std::unique_lock<std::mutex> lck(done_names_mtx_);
done_names_.clear();
}
int32_t RpcClient::IncreaseCount(const std::string& k, int32_t v) {
ClientCall<CtrlMethod::kIncreaseCount> call;
call.mut_request()->set_key(k);
call.mut_request()->set_val(v);
call(GetResponsibleStub(k));
return call.response().val();
}
void RpcClient::EraseCount(const std::string& k) {
ClientCall<CtrlMethod::kEraseCount> call;
call.mut_request()->set_key(k);
call(GetResponsibleStub(k));
}
void RpcClient::LoadServer(const std::string& server_addr, CtrlService::Stub* stub) {
int32_t retry_idx = 0;
for (; retry_idx < max_retry_num; ++retry_idx) {
grpc::ClientContext client_ctx;
LoadServerRequest request;
request.set_addr(server_addr);
LoadServerResponse response;
grpc::Status st = stub->CallMethod<CtrlMethod::kLoadServer>(&client_ctx, request, &response);
if (st.error_code() == grpc::StatusCode::OK) {
LOG(INFO) << "LoadServer " << server_addr << " Successful at " << retry_idx << " times";
break;
} else if (st.error_code() == grpc::StatusCode::UNAVAILABLE) {
LOG(INFO) << "LoadServer " << server_addr << " Failed at " << retry_idx << " times"
<< " error_code " << st.error_code() << " error_message " << st.error_message();
std::this_thread::sleep_for(std::chrono::seconds(sleep_seconds));
continue;
} else {
LOG(FATAL) << st.error_message();
}
}
CHECK_LT(retry_idx, max_retry_num);
}
CtrlService::Stub* RpcClient::GetThisStub() {
// TODO(hanbinbin): depend rank_id of Global<CtrlConf>
return stubs_[Global<MachineCtx>::Get()->this_machine_id()].get();
}
CtrlService::Stub* RpcClient::GetResponsibleStub(const std::string& key) {
// TODO(hanbinbin): depend world_size of Global<CtrlConf>
int64_t machine_id = (std::hash<std::string>{}(key)) % Global<EnvDesc>::Get()->TotalMachineNum();
return stubs_[machine_id].get();
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_CONTROL_RPC_CLIENT_H_
#define ONEFLOW_CORE_CONTROL_RPC_CLIENT_H_
#include "oneflow/core/actor/actor_message.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/control/ctrl_service.h"
#include "oneflow/core/job/global_for.h"
namespace oneflow {
class RpcClient {
public:
OF_DISALLOW_COPY_AND_MOVE(RpcClient);
virtual ~RpcClient();
void Barrier(const std::string& barrier_name);
void Barrier(const std::string& barrier_name, int32_t barrier_num);
TryLockResult TryLock(const std::string& name);
void NotifyDone(const std::string& name);
void WaitUntilDone(const std::string& name);
void PushKV(const std::string& k, std::function<void(std::string*)> VSetter);
void PushKV(const std::string& k, const std::string& v);
void PushKV(const std::string& k, const PbMessage& msg);
void PushMasterKV(const std::string& k, const PbMessage& msg);
template<typename T>
typename std::enable_if<std::is_arithmetic<T>::value>::type PushKVT(const std::string& k, T v) {
PushKV(k, std::to_string(v));
}
void ClearKV(const std::string& k);
void ClearMasterKV(const std::string& k);
void PullKV(const std::string& k, std::function<void(const std::string&)> VGetter);
void PullKV(const std::string& k, std::string* v);
void PullKV(const std::string& k, PbMessage* msg);
void PullMasterKV(const std::string& k, PbMessage* msg);
template<typename T>
typename std::enable_if<std::is_arithmetic<T>::value>::type PullKVT(const std::string& k, T* v) {
std::string v_str;
PullKV(k, &v_str);
*v = oneflow_cast<T>(v_str);
}
void PushActEvent(const ActEvent&);
void Clear();
int32_t IncreaseCount(const std::string& k, int32_t v);
int32_t IncreaseCount(const std::string& k) { return IncreaseCount(k, 1); }
void EraseCount(const std::string& k);
protected:
RpcClient() = default;
void LoadServer(const std::string& server_addr, CtrlService::Stub* stub);
void PushMasterKV(const std::string& k, std::function<void(std::string*)> VSetter);
void PullMasterKV(const std::string& k, std::function<void(const std::string&)> VGetter);
CtrlService::Stub* GetMasterStub() { return stubs_[0].get(); }
CtrlService::Stub* GetThisStub();
CtrlService::Stub* GetResponsibleStub(const std::string& key);
std::vector<std::unique_ptr<CtrlService::Stub>> stubs_;
std::mutex done_names_mtx_;
HashSet<std::string> done_names_;
bool need_heartbeat_thread_stop_;
std::mutex need_heartbeat_thread_stop_mtx_;
std::thread heartbeat_thread_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_CONTROL_RPC_CLIENT_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/control/rpc_server.h"
#include "oneflow/core/actor/act_event_logger.h"
#include "oneflow/core/job/profiler.h"
#include "oneflow/core/job/env_desc.h"
#include "grpc/grpc_posix.h"
namespace oneflow {
RpcServer::~RpcServer() {
// NOTE(chengcheng): This enqueues a special event (with a null tag) that causes
// the completion queue to be shut down on the polling thread.
grpc::Alarm alarm(cq_.get(), gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
loop_thread_.join();
}
void RpcServer::HandleRpcs() {
EnqueueRequests();
void* tag = nullptr;
bool ok = false;
// NOTE(chengcheng): The is_shutdown bool flag make sure that 'ok = false' occurs ONLY after
// cq_->Shutdown() for security check.
bool is_shutdown = false;
// NOTE(chengcheng): The final end is that cq_->Next() get false and cq_ is empty with no item.
while (cq_->Next(&tag, &ok)) {
auto call = static_cast<CtrlCallIf*>(tag);
if (!ok) {
// NOTE(chengcheng): After call grpc_server_->Shutdown() and cq_->Shutdown(),
// there will trigger some cancel tag items on each RPC. And cq_->Next() can get these tag
// with ok = false. Then delete the tag with CtrlCallIf pointer for recovery.
CHECK(is_shutdown);
CHECK(call);
delete call;
continue;
}
if (call) {
call->Process();
} else {
// NOTE(chengcheng): A null `call` indicates that this is the shutdown alarm.
CHECK(!is_shutdown);
is_shutdown = true;
grpc_server_->Shutdown();
cq_->Shutdown();
// NOTE(chengcheng): You CANNOT use code 'break;' in this block because that
// there still be items in the cq_.
// 'break;'
}
}
}
void RpcServer::Init() {
Add([this](CtrlCall<CtrlMethod::kLoadServer>* call) { OnLoadServer(call); });
Add([this](CtrlCall<CtrlMethod::kBarrier>* call) {
const std::string& barrier_name = call->request().name();
int32_t barrier_num = call->request().num();
auto barrier_call_it = barrier_calls_.find(barrier_name);
if (barrier_call_it == barrier_calls_.end()) {
barrier_call_it =
barrier_calls_
.emplace(barrier_name, std::make_pair(std::list<CtrlCallIf*>{}, barrier_num))
.first;
}
CHECK_EQ(barrier_num, barrier_call_it->second.second);
barrier_call_it->second.first.push_back(call);
if (barrier_call_it->second.first.size() == barrier_call_it->second.second) {
for (CtrlCallIf* pending_call : barrier_call_it->second.first) {
pending_call->SendResponse();
}
barrier_calls_.erase(barrier_call_it);
}
EnqueueRequest<CtrlMethod::kBarrier>();
});
Add([this](CtrlCall<CtrlMethod::kTryLock>* call) {
const std::string& lock_name = call->request().name();
auto name2lock_status_it = name2lock_status_.find(lock_name);
if (name2lock_status_it == name2lock_status_.end()) {
call->mut_response()->set_result(TryLockResult::kLocked);
auto waiting_until_done_calls = new std::list<CtrlCallIf*>;
CHECK(name2lock_status_.emplace(lock_name, waiting_until_done_calls).second);
} else {
if (name2lock_status_it->second) {
call->mut_response()->set_result(TryLockResult::kDoing);
} else {
call->mut_response()->set_result(TryLockResult::kDone);
}
}
call->SendResponse();
EnqueueRequest<CtrlMethod::kTryLock>();
});
Add([this](CtrlCall<CtrlMethod::kNotifyDone>* call) {
const std::string& lock_name = call->request().name();
auto name2lock_status_it = name2lock_status_.find(lock_name);
auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(name2lock_status_it->second);
for (CtrlCallIf* waiting_call : *waiting_calls) { waiting_call->SendResponse(); }
delete waiting_calls;
name2lock_status_it->second = nullptr;
call->SendResponse();
EnqueueRequest<CtrlMethod::kNotifyDone>();
});
Add([this](CtrlCall<CtrlMethod::kWaitUntilDone>* call) {
const std::string& lock_name = call->request().name();
void* lock_status = name2lock_status_.at(lock_name);
if (lock_status) {
auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(lock_status);
waiting_calls->push_back(call);
} else {
call->SendResponse();
}
EnqueueRequest<CtrlMethod::kWaitUntilDone>();
});
Add([this](CtrlCall<CtrlMethod::kPushKV>* call) {
const std::string& k = call->request().key();
const std::string& v = call->request().val();
CHECK(kv_.emplace(k, v).second);
auto pending_kv_calls_it = pending_kv_calls_.find(k);
if (pending_kv_calls_it != pending_kv_calls_.end()) {
for (auto pending_call : pending_kv_calls_it->second) {
pending_call->mut_response()->set_val(v);
pending_call->SendResponse();
}
pending_kv_calls_.erase(pending_kv_calls_it);
}
call->SendResponse();
EnqueueRequest<CtrlMethod::kPushKV>();
});
Add([this](CtrlCall<CtrlMethod::kClearKV>* call) {
const std::string& k = call->request().key();
CHECK_EQ(kv_.erase(k), 1);
CHECK(pending_kv_calls_.find(k) == pending_kv_calls_.end());
call->SendResponse();
EnqueueRequest<CtrlMethod::kClearKV>();
});
Add([this](CtrlCall<CtrlMethod::kPullKV>* call) {
const std::string& k = call->request().key();
auto kv_it = kv_.find(k);
if (kv_it != kv_.end()) {
call->mut_response()->set_val(kv_it->second);
call->SendResponse();
} else {
pending_kv_calls_[k].push_back(call);
}
EnqueueRequest<CtrlMethod::kPullKV>();
});
Add([this](CtrlCall<CtrlMethod::kPushActEvent>* call) {
ActEvent act_event = call->request().act_event();
call->SendResponse();
Global<ActEventLogger>::Get()->PrintActEventToLogDir(act_event);
EnqueueRequest<CtrlMethod::kPushActEvent>();
});
Add([this](CtrlCall<CtrlMethod::kClear>* call) {
name2lock_status_.clear();
kv_.clear();
CHECK(pending_kv_calls_.empty()) << "size(): " << pending_kv_calls_.size()
<< ", begin()->key: " << pending_kv_calls_.begin()->first;
call->SendResponse();
EnqueueRequest<CtrlMethod::kClear>();
});
Add([this](CtrlCall<CtrlMethod::kIncreaseCount>* call) {
int32_t& count = count_[call->request().key()];
count += call->request().val();
call->mut_response()->set_val(count);
call->SendResponse();
EnqueueRequest<CtrlMethod::kIncreaseCount>();
});
Add([this](CtrlCall<CtrlMethod::kEraseCount>* call) {
CHECK_EQ(count_.erase(call->request().key()), 1);
call->SendResponse();
EnqueueRequest<CtrlMethod::kEraseCount>();
});
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_CONTROL_RPC_SERVER_H_
#define ONEFLOW_CORE_CONTROL_RPC_SERVER_H_
#include <grpc++/alarm.h>
#include <grpc++/server_builder.h>
#include "oneflow/core/control/ctrl_call.h"
#include "oneflow/core/common/function_traits.h"
namespace oneflow {
namespace {
template<size_t... Idx>
static std::tuple<std::function<void(CtrlCall<(CtrlMethod)Idx>*)>...> GetHandlerTuple(
std::index_sequence<Idx...>) {
return {};
}
} // namespace
class RpcServer {
public:
OF_DISALLOW_COPY_AND_MOVE(RpcServer);
virtual ~RpcServer();
protected:
RpcServer() {}
void HandleRpcs();
void Init();
void EnqueueRequests() {
for_each_i(handlers_, helper{this}, std::make_index_sequence<kCtrlMethodNum>{});
}
template<CtrlMethod kMethod>
void EnqueueRequest() {
constexpr const size_t I = (size_t)kMethod;
auto handler = std::get<I>(handlers_);
auto call = new CtrlCall<(CtrlMethod)I>();
call->set_request_handler(std::bind(handler, call));
grpc_service_->RequestAsyncUnary(I, call->mut_server_ctx(), call->mut_request(),
call->mut_responder(), cq_.get(), cq_.get(), call);
}
template<typename F>
void Add(F f) {
using args_type = typename function_traits<F>::args_type;
using arg_type =
typename std::remove_pointer<typename std::tuple_element<0, args_type>::type>::type;
std::get<arg_type::value>(handlers_) = std::move(f);
}
virtual void OnLoadServer(CtrlCall<CtrlMethod::kLoadServer>* call) = 0;
struct helper {
helper(RpcServer* s) : s_(s) {}
template<typename T, typename V>
void operator()(const T& t, V) {
s_->EnqueueRequest<(CtrlMethod)V::value>();
}
RpcServer* s_;
};
using HandlerTuple = decltype(GetHandlerTuple(std::make_index_sequence<kCtrlMethodNum>{}));
HandlerTuple handlers_;
std::unique_ptr<CtrlService::AsyncService> grpc_service_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
std::unique_ptr<grpc::Server> grpc_server_;
std::thread loop_thread_;
// Barrier
HashMap<std::string, std::pair<std::list<CtrlCallIf*>, int32_t>> barrier_calls_;
// TryLock, NotifyDone, WaitUntilDone
HashMap<std::string, void*> name2lock_status_;
// PushKV, ClearKV, PullKV
HashMap<std::string, std::string> kv_;
HashMap<std::string, std::list<CtrlCall<CtrlMethod::kPullKV>*>> pending_kv_calls_;
// IncreaseCount, EraseCount
HashMap<std::string, int32_t> count_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_CONTROL_RPC_SERVER_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册