提交 93034430 编写于 作者: S Shiyuan Shang-Guan

rm lck


Former-commit-id: 3655dbad6d2d61ae84e0e484a3a5a2fed2fc66b2
上级 39ffdcb8
......@@ -7,9 +7,9 @@ CommNet::~CommNet() {
ready_cb_poller_.join();
}
void* CommNet::NewActorReadId() { return new ActorReadContext; }
void* CommNet::NewActorReadId() const { return new ActorReadContext; }
void CommNet::DeleteActorReadId(void* actor_read_id) {
void CommNet::DeleteActorReadId(void* actor_read_id) const {
auto actor_read_ctx = static_cast<ActorReadContext*>(actor_read_id);
CHECK(actor_read_ctx->waiting_list.empty());
delete actor_read_ctx;
......
......@@ -29,20 +29,21 @@ class CommNet {
virtual void RegisterMemoryDone() = 0;
// Stream
void* NewActorReadId();
void DeleteActorReadId(void* actor_read_id);
void* NewActorReadId() const;
void DeleteActorReadId(void* actor_read_id) const;
void Read(void* actor_read_id, int64_t src_machine_id, void* src_token, void* dst_token);
void AddReadCallBack(void* actor_read_id, std::function<void()> callback);
void ReadDone(void* read_id);
//
virtual void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) = 0;
virtual void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) const = 0;
protected:
CommNet(const Plan& plan);
virtual void DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) = 0;
const HashSet<int64_t>& peer_machine_id() { return peer_machine_id_; }
virtual void DoRead(void* read_id, int64_t src_machine_id, void* src_token,
void* dst_token) const = 0;
const HashSet<int64_t>& peer_machine_id() const { return peer_machine_id_; }
Channel<std::function<void()>> ready_cbs_;
......@@ -84,8 +85,8 @@ class CommNetIf : public CommNet {
}
protected:
virtual MemDescType* NewMemDesc(void* ptr, size_t byte_size) = 0;
const HashSet<MemDescType*>& mem_descs() { return mem_descs_; }
virtual MemDescType* NewMemDesc(void* ptr, size_t byte_size) const = 0;
const HashSet<MemDescType*>& mem_descs() const { return mem_descs_; }
private:
std::mutex mem_descs_mtx_;
......
......@@ -64,10 +64,15 @@ EpollCommNet::~EpollCommNet() {
}
void EpollCommNet::RegisterMemoryDone() {
// do nothing
for (void* dst_token : mem_descs()) {
CHECK(
dst_token2part_done_cnt_
.emplace(dst_token, std::shared_ptr<std::atomic<int32_t>>(new std::atomic<int32_t>(0)))
.second);
}
}
void EpollCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& actor_msg) {
void EpollCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& actor_msg) const {
SocketMsg msg;
msg.msg_type = SocketMsgType::kActor;
msg.actor_msg = actor_msg;
......@@ -75,7 +80,7 @@ void EpollCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& actor_ms
}
void EpollCommNet::RequestRead(int64_t dst_machine_id, void* src_token, void* dst_token,
void* read_id) {
void* read_id) const {
int32_t total_byte_size = static_cast<const SocketMemDesc*>(src_token)->byte_size;
int32_t offset = (total_byte_size + epoll_conf_.link_num() - 1) / epoll_conf_.link_num();
offset = RoundUp(offset, epoll_conf_.msg_segment_kbyte() * 1024);
......@@ -96,7 +101,7 @@ void EpollCommNet::RequestRead(int64_t dst_machine_id, void* src_token, void* ds
CHECK_EQ(total_byte_size, 0);
}
SocketMemDesc* EpollCommNet::NewMemDesc(void* ptr, size_t byte_size) {
SocketMemDesc* EpollCommNet::NewMemDesc(void* ptr, size_t byte_size) const {
SocketMemDesc* mem_desc = new SocketMemDesc;
mem_desc->mem_ptr = ptr;
mem_desc->byte_size = byte_size;
......@@ -185,29 +190,27 @@ void EpollCommNet::InitSockets() {
}
}
SocketHelper* EpollCommNet::GetSocketHelper(int64_t machine_id, int32_t link_index) {
SocketHelper* EpollCommNet::GetSocketHelper(int64_t machine_id, int32_t link_index) const {
int32_t sockfd = machine_id2sockfds_.at(machine_id * epoll_conf_.link_num() + link_index);
return sockfd2helper_.at(sockfd);
}
void EpollCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) {
CHECK(read_id2part_done_cnt_.emplace(read_id, 0).second);
void EpollCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token,
void* dst_token) const {
SocketMsg msg;
msg.msg_type = SocketMsgType::kRequestWrite;
msg.request_write_msg.src_token = src_token;
msg.request_write_msg.dst_machine_id = Global<MachineCtx>::Get()->this_machine_id();
msg.request_write_msg.dst_token = dst_token;
msg.request_write_msg.read_id = read_id;
*(dst_token2part_done_cnt_.at(dst_token)) = 0;
GetSocketHelper(src_machine_id, 0)->AsyncWrite(msg);
}
void EpollCommNet::PartReadDone(void* read_id, int32_t part_num) {
std::unique_lock<std::mutex> lck(part_done_cnt_mtx_);
int32_t& part_read_done_cnt = read_id2part_done_cnt_.at(read_id);
part_read_done_cnt++;
if (part_read_done_cnt == part_num) {
void EpollCommNet::PartReadDone(void* read_id, void* dst_token, int32_t part_num) {
if (dst_token2part_done_cnt_.at(dst_token)->fetch_add(1, std::memory_order_relaxed)
== (part_num - 1)) {
ReadDone(read_id);
read_id2part_done_cnt_.erase(read_id);
}
}
......
......@@ -18,24 +18,24 @@ class EpollCommNet final : public CommNetIf<SocketMemDesc> {
void RegisterMemoryDone() override;
void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override;
void RequestRead(int64_t dst_machine_id, void* src_token, void* dst_token, void* read_id);
void PartReadDone(void* read_id, int32_t part_num);
void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) const override;
void RequestRead(int64_t dst_machine_id, void* src_token, void* dst_token, void* read_id) const;
void PartReadDone(void* read_id, void* dst_token, int32_t part_num);
private:
SocketMemDesc* NewMemDesc(void* ptr, size_t byte_size) override;
SocketMemDesc* NewMemDesc(void* ptr, size_t byte_size) const override;
EpollCommNet(const Plan& plan);
void InitSockets();
SocketHelper* GetSocketHelper(int64_t machine_id, int32_t link_index);
void DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) override;
SocketHelper* GetSocketHelper(int64_t machine_id, int32_t link_index) const;
void DoRead(void* read_id, int64_t src_machine_id, void* src_token,
void* dst_token) const override;
const EpollConf& epoll_conf_;
std::vector<IOEventPoller*> pollers_;
std::vector<int32_t> machine_id2sockfds_;
HashMap<int, SocketHelper*> sockfd2helper_;
std::mutex part_done_cnt_mtx_;
HashMap<void*, int32_t> read_id2part_done_cnt_;
HashMap<void*, std::shared_ptr<std::atomic<int32_t>>> dst_token2part_done_cnt_;
};
template<>
......
......@@ -64,6 +64,7 @@ void SocketReadHelper::SetStatusWhenMsgHeadDone() {
void SocketReadHelper::SetStatusWhenMsgBodyDone() {
if (cur_msg_.msg_type == SocketMsgType::kRequestRead) {
Global<EpollCommNet>::Get()->PartReadDone(cur_msg_.request_read_msg.read_id,
cur_msg_.request_read_msg.dst_token,
cur_msg_.request_read_msg.part_num);
}
SwitchToMsgHeadReadHandle();
......
......@@ -49,7 +49,7 @@ void IBVerbsCommNet::RegisterMemoryDone() {
Global<CtrlClient>::Get()->ClearKV(GenTokensMsgKey(this_machine_id));
}
void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) {
void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) const {
qp_vec_.at(dst_machine_id)->PostSendRequest(msg);
}
......@@ -101,7 +101,7 @@ IBVerbsCommNet::IBVerbsCommNet(const Plan& plan)
}
void IBVerbsCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token,
void* dst_token) {
void* dst_token) const {
qp_vec_.at(src_machine_id)
->PostReadRequest(token2mem_desc_.at(src_machine_id).at(src_token),
*static_cast<const IBVerbsMemDesc*>(dst_token), read_id);
......
......@@ -23,15 +23,16 @@ class IBVerbsCommNet final : public CommNetIf<IBVerbsMemDesc> {
void RegisterMemoryDone() override;
void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override;
void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) const override;
private:
IBVerbsMemDesc* NewMemDesc(void* ptr, size_t byte_size) override {
IBVerbsMemDesc* NewMemDesc(void* ptr, size_t byte_size) const override {
return new IBVerbsMemDesc(pd_, ptr, byte_size);
}
IBVerbsCommNet(const Plan&);
void DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) override;
void DoRead(void* read_id, int64_t src_machine_id, void* src_token,
void* dst_token) const override;
void PollCQ();
static const int32_t max_poll_wc_num_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册