未验证 提交 e9583166 编写于 作者: S seemingwang 提交者: GitHub

fix block_queue problem (#34461)

上级 92d8fed8
...@@ -61,38 +61,65 @@ using Variable = framework::Variable; ...@@ -61,38 +61,65 @@ using Variable = framework::Variable;
template <typename T> template <typename T>
class BlockingQueue { class BlockingQueue {
public: public:
explicit BlockingQueue(size_t capacity) : capacity_(capacity) { explicit BlockingQueue(size_t capacity) : capacity_(capacity) {}
PADDLE_ENFORCE_GT(capacity_, 0,
platform::errors::InvalidArgument(
"The capacity must be greater than 0."));
}
bool Push(const T &elem) { bool Push(const T &elem) {
{
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return queue_.size() < capacity_; }); WaitForWrite(lock);
queue_.push_back(elem); queue_.push_back(elem);
Notify();
return true;
}
bool WaitForWrite(std::unique_lock<std::mutex> &lock) { // NOLINT
while (FullUnlocked()) {
if (empty_waiters_ != 0) {
empty_cond_.notify_one();
}
full_waiters_++;
full_cond_.wait(lock);
full_waiters_--;
} }
cv_.notify_one();
return true; return true;
} }
bool WaitForRead(std::unique_lock<std::mutex> &lock) { // NOLINT
while (EmptyUnlocked()) {
if (full_waiters_ != 0) {
full_cond_.notify_one();
}
empty_waiters_++;
empty_cond_.wait(lock);
empty_waiters_--;
}
return true;
}
bool EmptyUnlocked() { return queue_.empty(); }
bool FullUnlocked() { return queue_.size() >= capacity_; }
void Notify() {
if (empty_waiters_ != 0 && (!EmptyUnlocked())) {
empty_cond_.notify_one();
}
if (full_waiters_ != 0 && (!FullUnlocked())) {
full_cond_.notify_one();
}
}
bool Push(T &&elem) { bool Push(T &&elem) {
{
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return queue_.size() < capacity_; }); WaitForWrite(lock);
queue_.emplace_back(std::move(elem)); queue_.emplace_back(std::move(elem));
}
cv_.notify_one(); Notify();
return true; return true;
} }
T Pop() { T Pop() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !queue_.empty(); }); WaitForRead(lock);
T rc(std::move(queue_.front())); T rc(std::move(queue_.front()));
queue_.pop_front(); queue_.pop_front();
cv_.notify_one(); Notify();
return rc; return rc;
} }
...@@ -107,11 +134,14 @@ class BlockingQueue { ...@@ -107,11 +134,14 @@ class BlockingQueue {
} }
private: private:
int empty_waiters_ = 0;
int full_waiters_ = 0;
std::condition_variable empty_cond_;
std::condition_variable full_cond_;
const size_t capacity_; const size_t capacity_;
std::deque<T> queue_; std::deque<T> queue_;
mutable std::mutex mutex_; mutable std::mutex mutex_;
std::condition_variable cv_;
}; };
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册