提交 3fcd16ed 编写于 作者: F fengjiayi

init double buffer

上级 86263b2f
...@@ -112,5 +112,46 @@ void BatchReader::ReadNext(std::vector<LoDTensor>* out) { ...@@ -112,5 +112,46 @@ void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
out->push_back(out_tensor); out->push_back(out_tensor);
} }
} }
void DoubleBufferReader::ReadNext(std::vector<LoDTensor>* out) {
std::unique_lock<std::mutex> lck(mtx_);
while (write_pos_ == read_pos_) {
buffer_not_empty_.wait(lck);
}
out->clear();
out->resize(buffer_[read_pos_].size());
// TODO(fengjiayi): This copy shall be reduced.
for (size_t i = 0; i < buffer_[read_pos_].size(); ++i) {
TensorCopy(buffer_[read_pos_][i], platform::CPUPlace(), &out[i]);
out[i].set_lod(buffer_[read_pos_][i].lod());
}
++read_pos_;
if (read_pos_ >= kDoubleBufferSize) {
read_pos_ = 0;
}
buffer_not_full_.notify_all();
}
bool DoubleBufferReader::HasNext() {
return reader_->HasNext() || !buffer_.empty();
}
void DoubleBufferReader::ProducerThreadFunc() {
while (reader_->HasNext()) {
std::unique_lock<std::mutex> lck(mtx);
while (((write_pos_ + 1) % kDoubleBufferSize) == read_pos_) {
buffer_not_full_.wait(lck);
}
reader_->ReadNext(&buffer_[write_pos_]);
++write_pos_;
if (write_pos_ >= kDoubleBufferSize) {
write_pos_ = 0;
}
buffer_not_empty_.notify_all();
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -16,10 +16,13 @@ ...@@ -16,10 +16,13 @@
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/threadpool.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
static constexpr size_t kDoubleBufferSize = 3;
class ReaderBase { class ReaderBase {
public: public:
explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) { explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) {
...@@ -135,6 +138,28 @@ class BatchReader : public DecoratedReader { ...@@ -135,6 +138,28 @@ class BatchReader : public DecoratedReader {
std::vector<std::vector<LoDTensor>> buffer_; std::vector<std::vector<LoDTensor>> buffer_;
}; };
class DoubleBufferReader : public DecoratedReader {
public:
DoubleBufferReader(ReaderBase* reader)
: DecoratedReader(reader), buffer_(kDoubleBufferSize) {
framework::Async(std::bind(&DoubleBufferReader::ProducerThreadFunc, this));
}
void ReadNext(std::vector<LoDTensor>* out) override;
bool HasNext() const override;
private:
void ProducerThreadFunc();
std::vector<std::vector<LoDTensor>> buffer_;
size_t write_pos_;
size_t read_pos_;
std::mutex mtx_;
std::condition_variable buffer_not_full_;
std::condition_variable buffer_not_empty_;
};
// The ReaderHolder is used as readers' unified wrapper, // The ReaderHolder is used as readers' unified wrapper,
// making it easier to access different type readers in Variables. // making it easier to access different type readers in Variables.
class ReaderHolder { class ReaderHolder {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册