未验证 提交 c48c586a 编写于 作者: Y yuyang18

Use weak_ptr to implement DecoratedReaderChain

上级 2bbe5f77
......@@ -19,15 +19,11 @@ namespace paddle {
namespace framework {
ReaderBase::~ReaderBase() {}
void ReaderBase::InsertDecoratedReader(ReaderBase *decorated_reader) {
decorated_readers_.emplace(decorated_reader);
}
void ReaderBase::EraseDecoratedReader(ReaderBase *decorated_reader) {
auto it = decorated_readers_.find(decorated_reader);
PADDLE_ENFORCE(it != decorated_readers_.end(),
"Cannot find the decorated reader to erase");
decorated_readers_.erase(it);
void ReaderBase::InsertDecoratedReader(
const std::shared_ptr<ReaderBase> &decorated_reader) {
decorated_readers_.emplace_back(decorated_reader);
}
std::unordered_set<ReaderBase *> ReaderBase::GetEndPoints() {
std::unordered_set<ReaderBase *> result;
std::deque<ReaderBase *> queue;
......@@ -38,8 +34,10 @@ std::unordered_set<ReaderBase *> ReaderBase::GetEndPoints() {
if (front->decorated_readers_.empty()) {
result.emplace(front);
} else {
for (ReaderBase *reader : front->decorated_readers_) {
queue.emplace_back(reader);
for (auto &reader : front->decorated_readers_) {
if (auto *reader_ptr = reader.lock().get()) {
queue.emplace_back(reader_ptr);
}
}
}
}
......@@ -66,6 +64,5 @@ void FileReader::ReadNext(std::vector<LoDTensor> *out) {
}
}
}
DecoratedReader::~DecoratedReader() { reader_->EraseDecoratedReader(this); }
} // namespace framework
} // namespace paddle
......@@ -41,24 +41,26 @@ class ReaderBase {
friend class DecoratedReader;
// These methods can be only invoked inside DecoratedReader to record the
// decorating chain.
void InsertDecoratedReader(ReaderBase* decorated_reader);
void EraseDecoratedReader(ReaderBase* decorated_reader);
void InsertDecoratedReader(
const std::shared_ptr<ReaderBase>& decorated_reader);
// A set of which readers that decorated this reader.
std::unordered_set<ReaderBase*> decorated_readers_;
std::vector<std::weak_ptr<ReaderBase>> decorated_readers_;
};
class DecoratedReader : public ReaderBase {
class DecoratedReader : public ReaderBase,
public std::enable_shared_from_this<DecoratedReader> {
public:
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
: ReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->InsertDecoratedReader(this);
}
~DecoratedReader();
void ReInit() override { reader_->ReInit(); }
void RegisterDecorateChain() {
reader_->InsertDecoratedReader(shared_from_this());
}
protected:
std::shared_ptr<ReaderBase> reader_;
};
......@@ -80,9 +82,14 @@ class FileReader : public ReaderBase {
// making it easier to access different type reader in Variables.
class ReaderHolder {
public:
void Reset(ReaderBase* reader) { reader_.reset(reader); }
template <typename T>
void Reset(const std::shared_ptr<T>& reader) {
auto reader_base = std::dynamic_pointer_cast<ReaderBase>(reader);
PADDLE_ENFORCE_NOT_NULL(reader_base);
reader_ = reader_base;
}
std::shared_ptr<ReaderBase> Get() const { return reader_; }
const std::shared_ptr<ReaderBase>& Get() const { return reader_; }
void ReadNext(std::vector<LoDTensor>* out) {
PADDLE_ENFORCE_NOT_NULL(reader_);
......@@ -93,9 +100,18 @@ class ReaderHolder {
reader_->ReInit();
}
operator const std::shared_ptr<ReaderBase>&() const { return this->reader_; }
private:
std::shared_ptr<ReaderBase> reader_;
};
template <typename T, typename... ARGS>
inline std::shared_ptr<DecoratedReader> MakeDecoratedReader(ARGS&&... args) {
std::shared_ptr<DecoratedReader> reader(new T(std::forward<ARGS>(args)...));
reader->RegisterDecorateChain();
return reader;
}
} // namespace framework
} // namespace paddle
......@@ -32,18 +32,21 @@ class StubRootReader : public paddle::framework::ReaderBase {
TEST(READER, decorate_chain) {
auto root = std::make_shared<StubRootReader>();
auto end_point1 = StubDecoratedReader(root);
auto end_point2 = StubDecoratedReader(root);
auto end_point1 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
auto end_point2 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
{
auto endpoints = root->GetEndPoints();
ASSERT_EQ(endpoints.size(), 2U);
ASSERT_NE(endpoints.count(&end_point1), 0);
ASSERT_NE(endpoints.count(&end_point2), 0);
ASSERT_NE(endpoints.count(end_point1.get()), 0);
ASSERT_NE(endpoints.count(end_point2.get()), 0);
}
{
auto end_point3 = StubDecoratedReader(root);
auto end_point3 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
ASSERT_EQ(root->GetEndPoints().size(), 3U);
}
{ ASSERT_EQ(root->GetEndPoints().size(), 2U); }
......
......@@ -46,8 +46,8 @@ class CreateBatchReaderOp : public framework::OperatorBase {
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(
new BatchReader(underlying_reader.Get(), Attr<int>("batch_size")));
out->Reset(framework::MakeDecoratedReader<BatchReader>(
underlying_reader, Attr<int>("batch_size")));
}
};
......
......@@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase {
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(
new CustomReader(underlying_reader.Get(), *sub_block,
Attr<std::vector<std::string>>("source_var_names"),
Attr<std::vector<std::string>>("sink_var_names")));
out->Reset(framework::MakeDecoratedReader<CustomReader>(
underlying_reader, *sub_block,
Attr<std::vector<std::string>>("source_var_names"),
Attr<std::vector<std::string>>("sink_var_names")));
}
};
......
......@@ -109,7 +109,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place = platform::CUDAPlace(static_cast<int>(num));
}
out->Reset(new DoubleBufferReader(underlying_reader.Get(), place));
out->Reset(framework::MakeDecoratedReader<DoubleBufferReader>(
underlying_reader, place));
}
};
......
......@@ -60,7 +60,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
int pass_num = Attr<int>("pass_num");
out->Reset(new MultiPassReader(underlying_reader.Get(), pass_num));
out->Reset(framework::MakeDecoratedReader<MultiPassReader>(
underlying_reader, pass_num));
}
};
......
......@@ -58,7 +58,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
out->Reset(new PyReader(queue_holder->GetQueue()));
out->Reset(std::make_shared<PyReader>(queue_holder->GetQueue()));
}
};
......
......@@ -79,8 +79,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new RandomDataGenerator<T>(shapes, Attr<float>("low"),
Attr<float>("high")));
out->Reset(std::make_shared<RandomDataGenerator<T>>(
shapes, Attr<float>("low"), Attr<float>("high")));
}
};
......
......@@ -70,7 +70,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader<true>(
out->Reset(std::make_shared<RecordIOFileReader<true>>(
filename, RestoreShapes(shape_concat, ranks)));
}
};
......
......@@ -86,9 +86,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(
new ShuffleReader(underlying_reader.Get(),
static_cast<size_t>(Attr<int>("buffer_size"))));
out->Reset(framework::MakeDecoratedReader<ShuffleReader>(
underlying_reader, static_cast<size_t>(Attr<int>("buffer_size"))));
}
};
......
......@@ -49,7 +49,8 @@ class CreateThreadedReaderOp : public framework::OperatorBase {
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(new ThreadedReader(underlying_reader.Get()));
out->Reset(
framework::MakeDecoratedReader<ThreadedReader>(underlying_reader));
}
};
......
......@@ -180,9 +180,9 @@ class OpenFilesOp : public framework::OperatorBase {
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultiFileReader(file_names,
RestoreShapes(shape_concat, ranks),
thread_num, buffer_size));
out->Reset(std::make_shared<MultiFileReader>(
file_names, RestoreShapes(shape_concat, ranks), thread_num,
buffer_size));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册