提交 6d6f49cd 编写于 作者: F fengjiayi

Merge remote-tracking branch 'yuyang/feature/decorated_reader_chain' into dev_reader_ResetAll

...@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) ...@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim) cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
cc_test(reader_test SRCS reader_test.cc DEPS reader)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include <deque>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -23,6 +24,33 @@ void ReaderBase::ReadNext(std::vector<LoDTensor> *out) { ...@@ -23,6 +24,33 @@ void ReaderBase::ReadNext(std::vector<LoDTensor> *out) {
ReadNextImpl(out); ReadNextImpl(out);
} }
void ReaderBase::InsertDecoratedReader(
const std::shared_ptr<ReaderBase> &decorated_reader) {
std::lock_guard<std::mutex> guard(mu_));
decorated_readers_.emplace_back(decorated_reader);
}
std::unordered_set<ReaderBase *> ReaderBase::GetEndPoints() {
std::unordered_set<ReaderBase *> result;
std::deque<ReaderBase *> queue;
queue.emplace_back(this);
while (!queue.empty()) { // BFS search
auto *front = queue.front();
queue.pop_front();
if (front->decorated_readers_.empty()) {
result.emplace(front);
} else {
for (auto &reader : front->decorated_readers_) {
if (auto *reader_ptr = reader.lock().get()) {
queue.emplace_back(reader_ptr);
}
}
}
}
return result;
}
void ReaderBase::Shutdown() { void ReaderBase::Shutdown() {
std::lock_guard<std::mutex> lock(mu_); std::lock_guard<std::mutex> lock(mu_);
if (status_ != ReaderStatus::kStopped) { if (status_ != ReaderStatus::kStopped) {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
...@@ -34,6 +35,10 @@ class ReaderBase { ...@@ -34,6 +35,10 @@ class ReaderBase {
void Start(); void Start();
// Return the readers which are the end of decorating chain. Basically
// they are readers just before read op.
std::unordered_set<ReaderBase*> GetEndPoints();
virtual ~ReaderBase(); virtual ~ReaderBase();
protected: protected:
...@@ -46,15 +51,29 @@ class ReaderBase { ...@@ -46,15 +51,29 @@ class ReaderBase {
ReaderStatus status_{kRunning}; ReaderStatus status_{kRunning};
mutable std::mutex mu_; mutable std::mutex mu_;
private:
friend class DecoratedReader;
// These methods can be only invoked inside DecoratedReader to record the
// decorating chain.
void InsertDecoratedReader(
const std::shared_ptr<ReaderBase>& decorated_reader);
// A set of which readers that decorated this reader.
std::vector<std::weak_ptr<ReaderBase>> decorated_readers_;
}; };
class DecoratedReader : public ReaderBase { class DecoratedReader : public ReaderBase,
public std::enable_shared_from_this<DecoratedReader> {
public: public:
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader) explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
: ReaderBase(), reader_(reader) { : ReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
} }
void RegisterDecorateChain() {
reader_->InsertDecoratedReader(shared_from_this());
}
protected: protected:
void ShutdownImpl() override { reader_->Shutdown(); } void ShutdownImpl() override { reader_->Shutdown(); }
...@@ -70,9 +89,14 @@ class FileReader : public ReaderBase {}; ...@@ -70,9 +89,14 @@ class FileReader : public ReaderBase {};
// making it easier to access different type reader in Variables. // making it easier to access different type reader in Variables.
class ReaderHolder { class ReaderHolder {
public: public:
void Reset(ReaderBase* reader) { reader_.reset(reader); } template <typename T>
void Reset(const std::shared_ptr<T>& reader) {
auto reader_base = std::dynamic_pointer_cast<ReaderBase>(reader);
PADDLE_ENFORCE_NOT_NULL(reader_base);
reader_ = reader_base;
}
std::shared_ptr<ReaderBase> Get() const { return reader_; } const std::shared_ptr<ReaderBase>& Get() const { return reader_; }
void ReadNext(std::vector<LoDTensor>* out) { void ReadNext(std::vector<LoDTensor>* out) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
...@@ -93,9 +117,18 @@ class ReaderHolder { ...@@ -93,9 +117,18 @@ class ReaderHolder {
reader_->Start(); reader_->Start();
} }
operator const std::shared_ptr<ReaderBase>&() const { return this->reader_; }
private: private:
std::shared_ptr<ReaderBase> reader_; std::shared_ptr<ReaderBase> reader_;
}; };
template <typename T, typename... ARGS>
inline std::shared_ptr<DecoratedReader> MakeDecoratedReader(ARGS&&... args) {
std::shared_ptr<DecoratedReader> reader(new T(std::forward<ARGS>(args)...));
reader->RegisterDecorateChain();
return reader;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/reader.h"
#include <memory>
#include "gtest/gtest.h"
class StubDecoratedReader : public paddle::framework::DecoratedReader {
public:
explicit StubDecoratedReader(const std::shared_ptr<ReaderBase> &reader)
: DecoratedReader(reader) {}
void ReadNext(std::vector<paddle::framework::LoDTensor> *out) override {}
};
class StubRootReader : public paddle::framework::ReaderBase {
public:
void ReadNext(std::vector<paddle::framework::LoDTensor> *out) override {}
void ReInit() override {}
};
TEST(READER, decorate_chain) {
auto root = std::make_shared<StubRootReader>();
auto end_point1 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
auto end_point2 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
{
auto endpoints = root->GetEndPoints();
ASSERT_EQ(endpoints.size(), 2U);
ASSERT_NE(endpoints.count(end_point1.get()), 0);
ASSERT_NE(endpoints.count(end_point2.get()), 0);
}
{
auto end_point3 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
ASSERT_EQ(root->GetEndPoints().size(), 3U);
}
{ ASSERT_EQ(root->GetEndPoints().size(), 2U); }
}
...@@ -50,8 +50,9 @@ class CreateBatchReaderOp : public framework::OperatorBase { ...@@ -50,8 +50,9 @@ class CreateBatchReaderOp : public framework::OperatorBase {
} }
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
out->Reset(new BatchReader(underlying_reader.Get(), Attr<int>("batch_size"), out->Reset(framework::MakeDecoratedReader<BatchReader>(
Attr<bool>("discard_leftover"))); underlying_reader, Attr<int>("batch_size"),
Attr<bool>("discard_leftover")));
} }
}; };
......
...@@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase { ...@@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase {
} }
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
out->Reset( out->Reset(framework::MakeDecoratedReader<CustomReader>(
new CustomReader(underlying_reader.Get(), *sub_block, underlying_reader, *sub_block,
Attr<std::vector<std::string>>("source_var_names"), Attr<std::vector<std::string>>("source_var_names"),
Attr<std::vector<std::string>>("sink_var_names"))); Attr<std::vector<std::string>>("sink_var_names")));
} }
}; };
......
...@@ -118,7 +118,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { ...@@ -118,7 +118,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place = platform::CUDAPlace(static_cast<int>(num)); place = platform::CUDAPlace(static_cast<int>(num));
} }
out->Reset(new DoubleBufferReader(underlying_reader.Get(), place)); out->Reset(framework::MakeDecoratedReader<DoubleBufferReader>(
underlying_reader, place));
} }
}; };
......
...@@ -59,7 +59,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase { ...@@ -59,7 +59,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
int pass_num = Attr<int>("pass_num"); int pass_num = Attr<int>("pass_num");
out->Reset(new MultiPassReader(underlying_reader.Get(), pass_num)); out->Reset(framework::MakeDecoratedReader<MultiPassReader>(
underlying_reader, pass_num));
} }
}; };
......
...@@ -63,7 +63,7 @@ class CreatePyReaderOp : public framework::OperatorBase { ...@@ -63,7 +63,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
auto* queue_holder = auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>(); queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
out->Reset(new PyReader(queue_holder->GetQueue())); out->Reset(std::make_shared<PyReader>(queue_holder->GetQueue()));
} }
}; };
......
...@@ -77,8 +77,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase { ...@@ -77,8 +77,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks); std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new RandomDataGenerator<T>(shapes, Attr<float>("low"), out->Reset(std::make_shared<RandomDataGenerator<T>>(
Attr<float>("high"))); shapes, Attr<float>("low"), Attr<float>("high")));
} }
}; };
......
...@@ -59,7 +59,8 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { ...@@ -59,7 +59,8 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
std::string filename = Attr<std::string>("filename"); std::string filename = Attr<std::string>("filename");
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader<true>(filename));
out->Reset(std::make_shared<RecordIOFileReader<true>>(filename));
} }
}; };
......
...@@ -97,9 +97,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase { ...@@ -97,9 +97,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
} }
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
out->Reset( out->Reset(framework::MakeDecoratedReader<ShuffleReader>(
new ShuffleReader(underlying_reader.Get(), underlying_reader, static_cast<size_t>(Attr<int>("buffer_size"))));
static_cast<size_t>(Attr<int>("buffer_size"))));
} }
}; };
......
...@@ -178,7 +178,8 @@ class OpenFilesOp : public framework::OperatorBase { ...@@ -178,7 +178,8 @@ class OpenFilesOp : public framework::OperatorBase {
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultiFileReader(file_names, thread_num, buffer_size)); out->Reset(
std::make_shared<MultiFileReader>(file_names, thread_num, buffer_size));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册