From dc34effd35ae8aabf544919252d82796095cb507 Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Sat, 14 Jul 2018 14:02:54 +0800 Subject: [PATCH] Extract buffered reader --- paddle/fluid/operators/reader/CMakeLists.txt | 3 +- .../fluid/operators/reader/buffered_reader.cc | 73 +++++++++++++++++++ .../fluid/operators/reader/buffered_reader.h | 55 ++++++++++++++ .../reader/create_double_buffer_reader_op.cc | 73 +------------------ 4 files changed, 131 insertions(+), 73 deletions(-) create mode 100644 paddle/fluid/operators/reader/buffered_reader.cc create mode 100644 paddle/fluid/operators/reader/buffered_reader.h diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 9dbcc35e6..c6df2646c 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -15,12 +15,13 @@ function(reader_library TARGET_NAME) PARENT_SCOPE) endfunction() +cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool) reader_library(open_files_op SRCS open_files_op.cc) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) -reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) +reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc DEPS buffered_reader) reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc) reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc) reader_library(create_py_reader_op SRCS create_py_reader_op.cc) diff --git a/paddle/fluid/operators/reader/buffered_reader.cc b/paddle/fluid/operators/reader/buffered_reader.cc new file mode 100644 index 000000000..a020e0c68 --- /dev/null +++ b/paddle/fluid/operators/reader/buffered_reader.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reader/buffered_reader.h" +#include + +namespace paddle { +namespace operators { +namespace reader { +BufferedReader::~BufferedReader() { + reader_->Shutdown(); + buffer_.clear(); +} +BufferedReader::BufferedReader( + const std::shared_ptr &reader, + const platform::Place &place, size_t buffer_size) + : framework::DecoratedReader(reader), + thread_pool_(1), + place_(place), + buffer_size_(buffer_size) { + AppendFutureToBatchSize(); +} +void BufferedReader::AppendFutureToBatchSize() { + while (buffer_.size() < buffer_size_) { + AppendFuture(); + } +} +void BufferedReader::AppendFuture() { + buffer_.emplace_back(thread_pool_.enqueue([this] { + TensorVec cpu_buffer; + reader_->ReadNext(&cpu_buffer); + if (platform::is_gpu_place(place_)) { + TensorVec gpu_buffer; + + for (size_t i = 0; i < cpu_buffer.size(); ++i) { + gpu_buffer.emplace_back(); + framework::TensorCopySync(cpu_buffer[i], place_, &gpu_buffer.back()); + } + + cpu_buffer = gpu_buffer; + } + return cpu_buffer; + })); +} +void BufferedReader::ShutdownImpl() { + reader_->Shutdown(); + buffer_.clear(); +} +void BufferedReader::StartImpl() { + reader_->Start(); + AppendFutureToBatchSize(); +} +void BufferedReader::ReadNextImpl(std::vector *out) { + PADDLE_ENFORCE_EQ(buffer_.size(), buffer_size_); + *out = buffer_.front().get(); + buffer_.pop_front(); + AppendFuture(); +} + +} // namespace reader +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reader/buffered_reader.h b/paddle/fluid/operators/reader/buffered_reader.h new file mode 100644 index 000000000..eb702a232 --- /dev/null +++ b/paddle/fluid/operators/reader/buffered_reader.h @@ -0,0 +1,55 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "ThreadPool.h" +#include "paddle/fluid/framework/reader.h" + +namespace paddle { +namespace operators { +namespace reader { + +class BufferedReader : public framework::DecoratedReader { + using TensorVec = std::vector; + using VecFuture = std::future; + + public: + BufferedReader(const std::shared_ptr& reader, + const platform::Place& place, size_t buffer_size); + + ~BufferedReader() override; + + private: + void AppendFutureToBatchSize(); + + void AppendFuture(); + + protected: + void ShutdownImpl() override; + void StartImpl() override; + void ReadNextImpl(std::vector* out) override; + + private: + ThreadPool thread_pool_; + platform::Place place_; + const size_t buffer_size_; + std::list buffer_; +}; + +} // namespace reader +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index b452a7815..ed719f91d 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -12,83 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include // NOLINT - -#include "ThreadPool.h" -#include "paddle/fluid/operators/reader/blocking_queue.h" +#include "paddle/fluid/operators/reader/buffered_reader.h" #include "paddle/fluid/operators/reader/reader_op_registry.h" namespace paddle { namespace operators { namespace reader { -class BufferedReader : public framework::DecoratedReader { - using TensorVec = std::vector; - using VecFuture = std::future; - - public: - BufferedReader(const std::shared_ptr& reader, - const platform::Place& place, size_t buffer_size) - : framework::DecoratedReader(reader), - thread_pool_(1), - place_(place), - buffer_size_(buffer_size) { - AppendFutureToBatchSize(); - } - - ~BufferedReader() override { - reader_->Shutdown(); - buffer_.clear(); - } - - private: - void AppendFutureToBatchSize() { - while (buffer_.size() < buffer_size_) { - AppendFuture(); - } - } - - void AppendFuture() { - buffer_.emplace_back(thread_pool_.enqueue([this] { - TensorVec cpu_buffer; - reader_->ReadNext(&cpu_buffer); - if (platform::is_gpu_place(place_)) { - TensorVec gpu_buffer; - - for (size_t i = 0; i < cpu_buffer.size(); ++i) { - gpu_buffer.emplace_back(); - framework::TensorCopySync(cpu_buffer[i], place_, &gpu_buffer.back()); - } - - cpu_buffer = gpu_buffer; - } - return cpu_buffer; - })); - } - - protected: - void ShutdownImpl() override { - reader_->Shutdown(); - buffer_.clear(); - } - void StartImpl() override { - reader_->Start(); - AppendFutureToBatchSize(); - } - void ReadNextImpl(std::vector* out) override { - std::cerr << "Read" << std::endl; - PADDLE_ENFORCE_EQ(buffer_.size(), buffer_size_); - *out = buffer_.front().get(); - buffer_.pop_front(); - AppendFuture(); - } - - private: - ThreadPool thread_pool_; - platform::Place place_; - const size_t buffer_size_; - std::list buffer_; -}; - class CreateDoubleBufferReaderOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; -- GitLab