diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index 2e2aa1cba14e00cb0e3774db25a7bc1a9ede069d..e1d2ac79cf0d9eef7bdd46fb0074480b891a6144 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -19,15 +19,11 @@ namespace paddle { namespace framework { ReaderBase::~ReaderBase() {} -void ReaderBase::InsertDecoratedReader(ReaderBase *decorated_reader) { - decorated_readers_.emplace(decorated_reader); -} -void ReaderBase::EraseDecoratedReader(ReaderBase *decorated_reader) { - auto it = decorated_readers_.find(decorated_reader); - PADDLE_ENFORCE(it != decorated_readers_.end(), - "Cannot find the decorated reader to erase"); - decorated_readers_.erase(it); +void ReaderBase::InsertDecoratedReader( + const std::shared_ptr &decorated_reader) { + decorated_readers_.emplace_back(decorated_reader); } + std::unordered_set ReaderBase::GetEndPoints() { std::unordered_set result; std::deque queue; @@ -38,8 +34,10 @@ std::unordered_set ReaderBase::GetEndPoints() { if (front->decorated_readers_.empty()) { result.emplace(front); } else { - for (ReaderBase *reader : front->decorated_readers_) { - queue.emplace_back(reader); + for (auto &reader : front->decorated_readers_) { + if (auto *reader_ptr = reader.lock().get()) { + queue.emplace_back(reader_ptr); + } } } } @@ -66,6 +64,5 @@ void FileReader::ReadNext(std::vector *out) { } } } -DecoratedReader::~DecoratedReader() { reader_->EraseDecoratedReader(this); } } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 2a65c58e3fa6d8281c81a2e6206ec1a758d927b6..730e3faace1a92c071861e65262141a9edd41c3b 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -41,24 +41,26 @@ class ReaderBase { friend class DecoratedReader; // These methods can be only invoked inside DecoratedReader to record the // decorating chain. - void InsertDecoratedReader(ReaderBase* decorated_reader); - void EraseDecoratedReader(ReaderBase* decorated_reader); + void InsertDecoratedReader( + const std::shared_ptr& decorated_reader); // A set of which readers that decorated this reader. - std::unordered_set decorated_readers_; + std::vector> decorated_readers_; }; -class DecoratedReader : public ReaderBase { +class DecoratedReader : public ReaderBase, + public std::enable_shared_from_this { public: explicit DecoratedReader(const std::shared_ptr& reader) : ReaderBase(), reader_(reader) { PADDLE_ENFORCE_NOT_NULL(reader_); - reader_->InsertDecoratedReader(this); } - ~DecoratedReader(); - void ReInit() override { reader_->ReInit(); } + void RegisterDecorateChain() { + reader_->InsertDecoratedReader(shared_from_this()); + } + protected: std::shared_ptr reader_; }; @@ -80,9 +82,14 @@ class FileReader : public ReaderBase { // making it easier to access different type reader in Variables. class ReaderHolder { public: - void Reset(ReaderBase* reader) { reader_.reset(reader); } + template + void Reset(const std::shared_ptr& reader) { + auto reader_base = std::dynamic_pointer_cast(reader); + PADDLE_ENFORCE_NOT_NULL(reader_base); + reader_ = reader_base; + } - std::shared_ptr Get() const { return reader_; } + const std::shared_ptr& Get() const { return reader_; } void ReadNext(std::vector* out) { PADDLE_ENFORCE_NOT_NULL(reader_); @@ -93,9 +100,18 @@ class ReaderHolder { reader_->ReInit(); } + operator const std::shared_ptr&() const { return this->reader_; } + private: std::shared_ptr reader_; }; +template +inline std::shared_ptr MakeDecoratedReader(ARGS&&... args) { + std::shared_ptr reader(new T(std::forward(args)...)); + reader->RegisterDecorateChain(); + return reader; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader_test.cc b/paddle/fluid/framework/reader_test.cc index c763fe18d6d3a8ba98532cb7660eb743f2c0166d..c05be86706f4c8ac804cab149a9510c1102c36cb 100644 --- a/paddle/fluid/framework/reader_test.cc +++ b/paddle/fluid/framework/reader_test.cc @@ -32,18 +32,21 @@ class StubRootReader : public paddle::framework::ReaderBase { TEST(READER, decorate_chain) { auto root = std::make_shared(); - auto end_point1 = StubDecoratedReader(root); - auto end_point2 = StubDecoratedReader(root); + auto end_point1 = + paddle::framework::MakeDecoratedReader(root); + auto end_point2 = + paddle::framework::MakeDecoratedReader(root); { auto endpoints = root->GetEndPoints(); ASSERT_EQ(endpoints.size(), 2U); - ASSERT_NE(endpoints.count(&end_point1), 0); - ASSERT_NE(endpoints.count(&end_point2), 0); + ASSERT_NE(endpoints.count(end_point1.get()), 0); + ASSERT_NE(endpoints.count(end_point2.get()), 0); } { - auto end_point3 = StubDecoratedReader(root); + auto end_point3 = + paddle::framework::MakeDecoratedReader(root); ASSERT_EQ(root->GetEndPoints().size(), 3U); } { ASSERT_EQ(root->GetEndPoints().size(), 2U); } diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc index ecbae3894d551186f53625a6cc9cfdb36adc8d2d..41c3d379030a4aa758710601d084e03edb3564ce 100644 --- a/paddle/fluid/operators/reader/create_batch_reader_op.cc +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -46,8 +46,8 @@ class CreateBatchReaderOp : public framework::OperatorBase { } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - out->Reset( - new BatchReader(underlying_reader.Get(), Attr("batch_size"))); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, Attr("batch_size"))); } }; diff --git a/paddle/fluid/operators/reader/create_custom_reader_op.cc b/paddle/fluid/operators/reader/create_custom_reader_op.cc index a75c6d4c567ac93f37b38070421133af305f20a3..81a1aa7f9c88bdc2683367af96aa235e07bd6c59 100644 --- a/paddle/fluid/operators/reader/create_custom_reader_op.cc +++ b/paddle/fluid/operators/reader/create_custom_reader_op.cc @@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase { } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - out->Reset( - new CustomReader(underlying_reader.Get(), *sub_block, - Attr>("source_var_names"), - Attr>("sink_var_names"))); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, *sub_block, + Attr>("source_var_names"), + Attr>("sink_var_names"))); } }; diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 5f734489a81764875988f440696682570ff4d1d7..93820469542b915b053d71e4e5226e2afdf721b3 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -109,7 +109,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { place = platform::CUDAPlace(static_cast(num)); } - out->Reset(new DoubleBufferReader(underlying_reader.Get(), place)); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, place)); } }; diff --git a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc index 19b54110b9aeece33b8d6c73612ae0e12dbfafbd..69b3400a84c0981019bd1796d622c92b113cf692 100644 --- a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc +++ b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc @@ -60,7 +60,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase { const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); int pass_num = Attr("pass_num"); - out->Reset(new MultiPassReader(underlying_reader.Get(), pass_num)); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, pass_num)); } }; diff --git a/paddle/fluid/operators/reader/create_py_reader_op.cc b/paddle/fluid/operators/reader/create_py_reader_op.cc index 36587360f7347a10e01d4e994482027d9a9bb5d0..0b3578570bf5db4ad078d1fcd16705a0ab7f80ca 100644 --- a/paddle/fluid/operators/reader/create_py_reader_op.cc +++ b/paddle/fluid/operators/reader/create_py_reader_op.cc @@ -58,7 +58,7 @@ class CreatePyReaderOp : public framework::OperatorBase { auto* queue_holder = queue_holder_var->template GetMutable(); - out->Reset(new PyReader(queue_holder->GetQueue())); + out->Reset(std::make_shared(queue_holder->GetQueue())); } }; diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc index 5b7e8a063a034f0be056065826fca0fe807bc9a7..1c3de3feab6e784e5fc43336e44a3c600aa38dea 100644 --- a/paddle/fluid/operators/reader/create_random_data_generator_op.cc +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -79,8 +79,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase { std::vector shapes = RestoreShapes(shape_concat, ranks); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new RandomDataGenerator(shapes, Attr("low"), - Attr("high"))); + out->Reset(std::make_shared>( + shapes, Attr("low"), Attr("high"))); } }; diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index 559827f08494af6730aafa1e67c46a47c21dedf6..c457cb3fb49fb5159ab7098959be118f314bfc2c 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -70,7 +70,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new RecordIOFileReader( + out->Reset(std::make_shared>( filename, RestoreShapes(shape_concat, ranks))); } }; diff --git a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc index 57e8e21214b7c99e52550fe51a67c9b5201cb46f..75adabdaa9053c441ab3c7a73c2de5e7d5032dcd 100644 --- a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc +++ b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc @@ -86,9 +86,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase { } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - out->Reset( - new ShuffleReader(underlying_reader.Get(), - static_cast(Attr("buffer_size")))); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, static_cast(Attr("buffer_size")))); } }; diff --git a/paddle/fluid/operators/reader/create_threaded_reader_op.cc b/paddle/fluid/operators/reader/create_threaded_reader_op.cc index 3798015146f4ffb085aa82e23ca3f1fb3c5cf5a4..81d75cdd334082cd787004ebaf256efa649869bc 100644 --- a/paddle/fluid/operators/reader/create_threaded_reader_op.cc +++ b/paddle/fluid/operators/reader/create_threaded_reader_op.cc @@ -49,7 +49,8 @@ class CreateThreadedReaderOp : public framework::OperatorBase { } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - out->Reset(new ThreadedReader(underlying_reader.Get())); + out->Reset( + framework::MakeDecoratedReader(underlying_reader)); } }; diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 31e5d81e55ed9703eb3a9ef2595fa2a280f1a734..e382066be5aa54eeb8854c2ee3d152f0bb95224d 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -180,9 +180,9 @@ class OpenFilesOp : public framework::OperatorBase { auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new MultiFileReader(file_names, - RestoreShapes(shape_concat, ranks), - thread_num, buffer_size)); + out->Reset(std::make_shared( + file_names, RestoreShapes(shape_concat, ranks), thread_num, + buffer_size)); } };