提交 b012dc22 编写于 作者: J Juncheng 提交者: Jinhui Yuan

Dev refactor channel (#1181)

* add enum ChannelStatus

* merge CloseSendEnd and CloseReceiveEnd

* update channel_test


Former-commit-id: fda25987
上级 03c635ba
......@@ -3,7 +3,7 @@
namespace oneflow {
CommNet::~CommNet() {
ready_cbs_.CloseSendEnd();
ready_cbs_.Close();
ready_cb_poller_.join();
}
......@@ -79,7 +79,7 @@ CommNet::CommNet(const Plan& plan) {
ready_cb_poller_ = std::thread([this]() {
std::function<void()> cb;
while (ready_cbs_.Receive(&cb) == 0) { cb(); }
while (ready_cbs_.Receive(&cb) == kChannelStatusSuccess) { cb(); }
});
}
......
......@@ -5,69 +5,49 @@
namespace oneflow {
enum ChannelStatus { kChannelStatusSuccess = 0, kChannelStatusErrorClosed };
template<typename T>
class Channel final {
public:
OF_DISALLOW_COPY_AND_MOVE(Channel);
Channel() : is_send_closed_(false), is_receive_closed_(false) {}
Channel() : is_closed_(false) {}
~Channel() = default;
// return code
// 0 : success send item
// -1 : fail (send end has been closed)
int Send(const T& item);
// If the channel is empty, the thread calling Receive() would be blocked.
// return value
// 0: success -- if successfully get the item ref in val_
// -1: fail -- when the channel tell the owner thread should exit
int Receive(T* item);
// close the channel's send end, the thread can't send item to the channel
void CloseSendEnd();
// close the channel's receive end , the thread can't receive item from
// channel
void CloseReceiveEnd();
ChannelStatus Send(const T& item);
ChannelStatus Receive(T* item);
void Close();
private:
std::queue<T> val_;
std::queue<T> queue_;
mutable std::mutex mutex_;
bool is_send_closed_;
bool is_receive_closed_;
bool is_closed_;
std::condition_variable cond_;
};
template<typename T>
int Channel<T>::Send(const T& item) {
ChannelStatus Channel<T>::Send(const T& item) {
std::unique_lock<std::mutex> lock(mutex_);
if (is_send_closed_) { return -1; }
val_.push(item);
if (is_closed_) { return kChannelStatusErrorClosed; }
queue_.push(item);
cond_.notify_one();
return 0;
return kChannelStatusSuccess;
}
template<typename T>
int Channel<T>::Receive(T* item) {
ChannelStatus Channel<T>::Receive(T* item) {
std::unique_lock<std::mutex> lock(mutex_);
cond_.wait(lock, [this]() { return !val_.empty() || is_receive_closed_ || is_send_closed_; });
if (val_.empty() || is_receive_closed_) { return -1; }
*item = val_.front();
val_.pop();
return 0;
}
template<typename T>
void Channel<T>::CloseSendEnd() {
std::unique_lock<std::mutex> lock(mutex_);
is_send_closed_ = true;
cond_.notify_all();
cond_.wait(lock, [this]() { return (!queue_.empty()) || is_closed_; });
if (queue_.empty()) { return kChannelStatusErrorClosed; }
*item = queue_.front();
queue_.pop();
return kChannelStatusSuccess;
}
template<typename T>
void Channel<T>::CloseReceiveEnd() {
void Channel<T>::Close() {
std::unique_lock<std::mutex> lock(mutex_);
is_receive_closed_ = true;
is_closed_ = true;
cond_.notify_all();
}
......
......@@ -5,14 +5,14 @@ namespace oneflow {
void CallFromSenderThread(Channel<int>* channel, Range range) {
for (int i = range.begin(); i < range.end(); ++i) {
if (channel->Send(i) == -1) { break; }
if (channel->Send(i) != kChannelStatusSuccess) { break; }
}
}
void CallFromReceiverThread(std::vector<int>* visit, Channel<int>* channel) {
int num = -1;
int* num_ptr = &num;
while (channel->Receive(num_ptr) == 0) { ++visit->at(*num_ptr); }
while (channel->Receive(num_ptr) == kChannelStatusSuccess) { ++visit->at(*num_ptr); }
}
TEST(Channel, 30sender40receiver) {
......@@ -35,9 +35,8 @@ TEST(Channel, 30sender40receiver) {
receivers.push_back(std::thread(CallFromReceiverThread, &visits[i], &channel));
}
for (std::thread& this_thread : senders) { this_thread.join(); }
channel.CloseSendEnd();
channel.Close();
for (std::thread& this_thread : receivers) { this_thread.join(); }
channel.CloseReceiveEnd();
for (int i = 0; i < range_num; ++i) {
int visit_count = 0;
for (int j = 0; j < receiver_num; j++) { visit_count += visits[j][i]; }
......
......@@ -16,7 +16,7 @@ GpuThread::GpuThread(int64_t thrd_id, int64_t dev_id) {
});
cb_event_poller_ = std::thread([this]() {
CudaCBEvent cb_event;
while (cb_event_chan_.Receive(&cb_event) == 0) {
while (cb_event_chan_.Receive(&cb_event) == kChannelStatusSuccess) {
CudaCheck(cudaEventSynchronize(cb_event.event));
cb_event.callback();
CudaCheck(cudaEventDestroy(cb_event.event));
......@@ -25,8 +25,7 @@ GpuThread::GpuThread(int64_t thrd_id, int64_t dev_id) {
}
GpuThread::~GpuThread() {
cb_event_chan_.CloseSendEnd();
cb_event_chan_.CloseReceiveEnd();
cb_event_chan_.Close();
cb_event_poller_.join();
}
......
......@@ -5,8 +5,7 @@ namespace oneflow {
Thread::~Thread() {
actor_thread_.join();
CHECK(id2task_.empty());
msg_channel_.CloseSendEnd();
msg_channel_.CloseReceiveEnd();
msg_channel_.Close();
}
void Thread::AddTask(const TaskProto& task) {
......@@ -17,7 +16,7 @@ void Thread::AddTask(const TaskProto& task) {
void Thread::PollMsgChannel(const ThreadCtx& thread_ctx) {
ActorMsg msg;
while (true) {
CHECK_EQ(msg_channel_.Receive(&msg), 0);
CHECK_EQ(msg_channel_.Receive(&msg), kChannelStatusSuccess);
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
if (msg.actor_cmd() == ActorCmd::kStopThread) {
CHECK(id2actor_ptr_.empty());
......
......@@ -8,15 +8,14 @@ ThreadPool::ThreadPool(int32_t thread_num)
Channel<std::function<void()>>* chan = &(work_chans_.at(i));
threads_[i] = std::thread([chan]() {
std::function<void()> work;
while (chan->Receive(&work) == 0) { work(); }
while (chan->Receive(&work) == kChannelStatusSuccess) { work(); }
});
}
}
ThreadPool::~ThreadPool() {
FOR_RANGE(int32_t, i, 0, work_chans_.size()) {
work_chans_.at(i).CloseSendEnd();
work_chans_.at(i).CloseReceiveEnd();
work_chans_.at(i).Close();
threads_.at(i).join();
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册