From 78132f334207ad67434f25ab800c261a18526daf Mon Sep 17 00:00:00 2001 From: willzhang4a58 Date: Sun, 15 Oct 2017 13:29:49 +0800 Subject: [PATCH] refine mutex in read callback --- .../epoll/epoll_data_comm_network.cpp | 43 ++++++++++--------- .../epoll/epoll_data_comm_network.h | 3 +- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/oneflow/core/comm_network/epoll/epoll_data_comm_network.cpp b/oneflow/core/comm_network/epoll/epoll_data_comm_network.cpp index 0b17b73e7d..20c17af3a4 100644 --- a/oneflow/core/comm_network/epoll/epoll_data_comm_network.cpp +++ b/oneflow/core/comm_network/epoll/epoll_data_comm_network.cpp @@ -117,6 +117,7 @@ void EpollDataCommNet::AddReadCallBack(void* actor_read_id, void* read_id, break; } else { actor_read_ctx->read_ctx_list.back()->cbl.push_back(callback); + return; } } while (0); callback(); @@ -125,7 +126,10 @@ void EpollDataCommNet::AddReadCallBack(void* actor_read_id, void* read_id, void EpollDataCommNet::AddReadCallBackDone(void* actor_read_id, void* read_id) { auto actor_read_ctx = static_cast(actor_read_id); ReadContext* read_ctx = static_cast(read_id); - IncreaseDoneCnt(actor_read_ctx, read_ctx); + if (IncreaseDoneCnt(read_ctx) == 2) { + FinishOneReadContext(actor_read_ctx, read_ctx); + delete read_ctx; + } } void EpollDataCommNet::ReadDone(void* read_done_id) { @@ -134,27 +138,26 @@ void EpollDataCommNet::ReadDone(void* read_done_id) { auto actor_read_ctx = std::get<0>(*parsed_read_done_id); auto read_ctx = std::get<1>(*parsed_read_done_id); delete parsed_read_done_id; - IncreaseDoneCnt(actor_read_ctx, read_ctx); -} - -void EpollDataCommNet::IncreaseDoneCnt(ActorReadContext* actor_read_ctx, - ReadContext* read_ctx) { - do { - std::unique_lock lck(read_ctx->done_cnt_mtx); - read_ctx->done_cnt += 1; - if (read_ctx->done_cnt == 2) { - break; - } else { - return; + if (IncreaseDoneCnt(read_ctx) == 2) { + { + std::unique_lock lck(actor_read_ctx->read_ctx_list_mtx); + FinishOneReadContext(actor_read_ctx, read_ctx); } - } while (0); - { - std::unique_lock lck(actor_read_ctx->read_ctx_list_mtx); - CHECK_EQ(actor_read_ctx->read_ctx_list.front(), read_ctx); - actor_read_ctx->read_ctx_list.pop_front(); - for (std::function& callback : read_ctx->cbl) { callback(); } + delete read_ctx; } - delete read_ctx; +} + +int8_t EpollDataCommNet::IncreaseDoneCnt(ReadContext* read_ctx) { + std::unique_lock lck(read_ctx->done_cnt_mtx); + read_ctx->done_cnt += 1; + return read_ctx->done_cnt; +} + +void EpollDataCommNet::FinishOneReadContext(ActorReadContext* actor_read_ctx, + ReadContext* read_ctx) { + CHECK_EQ(actor_read_ctx->read_ctx_list.front(), read_ctx); + actor_read_ctx->read_ctx_list.pop_front(); + for (std::function& callback : read_ctx->cbl) { callback(); } } void EpollDataCommNet::SendActorMsg(int64_t dst_machine_id, diff --git a/oneflow/core/comm_network/epoll/epoll_data_comm_network.h b/oneflow/core/comm_network/epoll/epoll_data_comm_network.h index 70176d0376..eb799bff14 100644 --- a/oneflow/core/comm_network/epoll/epoll_data_comm_network.h +++ b/oneflow/core/comm_network/epoll/epoll_data_comm_network.h @@ -47,7 +47,8 @@ class EpollDataCommNet final : public DataCommNet { std::list read_ctx_list; }; EpollDataCommNet(); - void IncreaseDoneCnt(ActorReadContext* actor_read_ctx, ReadContext* read_ctx); + int8_t IncreaseDoneCnt(ReadContext*); + void FinishOneReadContext(ActorReadContext*, ReadContext*); void InitSockets(); SocketHelper* GetSocketHelper(int64_t machine_id); -- GitLab