diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 397c9f739452e5130dad28a763b92cf76720ec61..ec252929d5584c211cea7fa52004ecdfdf586a85 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) cc_library(reader SRCS reader.cc DEPS lod_tensor ddim) +cc_test(reader_test SRCS reader_test.cc DEPS reader) cc_test(variable_test SRCS variable_test.cc) diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index 0b36f1116d15004b355e854e101abb9ad3297836..5897d320a8b7e5af541098cadff8e78f8324949c 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -13,29 +13,61 @@ // limitations under the License. #include "paddle/fluid/framework/reader.h" +#include namespace paddle { namespace framework { -ReaderBase::~ReaderBase() {} -FileReader::FileReader(const std::vector &dims) : dims_(dims) {} - -void FileReader::ReadNext(std::vector *out) { +void ReaderBase::ReadNext(std::vector *out) { + std::lock_guard lock(mu_); + PADDLE_ENFORCE_EQ(status_, ReaderStatus::kRunning); ReadNextImpl(out); - if (out->empty()) { - return; - } +} - PADDLE_ENFORCE_EQ(out->size(), dims_.size()); - for (size_t i = 0; i < dims_.size(); ++i) { - auto &actual = (*out)[i].dims(); - auto &expect = dims_[i]; +void ReaderBase::InsertDecoratedReader( + const std::shared_ptr &decorated_reader) { + std::lock_guard guard(mu_); + decorated_readers_.emplace_back(decorated_reader); +} - PADDLE_ENFORCE_EQ(actual.size(), expect.size()); - for (int j = 0; j < actual.size(); ++j) { - // PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1); +std::unordered_set ReaderBase::GetEndPoints() { + std::unordered_set result; + std::deque queue; + queue.emplace_back(this); + while (!queue.empty()) { // BFS search + auto *front = queue.front(); + queue.pop_front(); + if (front->decorated_readers_.empty()) { + result.emplace(front); + } else { + for (auto &reader : front->decorated_readers_) { + if (auto *reader_ptr = reader.lock().get()) { + queue.emplace_back(reader_ptr); + } + } } } + + return result; } + +void ReaderBase::Shutdown() { + std::lock_guard lock(mu_); + if (status_ != ReaderStatus::kStopped) { + ShutdownImpl(); + status_ = ReaderStatus::kStopped; + } +} + +void ReaderBase::Start() { + std::lock_guard lock(mu_); + if (status_ != ReaderStatus::kRunning) { + StartImpl(); + status_ = ReaderStatus::kRunning; + } +} + +ReaderBase::~ReaderBase() { Shutdown(); } + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 64d4ceab624312ed366d7e835072899f1f033a88..6c4432cb7a70853e19460b1980d621c02caed970 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include "paddle/fluid/framework/ddim.h" @@ -24,61 +25,116 @@ namespace paddle { namespace framework { +enum ReaderStatus { kRunning, kStopped }; + class ReaderBase { public: - virtual void ReadNext(std::vector* out) = 0; + void ReadNext(std::vector* out); + + void Shutdown(); - virtual void ReInit() = 0; + void Start(); + + // Return the readers which are the end of decorating chain. Basically + // they are readers just before read op. + std::unordered_set GetEndPoints(); virtual ~ReaderBase(); + + protected: + virtual void ReadNextImpl(std::vector* out) = 0; + + virtual void ShutdownImpl() {} + + virtual void StartImpl() {} + + ReaderStatus status_{kRunning}; + + mutable std::mutex mu_; + + private: + friend class DecoratedReader; + // These methods can be only invoked inside DecoratedReader to record the + // decorating chain. + void InsertDecoratedReader( + const std::shared_ptr& decorated_reader); + // A set of which readers that decorated this reader. + std::vector> decorated_readers_; }; -class DecoratedReader : public ReaderBase { +class DecoratedReader : public ReaderBase, + public std::enable_shared_from_this { public: explicit DecoratedReader(const std::shared_ptr& reader) : ReaderBase(), reader_(reader) { PADDLE_ENFORCE_NOT_NULL(reader_); } - void ReInit() override { reader_->ReInit(); } + void RegisterDecorateChain() { + reader_->InsertDecoratedReader(shared_from_this()); + } protected: - std::shared_ptr reader_; -}; - -class FileReader : public ReaderBase { - public: - explicit FileReader(const std::vector& dims); - - void ReadNext(std::vector* out) override; + void ShutdownImpl() override { reader_->Shutdown(); } - protected: - virtual void ReadNextImpl(std::vector* out) = 0; + void StartImpl() override { reader_->Start(); } - private: - std::vector dims_; + std::shared_ptr reader_; }; +// FileReader is just a conceptual class. +class FileReader : public ReaderBase {}; + // The ReaderHolder is used as reader' unified wrapper, // making it easier to access different type reader in Variables. class ReaderHolder { public: - void Reset(ReaderBase* reader) { reader_.reset(reader); } + template + void Reset(const std::shared_ptr& reader) { + auto reader_base = std::dynamic_pointer_cast(reader); + PADDLE_ENFORCE_NOT_NULL(reader_base); + reader_ = reader_base; + } - std::shared_ptr Get() const { return reader_; } + const std::shared_ptr& Get() const { return reader_; } void ReadNext(std::vector* out) { PADDLE_ENFORCE_NOT_NULL(reader_); reader_->ReadNext(out); } - void ReInit() { + + void ResetAll() { + auto end_readers = reader_->GetEndPoints(); + for (auto* reader : end_readers) { + reader->Shutdown(); + } + for (auto* reader : end_readers) { + reader->Start(); + } + } + + void Shutdown() { + PADDLE_ENFORCE_NOT_NULL(reader_); + reader_->Shutdown(); + } + + void Start() { PADDLE_ENFORCE_NOT_NULL(reader_); - reader_->ReInit(); + reader_->Start(); } + operator const std::shared_ptr&() const { return this->reader_; } + private: std::shared_ptr reader_; }; +template +inline std::shared_ptr MakeDecoratedReader(ARGS&&... args) { + std::shared_ptr reader(new T(std::forward(args)...)); + reader->RegisterDecorateChain(); + return reader; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader_test.cc b/paddle/fluid/framework/reader_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f0d07cb7c1367576084b9494e7758103bb45d1e5 --- /dev/null +++ b/paddle/fluid/framework/reader_test.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/reader.h" +#include +#include "gtest/gtest.h" + +class StubDecoratedReader : public paddle::framework::DecoratedReader { + public: + explicit StubDecoratedReader(const std::shared_ptr &reader) + : DecoratedReader(reader) {} + + void ReadNextImpl(std::vector *out) override {} +}; + +class StubRootReader : public paddle::framework::ReaderBase { + public: + void ReadNextImpl(std::vector *out) override {} +}; + +TEST(READER, decorate_chain) { + auto root = std::make_shared(); + auto end_point1 = + paddle::framework::MakeDecoratedReader(root); + auto end_point2 = + paddle::framework::MakeDecoratedReader(root); + + { + auto endpoints = root->GetEndPoints(); + ASSERT_EQ(endpoints.size(), 2U); + ASSERT_NE(endpoints.count(end_point1.get()), 0); + ASSERT_NE(endpoints.count(end_point2.get()), 0); + } + + { + auto end_point3 = + paddle::framework::MakeDecoratedReader(root); + ASSERT_EQ(root->GetEndPoints().size(), 3U); + } + { ASSERT_EQ(root->GetEndPoints().size(), 2U); } +} diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index a39c8a00538875e4e3284898230a6cb0693b7a12..9dbcc35e6f5bb01c159980a49dd4b4c9d37d2aab 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -22,7 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc) -reader_library(create_threaded_reader_op SRCS create_threaded_reader_op.cc) reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc) reader_library(create_py_reader_op SRCS create_py_reader_op.cc) diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc index ecbae3894d551186f53625a6cc9cfdb36adc8d2d..1dbafd23e92732bdaf0d263a01e267227786d839 100644 --- a/paddle/fluid/operators/reader/create_batch_reader_op.cc +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -20,15 +20,19 @@ namespace reader { class BatchReader : public framework::DecoratedReader { public: - BatchReader(const std::shared_ptr& reader, int batch_size) - : DecoratedReader(reader), batch_size_(batch_size) { + BatchReader(const std::shared_ptr& reader, int batch_size, + bool discard_leftover) + : DecoratedReader(reader), + batch_size_(batch_size), + discard_leftover_(discard_leftover) { buffer_.reserve(batch_size_); } - void ReadNext(std::vector* out) override; + void ReadNextImpl(std::vector* out) override; private: int batch_size_; + bool discard_leftover_; std::vector> buffer_; }; @@ -46,8 +50,9 @@ class CreateBatchReaderOp : public framework::OperatorBase { } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - out->Reset( - new BatchReader(underlying_reader.Get(), Attr("batch_size"))); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, Attr("batch_size"), + Attr("discard_leftover"))); } }; @@ -57,6 +62,10 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase { AddAttr("batch_size", "How many instances the batch reader yields each time.") .GreaterThan(0); + AddAttr("discard_leftover", + "If true, the leftover instances that are not enough for a " + "new batch will be discarded.") + .SetDefault(true); AddComment(R"DOC( CreateBatchReader Operator @@ -66,7 +75,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase { } }; -void BatchReader::ReadNext(std::vector* out) { +void BatchReader::ReadNextImpl(std::vector* out) { buffer_.clear(); buffer_.reserve(batch_size_); for (int i = 0; i < batch_size_; ++i) { @@ -77,6 +86,9 @@ void BatchReader::ReadNext(std::vector* out) { break; } } + if (discard_leftover_ && buffer_.size() < batch_size_) { + buffer_.clear(); + } // Concat instances out->clear(); if (buffer_.empty()) { diff --git a/paddle/fluid/operators/reader/create_custom_reader_op.cc b/paddle/fluid/operators/reader/create_custom_reader_op.cc index a75c6d4c567ac93f37b38070421133af305f20a3..85394b336fc967fc6973131fbedda4c796825185 100644 --- a/paddle/fluid/operators/reader/create_custom_reader_op.cc +++ b/paddle/fluid/operators/reader/create_custom_reader_op.cc @@ -33,7 +33,7 @@ class CustomReader : public framework::DecoratedReader { source_var_names_(source_var_names), sink_var_names_(sink_var_names) {} - void ReadNext(std::vector* out) override; + void ReadNextImpl(std::vector* out) override; private: const framework::ProgramDesc program_; @@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase { } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - out->Reset( - new CustomReader(underlying_reader.Get(), *sub_block, - Attr>("source_var_names"), - Attr>("sink_var_names"))); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, *sub_block, + Attr>("source_var_names"), + Attr>("sink_var_names"))); } }; @@ -143,7 +143,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference { } }; -void CustomReader::ReadNext(std::vector* out) { +void CustomReader::ReadNextImpl(std::vector* out) { out->clear(); std::vector underlying_outs; reader_->ReadNext(&underlying_outs); diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 5f734489a81764875988f440696682570ff4d1d7..7b14370f4fd64e8fd5b8d9038006494b88d671dc 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -23,13 +23,13 @@ namespace reader { // 'Double buffer' means we shall maintain two batches of input data at the same // time. So the kCacheSize shoul be at least 2. -static constexpr size_t kCacheSize = 5; +static constexpr size_t kCacheSize = 3; // There will be two bacthes out of the channel during training: // 1. the one waiting to be sent to the channel // 2. the one just be received from the channel, which is also being used by // subsequent operators. // So the channel size should be kChacheSize - 2 -static constexpr size_t kChannelSize = 3; // kCacheSize - 2 +static constexpr size_t kChannelSize = 1; // kCacheSize - 2 class DoubleBufferReader : public framework::DecoratedReader { public: @@ -50,12 +50,21 @@ class DoubleBufferReader : public framework::DecoratedReader { StartPrefetcher(); } - void ReadNext(std::vector* out) override; - void ReInit() override; + void ReadNextImpl(std::vector* out) override; ~DoubleBufferReader() { EndPrefetcher(); } private: + void ShutdownImpl() override { + EndPrefetcher(); + reader_->Shutdown(); + } + + void StartImpl() override { + reader_->Start(); + StartPrefetcher(); + } + void StartPrefetcher() { channel_ = new reader::BlockingQueue(kChannelSize); prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); @@ -109,7 +118,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { place = platform::CUDAPlace(static_cast(num)); } - out->Reset(new DoubleBufferReader(underlying_reader.Get(), place)); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, place)); } }; @@ -136,7 +146,7 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { } }; -void DoubleBufferReader::ReadNext(std::vector* out) { +void DoubleBufferReader::ReadNextImpl(std::vector* out) { size_t cached_tensor_id; if (channel_->Receive(&cached_tensor_id)) { if (platform::is_gpu_place(place_)) { @@ -150,12 +160,6 @@ void DoubleBufferReader::ReadNext(std::vector* out) { } } -void DoubleBufferReader::ReInit() { - reader_->ReInit(); - EndPrefetcher(); - StartPrefetcher(); -} - void DoubleBufferReader::PrefetchThreadFunc() { VLOG(5) << "A new prefetch thread starts."; size_t cached_tensor_id = 0; diff --git a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc index 19b54110b9aeece33b8d6c73612ae0e12dbfafbd..0a225597d34f43c7fb82aeae2552cdf16c8ba566 100644 --- a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc +++ b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc @@ -24,23 +24,22 @@ class MultiPassReader : public framework::DecoratedReader { MultiPassReader(const std::shared_ptr& reader, int pass_num) : DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {} - void ReadNext(std::vector* out) override { + void ReadNextImpl(std::vector* out) override { reader_->ReadNext(out); - if (out->empty()) { + if (out->empty() && pass_count_ < pass_num_ - 1) { + reader_->Shutdown(); + reader_->Start(); + reader_->ReadNext(out); ++pass_count_; - if (pass_count_ < pass_num_) { - reader_->ReInit(); - reader_->ReadNext(out); - } } } - void ReInit() override { + private: + void StartImpl() override { pass_count_ = 0; - reader_->ReInit(); + reader_->Start(); } - private: int pass_num_; mutable int pass_count_; }; @@ -60,7 +59,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase { const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); int pass_num = Attr("pass_num"); - out->Reset(new MultiPassReader(underlying_reader.Get(), pass_num)); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, pass_num)); } }; diff --git a/paddle/fluid/operators/reader/create_py_reader_op.cc b/paddle/fluid/operators/reader/create_py_reader_op.cc index 36587360f7347a10e01d4e994482027d9a9bb5d0..d41124279930e92138e7e6a5ab045659a415eb6d 100644 --- a/paddle/fluid/operators/reader/create_py_reader_op.cc +++ b/paddle/fluid/operators/reader/create_py_reader_op.cc @@ -19,22 +19,27 @@ namespace paddle { namespace operators { namespace reader { -class PyReader : public framework::ReaderBase { +class PyReader : public framework::FileReader { public: - explicit PyReader(const std::shared_ptr& queue) { + explicit PyReader(const std::shared_ptr& queue) + : framework::FileReader() { PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); queue_ = queue; } - void ReadNext(std::vector* out) override { + void ReadNextImpl(std::vector* out) override { bool success; *out = queue_->Pop(&success); if (!success) out->clear(); } - void ReInit() override {} - private: + void ShutdownImpl() override { /* TODO */ + } + + void StartImpl() override { /* TODO */ + } + std::shared_ptr queue_; }; @@ -51,14 +56,14 @@ class CreatePyReaderOp : public framework::OperatorBase { const std::string& queue_name = Input("blocking_queue"); auto* queue_holder_var = scope.FindVar(queue_name); - PADDLE_ENFORCE( - queue_holder_var != nullptr, + PADDLE_ENFORCE_NOT_NULL( + queue_holder_var, "No LoDTensorBlockingQueueHolder variable with name %s found", queue_name); auto* queue_holder = queue_holder_var->template GetMutable(); - out->Reset(new PyReader(queue_holder->GetQueue())); + out->Reset(std::make_shared(queue_holder->GetQueue())); } }; diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc index 5b7e8a063a034f0be056065826fca0fe807bc9a7..e5c116dfcd71ef40597ca19d1da0b51038baaad1 100644 --- a/paddle/fluid/operators/reader/create_random_data_generator_op.cc +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -19,11 +19,11 @@ namespace operators { namespace reader { template -class RandomDataGenerator : public framework::ReaderBase { +class RandomDataGenerator : public framework::FileReader { public: RandomDataGenerator(const std::vector& shapes, float low, float high) - : framework::ReaderBase(), low_(low), high_(high), shapes_(shapes) { + : framework::FileReader(), low_(low), high_(high), shapes_(shapes) { PADDLE_ENFORCE_LE(low, high, "'low' shouldn't be greater than 'high'.(%f vs %f)", low, high); @@ -32,7 +32,7 @@ class RandomDataGenerator : public framework::ReaderBase { dist_ = std::uniform_real_distribution(low_, high_); } - void ReadNext(std::vector* out) override { + void ReadNextImpl(std::vector* out) override { out->clear(); out->reserve(shapes_.size()); for (const framework::DDim& shape : shapes_) { @@ -51,8 +51,6 @@ class RandomDataGenerator : public framework::ReaderBase { } } - void ReInit() override { return; } - private: float low_; float high_; @@ -79,8 +77,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase { std::vector shapes = RestoreShapes(shape_concat, ranks); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new RandomDataGenerator(shapes, Attr("low"), - Attr("high"))); + out->Reset(std::make_shared>( + shapes, Attr("low"), Attr("high"))); } }; diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index 559827f08494af6730aafa1e67c46a47c21dedf6..b32f09b22524c8b67ce57cc6022ef46efc2e828d 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -21,10 +21,8 @@ namespace reader { template class RecordIOFileReader : public framework::FileReader { public: - explicit RecordIOFileReader(const std::string& filename, - const std::vector& dims) - : FileReader(dims), - scanner_(filename), + explicit RecordIOFileReader(const std::string& filename) + : scanner_(filename), dev_ctx_(*platform::DeviceContextPool::Instance().Get( platform::CPUPlace())) { if (ThreadSafe) { @@ -33,8 +31,6 @@ class RecordIOFileReader : public framework::FileReader { LOG(INFO) << "Creating file reader" << filename; } - void ReInit() override { scanner_.Reset(); } - protected: void ReadNextImpl(std::vector* out) override { if (ThreadSafe) { @@ -45,6 +41,8 @@ class RecordIOFileReader : public framework::FileReader { } } + void StartImpl() override { scanner_.Reset(); } + private: std::unique_ptr mutex_; recordio::Scanner scanner_; @@ -58,20 +56,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override { - const auto& shape_concat = Attr>("shape_concat"); - const auto& ranks = Attr>("ranks"); - PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); - PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), - static_cast(shape_concat.size()), - "The accumulate of all ranks should be equal to the " - "shape concat's length."); std::string filename = Attr("filename"); - auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new RecordIOFileReader( - filename, RestoreShapes(shape_concat, ranks))); + out->Reset(std::make_shared>(filename)); } }; diff --git a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc index 57e8e21214b7c99e52550fe51a67c9b5201cb46f..4b308abc290c10a8a5846672e719b503dfc79b21 100644 --- a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc +++ b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc @@ -34,7 +34,7 @@ class ShuffleReader : public framework::DecoratedReader { ReloadBuffer(); } - void ReadNext(std::vector* out) override { + void ReadNextImpl(std::vector* out) override { out->clear(); if (iteration_pos_ >= buffer_.size()) { VLOG(10) << "Resetting shuffle buffer"; @@ -47,6 +47,17 @@ class ShuffleReader : public framework::DecoratedReader { } private: + void ShutdownImpl() override { + buffer_.clear(); + iteration_pos_ = 0; + reader_->Shutdown(); + } + + void StartImpl() override { + reader_->Start(); + ReloadBuffer(); + } + void ReloadBuffer() { buffer_.clear(); buffer_.reserve(buffer_size_); @@ -86,9 +97,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase { } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - out->Reset( - new ShuffleReader(underlying_reader.Get(), - static_cast(Attr("buffer_size")))); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, static_cast(Attr("buffer_size")))); } }; diff --git a/paddle/fluid/operators/reader/create_threaded_reader_op.cc b/paddle/fluid/operators/reader/create_threaded_reader_op.cc deleted file mode 100644 index 3798015146f4ffb085aa82e23ca3f1fb3c5cf5a4..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reader/create_threaded_reader_op.cc +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/detail/safe_ref.h" -#include "paddle/fluid/operators/reader/reader_op_registry.h" - -namespace paddle { -namespace operators { -namespace reader { - -class ThreadedReader : public framework::DecoratedReader { - public: - explicit ThreadedReader(const std::shared_ptr& reader) - : DecoratedReader(reader) {} - - void ReadNext(std::vector* out) override { - std::lock_guard lock(mutex_); - reader_->ReadNext(out); - } - - void ReInit() override { reader_->ReInit(); } - - private: - std::mutex mutex_; -}; - -class CreateThreadedReaderOp : public framework::OperatorBase { - public: - using framework::OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope& scope, - const platform::Place& dev_place) const override { - auto* out = detail::Ref(scope.FindVar(Output("Out"))) - .GetMutable(); - if (out->Get() != nullptr) { - return; - } - const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) - ->Get(); - out->Reset(new ThreadedReader(underlying_reader.Get())); - } -}; - -class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase { - protected: - void Apply() override { - AddComment(R"DOC( - CreateThreadedReader Operator - - This operator creates a threaded reader. A threaded reader's - 'ReadNext()' can be invoked by several threads at the same - time. - When the attribute 'safe_mode' is true, the threaded reader's - 'ReInit()' is disabled to avoid unexpected bugs in multi-thread - environment. - )DOC"); - } -}; - -} // namespace reader -} // namespace operators -} // namespace paddle - -namespace reader = paddle::operators::reader; -REGISTER_DECORATED_READER_OPERATOR(create_threaded_reader, - reader::CreateThreadedReaderOp, - reader::CreateThreadedReaderOpMaker); diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 31e5d81e55ed9703eb3a9ef2595fa2a280f1a734..9a8d203672fa2d560440d063d93fa5f8523690ef 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -23,24 +23,26 @@ namespace reader { class MultiFileReader : public framework::ReaderBase { public: - MultiFileReader(const std::vector& file_names, - const std::vector& dims, size_t thread_num, + MultiFileReader(const std::vector& file_names, size_t thread_num, size_t buffer_size) : buffer_size_(buffer_size) { readers_.reserve(file_names.size()); for (const std::string& f_name : file_names) { - readers_.emplace_back(CreateReaderByFileName(f_name, dims)); + readers_.emplace_back(CreateReaderByFileName(f_name)); } prefetchers_.resize(thread_num); StartNewScheduler(); } - void ReadNext(std::vector* out) override; - void ReInit() override; + void ReadNextImpl(std::vector* out) override; ~MultiFileReader() { EndScheduler(); } private: + void ShutdownImpl() override { EndScheduler(); } + + void StartImpl() override { StartNewScheduler(); } + void StartNewScheduler(); void EndScheduler(); void ScheduleThreadFunc(); @@ -55,17 +57,12 @@ class MultiFileReader : public framework::ReaderBase { reader::BlockingQueue>* buffer_; }; -void MultiFileReader::ReadNext(std::vector* out) { +void MultiFileReader::ReadNextImpl(std::vector* out) { if (!buffer_->Receive(out)) { out->clear(); } } -void MultiFileReader::ReInit() { - EndScheduler(); - StartNewScheduler(); -} - void MultiFileReader::StartNewScheduler() { size_t thread_num = prefetchers_.size(); waiting_reader_idx_ = new reader::BlockingQueue(readers_.size()); @@ -120,7 +117,7 @@ void MultiFileReader::ScheduleThreadFunc() { } } } - // If users invoke ReInit() when scheduler is running, it will close the + // If users invoke Shutdown() when scheduler is running, it will close the // 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler // to release their resource. So a check is needed before scheduler ends. for (auto& p : prefetchers_) { @@ -138,7 +135,8 @@ void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) { std::vector ins; reader->ReadNext(&ins); if (ins.empty()) { - reader->ReInit(); + reader->Shutdown(); + reader->Start(); break; } try { @@ -180,9 +178,8 @@ class OpenFilesOp : public framework::OperatorBase { auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new MultiFileReader(file_names, - RestoreShapes(shape_concat, ranks), - thread_num, buffer_size)); + out->Reset( + std::make_shared(file_names, thread_num, buffer_size)); } }; diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index e11256a49ffa6adc9410376cc8a71fa017df7e9c..b82aab1214992be73d876a42424234e3cea46455 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -39,7 +39,7 @@ std::unordered_map& FileReaderRegistry() { } std::unique_ptr CreateReaderByFileName( - const std::string& file_name, const std::vector& dims) { + const std::string& file_name) { size_t separator_pos = file_name.find_last_of(kFileFormatSeparator); PADDLE_ENFORCE_NE(separator_pos, std::string::npos, "File name illegal! A legal file name should be like: " @@ -49,7 +49,7 @@ std::unique_ptr CreateReaderByFileName( auto itor = FileReaderRegistry().find(filetype); PADDLE_ENFORCE(itor != FileReaderRegistry().end(), "No file reader registered for '%s' format.", filetype); - framework::ReaderBase* reader = (itor->second)(file_name, dims); + framework::ReaderBase* reader = (itor->second)(file_name); return std::unique_ptr(reader); } diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h index 244bf15f068a47efc29ee54492cdbdeb10025020..25c3e7d77b788d38daf6dee1fc79e5c1c97e8842 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.h +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -25,22 +25,21 @@ namespace reader { static constexpr char kFileFormatSeparator[] = "."; -using FileReaderCreator = std::function&)>; +using FileReaderCreator = + std::function; std::unordered_map& FileReaderRegistry(); template int RegisterFileReader(const std::string& filetype) { - FileReaderRegistry()[filetype] = []( - const std::string& fn, const std::vector& dims) { - return new Reader(fn, dims); + FileReaderRegistry()[filetype] = [](const std::string& fn) { + return new Reader(fn); }; return 0; } std::unique_ptr CreateReaderByFileName( - const std::string& file_name, const std::vector& dims); + const std::string& file_name); extern std::vector RestoreShapes( const std::vector& shape_concat, const std::vector& ranks); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 7a8bb712452538b7e2a349d56a15de3284f82b39..0c523b6f176345c0407b8541c04fb8c3b27f7c60 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -296,7 +296,7 @@ All parameter, weight, gradient are variables in Paddle. py::return_value_policy::reference); py::class_(m, "Reader", "") - .def("reset", &framework::ReaderHolder::ReInit); + .def("reset", &framework::ReaderHolder::ResetAll); using LoDTensorBlockingQueue = ::paddle::operators::reader::LoDTensorBlockingQueue; diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index f33ae76aea95ceeca73c5bae6e4e490cdff29bf3..977abde21f38a0d25a90bc14426fd817df2c8508 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -375,9 +375,6 @@ def open_recordio_file(filename, if pass_num > 1: main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num) - if for_parallel: - main_prog_var = parallel(reader=main_prog_var) - return monkey_patch_reader_methods(main_prog_var) @@ -529,9 +526,6 @@ def open_files(filenames, main_prog_reader = multi_pass( reader=main_prog_reader, pass_num=pass_num) - if for_parallel: - main_prog_reader = parallel(reader=main_prog_reader) - return monkey_patch_reader_methods(main_prog_reader) @@ -647,11 +641,6 @@ def multi_pass(reader, pass_num): 'create_multi_pass_reader', reader, {'pass_num': int(pass_num)}) -def parallel(reader): - return __create_shared_decorated_reader__('create_threaded_reader', reader, - {}) - - def read_file(reader): """ Execute the given reader and get data via it.