提交 b60da672 编写于 作者: C chengduo 提交者: Abhinav Arora

Refine buffer channel (#8098)

* refine buffer channel

*  refine Receive and Send

* follow comments
上级 022e5dee
......@@ -23,8 +23,8 @@ namespace framework {
template <typename T>
class Channel {
public:
virtual void Send(T*) = 0;
virtual void Receive(T*) = 0;
virtual bool Send(T*) = 0;
virtual bool Receive(T*) = 0;
virtual size_t Cap() = 0;
virtual void Close() = 0;
virtual ~Channel() {}
......
......@@ -30,8 +30,8 @@ class Buffered : public paddle::framework::Channel<T> {
friend void paddle::framework::CloseChannel<T>(Channel<T>*);
public:
virtual void Send(T*);
virtual void Receive(T*);
virtual bool Send(T*);
virtual bool Receive(T*);
virtual size_t Cap() { return cap_; }
virtual void Close();
virtual ~Buffered();
......@@ -48,33 +48,36 @@ class Buffered : public paddle::framework::Channel<T> {
PADDLE_ENFORCE_GT(cap, 0);
}
void NotifyAllSenders(std::unique_lock<std::mutex>*);
void NotifyAllParticipants(std::unique_lock<std::mutex>*);
};
template <typename T>
void Buffered<T>::Send(T* item) {
bool Buffered<T>::Send(T* item) {
std::unique_lock<std::mutex> lock(mu_);
full_cond_var_.wait(lock,
[this]() { return channel_.size() < cap_ || closed_; });
bool ret = false;
if (!closed_) {
channel_.push_back(std::move(*item));
lock.unlock();
empty_cond_var_.notify_one();
ret = true;
}
return ret;
}
template <typename T>
void Buffered<T>::Receive(T* item) {
bool Buffered<T>::Receive(T* item) {
std::unique_lock<std::mutex> lock(mu_);
empty_cond_var_.wait(lock, [this]() { return !channel_.empty() || closed_; });
bool ret = false;
if (!closed_) {
*item = std::move(channel_.front());
channel_.pop_front();
NotifyAllSenders(&lock);
} else {
item = nullptr;
full_cond_var_.notify_one();
ret = true;
}
return ret;
}
template <typename T>
......@@ -92,12 +95,6 @@ Buffered<T>::~Buffered() {
NotifyAllParticipants(&lock);
}
template <typename T>
void Buffered<T>::NotifyAllSenders(std::unique_lock<std::mutex>* lock) {
lock->unlock();
full_cond_var_.notify_all();
}
template <typename T>
void Buffered<T>::NotifyAllParticipants(std::unique_lock<std::mutex>* lock) {
lock->unlock();
......
......@@ -29,8 +29,8 @@ class UnBuffered : public paddle::framework::Channel<T> {
friend void paddle::framework::CloseChannel<T>(Channel<T>*);
public:
virtual void Send(T*);
virtual void Receive(T*);
virtual bool Send(T*);
virtual bool Receive(T*);
virtual size_t Cap() { return 0; }
virtual void Close();
virtual ~UnBuffered();
......@@ -57,7 +57,7 @@ class UnBuffered : public paddle::framework::Channel<T> {
// This function implements the concept of how data should
// be sent from a writer to a reader.
template <typename T>
void UnBuffered<T>::Send(T* data) {
bool UnBuffered<T>::Send(T* data) {
// Prevent other writers from entering
std::unique_lock<std::recursive_mutex> writer_lock(mu_write_);
writer_found_ = true;
......@@ -66,6 +66,7 @@ void UnBuffered<T>::Send(T* data) {
cv_writer_.wait(cv_lock,
[this]() { return reader_found_ == true || closed_; });
cv_reader_.notify_one();
bool ret = false;
if (!closed_) {
std::unique_lock<std::mutex> channel_lock(mu_ch_);
item = data;
......@@ -74,14 +75,16 @@ void UnBuffered<T>::Send(T* data) {
channel_lock.lock();
cv_channel_.wait(channel_lock,
[this]() { return item == nullptr || closed_; });
ret = true;
}
writer_found_ = false;
return ret;
}
// This function implements the concept of how
// data that was sent by a writer is read from a reader.
template <typename T>
void UnBuffered<T>::Receive(T* data) {
bool UnBuffered<T>::Receive(T* data) {
// Prevent other readers from entering
std::unique_lock<std::recursive_mutex> read_lock{mu_read_};
reader_found_ = true;
......@@ -90,6 +93,7 @@ void UnBuffered<T>::Receive(T* data) {
cv_reader_.wait(cv_lock,
[this]() { return writer_found_ == true || closed_; });
cv_writer_.notify_one();
bool ret = false;
if (!closed_) {
std::unique_lock<std::mutex> lock_ch{mu_ch_};
// Reader should wait for the writer to first write its data
......@@ -98,10 +102,12 @@ void UnBuffered<T>::Receive(T* data) {
*data = std::move(*item);
item = nullptr;
lock_ch.unlock();
ret = true;
}
cv_channel_.notify_one();
}
reader_found_ = false;
return ret;
}
// This function implements the sequence of events
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册