提交 304b6b71 编写于 作者: F fengjiayi

Follow comments

上级 4cb63d84
...@@ -25,6 +25,11 @@ namespace reader { ...@@ -25,6 +25,11 @@ namespace reader {
template <typename T> template <typename T>
class BlockingQueue { class BlockingQueue {
// BlockingQueue is for buffered reading and is supposed to use only the
// reader package. It is true that we could and we should have been using
// framework::Channel, but which has currently a deadlock bug. BlockingQueue
// is a workaround and a simplified version of framework::Channel as it
// doesn't support GPU and it implements on buffered blocking queue.
public: public:
explicit BlockingQueue(size_t capacity) explicit BlockingQueue(size_t capacity)
: capacity_(capacity), closed_(false) { : capacity_(capacity), closed_(false) {
...@@ -37,26 +42,28 @@ class BlockingQueue { ...@@ -37,26 +42,28 @@ class BlockingQueue {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait(lock, [&] { return queue_.size() < capacity_ || closed_; }); send_cv_.wait(lock, [&] { return queue_.size() < capacity_ || closed_; });
if (closed_) { if (closed_) {
VLOG(5)
<< "WARNING: Sending an element to a closed reader::BlokcingQueue.";
return false; return false;
} else {
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.push_back(elem);
receive_cv_.notify_one();
return true;
} }
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.push_back(elem);
receive_cv_.notify_one();
return true;
} }
bool Send(T&& elem) { bool Send(T&& elem) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait(lock, [&] { return queue_.size() < capacity_ || closed_; }); send_cv_.wait(lock, [&] { return queue_.size() < capacity_ || closed_; });
if (closed_) { if (closed_) {
VLOG(5)
<< "WARNING: Sending an element to a closed reader::BlokcingQueue.";
return false; return false;
} else {
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.emplace_back(std::move(elem));
receive_cv_.notify_one();
return true;
} }
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.emplace_back(std::move(elem));
receive_cv_.notify_one();
return true;
} }
bool Receive(T* elem) { bool Receive(T* elem) {
...@@ -86,16 +93,6 @@ class BlockingQueue { ...@@ -86,16 +93,6 @@ class BlockingQueue {
return closed_; return closed_;
} }
bool CanSend() {
std::lock_guard<std::mutex> lock(mutex_);
return !closed_ && queue_.size() < capacity_;
}
bool CanReceive() {
std::lock_guard<std::mutex> lock(mutex_);
return !queue_.empty();
}
size_t Cap() { size_t Cap() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return capacity_; return capacity_;
......
...@@ -55,8 +55,6 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -55,8 +55,6 @@ class DoubleBufferReader : public framework::DecoratedReader {
~DoubleBufferReader() { EndPrefetcher(); } ~DoubleBufferReader() { EndPrefetcher(); }
private: private:
bool HasNext() const;
void StartPrefetcher() { void StartPrefetcher() {
channel_ = new reader::BlockingQueue<size_t>(kChannelSize); channel_ = new reader::BlockingQueue<size_t>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
...@@ -139,17 +137,16 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -139,17 +137,16 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}; };
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear(); size_t cached_tensor_id;
if (HasNext()) { if (channel_->Receive(&cached_tensor_id)) {
size_t cached_tensor_id;
channel_->Receive(&cached_tensor_id);
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
*out = gpu_tensor_cache_[cached_tensor_id]; *out = gpu_tensor_cache_[cached_tensor_id];
ctxs_[cached_tensor_id]->Wait();
} else { } else {
// CPU place // CPU place
*out = cpu_tensor_cache_[cached_tensor_id]; *out = cpu_tensor_cache_[cached_tensor_id];
} }
} else {
out->clear();
} }
} }
...@@ -159,12 +156,6 @@ void DoubleBufferReader::ReInit() { ...@@ -159,12 +156,6 @@ void DoubleBufferReader::ReInit() {
StartPrefetcher(); StartPrefetcher();
} }
bool DoubleBufferReader::HasNext() const {
while (!channel_->IsClosed() && !channel_->CanReceive()) {
}
return channel_->CanReceive();
}
void DoubleBufferReader::PrefetchThreadFunc() { void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts."; VLOG(5) << "A new prefetch thread starts.";
size_t cached_tensor_id = 0; size_t cached_tensor_id = 0;
......
...@@ -37,7 +37,6 @@ class MultiFileReader : public framework::ReaderBase { ...@@ -37,7 +37,6 @@ class MultiFileReader : public framework::ReaderBase {
~MultiFileReader() { EndScheduler(); } ~MultiFileReader() { EndScheduler(); }
private: private:
bool HasNext();
void StartNewScheduler(); void StartNewScheduler();
void EndScheduler(); void EndScheduler();
void ScheduleThreadFunc(); void ScheduleThreadFunc();
...@@ -54,9 +53,8 @@ class MultiFileReader : public framework::ReaderBase { ...@@ -54,9 +53,8 @@ class MultiFileReader : public framework::ReaderBase {
}; };
void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) { void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear(); if (!buffer_->Receive(out)) {
if (HasNext()) { out->clear();
buffer_->Receive(out);
} }
} }
...@@ -65,12 +63,6 @@ void MultiFileReader::ReInit() { ...@@ -65,12 +63,6 @@ void MultiFileReader::ReInit() {
StartNewScheduler(); StartNewScheduler();
} }
bool MultiFileReader::HasNext() {
while (!buffer_->IsClosed() && !buffer_->CanReceive()) {
}
return buffer_->CanReceive();
}
void MultiFileReader::StartNewScheduler() { void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size(); size_t thread_num = prefetchers_.size();
waiting_file_idx_ = new reader::BlockingQueue<size_t>(file_names_.size()); waiting_file_idx_ = new reader::BlockingQueue<size_t>(file_names_.size());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册