提交 b012dc22 编写于 作者: J Juncheng 提交者: Jinhui Yuan

Dev refactor channel (#1181)

* add enum ChannelStatus

* merge CloseSendEnd and CloseReceiveEnd

* update channel_test


Former-commit-id: fda25987
上级 03c635ba
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace oneflow { namespace oneflow {
CommNet::~CommNet() { CommNet::~CommNet() {
ready_cbs_.CloseSendEnd(); ready_cbs_.Close();
ready_cb_poller_.join(); ready_cb_poller_.join();
} }
...@@ -79,7 +79,7 @@ CommNet::CommNet(const Plan& plan) { ...@@ -79,7 +79,7 @@ CommNet::CommNet(const Plan& plan) {
ready_cb_poller_ = std::thread([this]() { ready_cb_poller_ = std::thread([this]() {
std::function<void()> cb; std::function<void()> cb;
while (ready_cbs_.Receive(&cb) == 0) { cb(); } while (ready_cbs_.Receive(&cb) == kChannelStatusSuccess) { cb(); }
}); });
} }
......
...@@ -5,69 +5,49 @@ ...@@ -5,69 +5,49 @@
namespace oneflow { namespace oneflow {
enum ChannelStatus { kChannelStatusSuccess = 0, kChannelStatusErrorClosed };
template<typename T> template<typename T>
class Channel final { class Channel final {
public: public:
OF_DISALLOW_COPY_AND_MOVE(Channel); OF_DISALLOW_COPY_AND_MOVE(Channel);
Channel() : is_send_closed_(false), is_receive_closed_(false) {} Channel() : is_closed_(false) {}
~Channel() = default; ~Channel() = default;
// return code ChannelStatus Send(const T& item);
// 0 : success send item ChannelStatus Receive(T* item);
// -1 : fail (send end has been closed) void Close();
int Send(const T& item);
// If the channel is empty, the thread calling Receive() would be blocked.
// return value
// 0: success -- if successfully get the item ref in val_
// -1: fail -- when the channel tell the owner thread should exit
int Receive(T* item);
// close the channel's send end, the thread can't send item to the channel
void CloseSendEnd();
// close the channel's receive end , the thread can't receive item from
// channel
void CloseReceiveEnd();
private: private:
std::queue<T> val_; std::queue<T> queue_;
mutable std::mutex mutex_; mutable std::mutex mutex_;
bool is_send_closed_; bool is_closed_;
bool is_receive_closed_;
std::condition_variable cond_; std::condition_variable cond_;
}; };
template<typename T> template<typename T>
int Channel<T>::Send(const T& item) { ChannelStatus Channel<T>::Send(const T& item) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
if (is_send_closed_) { return -1; } if (is_closed_) { return kChannelStatusErrorClosed; }
val_.push(item); queue_.push(item);
cond_.notify_one(); cond_.notify_one();
return 0; return kChannelStatusSuccess;
} }
template<typename T> template<typename T>
int Channel<T>::Receive(T* item) { ChannelStatus Channel<T>::Receive(T* item) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
cond_.wait(lock, [this]() { return !val_.empty() || is_receive_closed_ || is_send_closed_; }); cond_.wait(lock, [this]() { return (!queue_.empty()) || is_closed_; });
if (val_.empty() || is_receive_closed_) { return -1; } if (queue_.empty()) { return kChannelStatusErrorClosed; }
*item = val_.front(); *item = queue_.front();
val_.pop(); queue_.pop();
return 0; return kChannelStatusSuccess;
}
template<typename T>
void Channel<T>::CloseSendEnd() {
std::unique_lock<std::mutex> lock(mutex_);
is_send_closed_ = true;
cond_.notify_all();
} }
template<typename T> template<typename T>
void Channel<T>::CloseReceiveEnd() { void Channel<T>::Close() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
is_receive_closed_ = true; is_closed_ = true;
cond_.notify_all(); cond_.notify_all();
} }
......
...@@ -5,14 +5,14 @@ namespace oneflow { ...@@ -5,14 +5,14 @@ namespace oneflow {
void CallFromSenderThread(Channel<int>* channel, Range range) { void CallFromSenderThread(Channel<int>* channel, Range range) {
for (int i = range.begin(); i < range.end(); ++i) { for (int i = range.begin(); i < range.end(); ++i) {
if (channel->Send(i) == -1) { break; } if (channel->Send(i) != kChannelStatusSuccess) { break; }
} }
} }
void CallFromReceiverThread(std::vector<int>* visit, Channel<int>* channel) { void CallFromReceiverThread(std::vector<int>* visit, Channel<int>* channel) {
int num = -1; int num = -1;
int* num_ptr = &num; int* num_ptr = &num;
while (channel->Receive(num_ptr) == 0) { ++visit->at(*num_ptr); } while (channel->Receive(num_ptr) == kChannelStatusSuccess) { ++visit->at(*num_ptr); }
} }
TEST(Channel, 30sender40receiver) { TEST(Channel, 30sender40receiver) {
...@@ -35,9 +35,8 @@ TEST(Channel, 30sender40receiver) { ...@@ -35,9 +35,8 @@ TEST(Channel, 30sender40receiver) {
receivers.push_back(std::thread(CallFromReceiverThread, &visits[i], &channel)); receivers.push_back(std::thread(CallFromReceiverThread, &visits[i], &channel));
} }
for (std::thread& this_thread : senders) { this_thread.join(); } for (std::thread& this_thread : senders) { this_thread.join(); }
channel.CloseSendEnd(); channel.Close();
for (std::thread& this_thread : receivers) { this_thread.join(); } for (std::thread& this_thread : receivers) { this_thread.join(); }
channel.CloseReceiveEnd();
for (int i = 0; i < range_num; ++i) { for (int i = 0; i < range_num; ++i) {
int visit_count = 0; int visit_count = 0;
for (int j = 0; j < receiver_num; j++) { visit_count += visits[j][i]; } for (int j = 0; j < receiver_num; j++) { visit_count += visits[j][i]; }
......
...@@ -16,7 +16,7 @@ GpuThread::GpuThread(int64_t thrd_id, int64_t dev_id) { ...@@ -16,7 +16,7 @@ GpuThread::GpuThread(int64_t thrd_id, int64_t dev_id) {
}); });
cb_event_poller_ = std::thread([this]() { cb_event_poller_ = std::thread([this]() {
CudaCBEvent cb_event; CudaCBEvent cb_event;
while (cb_event_chan_.Receive(&cb_event) == 0) { while (cb_event_chan_.Receive(&cb_event) == kChannelStatusSuccess) {
CudaCheck(cudaEventSynchronize(cb_event.event)); CudaCheck(cudaEventSynchronize(cb_event.event));
cb_event.callback(); cb_event.callback();
CudaCheck(cudaEventDestroy(cb_event.event)); CudaCheck(cudaEventDestroy(cb_event.event));
...@@ -25,8 +25,7 @@ GpuThread::GpuThread(int64_t thrd_id, int64_t dev_id) { ...@@ -25,8 +25,7 @@ GpuThread::GpuThread(int64_t thrd_id, int64_t dev_id) {
} }
GpuThread::~GpuThread() { GpuThread::~GpuThread() {
cb_event_chan_.CloseSendEnd(); cb_event_chan_.Close();
cb_event_chan_.CloseReceiveEnd();
cb_event_poller_.join(); cb_event_poller_.join();
} }
......
...@@ -5,8 +5,7 @@ namespace oneflow { ...@@ -5,8 +5,7 @@ namespace oneflow {
Thread::~Thread() { Thread::~Thread() {
actor_thread_.join(); actor_thread_.join();
CHECK(id2task_.empty()); CHECK(id2task_.empty());
msg_channel_.CloseSendEnd(); msg_channel_.Close();
msg_channel_.CloseReceiveEnd();
} }
void Thread::AddTask(const TaskProto& task) { void Thread::AddTask(const TaskProto& task) {
...@@ -17,7 +16,7 @@ void Thread::AddTask(const TaskProto& task) { ...@@ -17,7 +16,7 @@ void Thread::AddTask(const TaskProto& task) {
void Thread::PollMsgChannel(const ThreadCtx& thread_ctx) { void Thread::PollMsgChannel(const ThreadCtx& thread_ctx) {
ActorMsg msg; ActorMsg msg;
while (true) { while (true) {
CHECK_EQ(msg_channel_.Receive(&msg), 0); CHECK_EQ(msg_channel_.Receive(&msg), kChannelStatusSuccess);
if (msg.msg_type() == ActorMsgType::kCmdMsg) { if (msg.msg_type() == ActorMsgType::kCmdMsg) {
if (msg.actor_cmd() == ActorCmd::kStopThread) { if (msg.actor_cmd() == ActorCmd::kStopThread) {
CHECK(id2actor_ptr_.empty()); CHECK(id2actor_ptr_.empty());
......
...@@ -8,15 +8,14 @@ ThreadPool::ThreadPool(int32_t thread_num) ...@@ -8,15 +8,14 @@ ThreadPool::ThreadPool(int32_t thread_num)
Channel<std::function<void()>>* chan = &(work_chans_.at(i)); Channel<std::function<void()>>* chan = &(work_chans_.at(i));
threads_[i] = std::thread([chan]() { threads_[i] = std::thread([chan]() {
std::function<void()> work; std::function<void()> work;
while (chan->Receive(&work) == 0) { work(); } while (chan->Receive(&work) == kChannelStatusSuccess) { work(); }
}); });
} }
} }
ThreadPool::~ThreadPool() { ThreadPool::~ThreadPool() {
FOR_RANGE(int32_t, i, 0, work_chans_.size()) { FOR_RANGE(int32_t, i, 0, work_chans_.size()) {
work_chans_.at(i).CloseSendEnd(); work_chans_.at(i).Close();
work_chans_.at(i).CloseReceiveEnd();
threads_.at(i).join(); threads_.at(i).join();
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册