From 2bbe5f77e79d33bf5cc9fb907abc3230c31ea0bb Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Sat, 7 Jul 2018 14:20:22 +0800 Subject: [PATCH] Add GetEndPoints of Reader. We can get endpoints of a reader chain. --- paddle/fluid/framework/CMakeLists.txt | 1 + paddle/fluid/framework/reader.cc | 30 ++++++++++++++++ paddle/fluid/framework/reader.h | 17 +++++++++ paddle/fluid/framework/reader_test.cc | 50 +++++++++++++++++++++++++++ 4 files changed, 98 insertions(+) create mode 100644 paddle/fluid/framework/reader_test.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 397c9f739..ec252929d 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) cc_library(reader SRCS reader.cc DEPS lod_tensor ddim) +cc_test(reader_test SRCS reader_test.cc DEPS reader) cc_test(variable_test SRCS variable_test.cc) diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index 0b36f1116..2e2aa1cba 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -13,11 +13,40 @@ // limitations under the License. #include "paddle/fluid/framework/reader.h" +#include namespace paddle { namespace framework { ReaderBase::~ReaderBase() {} +void ReaderBase::InsertDecoratedReader(ReaderBase *decorated_reader) { + decorated_readers_.emplace(decorated_reader); +} +void ReaderBase::EraseDecoratedReader(ReaderBase *decorated_reader) { + auto it = decorated_readers_.find(decorated_reader); + PADDLE_ENFORCE(it != decorated_readers_.end(), + "Cannot find the decorated reader to erase"); + decorated_readers_.erase(it); +} +std::unordered_set ReaderBase::GetEndPoints() { + std::unordered_set result; + std::deque queue; + queue.emplace_back(this); + while (!queue.empty()) { // BFS search + auto *front = queue.front(); + queue.pop_front(); + if (front->decorated_readers_.empty()) { + result.emplace(front); + } else { + for (ReaderBase *reader : front->decorated_readers_) { + queue.emplace_back(reader); + } + } + } + + return result; +} + FileReader::FileReader(const std::vector &dims) : dims_(dims) {} void FileReader::ReadNext(std::vector *out) { @@ -37,5 +66,6 @@ void FileReader::ReadNext(std::vector *out) { } } } +DecoratedReader::~DecoratedReader() { reader_->EraseDecoratedReader(this); } } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 64d4ceab6..2a65c58e3 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include "paddle/fluid/framework/ddim.h" @@ -31,6 +32,19 @@ class ReaderBase { virtual void ReInit() = 0; virtual ~ReaderBase(); + + // Return the readers which are the end of decorating chain. Basically + // they are readers just before read op. + std::unordered_set GetEndPoints(); + + private: + friend class DecoratedReader; + // These methods can be only invoked inside DecoratedReader to record the + // decorating chain. + void InsertDecoratedReader(ReaderBase* decorated_reader); + void EraseDecoratedReader(ReaderBase* decorated_reader); + // A set of which readers that decorated this reader. + std::unordered_set decorated_readers_; }; class DecoratedReader : public ReaderBase { @@ -38,8 +52,11 @@ class DecoratedReader : public ReaderBase { explicit DecoratedReader(const std::shared_ptr& reader) : ReaderBase(), reader_(reader) { PADDLE_ENFORCE_NOT_NULL(reader_); + reader_->InsertDecoratedReader(this); } + ~DecoratedReader(); + void ReInit() override { reader_->ReInit(); } protected: diff --git a/paddle/fluid/framework/reader_test.cc b/paddle/fluid/framework/reader_test.cc new file mode 100644 index 000000000..c763fe18d --- /dev/null +++ b/paddle/fluid/framework/reader_test.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/reader.h" +#include +#include "gtest/gtest.h" + +class StubDecoratedReader : public paddle::framework::DecoratedReader { + public: + explicit StubDecoratedReader(const std::shared_ptr &reader) + : DecoratedReader(reader) {} + + void ReadNext(std::vector *out) override {} +}; + +class StubRootReader : public paddle::framework::ReaderBase { + public: + void ReadNext(std::vector *out) override {} + void ReInit() override {} +}; + +TEST(READER, decorate_chain) { + auto root = std::make_shared(); + auto end_point1 = StubDecoratedReader(root); + auto end_point2 = StubDecoratedReader(root); + + { + auto endpoints = root->GetEndPoints(); + ASSERT_EQ(endpoints.size(), 2U); + ASSERT_NE(endpoints.count(&end_point1), 0); + ASSERT_NE(endpoints.count(&end_point2), 0); + } + + { + auto end_point3 = StubDecoratedReader(root); + ASSERT_EQ(root->GetEndPoints().size(), 3U); + } + { ASSERT_EQ(root->GetEndPoints().size(), 2U); } +} -- GitLab