提交 6703ae6a 编写于 作者: S Shiyuan Shang-Guan

update


Former-commit-id: af8ca7ba701ef2eb4c9e003a79d09bc92b3f8243
上级 a5aaf582
......@@ -56,7 +56,7 @@ uint16_t PullPort(int64_t machine_id) {
EpollCommNet::~EpollCommNet() {
for (size_t i = 0; i < pollers_.size(); ++i) {
LOG(INFO) << "CommNet Thread " << i << " finish";
pollers_[i]->Stop();
pollers_.at(i)->Stop();
}
OF_BARRIER();
for (IOEventPoller* poller : pollers_) { delete poller; }
......@@ -71,9 +71,7 @@ void EpollCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& actor_ms
SocketMsg msg;
msg.msg_type = SocketMsgType::kActor;
msg.actor_msg = actor_msg;
int32_t link_i =
std::uniform_int_distribution<int32_t>(0, epoll_conf_.link_num() - 1)(random_gen_);
GetSocketHelper(dst_machine_id, link_i)->AsyncWrite(msg);
GetSocketHelper(dst_machine_id, 0)->AsyncWrite(msg);
}
void EpollCommNet::SendSocketMsg(int64_t dst_machine_id, const SocketMsg& total_msg) {
......@@ -81,7 +79,7 @@ void EpollCommNet::SendSocketMsg(int64_t dst_machine_id, const SocketMsg& total_
static_cast<const SocketMemDesc*>(total_msg.request_read_msg.src_token);
int32_t total_byte_size = src_mem_desc->byte_size;
int32_t offset = (total_byte_size + epoll_conf_.link_num() - 1) / epoll_conf_.link_num();
offset = RoundUp(offset, kCacheLineSize);
offset = RoundUp(offset, epoll_conf_.msg_segment_kbyte() * 1024);
int32_t part_num = (total_byte_size + offset - 1) / offset;
for (int32_t link_i = 0; link_i < part_num; ++link_i) {
int32_t byte_size = (total_byte_size > offset) ? (offset) : (total_byte_size);
......@@ -109,7 +107,7 @@ SocketMemDesc* EpollCommNet::NewMemDesc(void* ptr, size_t byte_size) {
EpollCommNet::EpollCommNet(const Plan& plan)
: CommNetIf(plan), epoll_conf_(Global<JobDesc>::Get()->epoll_conf()) {
pollers_.resize(Global<JobDesc>::Get()->CommNetWorkerNum(), nullptr);
for (size_t i = 0; i < pollers_.size(); ++i) { pollers_[i] = new IOEventPoller; }
for (size_t i = 0; i < pollers_.size(); ++i) { pollers_.at(i) = new IOEventPoller; }
InitSockets();
for (IOEventPoller* poller : pollers_) { poller->Start(); }
}
......@@ -122,7 +120,7 @@ void EpollCommNet::InitSockets() {
sockfd2helper_.clear();
size_t poller_idx = 0;
auto NewSocketHelper = [&](int32_t sockfd) {
IOEventPoller* poller = pollers_[poller_idx];
IOEventPoller* poller = pollers_.at(poller_idx);
poller_idx = (poller_idx + 1) % pollers_.size();
return new SocketHelper(sockfd, poller);
};
......@@ -160,7 +158,7 @@ void EpollCommNet::InitSockets() {
PCHECK(connect(sockfd, reinterpret_cast<sockaddr*>(&peer_sockaddr), sizeof(peer_sockaddr))
== 0);
CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second);
machine_id2sockfds_[peer_mchn_id * epoll_conf_.link_num() + link_i] = sockfd;
machine_id2sockfds_.at(peer_mchn_id * epoll_conf_.link_num() + link_i) = sockfd;
}
}
......@@ -173,7 +171,7 @@ void EpollCommNet::InitSockets() {
PCHECK(sockfd != -1);
CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second);
int64_t peer_mchn_id = GetMachineId(peer_sockaddr);
machine_id2sockfds_[peer_mchn_id * epoll_conf_.link_num() + link_i] = sockfd;
machine_id2sockfds_.at(peer_mchn_id * epoll_conf_.link_num() + link_i) = sockfd;
}
}
PCHECK(close(listen_sockfd) == 0);
......@@ -183,13 +181,13 @@ void EpollCommNet::InitSockets() {
FOR_RANGE(int64_t, machine_id, 0, total_machine_num) {
FOR_RANGE(int32_t, link_i, 0, epoll_conf_.link_num()) {
LOG(INFO) << "machine: " << machine_id << ", link index: " << link_i << ", sockfd: "
<< machine_id2sockfds_[machine_id * epoll_conf_.link_num() + link_i];
<< machine_id2sockfds_.at(machine_id * epoll_conf_.link_num() + link_i);
}
}
}
SocketHelper* EpollCommNet::GetSocketHelper(int64_t machine_id, int32_t link_index) {
int32_t sockfd = machine_id2sockfds_[machine_id * epoll_conf_.link_num() + link_index];
int32_t sockfd = machine_id2sockfds_.at(machine_id * epoll_conf_.link_num() + link_index);
return sockfd2helper_.at(sockfd);
}
......@@ -201,9 +199,7 @@ void EpollCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token
msg.request_write_msg.dst_machine_id = Global<MachineCtx>::Get()->this_machine_id();
msg.request_write_msg.dst_token = dst_token;
msg.request_write_msg.read_id = read_id;
int32_t link_i =
std::uniform_int_distribution<int32_t>(0, epoll_conf_.link_num() - 1)(random_gen_);
GetSocketHelper(src_machine_id, link_i)->AsyncWrite(msg);
GetSocketHelper(src_machine_id, 0)->AsyncWrite(msg);
}
void EpollCommNet::PartReadDone(void* read_id, int32_t part_num) {
......
......@@ -34,7 +34,6 @@ class EpollCommNet final : public CommNetIf<SocketMemDesc> {
std::vector<IOEventPoller*> pollers_;
std::vector<int32_t> machine_id2sockfds_;
HashMap<int, SocketHelper*> sockfd2helper_;
std::mt19937 random_gen_;
std::mutex part_done_cnt_mtx_;
HashMap<void*, int32_t> read_id2part_done_cnt_;
};
......
......@@ -177,7 +177,6 @@ inline double GetCurTime() {
const size_t kCudaAlignSize = 8;
const size_t kCudaMemAllocAlignSize = 256;
const size_t kCacheLineSize = 64;
inline size_t RoundUp(size_t n, size_t val) { return (n + val - 1) / val * val; }
size_t GetAvailableCpuMemSize();
......
......@@ -52,6 +52,7 @@ message ExperimentalRunConf {
message EpollConf {
optional int32 link_num = 1 [default = 5];
optional int32 msg_segment_kbyte = 2 [default = 256];
}
message IBVerbsConf {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册