提交 9072ef48 编写于 作者: Y Yi Zhu 提交者: Will Zhang

IBVerbsCommNet and EndpointManager (#479)

* add endpoint_manager & rdma_comm_network

* refine code

* refine code

* fix macro guardian bug

* refine endpoint_manager && ibverbs_comm_net

* merge master

* refine code

* replace ptr with normal var
上级 39cc7788
#include "oneflow/core/comm_network/ibverbs/endpoint_manager.h"
#include "oneflow/core/comm_network/comm_network.h"
#include "oneflow/core/actor/actor_message_bus.h"
#if defined(WITH_RDMA) && defined(PLATFORM_POSIX)
namespace oneflow {
namespace {
std::string GenConnInfoKey(int64_t src_machine_id, int64_t dst_machine_id) {
return "IBVerbsConnInfo/" + std::to_string(src_machine_id) + " "
+ std::to_string(dst_machine_id);
}
} // namespace
EndpointManager::EndpointManager() {
// Init Adapter
ibv_device** device_list = ibv_get_device_list(NULL);
ibv_device* device = device_list[0];
context_ = ibv_open_device(device);
CHECK(context_);
pd_ = ibv_alloc_pd(context_);
CHECK(pd_);
// Init env
send_cq_ = ibv_create_cq(context_, 10, NULL, NULL, 0); // cqe
CHECK(send_cq_);
recv_cq_ = ibv_create_cq(context_, 10, NULL, NULL, 0); // cqe
CHECK(recv_cq_);
ibv_free_device_list(device_list);
InitRdma();
Start();
}
EndpointManager::~EndpointManager() {
Stop();
for (auto& pair : send_msg2mem_desc_) {
delete pair.first;
delete pair.second;
}
for (auto& pair : recv_msg2mem_desc_) {
delete pair.first;
delete pair.second;
}
for (auto& pair : connection_pool_) { delete pair.second; }
if (send_cq_ != nullptr) { CHECK_EQ(ibv_destroy_cq(send_cq_), 0); }
if (recv_cq_ != nullptr) { CHECK_EQ(ibv_destroy_cq(recv_cq_), 0); }
if (pd_ != nullptr) { CHECK_EQ(ibv_dealloc_pd(pd_), 0); }
if (context_ != nullptr) { CHECK_EQ(ibv_close_device(context_), 0); }
}
void EndpointManager::InitRdma() {
int64_t total_machine_num = JobDesc::Singleton()->TotalMachineNum();
int64_t this_machine_id = MachineCtx::Singleton()->this_machine_id();
FOR_RANGE(int64_t, peer_machine_id, 0, total_machine_num) {
if (peer_machine_id == this_machine_id) { continue; }
IBVerbsConnection* conn = NewIBVerbsConnection();
connection_pool_.emplace(peer_machine_id, conn);
CtrlClient::Singleton()->PushKV(
GenConnInfoKey(this_machine_id, peer_machine_id),
conn->mut_this_machine_conn_info());
}
OF_BARRIER();
FOR_RANGE(int64_t, peer_machine_id, 0, total_machine_num) {
if (peer_machine_id == this_machine_id) { continue; }
IBVerbsConnection* conn = connection_pool_[peer_machine_id];
CtrlClient::Singleton()->PullKV(
GenConnInfoKey(peer_machine_id, this_machine_id),
conn->mut_peer_machine_conn_info_ptr());
for (size_t i = 0; i != kPrePostRecvNum; ++i) {
ActorMsg* actor_msg = new ActorMsg;
auto ibverbs_mem_desc = NewIBVerbsMemDesc(actor_msg, sizeof(ActorMsg));
recv_msg2conn_ptr_.emplace(actor_msg, conn);
recv_msg2mem_desc_.emplace(actor_msg, ibverbs_mem_desc);
conn->PostRecvRequest(actor_msg, ibverbs_mem_desc);
}
conn->CompleteConnection();
}
OF_BARRIER();
}
IBVerbsMemDesc* EndpointManager::NewIBVerbsMemDesc(void* mem_ptr,
size_t byte_size) {
return new IBVerbsMemDesc(pd_, mem_ptr, byte_size);
}
IBVerbsConnection* EndpointManager::NewIBVerbsConnection() {
IBVerbsConnection* conn = new IBVerbsConnection();
// Init queue pair
ibv_qp_init_attr qp_init_attr;
memset(&qp_init_attr, 0, sizeof(qp_init_attr));
qp_init_attr.qp_context = nullptr;
qp_init_attr.send_cq = send_cq_;
qp_init_attr.recv_cq = recv_cq_;
qp_init_attr.qp_type = IBV_QPT_RC;
qp_init_attr.srq = nullptr;
qp_init_attr.sq_sig_all = 1;
qp_init_attr.cap.max_send_wr = 10;
qp_init_attr.cap.max_recv_wr = 10;
qp_init_attr.cap.max_send_sge = 1;
qp_init_attr.cap.max_recv_sge = 1;
ibv_qp* qp_ptr = ibv_create_qp(pd_, &qp_init_attr);
CHECK(qp_ptr);
// Init connection info
ibv_port_attr attr;
CHECK_EQ(ibv_query_port(context_, (uint8_t)1, &attr), 0);
srand((unsigned)time(NULL));
conn->mut_this_machine_conn_info_ptr()->set_lid(attr.lid);
conn->mut_this_machine_conn_info_ptr()->set_qpn(qp_ptr->qp_num);
conn->mut_this_machine_conn_info_ptr()->set_psn(static_cast<uint32_t>(rand())
& 0xffffff);
union ibv_gid gid;
CHECK_EQ(ibv_query_gid(context_, (uint8_t)1, 0, &gid), 0);
conn->mut_this_machine_conn_info_ptr()->set_snp(gid.global.subnet_prefix);
conn->mut_this_machine_conn_info_ptr()->set_iid(gid.global.interface_id);
conn->set_ibv_mtu(attr.active_mtu);
conn->set_ibv_qp_ptr(qp_ptr);
return conn;
}
void EndpointManager::Read(void* read_ctx, int64_t src_machine_id,
IBVerbsMemDesc* local_mem_desc,
IBVerbsMemDescProto& remote_mem_desc_proto) {
auto iter = connection_pool_.find(src_machine_id);
CHECK(iter != connection_pool_.end());
IBVerbsConnection* conn = iter->second;
conn->PostReadRequest(read_ctx, local_mem_desc, remote_mem_desc_proto);
}
void EndpointManager::SendActorMsg(int64_t dst_machine_id,
const ActorMsg& msg) {
auto iter = connection_pool_.find(dst_machine_id);
CHECK(iter != connection_pool_.end());
IBVerbsConnection* conn = iter->second;
std::tuple<ActorMsg*, IBVerbsMemDesc*> allocate_ret = AllocateSendMsg();
ActorMsg* msg_ptr = std::get<0>(allocate_ret);
*msg_ptr = msg;
conn->PostSendRequest(msg_ptr, std::get<1>(allocate_ret));
}
void EndpointManager::Start() {
poll_state_ = true;
poll_thread_ = std::thread(&EndpointManager::PollLoop, this);
}
void EndpointManager::Stop() {
poll_state_ = false;
poll_thread_.join();
}
void EndpointManager::PollLoop() {
while (true) {
if (!poll_state_) { return; }
PollSendQueue();
PollRecvQueue();
}
}
void EndpointManager::PollSendQueue() {
ibv_wc wc;
int32_t len = ibv_poll_cq(send_cq_, 1, &wc);
if (len <= 0) { return; }
if (wc.status != IBV_WC_SUCCESS) {
LOG(FATAL) << "PollSendQueue Error Code: " << wc.status;
}
switch (wc.opcode) {
case IBV_WC_SEND: {
ReleaseSendMsg(reinterpret_cast<ActorMsg*>(wc.wr_id));
return;
}
case IBV_WC_RDMA_READ: {
CommNet::Singleton()->ReadDone(reinterpret_cast<void*>(wc.wr_id));
return;
}
default: return;
}
}
void EndpointManager::PollRecvQueue() {
ibv_wc wc;
int32_t len = ibv_poll_cq(recv_cq_, 1, &wc);
if (len <= 0) { return; }
if (wc.status != IBV_WC_SUCCESS) {
LOG(FATAL) << "PollRecvQueue Error Code: " << wc.status;
}
ActorMsg* msg = reinterpret_cast<ActorMsg*>(wc.wr_id);
ActorMsgBus::Singleton()->SendMsg(*msg);
CHECK(recv_msg2conn_ptr_.find(msg) != recv_msg2conn_ptr_.end());
IBVerbsConnection* conn = recv_msg2conn_ptr_.at(msg);
auto msg2mem_it = recv_msg2mem_desc_.find(msg);
conn->PostRecvRequest(msg, msg2mem_it->second);
}
std::tuple<ActorMsg*, IBVerbsMemDesc*> EndpointManager::AllocateSendMsg() {
std::unique_lock<std::mutex> lck(send_msg_pool_mutex_);
if (send_msg_pool_.empty()) {
ActorMsg* msg = new ActorMsg;
IBVerbsMemDesc* mem_desc = NewIBVerbsMemDesc(msg, sizeof(ActorMsg));
send_msg2mem_desc_.emplace(msg, mem_desc);
send_msg_pool_.push(msg);
}
ActorMsg* ret_msg = send_msg_pool_.front();
send_msg_pool_.pop();
auto send_msg2mem_desc_it = send_msg2mem_desc_.find(ret_msg);
CHECK(send_msg2mem_desc_it != send_msg2mem_desc_.end());
return std::make_tuple(ret_msg, send_msg2mem_desc_it->second);
}
void EndpointManager::ReleaseSendMsg(ActorMsg* msg) {
std::unique_lock<std::mutex> lck(send_msg_pool_mutex_);
send_msg_pool_.push(msg);
}
} // namespace oneflow
#endif // WITH_RDMA && PLATFORM_POSIX
#ifndef ONEFLOW_CORE_COMM_NETWORK_IBVERBS_ENDPOINT_MANAGER_H_
#define ONEFLOW_CORE_COMM_NETWORK_IBVERBS_ENDPOINT_MANAGER_H_
#include "oneflow/core/comm_network/ibverbs/ibverbs_connection.h"
#if defined(WITH_RDMA) && defined(PLATFORM_POSIX)
#include "oneflow/core/control/ctrl_client.h"
#include "oneflow/core/job/machine_context.h"
#include "oneflow/core/job/job_desc.h"
#include <netdb.h>
#include <arpa/inet.h>
namespace oneflow {
class EndpointManager {
public:
OF_DISALLOW_COPY_AND_MOVE(EndpointManager);
EndpointManager();
~EndpointManager();
IBVerbsMemDesc* NewIBVerbsMemDesc(void* mem_ptr, size_t byte_size);
IBVerbsConnection* NewIBVerbsConnection();
void Read(void* read_ctx, int64_t src_machine_id,
IBVerbsMemDesc* local_mem_desc,
IBVerbsMemDescProto& remote_mem_desc_proto);
void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg);
private:
void InitRdma();
void Start();
void Stop();
void PollLoop();
void PollSendQueue();
void PollRecvQueue();
std::tuple<ActorMsg*, IBVerbsMemDesc*> AllocateSendMsg();
void ReleaseSendMsg(ActorMsg* msg);
enum { kPrePostRecvNum = 15 }; // TODO
HashMap<ActorMsg*, IBVerbsConnection*> recv_msg2conn_ptr_;
HashMap<ActorMsg*, IBVerbsMemDesc*> recv_msg2mem_desc_;
HashMap<int64_t, IBVerbsConnection*> connection_pool_;
std::mutex send_msg_pool_mutex_;
std::queue<ActorMsg*> send_msg_pool_;
HashMap<ActorMsg*, IBVerbsMemDesc*> send_msg2mem_desc_;
std::thread poll_thread_;
bool poll_state_;
ibv_context* context_;
ibv_pd* pd_;
ibv_cq* send_cq_;
ibv_cq* recv_cq_;
};
} // namespace oneflow
#endif // WITH_RDMA && PLATFORM_POSIX
#endif // ONEFLOW_CORE_COMM_NETWORK_RDMA_ENDPOINT_MANAGER_H_
#include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h"
#include "oneflow/core/comm_network/ibverbs/ibverbs_tokens_message.pb.h"
#if defined(WITH_RDMA) && defined(PLATFORM_POSIX)
namespace oneflow {
namespace {
std::string GenTokensMsgKey(int64_t machine_id) {
return "IBVerbsTokensMsg/" + std::to_string(machine_id);
}
} // namespace
void IBVerbsCommNet::Init() {
CommNet::Singleton()->set_comm_network_ptr(new IBVerbsCommNet());
}
const void* IBVerbsCommNet::RegisterMemory(void* mem_ptr, size_t byte_size) {
IBVerbsMemDesc* ibverbs_mem_desc =
endpoint_manager_.NewIBVerbsMemDesc(mem_ptr, byte_size);
mem_desc_mgr_.RegisterMemDesc(ibverbs_mem_desc);
return ibverbs_mem_desc;
}
void IBVerbsCommNet::UnRegisterMemory(const void* token) {
mem_desc_mgr_.UnRegisterMemDesc();
}
void IBVerbsCommNet::RegisterMemoryDone() {
int64_t total_machine_num = JobDesc::Singleton()->TotalMachineNum();
int64_t this_machine_id = MachineCtx::Singleton()->this_machine_id();
IBVerbsTokensMsg this_machine_tokens_msg;
const std::list<IBVerbsMemDesc*> mem_descs = mem_desc_mgr_.mem_descs();
for (auto mem_desc : mem_descs) {
this_machine_tokens_msg.mutable_token2mem_desc_proto()->insert(
{reinterpret_cast<uint64_t>(mem_desc),
mem_desc->IBVerbsMemDescToProto()});
}
CtrlClient::Singleton()->PushKV(GenTokensMsgKey(this_machine_id),
this_machine_tokens_msg);
OF_BARRIER();
FOR_RANGE(int64_t, peer_machine_id, 0, total_machine_num) {
if (peer_machine_id == MachineCtx::Singleton()->this_machine_id()) {
continue;
}
IBVerbsTokensMsg peer_machine_tokens_msg;
CtrlClient::Singleton()->PullKV(GenTokensMsgKey(peer_machine_id),
&peer_machine_tokens_msg);
for (const auto& pair : peer_machine_tokens_msg.token2mem_desc_proto()) {
CHECK(token2mem_desc_proto_.insert({pair.first, pair.second}).second);
}
}
OF_BARRIER();
}
void* IBVerbsCommNet::Read(void* actor_read_id, int64_t src_machine_id,
const void* src_token, const void* dst_token) {
auto actor_read_ctx = static_cast<ActorReadContext*>(actor_read_id);
ReadContext* read_ctx = NewReadCtxInActorReadCtx(actor_read_ctx);
IBVerbsMemDescProto& remote_mem_desc_proto =
token2mem_desc_proto_[reinterpret_cast<uint64_t>(src_token)];
auto local_mem_desc = const_cast<IBVerbsMemDesc*>(
static_cast<const IBVerbsMemDesc*>(dst_token));
void* read_done_id =
new std::tuple<ActorReadContext*, ReadContext*>(actor_read_ctx, read_ctx);
endpoint_manager_.Read(read_done_id, src_machine_id, local_mem_desc,
remote_mem_desc_proto);
return read_ctx;
}
void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) {
endpoint_manager_.SendActorMsg(dst_machine_id, msg);
}
} // namespace oneflow
#endif // WITH_RDMA && PLATFORM_POSIX
#ifndef ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_COMM_NETWORK_H_
#define ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_COMM_NETWORK_H_
#include "oneflow/core/common/platform.h"
#include "oneflow/core/comm_network/comm_network.h"
#include "oneflow/core/comm_network/memory_desc_manager.h"
#include "oneflow/core/comm_network/ibverbs/endpoint_manager.h"
#if defined(WITH_RDMA) && defined(PLATFORM_POSIX)
namespace oneflow {
class IBVerbsCommNet final : public CommNet {
public:
OF_DISALLOW_COPY_AND_MOVE(IBVerbsCommNet);
~IBVerbsCommNet() = default;
static IBVerbsCommNet* Singleton() {
return static_cast<IBVerbsCommNet*>(CommNet::Singleton());
}
static void Init();
const void* RegisterMemory(void* mem_ptr, size_t byte_size) override;
void UnRegisterMemory(const void* token) override;
void RegisterMemoryDone() override;
void* Read(void* actor_read_id, int64_t src_machine_id, const void* src_token,
const void* dst_token) override;
void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override;
private:
IBVerbsCommNet() = default;
MemDescMgr<IBVerbsMemDesc> mem_desc_mgr_;
EndpointManager endpoint_manager_;
HashMap<uint64_t, IBVerbsMemDescProto> token2mem_desc_proto_;
};
} // namespace oneflow
#endif // WITH_RDMA && PLATFORM_POSIX
#endif // ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_COMM_NETWORK_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册