未验证 提交 e576345f 编写于 作者: Y yuyang18

Try to speed up buffered reader

上级 61b3a597
...@@ -18,10 +18,7 @@ ...@@ -18,10 +18,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
BufferedReader::~BufferedReader() { BufferedReader::~BufferedReader() { reader_->Shutdown(); }
reader_->Shutdown();
buffer_.clear();
}
BufferedReader::BufferedReader( BufferedReader::BufferedReader(
const std::shared_ptr<framework::ReaderBase> &reader, const std::shared_ptr<framework::ReaderBase> &reader,
const platform::Place &place, size_t buffer_size) const platform::Place &place, size_t buffer_size)
...@@ -29,43 +26,60 @@ BufferedReader::BufferedReader( ...@@ -29,43 +26,60 @@ BufferedReader::BufferedReader(
thread_pool_(1), thread_pool_(1),
place_(place), place_(place),
buffer_size_(buffer_size) { buffer_size_(buffer_size) {
cpu_buffer_.resize(buffer_size);
gpu_buffer_.resize(buffer_size);
AppendFutureToBatchSize(); AppendFutureToBatchSize();
} }
void BufferedReader::AppendFutureToBatchSize() { void BufferedReader::AppendFutureToBatchSize() {
while (buffer_.size() < buffer_size_) { PADDLE_ENFORCE_EQ(position_.size(), 0U);
AppendFuture(); for (size_t i = 0; i < buffer_size_; ++i) {
AppendFuture(i);
} }
} }
void BufferedReader::AppendFuture() { void BufferedReader::AppendFuture(size_t i) {
buffer_.emplace_back(thread_pool_.enqueue([this] { position_.emplace(thread_pool_.enqueue([this, i]() -> size_t {
TensorVec cpu_buffer; TensorVec &cpu = cpu_buffer_[i];
reader_->ReadNext(&cpu_buffer); reader_->ReadNext(&cpu);
if (platform::is_gpu_place(place_)) {
TensorVec gpu_buffer;
for (size_t i = 0; i < cpu_buffer.size(); ++i) { if (cpu.empty()) {
gpu_buffer.emplace_back(); return -1UL;
framework::TensorCopySync(cpu_buffer[i], place_, &gpu_buffer.back());
} }
cpu_buffer = gpu_buffer; if (platform::is_gpu_place(place_)) {
TensorVec &gpu = gpu_buffer_[i];
gpu.resize(cpu.size());
for (size_t i = 0; i < cpu.size(); ++i) {
framework::TensorCopySync(cpu[i], place_, &gpu[i]);
}
} }
return cpu_buffer; return i;
})); }));
} }
void BufferedReader::ShutdownImpl() { void BufferedReader::ShutdownImpl() {
reader_->Shutdown(); reader_->Shutdown();
buffer_.clear(); while (!position_.empty()) {
position_.pop();
}
} }
void BufferedReader::StartImpl() { void BufferedReader::StartImpl() {
reader_->Start(); reader_->Start();
AppendFutureToBatchSize(); AppendFutureToBatchSize();
} }
void BufferedReader::ReadNextImpl(std::vector<framework::LoDTensor> *out) { void BufferedReader::ReadNextImpl(std::vector<framework::LoDTensor> *out) {
PADDLE_ENFORCE_EQ(buffer_.size(), buffer_size_); if (position_.empty()) {
*out = buffer_.front().get(); out->clear();
buffer_.pop_front(); return;
AppendFuture(); }
size_t i = position_.front().get();
position_.pop();
if (i == -1UL) {
ReadNextImpl(out);
return;
}
*out = platform::is_gpu_place(place_) ? gpu_buffer_[i] : cpu_buffer_[i];
AppendFuture(i);
} }
} // namespace reader } // namespace reader
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <list> #include <list>
#include <queue>
#include <vector> #include <vector>
#include "ThreadPool.h" #include "ThreadPool.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
...@@ -36,7 +37,7 @@ class BufferedReader : public framework::DecoratedReader { ...@@ -36,7 +37,7 @@ class BufferedReader : public framework::DecoratedReader {
private: private:
void AppendFutureToBatchSize(); void AppendFutureToBatchSize();
void AppendFuture(); void AppendFuture(size_t i);
protected: protected:
void ShutdownImpl() override; void ShutdownImpl() override;
...@@ -47,7 +48,10 @@ class BufferedReader : public framework::DecoratedReader { ...@@ -47,7 +48,10 @@ class BufferedReader : public framework::DecoratedReader {
ThreadPool thread_pool_; ThreadPool thread_pool_;
platform::Place place_; platform::Place place_;
const size_t buffer_size_; const size_t buffer_size_;
std::list<VecFuture> buffer_;
std::queue<std::future<size_t>> position_;
std::vector<TensorVec> cpu_buffer_;
std::vector<TensorVec> gpu_buffer_;
}; };
} // namespace reader } // namespace reader
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册