diff --git a/oneflow/core/comm_network/comm_network.cpp b/oneflow/core/comm_network/comm_network.cpp index 9ed0fa803bf208cf1f4c131b3894725f7373b5cf..ca186e39e49098714587e68ee277cc03b51ecccf 100644 --- a/oneflow/core/comm_network/comm_network.cpp +++ b/oneflow/core/comm_network/comm_network.cpp @@ -3,7 +3,7 @@ namespace oneflow { CommNet::~CommNet() { - ready_cbs_.CloseSendEnd(); + ready_cbs_.Close(); ready_cb_poller_.join(); } @@ -79,7 +79,7 @@ CommNet::CommNet(const Plan& plan) { ready_cb_poller_ = std::thread([this]() { std::function cb; - while (ready_cbs_.Receive(&cb) == 0) { cb(); } + while (ready_cbs_.Receive(&cb) == kChannelStatusSuccess) { cb(); } }); } diff --git a/oneflow/core/common/channel.h b/oneflow/core/common/channel.h index b00065a449d16d9b7f02a23c28f2d69d25f5ed69..8a36a7590ab5592a65c267ac04a91929c9edd7eb 100644 --- a/oneflow/core/common/channel.h +++ b/oneflow/core/common/channel.h @@ -5,69 +5,49 @@ namespace oneflow { +enum ChannelStatus { kChannelStatusSuccess = 0, kChannelStatusErrorClosed }; + template class Channel final { public: OF_DISALLOW_COPY_AND_MOVE(Channel); - Channel() : is_send_closed_(false), is_receive_closed_(false) {} + Channel() : is_closed_(false) {} ~Channel() = default; - // return code - // 0 : success send item - // -1 : fail (send end has been closed) - int Send(const T& item); - - // If the channel is empty, the thread calling Receive() would be blocked. - // return value - // 0: success -- if successfully get the item ref in val_ - // -1: fail -- when the channel tell the owner thread should exit - int Receive(T* item); - - // close the channel's send end, the thread can't send item to the channel - void CloseSendEnd(); - - // close the channel's receive end , the thread can't receive item from - // channel - void CloseReceiveEnd(); + ChannelStatus Send(const T& item); + ChannelStatus Receive(T* item); + void Close(); private: - std::queue val_; + std::queue queue_; mutable std::mutex mutex_; - bool is_send_closed_; - bool is_receive_closed_; + bool is_closed_; std::condition_variable cond_; }; template -int Channel::Send(const T& item) { +ChannelStatus Channel::Send(const T& item) { std::unique_lock lock(mutex_); - if (is_send_closed_) { return -1; } - val_.push(item); + if (is_closed_) { return kChannelStatusErrorClosed; } + queue_.push(item); cond_.notify_one(); - return 0; + return kChannelStatusSuccess; } template -int Channel::Receive(T* item) { +ChannelStatus Channel::Receive(T* item) { std::unique_lock lock(mutex_); - cond_.wait(lock, [this]() { return !val_.empty() || is_receive_closed_ || is_send_closed_; }); - if (val_.empty() || is_receive_closed_) { return -1; } - *item = val_.front(); - val_.pop(); - return 0; -} - -template -void Channel::CloseSendEnd() { - std::unique_lock lock(mutex_); - is_send_closed_ = true; - cond_.notify_all(); + cond_.wait(lock, [this]() { return (!queue_.empty()) || is_closed_; }); + if (queue_.empty()) { return kChannelStatusErrorClosed; } + *item = queue_.front(); + queue_.pop(); + return kChannelStatusSuccess; } template -void Channel::CloseReceiveEnd() { +void Channel::Close() { std::unique_lock lock(mutex_); - is_receive_closed_ = true; + is_closed_ = true; cond_.notify_all(); } diff --git a/oneflow/core/common/channel_test.cpp b/oneflow/core/common/channel_test.cpp index f5ad210bbafc5a0133f799954b3788cea6e07712..6909ff3d3fb4b5e3819a5cea052f6608deefcefd 100644 --- a/oneflow/core/common/channel_test.cpp +++ b/oneflow/core/common/channel_test.cpp @@ -5,14 +5,14 @@ namespace oneflow { void CallFromSenderThread(Channel* channel, Range range) { for (int i = range.begin(); i < range.end(); ++i) { - if (channel->Send(i) == -1) { break; } + if (channel->Send(i) != kChannelStatusSuccess) { break; } } } void CallFromReceiverThread(std::vector* visit, Channel* channel) { int num = -1; int* num_ptr = # - while (channel->Receive(num_ptr) == 0) { ++visit->at(*num_ptr); } + while (channel->Receive(num_ptr) == kChannelStatusSuccess) { ++visit->at(*num_ptr); } } TEST(Channel, 30sender40receiver) { @@ -35,9 +35,8 @@ TEST(Channel, 30sender40receiver) { receivers.push_back(std::thread(CallFromReceiverThread, &visits[i], &channel)); } for (std::thread& this_thread : senders) { this_thread.join(); } - channel.CloseSendEnd(); + channel.Close(); for (std::thread& this_thread : receivers) { this_thread.join(); } - channel.CloseReceiveEnd(); for (int i = 0; i < range_num; ++i) { int visit_count = 0; for (int j = 0; j < receiver_num; j++) { visit_count += visits[j][i]; } diff --git a/oneflow/core/thread/gpu_thread.cpp b/oneflow/core/thread/gpu_thread.cpp index ec0aff487bb8cce89d0f3abc08739b7d17110fc3..8e36bffdaaeca2ee6d948171c83fb91d1a024eee 100644 --- a/oneflow/core/thread/gpu_thread.cpp +++ b/oneflow/core/thread/gpu_thread.cpp @@ -16,7 +16,7 @@ GpuThread::GpuThread(int64_t thrd_id, int64_t dev_id) { }); cb_event_poller_ = std::thread([this]() { CudaCBEvent cb_event; - while (cb_event_chan_.Receive(&cb_event) == 0) { + while (cb_event_chan_.Receive(&cb_event) == kChannelStatusSuccess) { CudaCheck(cudaEventSynchronize(cb_event.event)); cb_event.callback(); CudaCheck(cudaEventDestroy(cb_event.event)); @@ -25,8 +25,7 @@ GpuThread::GpuThread(int64_t thrd_id, int64_t dev_id) { } GpuThread::~GpuThread() { - cb_event_chan_.CloseSendEnd(); - cb_event_chan_.CloseReceiveEnd(); + cb_event_chan_.Close(); cb_event_poller_.join(); } diff --git a/oneflow/core/thread/thread.cpp b/oneflow/core/thread/thread.cpp index b9119d4b028b7966f6223db10721741f8e546278..93918097300d28f0e2d10f26b62971e818a93eaa 100644 --- a/oneflow/core/thread/thread.cpp +++ b/oneflow/core/thread/thread.cpp @@ -5,8 +5,7 @@ namespace oneflow { Thread::~Thread() { actor_thread_.join(); CHECK(id2task_.empty()); - msg_channel_.CloseSendEnd(); - msg_channel_.CloseReceiveEnd(); + msg_channel_.Close(); } void Thread::AddTask(const TaskProto& task) { @@ -17,7 +16,7 @@ void Thread::AddTask(const TaskProto& task) { void Thread::PollMsgChannel(const ThreadCtx& thread_ctx) { ActorMsg msg; while (true) { - CHECK_EQ(msg_channel_.Receive(&msg), 0); + CHECK_EQ(msg_channel_.Receive(&msg), kChannelStatusSuccess); if (msg.msg_type() == ActorMsgType::kCmdMsg) { if (msg.actor_cmd() == ActorCmd::kStopThread) { CHECK(id2actor_ptr_.empty()); diff --git a/oneflow/core/thread/thread_pool.cpp b/oneflow/core/thread/thread_pool.cpp index 5a124f830ebaada0637daa4cf293d34fd623daf2..b6391233537c276fe8988e4f0fed001b3a31f079 100644 --- a/oneflow/core/thread/thread_pool.cpp +++ b/oneflow/core/thread/thread_pool.cpp @@ -8,15 +8,14 @@ ThreadPool::ThreadPool(int32_t thread_num) Channel>* chan = &(work_chans_.at(i)); threads_[i] = std::thread([chan]() { std::function work; - while (chan->Receive(&work) == 0) { work(); } + while (chan->Receive(&work) == kChannelStatusSuccess) { work(); } }); } } ThreadPool::~ThreadPool() { FOR_RANGE(int32_t, i, 0, work_chans_.size()) { - work_chans_.at(i).CloseSendEnd(); - work_chans_.at(i).CloseReceiveEnd(); + work_chans_.at(i).Close(); threads_.at(i).join(); } }