提交 b60da672 编写于 作者: C chengduo 提交者: Abhinav Arora

Refine buffer channel (#8098)

* refine buffer channel

*  refine Receive and Send

* follow comments
上级 022e5dee
...@@ -23,8 +23,8 @@ namespace framework { ...@@ -23,8 +23,8 @@ namespace framework {
template <typename T> template <typename T>
class Channel { class Channel {
public: public:
virtual void Send(T*) = 0; virtual bool Send(T*) = 0;
virtual void Receive(T*) = 0; virtual bool Receive(T*) = 0;
virtual size_t Cap() = 0; virtual size_t Cap() = 0;
virtual void Close() = 0; virtual void Close() = 0;
virtual ~Channel() {} virtual ~Channel() {}
......
...@@ -30,8 +30,8 @@ class Buffered : public paddle::framework::Channel<T> { ...@@ -30,8 +30,8 @@ class Buffered : public paddle::framework::Channel<T> {
friend void paddle::framework::CloseChannel<T>(Channel<T>*); friend void paddle::framework::CloseChannel<T>(Channel<T>*);
public: public:
virtual void Send(T*); virtual bool Send(T*);
virtual void Receive(T*); virtual bool Receive(T*);
virtual size_t Cap() { return cap_; } virtual size_t Cap() { return cap_; }
virtual void Close(); virtual void Close();
virtual ~Buffered(); virtual ~Buffered();
...@@ -48,33 +48,36 @@ class Buffered : public paddle::framework::Channel<T> { ...@@ -48,33 +48,36 @@ class Buffered : public paddle::framework::Channel<T> {
PADDLE_ENFORCE_GT(cap, 0); PADDLE_ENFORCE_GT(cap, 0);
} }
void NotifyAllSenders(std::unique_lock<std::mutex>*);
void NotifyAllParticipants(std::unique_lock<std::mutex>*); void NotifyAllParticipants(std::unique_lock<std::mutex>*);
}; };
template <typename T> template <typename T>
void Buffered<T>::Send(T* item) { bool Buffered<T>::Send(T* item) {
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
full_cond_var_.wait(lock, full_cond_var_.wait(lock,
[this]() { return channel_.size() < cap_ || closed_; }); [this]() { return channel_.size() < cap_ || closed_; });
bool ret = false;
if (!closed_) { if (!closed_) {
channel_.push_back(std::move(*item)); channel_.push_back(std::move(*item));
lock.unlock(); lock.unlock();
empty_cond_var_.notify_one(); empty_cond_var_.notify_one();
ret = true;
} }
return ret;
} }
template <typename T> template <typename T>
void Buffered<T>::Receive(T* item) { bool Buffered<T>::Receive(T* item) {
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
empty_cond_var_.wait(lock, [this]() { return !channel_.empty() || closed_; }); empty_cond_var_.wait(lock, [this]() { return !channel_.empty() || closed_; });
bool ret = false;
if (!closed_) { if (!closed_) {
*item = std::move(channel_.front()); *item = std::move(channel_.front());
channel_.pop_front(); channel_.pop_front();
NotifyAllSenders(&lock); full_cond_var_.notify_one();
} else { ret = true;
item = nullptr;
} }
return ret;
} }
template <typename T> template <typename T>
...@@ -92,12 +95,6 @@ Buffered<T>::~Buffered() { ...@@ -92,12 +95,6 @@ Buffered<T>::~Buffered() {
NotifyAllParticipants(&lock); NotifyAllParticipants(&lock);
} }
template <typename T>
void Buffered<T>::NotifyAllSenders(std::unique_lock<std::mutex>* lock) {
lock->unlock();
full_cond_var_.notify_all();
}
template <typename T> template <typename T>
void Buffered<T>::NotifyAllParticipants(std::unique_lock<std::mutex>* lock) { void Buffered<T>::NotifyAllParticipants(std::unique_lock<std::mutex>* lock) {
lock->unlock(); lock->unlock();
......
...@@ -29,8 +29,8 @@ class UnBuffered : public paddle::framework::Channel<T> { ...@@ -29,8 +29,8 @@ class UnBuffered : public paddle::framework::Channel<T> {
friend void paddle::framework::CloseChannel<T>(Channel<T>*); friend void paddle::framework::CloseChannel<T>(Channel<T>*);
public: public:
virtual void Send(T*); virtual bool Send(T*);
virtual void Receive(T*); virtual bool Receive(T*);
virtual size_t Cap() { return 0; } virtual size_t Cap() { return 0; }
virtual void Close(); virtual void Close();
virtual ~UnBuffered(); virtual ~UnBuffered();
...@@ -57,7 +57,7 @@ class UnBuffered : public paddle::framework::Channel<T> { ...@@ -57,7 +57,7 @@ class UnBuffered : public paddle::framework::Channel<T> {
// This function implements the concept of how data should // This function implements the concept of how data should
// be sent from a writer to a reader. // be sent from a writer to a reader.
template <typename T> template <typename T>
void UnBuffered<T>::Send(T* data) { bool UnBuffered<T>::Send(T* data) {
// Prevent other writers from entering // Prevent other writers from entering
std::unique_lock<std::recursive_mutex> writer_lock(mu_write_); std::unique_lock<std::recursive_mutex> writer_lock(mu_write_);
writer_found_ = true; writer_found_ = true;
...@@ -66,6 +66,7 @@ void UnBuffered<T>::Send(T* data) { ...@@ -66,6 +66,7 @@ void UnBuffered<T>::Send(T* data) {
cv_writer_.wait(cv_lock, cv_writer_.wait(cv_lock,
[this]() { return reader_found_ == true || closed_; }); [this]() { return reader_found_ == true || closed_; });
cv_reader_.notify_one(); cv_reader_.notify_one();
bool ret = false;
if (!closed_) { if (!closed_) {
std::unique_lock<std::mutex> channel_lock(mu_ch_); std::unique_lock<std::mutex> channel_lock(mu_ch_);
item = data; item = data;
...@@ -74,14 +75,16 @@ void UnBuffered<T>::Send(T* data) { ...@@ -74,14 +75,16 @@ void UnBuffered<T>::Send(T* data) {
channel_lock.lock(); channel_lock.lock();
cv_channel_.wait(channel_lock, cv_channel_.wait(channel_lock,
[this]() { return item == nullptr || closed_; }); [this]() { return item == nullptr || closed_; });
ret = true;
} }
writer_found_ = false; writer_found_ = false;
return ret;
} }
// This function implements the concept of how // This function implements the concept of how
// data that was sent by a writer is read from a reader. // data that was sent by a writer is read from a reader.
template <typename T> template <typename T>
void UnBuffered<T>::Receive(T* data) { bool UnBuffered<T>::Receive(T* data) {
// Prevent other readers from entering // Prevent other readers from entering
std::unique_lock<std::recursive_mutex> read_lock{mu_read_}; std::unique_lock<std::recursive_mutex> read_lock{mu_read_};
reader_found_ = true; reader_found_ = true;
...@@ -90,6 +93,7 @@ void UnBuffered<T>::Receive(T* data) { ...@@ -90,6 +93,7 @@ void UnBuffered<T>::Receive(T* data) {
cv_reader_.wait(cv_lock, cv_reader_.wait(cv_lock,
[this]() { return writer_found_ == true || closed_; }); [this]() { return writer_found_ == true || closed_; });
cv_writer_.notify_one(); cv_writer_.notify_one();
bool ret = false;
if (!closed_) { if (!closed_) {
std::unique_lock<std::mutex> lock_ch{mu_ch_}; std::unique_lock<std::mutex> lock_ch{mu_ch_};
// Reader should wait for the writer to first write its data // Reader should wait for the writer to first write its data
...@@ -98,10 +102,12 @@ void UnBuffered<T>::Receive(T* data) { ...@@ -98,10 +102,12 @@ void UnBuffered<T>::Receive(T* data) {
*data = std::move(*item); *data = std::move(*item);
item = nullptr; item = nullptr;
lock_ch.unlock(); lock_ch.unlock();
ret = true;
} }
cv_channel_.notify_one(); cv_channel_.notify_one();
} }
reader_found_ = false; reader_found_ = false;
return ret;
} }
// This function implements the sequence of events // This function implements the sequence of events
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册