未验证 提交 c48c586a 编写于 作者: Y yuyang18

Use weak_ptr to implement DecoratedReaderChain

上级 2bbe5f77
...@@ -19,15 +19,11 @@ namespace paddle { ...@@ -19,15 +19,11 @@ namespace paddle {
namespace framework { namespace framework {
ReaderBase::~ReaderBase() {} ReaderBase::~ReaderBase() {}
void ReaderBase::InsertDecoratedReader(ReaderBase *decorated_reader) { void ReaderBase::InsertDecoratedReader(
decorated_readers_.emplace(decorated_reader); const std::shared_ptr<ReaderBase> &decorated_reader) {
} decorated_readers_.emplace_back(decorated_reader);
void ReaderBase::EraseDecoratedReader(ReaderBase *decorated_reader) {
auto it = decorated_readers_.find(decorated_reader);
PADDLE_ENFORCE(it != decorated_readers_.end(),
"Cannot find the decorated reader to erase");
decorated_readers_.erase(it);
} }
std::unordered_set<ReaderBase *> ReaderBase::GetEndPoints() { std::unordered_set<ReaderBase *> ReaderBase::GetEndPoints() {
std::unordered_set<ReaderBase *> result; std::unordered_set<ReaderBase *> result;
std::deque<ReaderBase *> queue; std::deque<ReaderBase *> queue;
...@@ -38,8 +34,10 @@ std::unordered_set<ReaderBase *> ReaderBase::GetEndPoints() { ...@@ -38,8 +34,10 @@ std::unordered_set<ReaderBase *> ReaderBase::GetEndPoints() {
if (front->decorated_readers_.empty()) { if (front->decorated_readers_.empty()) {
result.emplace(front); result.emplace(front);
} else { } else {
for (ReaderBase *reader : front->decorated_readers_) { for (auto &reader : front->decorated_readers_) {
queue.emplace_back(reader); if (auto *reader_ptr = reader.lock().get()) {
queue.emplace_back(reader_ptr);
}
} }
} }
} }
...@@ -66,6 +64,5 @@ void FileReader::ReadNext(std::vector<LoDTensor> *out) { ...@@ -66,6 +64,5 @@ void FileReader::ReadNext(std::vector<LoDTensor> *out) {
} }
} }
} }
DecoratedReader::~DecoratedReader() { reader_->EraseDecoratedReader(this); }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -41,24 +41,26 @@ class ReaderBase { ...@@ -41,24 +41,26 @@ class ReaderBase {
friend class DecoratedReader; friend class DecoratedReader;
// These methods can be only invoked inside DecoratedReader to record the // These methods can be only invoked inside DecoratedReader to record the
// decorating chain. // decorating chain.
void InsertDecoratedReader(ReaderBase* decorated_reader); void InsertDecoratedReader(
void EraseDecoratedReader(ReaderBase* decorated_reader); const std::shared_ptr<ReaderBase>& decorated_reader);
// A set of which readers that decorated this reader. // A set of which readers that decorated this reader.
std::unordered_set<ReaderBase*> decorated_readers_; std::vector<std::weak_ptr<ReaderBase>> decorated_readers_;
}; };
class DecoratedReader : public ReaderBase { class DecoratedReader : public ReaderBase,
public std::enable_shared_from_this<DecoratedReader> {
public: public:
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader) explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
: ReaderBase(), reader_(reader) { : ReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->InsertDecoratedReader(this);
} }
~DecoratedReader();
void ReInit() override { reader_->ReInit(); } void ReInit() override { reader_->ReInit(); }
void RegisterDecorateChain() {
reader_->InsertDecoratedReader(shared_from_this());
}
protected: protected:
std::shared_ptr<ReaderBase> reader_; std::shared_ptr<ReaderBase> reader_;
}; };
...@@ -80,9 +82,14 @@ class FileReader : public ReaderBase { ...@@ -80,9 +82,14 @@ class FileReader : public ReaderBase {
// making it easier to access different type reader in Variables. // making it easier to access different type reader in Variables.
class ReaderHolder { class ReaderHolder {
public: public:
void Reset(ReaderBase* reader) { reader_.reset(reader); } template <typename T>
void Reset(const std::shared_ptr<T>& reader) {
auto reader_base = std::dynamic_pointer_cast<ReaderBase>(reader);
PADDLE_ENFORCE_NOT_NULL(reader_base);
reader_ = reader_base;
}
std::shared_ptr<ReaderBase> Get() const { return reader_; } const std::shared_ptr<ReaderBase>& Get() const { return reader_; }
void ReadNext(std::vector<LoDTensor>* out) { void ReadNext(std::vector<LoDTensor>* out) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
...@@ -93,9 +100,18 @@ class ReaderHolder { ...@@ -93,9 +100,18 @@ class ReaderHolder {
reader_->ReInit(); reader_->ReInit();
} }
operator const std::shared_ptr<ReaderBase>&() const { return this->reader_; }
private: private:
std::shared_ptr<ReaderBase> reader_; std::shared_ptr<ReaderBase> reader_;
}; };
template <typename T, typename... ARGS>
inline std::shared_ptr<DecoratedReader> MakeDecoratedReader(ARGS&&... args) {
std::shared_ptr<DecoratedReader> reader(new T(std::forward<ARGS>(args)...));
reader->RegisterDecorateChain();
return reader;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -32,18 +32,21 @@ class StubRootReader : public paddle::framework::ReaderBase { ...@@ -32,18 +32,21 @@ class StubRootReader : public paddle::framework::ReaderBase {
TEST(READER, decorate_chain) { TEST(READER, decorate_chain) {
auto root = std::make_shared<StubRootReader>(); auto root = std::make_shared<StubRootReader>();
auto end_point1 = StubDecoratedReader(root); auto end_point1 =
auto end_point2 = StubDecoratedReader(root); paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
auto end_point2 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
{ {
auto endpoints = root->GetEndPoints(); auto endpoints = root->GetEndPoints();
ASSERT_EQ(endpoints.size(), 2U); ASSERT_EQ(endpoints.size(), 2U);
ASSERT_NE(endpoints.count(&end_point1), 0); ASSERT_NE(endpoints.count(end_point1.get()), 0);
ASSERT_NE(endpoints.count(&end_point2), 0); ASSERT_NE(endpoints.count(end_point2.get()), 0);
} }
{ {
auto end_point3 = StubDecoratedReader(root); auto end_point3 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
ASSERT_EQ(root->GetEndPoints().size(), 3U); ASSERT_EQ(root->GetEndPoints().size(), 3U);
} }
{ ASSERT_EQ(root->GetEndPoints().size(), 2U); } { ASSERT_EQ(root->GetEndPoints().size(), 2U); }
......
...@@ -46,8 +46,8 @@ class CreateBatchReaderOp : public framework::OperatorBase { ...@@ -46,8 +46,8 @@ class CreateBatchReaderOp : public framework::OperatorBase {
} }
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
out->Reset( out->Reset(framework::MakeDecoratedReader<BatchReader>(
new BatchReader(underlying_reader.Get(), Attr<int>("batch_size"))); underlying_reader, Attr<int>("batch_size")));
} }
}; };
......
...@@ -60,8 +60,8 @@ class CreateCustomReaderOp : public framework::OperatorBase { ...@@ -60,8 +60,8 @@ class CreateCustomReaderOp : public framework::OperatorBase {
} }
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
out->Reset( out->Reset(framework::MakeDecoratedReader<CustomReader>(
new CustomReader(underlying_reader.Get(), *sub_block, underlying_reader, *sub_block,
Attr<std::vector<std::string>>("source_var_names"), Attr<std::vector<std::string>>("source_var_names"),
Attr<std::vector<std::string>>("sink_var_names"))); Attr<std::vector<std::string>>("sink_var_names")));
} }
......
...@@ -109,7 +109,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { ...@@ -109,7 +109,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place = platform::CUDAPlace(static_cast<int>(num)); place = platform::CUDAPlace(static_cast<int>(num));
} }
out->Reset(new DoubleBufferReader(underlying_reader.Get(), place)); out->Reset(framework::MakeDecoratedReader<DoubleBufferReader>(
underlying_reader, place));
} }
}; };
......
...@@ -60,7 +60,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase { ...@@ -60,7 +60,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
int pass_num = Attr<int>("pass_num"); int pass_num = Attr<int>("pass_num");
out->Reset(new MultiPassReader(underlying_reader.Get(), pass_num)); out->Reset(framework::MakeDecoratedReader<MultiPassReader>(
underlying_reader, pass_num));
} }
}; };
......
...@@ -58,7 +58,7 @@ class CreatePyReaderOp : public framework::OperatorBase { ...@@ -58,7 +58,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
auto* queue_holder = auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>(); queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
out->Reset(new PyReader(queue_holder->GetQueue())); out->Reset(std::make_shared<PyReader>(queue_holder->GetQueue()));
} }
}; };
......
...@@ -79,8 +79,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase { ...@@ -79,8 +79,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks); std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new RandomDataGenerator<T>(shapes, Attr<float>("low"), out->Reset(std::make_shared<RandomDataGenerator<T>>(
Attr<float>("high"))); shapes, Attr<float>("low"), Attr<float>("high")));
} }
}; };
......
...@@ -70,7 +70,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { ...@@ -70,7 +70,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader<true>( out->Reset(std::make_shared<RecordIOFileReader<true>>(
filename, RestoreShapes(shape_concat, ranks))); filename, RestoreShapes(shape_concat, ranks)));
} }
}; };
......
...@@ -86,9 +86,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase { ...@@ -86,9 +86,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
} }
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
out->Reset( out->Reset(framework::MakeDecoratedReader<ShuffleReader>(
new ShuffleReader(underlying_reader.Get(), underlying_reader, static_cast<size_t>(Attr<int>("buffer_size"))));
static_cast<size_t>(Attr<int>("buffer_size"))));
} }
}; };
......
...@@ -49,7 +49,8 @@ class CreateThreadedReaderOp : public framework::OperatorBase { ...@@ -49,7 +49,8 @@ class CreateThreadedReaderOp : public framework::OperatorBase {
} }
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
out->Reset(new ThreadedReader(underlying_reader.Get())); out->Reset(
framework::MakeDecoratedReader<ThreadedReader>(underlying_reader));
} }
}; };
......
...@@ -180,9 +180,9 @@ class OpenFilesOp : public framework::OperatorBase { ...@@ -180,9 +180,9 @@ class OpenFilesOp : public framework::OperatorBase {
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultiFileReader(file_names, out->Reset(std::make_shared<MultiFileReader>(
RestoreShapes(shape_concat, ranks), file_names, RestoreShapes(shape_concat, ranks), thread_num,
thread_num, buffer_size)); buffer_size));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册