未验证 提交 2bbe5f77 编写于 作者: Y yuyang18

Add GetEndPoints of Reader.

We can get endpoints of a reader chain.
上级 d3a48484
...@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) ...@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim) cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
cc_test(reader_test SRCS reader_test.cc DEPS reader)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
......
...@@ -13,11 +13,40 @@ ...@@ -13,11 +13,40 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include <deque>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
ReaderBase::~ReaderBase() {} ReaderBase::~ReaderBase() {}
void ReaderBase::InsertDecoratedReader(ReaderBase *decorated_reader) {
decorated_readers_.emplace(decorated_reader);
}
void ReaderBase::EraseDecoratedReader(ReaderBase *decorated_reader) {
auto it = decorated_readers_.find(decorated_reader);
PADDLE_ENFORCE(it != decorated_readers_.end(),
"Cannot find the decorated reader to erase");
decorated_readers_.erase(it);
}
std::unordered_set<ReaderBase *> ReaderBase::GetEndPoints() {
std::unordered_set<ReaderBase *> result;
std::deque<ReaderBase *> queue;
queue.emplace_back(this);
while (!queue.empty()) { // BFS search
auto *front = queue.front();
queue.pop_front();
if (front->decorated_readers_.empty()) {
result.emplace(front);
} else {
for (ReaderBase *reader : front->decorated_readers_) {
queue.emplace_back(reader);
}
}
}
return result;
}
FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {} FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
void FileReader::ReadNext(std::vector<LoDTensor> *out) { void FileReader::ReadNext(std::vector<LoDTensor> *out) {
...@@ -37,5 +66,6 @@ void FileReader::ReadNext(std::vector<LoDTensor> *out) { ...@@ -37,5 +66,6 @@ void FileReader::ReadNext(std::vector<LoDTensor> *out) {
} }
} }
} }
DecoratedReader::~DecoratedReader() { reader_->EraseDecoratedReader(this); }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
...@@ -31,6 +32,19 @@ class ReaderBase { ...@@ -31,6 +32,19 @@ class ReaderBase {
virtual void ReInit() = 0; virtual void ReInit() = 0;
virtual ~ReaderBase(); virtual ~ReaderBase();
// Return the readers which are the end of decorating chain. Basically
// they are readers just before read op.
std::unordered_set<ReaderBase*> GetEndPoints();
private:
friend class DecoratedReader;
// These methods can be only invoked inside DecoratedReader to record the
// decorating chain.
void InsertDecoratedReader(ReaderBase* decorated_reader);
void EraseDecoratedReader(ReaderBase* decorated_reader);
// A set of which readers that decorated this reader.
std::unordered_set<ReaderBase*> decorated_readers_;
}; };
class DecoratedReader : public ReaderBase { class DecoratedReader : public ReaderBase {
...@@ -38,8 +52,11 @@ class DecoratedReader : public ReaderBase { ...@@ -38,8 +52,11 @@ class DecoratedReader : public ReaderBase {
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader) explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
: ReaderBase(), reader_(reader) { : ReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->InsertDecoratedReader(this);
} }
~DecoratedReader();
void ReInit() override { reader_->ReInit(); } void ReInit() override { reader_->ReInit(); }
protected: protected:
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/reader.h"
#include <memory>
#include "gtest/gtest.h"
class StubDecoratedReader : public paddle::framework::DecoratedReader {
public:
explicit StubDecoratedReader(const std::shared_ptr<ReaderBase> &reader)
: DecoratedReader(reader) {}
void ReadNext(std::vector<paddle::framework::LoDTensor> *out) override {}
};
class StubRootReader : public paddle::framework::ReaderBase {
public:
void ReadNext(std::vector<paddle::framework::LoDTensor> *out) override {}
void ReInit() override {}
};
TEST(READER, decorate_chain) {
auto root = std::make_shared<StubRootReader>();
auto end_point1 = StubDecoratedReader(root);
auto end_point2 = StubDecoratedReader(root);
{
auto endpoints = root->GetEndPoints();
ASSERT_EQ(endpoints.size(), 2U);
ASSERT_NE(endpoints.count(&end_point1), 0);
ASSERT_NE(endpoints.count(&end_point2), 0);
}
{
auto end_point3 = StubDecoratedReader(root);
ASSERT_EQ(root->GetEndPoints().size(), 3U);
}
{ ASSERT_EQ(root->GetEndPoints().size(), 2U); }
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册