提交 35e1e0d5 编写于 作者: F fengjiayi

uses channel to replace the traditional buffer

上级 b3a11fdf
...@@ -12,42 +12,35 @@ ...@@ -12,42 +12,35 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <condition_variable>
#include <mutex>
#include <thread> #include <thread>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
static constexpr size_t kDoubleBufferSize = 3; static constexpr size_t kDoubleBufferSize = 2;
class DoubleBufferReader : public framework::DecoratedReader { class DoubleBufferReader : public framework::DecoratedReader {
public: public:
explicit DoubleBufferReader(ReaderBase* reader) explicit DoubleBufferReader(ReaderBase* reader)
: DecoratedReader(reader), : DecoratedReader(reader),
buffer_(kDoubleBufferSize), buffer_(framework::MakeChannel<std::vector<framework::LoDTensor>>(
write_pos_(0), kDoubleBufferSize)) {
read_pos_(0) { std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
std::thread prefetch(
std::bind(&DoubleBufferReader::PrefetchThreadFunc, this));
prefetch.detach(); prefetch.detach();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNext(std::vector<framework::LoDTensor>* out) override;
bool HasNext() const override; void ReInit() override;
~DoubleBufferReader() { buffer_->Close(); }
private: private:
void PrefetchThreadFunc(); void PrefetchThreadFunc();
std::vector<std::vector<framework::LoDTensor>> buffer_; framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
size_t write_pos_;
size_t read_pos_;
std::mutex mtx_;
std::condition_variable buffer_not_full_;
std::condition_variable buffer_not_empty_;
}; };
class CreateDoubleBufferReaderOp : public framework::OperatorBase { class CreateDoubleBufferReaderOp : public framework::OperatorBase {
...@@ -80,44 +73,36 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -80,44 +73,36 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}; };
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
std::unique_lock<std::mutex> lck(mtx_);
while (write_pos_ == read_pos_) {
buffer_not_empty_.wait(lck);
}
out->clear(); out->clear();
out->reserve(buffer_[read_pos_].size()); buffer_->Receive(out);
// TODO(fengjiayi): This copy shall be reduced.
for (size_t i = 0; i < buffer_[read_pos_].size(); ++i) {
framework::LoDTensor dst;
TensorCopy(buffer_[read_pos_][i], platform::CPUPlace(), &dst);
dst.set_lod(buffer_[read_pos_][i].lod());
out->push_back(dst);
}
++read_pos_;
if (read_pos_ >= kDoubleBufferSize) {
read_pos_ = 0;
}
buffer_not_full_.notify_all();
} }
bool DoubleBufferReader::HasNext() const { void DoubleBufferReader::ReInit() {
return reader_->HasNext() || !buffer_.empty(); reader_->ReInit();
buffer_->Close();
// The existing prefetch thread will terminate for the buffer_ is closed.
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>(
kDoubleBufferSize);
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
prefetch.detach();
} }
void DoubleBufferReader::PrefetchThreadFunc() { void DoubleBufferReader::PrefetchThreadFunc() {
while (reader_->HasNext()) { VLOG(5) << "A new prefetch thread starts.";
std::unique_lock<std::mutex> lck(mtx_); while (true) {
while (((write_pos_ + 1) % kDoubleBufferSize) == read_pos_) { std::vector<framework::LoDTensor> batch;
buffer_not_full_.wait(lck); reader_->ReadNext(&batch);
if (batch.empty()) {
// EOF
buffer_->Close();
VLOG(5) << "Reached the end of the file. The prefetch thread terminates.";
break;
} }
reader_->ReadNext(&buffer_[write_pos_]); if (!buffer_->Send(&batch)) {
++write_pos_; VLOG(5) << "WARNING: The double buffer channel has been closed. The "
if (write_pos_ >= kDoubleBufferSize) { "prefetch thread terminates.";
write_pos_ = 0; break;
} }
buffer_not_empty_.notify_all();
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册