diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 397c9f739452e5130dad28a763b92cf76720ec61..ec252929d5584c211cea7fa52004ecdfdf586a85 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) cc_library(reader SRCS reader.cc DEPS lod_tensor ddim) +cc_test(reader_test SRCS reader_test.cc DEPS reader) cc_test(variable_test SRCS variable_test.cc) diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index 9939797f4c6fb4c7fa6dc8851d830cb4180790a9..f8877e5cb0e1463fb4e57877b35c713461251e00 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/reader.h" +#include namespace paddle { namespace framework { @@ -23,6 +24,33 @@ void ReaderBase::ReadNext(std::vector *out) { ReadNextImpl(out); } +void ReaderBase::InsertDecoratedReader( + const std::shared_ptr &decorated_reader) { + std::lock_guard guard(mu_)); + decorated_readers_.emplace_back(decorated_reader); +} + +std::unordered_set ReaderBase::GetEndPoints() { + std::unordered_set result; + std::deque queue; + queue.emplace_back(this); + while (!queue.empty()) { // BFS search + auto *front = queue.front(); + queue.pop_front(); + if (front->decorated_readers_.empty()) { + result.emplace(front); + } else { + for (auto &reader : front->decorated_readers_) { + if (auto *reader_ptr = reader.lock().get()) { + queue.emplace_back(reader_ptr); + } + } + } + } + + return result; +} + void ReaderBase::Shutdown() { std::lock_guard lock(mu_); if (status_ != ReaderStatus::kStopped) { diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 95af7b12562d15430988b34923a968f8c4c23eb4..93cd6243ff2137292e69964a8b23a79315e27b18 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include "paddle/fluid/framework/ddim.h" @@ -34,6 +35,10 @@ class ReaderBase { void Start(); + // Return the readers which are the end of decorating chain. Basically + // they are readers just before read op. + std::unordered_set GetEndPoints(); + virtual ~ReaderBase(); protected: @@ -46,15 +51,29 @@ class ReaderBase { ReaderStatus status_{kRunning}; mutable std::mutex mu_; + + private: + friend class DecoratedReader; + // These methods can be only invoked inside DecoratedReader to record the + // decorating chain. + void InsertDecoratedReader( + const std::shared_ptr& decorated_reader); + // A set of which readers that decorated this reader. + std::vector> decorated_readers_; }; -class DecoratedReader : public ReaderBase { +class DecoratedReader : public ReaderBase, + public std::enable_shared_from_this { public: explicit DecoratedReader(const std::shared_ptr& reader) : ReaderBase(), reader_(reader) { PADDLE_ENFORCE_NOT_NULL(reader_); } + void RegisterDecorateChain() { + reader_->InsertDecoratedReader(shared_from_this()); + } + protected: void ShutdownImpl() override { reader_->Shutdown(); } @@ -70,9 +89,14 @@ class FileReader : public ReaderBase {}; // making it easier to access different type reader in Variables. class ReaderHolder { public: - void Reset(ReaderBase* reader) { reader_.reset(reader); } + template + void Reset(const std::shared_ptr& reader) { + auto reader_base = std::dynamic_pointer_cast(reader); + PADDLE_ENFORCE_NOT_NULL(reader_base); + reader_ = reader_base; + } - std::shared_ptr Get() const { return reader_; } + const std::shared_ptr& Get() const { return reader_; } void ReadNext(std::vector* out) { PADDLE_ENFORCE_NOT_NULL(reader_); @@ -93,9 +117,18 @@ class ReaderHolder { reader_->Start(); } + operator const std::shared_ptr&() const { return this->reader_; } + private: std::shared_ptr reader_; }; +template +inline std::shared_ptr MakeDecoratedReader(ARGS&&... args) { + std::shared_ptr reader(new T(std::forward(args)...)); + reader->RegisterDecorateChain(); + return reader; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader_test.cc b/paddle/fluid/framework/reader_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c05be86706f4c8ac804cab149a9510c1102c36cb --- /dev/null +++ b/paddle/fluid/framework/reader_test.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/reader.h" +#include +#include "gtest/gtest.h" + +class StubDecoratedReader : public paddle::framework::DecoratedReader { + public: + explicit StubDecoratedReader(const std::shared_ptr &reader) + : DecoratedReader(reader) {} + + void ReadNext(std::vector *out) override {} +}; + +class StubRootReader : public paddle::framework::ReaderBase { + public: + void ReadNext(std::vector *out) override {} + void ReInit() override {} +}; + +TEST(READER, decorate_chain) { + auto root = std::make_shared(); + auto end_point1 = + paddle::framework::MakeDecoratedReader(root); + auto end_point2 = + paddle::framework::MakeDecoratedReader(root); + + { + auto endpoints = root->GetEndPoints(); + ASSERT_EQ(endpoints.size(), 2U); + ASSERT_NE(endpoints.count(end_point1.get()), 0); + ASSERT_NE(endpoints.count(end_point2.get()), 0); + } + + { + auto end_point3 = + paddle::framework::MakeDecoratedReader(root); + ASSERT_EQ(root->GetEndPoints().size(), 3U); + } + { ASSERT_EQ(root->GetEndPoints().size(), 2U); } +} diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc index e5b69dabcbd864cc8cc6804e392ea36355e4215f..1dbafd23e92732bdaf0d263a01e267227786d839 100644 --- a/paddle/fluid/operators/reader/create_batch_reader_op.cc +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -50,8 +50,9 @@ class CreateBatchReaderOp : public framework::OperatorBase { } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - out->Reset(new BatchReader(underlying_reader.Get(), Attr("batch_size"), - Attr("discard_leftover"))); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, Attr("batch_size"), + Attr("discard_leftover"))); } }; diff --git a/paddle/fluid/operators/reader/create_custom_reader_op.cc b/paddle/fluid/operators/reader/create_custom_reader_op.cc index a53dceced323a9cc10f7e1b195b579669f5a9e6d..85394b336fc967fc6973131fbedda4c796825185 100644 --- a/paddle/fluid/operators/reader/create_custom_reader_op.cc +++ b/paddle/fluid/operators/reader/create_custom_reader_op.cc @@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase { } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - out->Reset( - new CustomReader(underlying_reader.Get(), *sub_block, - Attr>("source_var_names"), - Attr>("sink_var_names"))); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, *sub_block, + Attr>("source_var_names"), + Attr>("sink_var_names"))); } }; diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 1bf6a86a5ad6bd308f53dacd6bab4be8723aeccd..7b14370f4fd64e8fd5b8d9038006494b88d671dc 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -118,7 +118,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { place = platform::CUDAPlace(static_cast(num)); } - out->Reset(new DoubleBufferReader(underlying_reader.Get(), place)); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, place)); } }; diff --git a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc index f26a470c258b3d737e9d7c653750115bc7991857..0a225597d34f43c7fb82aeae2552cdf16c8ba566 100644 --- a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc +++ b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc @@ -59,7 +59,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase { const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); int pass_num = Attr("pass_num"); - out->Reset(new MultiPassReader(underlying_reader.Get(), pass_num)); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, pass_num)); } }; diff --git a/paddle/fluid/operators/reader/create_py_reader_op.cc b/paddle/fluid/operators/reader/create_py_reader_op.cc index fb7a96cb73d12a37e98fd064b3daf719e2c26bb2..d41124279930e92138e7e6a5ab045659a415eb6d 100644 --- a/paddle/fluid/operators/reader/create_py_reader_op.cc +++ b/paddle/fluid/operators/reader/create_py_reader_op.cc @@ -63,7 +63,7 @@ class CreatePyReaderOp : public framework::OperatorBase { auto* queue_holder = queue_holder_var->template GetMutable(); - out->Reset(new PyReader(queue_holder->GetQueue())); + out->Reset(std::make_shared(queue_holder->GetQueue())); } }; diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc index 9f7e3fd2d8718c1c4b25e4f95a134b6c7de37299..e5c116dfcd71ef40597ca19d1da0b51038baaad1 100644 --- a/paddle/fluid/operators/reader/create_random_data_generator_op.cc +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -77,8 +77,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase { std::vector shapes = RestoreShapes(shape_concat, ranks); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new RandomDataGenerator(shapes, Attr("low"), - Attr("high"))); + out->Reset(std::make_shared>( + shapes, Attr("low"), Attr("high"))); } }; diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index 61d94cfbd38bc0e43997f053de7a5d6fe4065a23..b32f09b22524c8b67ce57cc6022ef46efc2e828d 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -59,7 +59,8 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { std::string filename = Attr("filename"); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new RecordIOFileReader(filename)); + + out->Reset(std::make_shared>(filename)); } }; diff --git a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc index 1d3d85b9e4e1a5974e69a516ebf26d3aa1376b10..4b308abc290c10a8a5846672e719b503dfc79b21 100644 --- a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc +++ b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc @@ -97,9 +97,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase { } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - out->Reset( - new ShuffleReader(underlying_reader.Get(), - static_cast(Attr("buffer_size")))); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, static_cast(Attr("buffer_size")))); } }; diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index c657ffc5359605b0513498f4a92944f797ea3a9f..9a8d203672fa2d560440d063d93fa5f8523690ef 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -178,7 +178,8 @@ class OpenFilesOp : public framework::OperatorBase { auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new MultiFileReader(file_names, thread_num, buffer_size)); + out->Reset( + std::make_shared(file_names, thread_num, buffer_size)); } };