未验证 提交 832deee4 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #9178 from JiayiFeng/fix_bugs_in_reader

Fix bugs in c++ readers
...@@ -48,20 +48,24 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -48,20 +48,24 @@ class DoubleBufferReader : public framework::DecoratedReader {
void start_thread() { void start_thread() {
buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize); buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize);
std::thread prefetch([this] { PrefetchThreadFunc(); }); prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
prefetch.detach();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override; void ReInit() override;
~DoubleBufferReader() { buffer_->Close(); } ~DoubleBufferReader() {
buffer_->Close();
prefetcher_.join();
delete buffer_;
}
bool HasNext() const override; bool HasNext() const override;
private: private:
void PrefetchThreadFunc(); void PrefetchThreadFunc();
std::thread prefetcher_;
framework::Channel<Item>* buffer_; framework::Channel<Item>* buffer_;
platform::Place place_; platform::Place place_;
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_; std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
...@@ -134,6 +138,8 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { ...@@ -134,6 +138,8 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
void DoubleBufferReader::ReInit() { void DoubleBufferReader::ReInit() {
reader_->ReInit(); reader_->ReInit();
buffer_->Close(); buffer_->Close();
prefetcher_.join();
delete buffer_;
start_thread(); start_thread();
} }
...@@ -159,11 +165,12 @@ void DoubleBufferReader::PrefetchThreadFunc() { ...@@ -159,11 +165,12 @@ void DoubleBufferReader::PrefetchThreadFunc() {
if (!buffer_->Send(&batch)) { if (!buffer_->Send(&batch)) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The " VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread terminates."; "prefetch thread will terminate.";
break; break;
} }
} }
buffer_->Close(); buffer_->Close();
VLOG(5) << "Prefetch thread terminates.";
} }
bool DoubleBufferReader::HasNext() const { bool DoubleBufferReader::HasNext() const {
......
...@@ -34,6 +34,9 @@ class ShuffleReader : public framework::DecoratedReader { ...@@ -34,6 +34,9 @@ class ShuffleReader : public framework::DecoratedReader {
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
if (iteration_pos_ >= buffer_.size()) { if (iteration_pos_ >= buffer_.size()) {
VLOG(10) << "Resetting shuffle buffer"; VLOG(10) << "Resetting shuffle buffer";
ReadIntoBuffers(); ReadIntoBuffers();
...@@ -50,7 +53,6 @@ class ShuffleReader : public framework::DecoratedReader { ...@@ -50,7 +53,6 @@ class ShuffleReader : public framework::DecoratedReader {
buffer_.clear(); buffer_.clear();
buffer_.reserve(buffer_size_); buffer_.reserve(buffer_size_);
iteration_pos_ = 0; iteration_pos_ = 0;
PADDLE_ENFORCE(reader_->HasNext());
for (size_t i = 0; i < buffer_size_; ++i) { for (size_t i = 0; i < buffer_size_; ++i) {
if (!reader_->HasNext()) { if (!reader_->HasNext()) {
break; break;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册