未验证 提交 26ae6111 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #12051 from JiayiFeng/dev_reader_ResetAll

[WIP] Dev reader reset all
......@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
cc_test(reader_test SRCS reader_test.cc DEPS reader)
cc_test(variable_test SRCS variable_test.cc)
......
......@@ -13,29 +13,61 @@
// limitations under the License.
#include "paddle/fluid/framework/reader.h"
#include <deque>
namespace paddle {
namespace framework {
ReaderBase::~ReaderBase() {}
FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
void FileReader::ReadNext(std::vector<LoDTensor> *out) {
void ReaderBase::ReadNext(std::vector<LoDTensor> *out) {
std::lock_guard<std::mutex> lock(mu_);
PADDLE_ENFORCE_EQ(status_, ReaderStatus::kRunning);
ReadNextImpl(out);
if (out->empty()) {
return;
}
}
PADDLE_ENFORCE_EQ(out->size(), dims_.size());
for (size_t i = 0; i < dims_.size(); ++i) {
auto &actual = (*out)[i].dims();
auto &expect = dims_[i];
void ReaderBase::InsertDecoratedReader(
const std::shared_ptr<ReaderBase> &decorated_reader) {
std::lock_guard<std::mutex> guard(mu_);
decorated_readers_.emplace_back(decorated_reader);
}
PADDLE_ENFORCE_EQ(actual.size(), expect.size());
for (int j = 0; j < actual.size(); ++j) {
// PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1);
std::unordered_set<ReaderBase *> ReaderBase::GetEndPoints() {
std::unordered_set<ReaderBase *> result;
std::deque<ReaderBase *> queue;
queue.emplace_back(this);
while (!queue.empty()) { // BFS search
auto *front = queue.front();
queue.pop_front();
if (front->decorated_readers_.empty()) {
result.emplace(front);
} else {
for (auto &reader : front->decorated_readers_) {
if (auto *reader_ptr = reader.lock().get()) {
queue.emplace_back(reader_ptr);
}
}
}
}
return result;
}
void ReaderBase::Shutdown() {
std::lock_guard<std::mutex> lock(mu_);
if (status_ != ReaderStatus::kStopped) {
ShutdownImpl();
status_ = ReaderStatus::kStopped;
}
}
void ReaderBase::Start() {
std::lock_guard<std::mutex> lock(mu_);
if (status_ != ReaderStatus::kRunning) {
StartImpl();
status_ = ReaderStatus::kRunning;
}
}
ReaderBase::~ReaderBase() { Shutdown(); }
} // namespace framework
} // namespace paddle
......@@ -15,6 +15,7 @@
#pragma once
#include <memory>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
......@@ -24,61 +25,116 @@
namespace paddle {
namespace framework {
enum ReaderStatus { kRunning, kStopped };
class ReaderBase {
public:
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
void ReadNext(std::vector<LoDTensor>* out);
void Shutdown();
virtual void ReInit() = 0;
void Start();
// Return the readers which are the end of decorating chain. Basically
// they are readers just before read op.
std::unordered_set<ReaderBase*> GetEndPoints();
virtual ~ReaderBase();
protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
virtual void ShutdownImpl() {}
virtual void StartImpl() {}
ReaderStatus status_{kRunning};
mutable std::mutex mu_;
private:
friend class DecoratedReader;
// These methods can be only invoked inside DecoratedReader to record the
// decorating chain.
void InsertDecoratedReader(
const std::shared_ptr<ReaderBase>& decorated_reader);
// A set of which readers that decorated this reader.
std::vector<std::weak_ptr<ReaderBase>> decorated_readers_;
};
class DecoratedReader : public ReaderBase {
class DecoratedReader : public ReaderBase,
public std::enable_shared_from_this<DecoratedReader> {
public:
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
: ReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_);
}
void ReInit() override { reader_->ReInit(); }
void RegisterDecorateChain() {
reader_->InsertDecoratedReader(shared_from_this());
}
protected:
std::shared_ptr<ReaderBase> reader_;
};
class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& dims);
void ReadNext(std::vector<LoDTensor>* out) override;
void ShutdownImpl() override { reader_->Shutdown(); }
protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
void StartImpl() override { reader_->Start(); }
private:
std::vector<DDim> dims_;
std::shared_ptr<ReaderBase> reader_;
};
// FileReader is just a conceptual class.
class FileReader : public ReaderBase {};
// The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables.
class ReaderHolder {
public:
void Reset(ReaderBase* reader) { reader_.reset(reader); }
template <typename T>
void Reset(const std::shared_ptr<T>& reader) {
auto reader_base = std::dynamic_pointer_cast<ReaderBase>(reader);
PADDLE_ENFORCE_NOT_NULL(reader_base);
reader_ = reader_base;
}
std::shared_ptr<ReaderBase> Get() const { return reader_; }
const std::shared_ptr<ReaderBase>& Get() const { return reader_; }
void ReadNext(std::vector<LoDTensor>* out) {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReadNext(out);
}
void ReInit() {
void ResetAll() {
auto end_readers = reader_->GetEndPoints();
for (auto* reader : end_readers) {
reader->Shutdown();
}
for (auto* reader : end_readers) {
reader->Start();
}
}
void Shutdown() {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->Shutdown();
}
void Start() {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReInit();
reader_->Start();
}
operator const std::shared_ptr<ReaderBase>&() const { return this->reader_; }
private:
std::shared_ptr<ReaderBase> reader_;
};
template <typename T, typename... ARGS>
inline std::shared_ptr<DecoratedReader> MakeDecoratedReader(ARGS&&... args) {
std::shared_ptr<DecoratedReader> reader(new T(std::forward<ARGS>(args)...));
reader->RegisterDecorateChain();
return reader;
}
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/reader.h"
#include <memory>
#include "gtest/gtest.h"
class StubDecoratedReader : public paddle::framework::DecoratedReader {
public:
explicit StubDecoratedReader(const std::shared_ptr<ReaderBase> &reader)
: DecoratedReader(reader) {}
void ReadNextImpl(std::vector<paddle::framework::LoDTensor> *out) override {}
};
class StubRootReader : public paddle::framework::ReaderBase {
public:
void ReadNextImpl(std::vector<paddle::framework::LoDTensor> *out) override {}
};
TEST(READER, decorate_chain) {
auto root = std::make_shared<StubRootReader>();
auto end_point1 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
auto end_point2 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
{
auto endpoints = root->GetEndPoints();
ASSERT_EQ(endpoints.size(), 2U);
ASSERT_NE(endpoints.count(end_point1.get()), 0);
ASSERT_NE(endpoints.count(end_point2.get()), 0);
}
{
auto end_point3 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
ASSERT_EQ(root->GetEndPoints().size(), 3U);
}
{ ASSERT_EQ(root->GetEndPoints().size(), 2U); }
}
......@@ -22,7 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
reader_library(create_threaded_reader_op SRCS create_threaded_reader_op.cc)
reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc)
reader_library(create_py_reader_op SRCS create_py_reader_op.cc)
......
......@@ -20,15 +20,19 @@ namespace reader {
class BatchReader : public framework::DecoratedReader {
public:
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size)
: DecoratedReader(reader), batch_size_(batch_size) {
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size,
bool discard_leftover)
: DecoratedReader(reader),
batch_size_(batch_size),
discard_leftover_(discard_leftover) {
buffer_.reserve(batch_size_);
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
private:
int batch_size_;
bool discard_leftover_;
std::vector<std::vector<framework::LoDTensor>> buffer_;
};
......@@ -46,8 +50,9 @@ class CreateBatchReaderOp : public framework::OperatorBase {
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(
new BatchReader(underlying_reader.Get(), Attr<int>("batch_size")));
out->Reset(framework::MakeDecoratedReader<BatchReader>(
underlying_reader, Attr<int>("batch_size"),
Attr<bool>("discard_leftover")));
}
};
......@@ -57,6 +62,10 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
AddAttr<int>("batch_size",
"How many instances the batch reader yields each time.")
.GreaterThan(0);
AddAttr<bool>("discard_leftover",
"If true, the leftover instances that are not enough for a "
"new batch will be discarded.")
.SetDefault(true);
AddComment(R"DOC(
CreateBatchReader Operator
......@@ -66,7 +75,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
}
};
void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
buffer_.clear();
buffer_.reserve(batch_size_);
for (int i = 0; i < batch_size_; ++i) {
......@@ -77,6 +86,9 @@ void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
break;
}
}
if (discard_leftover_ && buffer_.size() < batch_size_) {
buffer_.clear();
}
// Concat instances
out->clear();
if (buffer_.empty()) {
......
......@@ -33,7 +33,7 @@ class CustomReader : public framework::DecoratedReader {
source_var_names_(source_var_names),
sink_var_names_(sink_var_names) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
private:
const framework::ProgramDesc program_;
......@@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase {
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(
new CustomReader(underlying_reader.Get(), *sub_block,
Attr<std::vector<std::string>>("source_var_names"),
Attr<std::vector<std::string>>("sink_var_names")));
out->Reset(framework::MakeDecoratedReader<CustomReader>(
underlying_reader, *sub_block,
Attr<std::vector<std::string>>("source_var_names"),
Attr<std::vector<std::string>>("sink_var_names")));
}
};
......@@ -143,7 +143,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
}
};
void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
void CustomReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
out->clear();
std::vector<framework::LoDTensor> underlying_outs;
reader_->ReadNext(&underlying_outs);
......
......@@ -23,13 +23,13 @@ namespace reader {
// 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2.
static constexpr size_t kCacheSize = 5;
static constexpr size_t kCacheSize = 3;
// There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by
// subsequent operators.
// So the channel size should be kChacheSize - 2
static constexpr size_t kChannelSize = 3; // kCacheSize - 2
static constexpr size_t kChannelSize = 1; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader {
public:
......@@ -50,12 +50,21 @@ class DoubleBufferReader : public framework::DecoratedReader {
StartPrefetcher();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
~DoubleBufferReader() { EndPrefetcher(); }
private:
void ShutdownImpl() override {
EndPrefetcher();
reader_->Shutdown();
}
void StartImpl() override {
reader_->Start();
StartPrefetcher();
}
void StartPrefetcher() {
channel_ = new reader::BlockingQueue<size_t>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
......@@ -109,7 +118,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place = platform::CUDAPlace(static_cast<int>(num));
}
out->Reset(new DoubleBufferReader(underlying_reader.Get(), place));
out->Reset(framework::MakeDecoratedReader<DoubleBufferReader>(
underlying_reader, place));
}
};
......@@ -136,7 +146,7 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}
};
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
void DoubleBufferReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
size_t cached_tensor_id;
if (channel_->Receive(&cached_tensor_id)) {
if (platform::is_gpu_place(place_)) {
......@@ -150,12 +160,6 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
}
}
void DoubleBufferReader::ReInit() {
reader_->ReInit();
EndPrefetcher();
StartPrefetcher();
}
void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts.";
size_t cached_tensor_id = 0;
......
......@@ -24,23 +24,22 @@ class MultiPassReader : public framework::DecoratedReader {
MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num)
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
reader_->ReadNext(out);
if (out->empty()) {
if (out->empty() && pass_count_ < pass_num_ - 1) {
reader_->Shutdown();
reader_->Start();
reader_->ReadNext(out);
++pass_count_;
if (pass_count_ < pass_num_) {
reader_->ReInit();
reader_->ReadNext(out);
}
}
}
void ReInit() override {
private:
void StartImpl() override {
pass_count_ = 0;
reader_->ReInit();
reader_->Start();
}
private:
int pass_num_;
mutable int pass_count_;
};
......@@ -60,7 +59,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
int pass_num = Attr<int>("pass_num");
out->Reset(new MultiPassReader(underlying_reader.Get(), pass_num));
out->Reset(framework::MakeDecoratedReader<MultiPassReader>(
underlying_reader, pass_num));
}
};
......
......@@ -19,22 +19,27 @@ namespace paddle {
namespace operators {
namespace reader {
class PyReader : public framework::ReaderBase {
class PyReader : public framework::FileReader {
public:
explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue) {
explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue)
: framework::FileReader() {
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
queue_ = queue;
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
bool success;
*out = queue_->Pop(&success);
if (!success) out->clear();
}
void ReInit() override {}
private:
void ShutdownImpl() override { /* TODO */
}
void StartImpl() override { /* TODO */
}
std::shared_ptr<LoDTensorBlockingQueue> queue_;
};
......@@ -51,14 +56,14 @@ class CreatePyReaderOp : public framework::OperatorBase {
const std::string& queue_name = Input("blocking_queue");
auto* queue_holder_var = scope.FindVar(queue_name);
PADDLE_ENFORCE(
queue_holder_var != nullptr,
PADDLE_ENFORCE_NOT_NULL(
queue_holder_var,
"No LoDTensorBlockingQueueHolder variable with name %s found",
queue_name);
auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
out->Reset(new PyReader(queue_holder->GetQueue()));
out->Reset(std::make_shared<PyReader>(queue_holder->GetQueue()));
}
};
......
......@@ -19,11 +19,11 @@ namespace operators {
namespace reader {
template <typename T>
class RandomDataGenerator : public framework::ReaderBase {
class RandomDataGenerator : public framework::FileReader {
public:
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float low,
float high)
: framework::ReaderBase(), low_(low), high_(high), shapes_(shapes) {
: framework::FileReader(), low_(low), high_(high), shapes_(shapes) {
PADDLE_ENFORCE_LE(low, high,
"'low' shouldn't be greater than 'high'.(%f vs %f)", low,
high);
......@@ -32,7 +32,7 @@ class RandomDataGenerator : public framework::ReaderBase {
dist_ = std::uniform_real_distribution<float>(low_, high_);
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
out->clear();
out->reserve(shapes_.size());
for (const framework::DDim& shape : shapes_) {
......@@ -51,8 +51,6 @@ class RandomDataGenerator : public framework::ReaderBase {
}
}
void ReInit() override { return; }
private:
float low_;
float high_;
......@@ -79,8 +77,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new RandomDataGenerator<T>(shapes, Attr<float>("low"),
Attr<float>("high")));
out->Reset(std::make_shared<RandomDataGenerator<T>>(
shapes, Attr<float>("low"), Attr<float>("high")));
}
};
......
......@@ -21,10 +21,8 @@ namespace reader {
template <bool ThreadSafe>
class RecordIOFileReader : public framework::FileReader {
public:
explicit RecordIOFileReader(const std::string& filename,
const std::vector<framework::DDim>& dims)
: FileReader(dims),
scanner_(filename),
explicit RecordIOFileReader(const std::string& filename)
: scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) {
if (ThreadSafe) {
......@@ -33,8 +31,6 @@ class RecordIOFileReader : public framework::FileReader {
LOG(INFO) << "Creating file reader" << filename;
}
void ReInit() override { scanner_.Reset(); }
protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
if (ThreadSafe) {
......@@ -45,6 +41,8 @@ class RecordIOFileReader : public framework::FileReader {
}
}
void StartImpl() override { scanner_.Reset(); }
private:
std::unique_ptr<std::mutex> mutex_;
recordio::Scanner scanner_;
......@@ -58,20 +56,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
static_cast<int>(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::string filename = Attr<std::string>("filename");
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader<true>(
filename, RestoreShapes(shape_concat, ranks)));
out->Reset(std::make_shared<RecordIOFileReader<true>>(filename));
}
};
......
......@@ -34,7 +34,7 @@ class ShuffleReader : public framework::DecoratedReader {
ReloadBuffer();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
out->clear();
if (iteration_pos_ >= buffer_.size()) {
VLOG(10) << "Resetting shuffle buffer";
......@@ -47,6 +47,17 @@ class ShuffleReader : public framework::DecoratedReader {
}
private:
void ShutdownImpl() override {
buffer_.clear();
iteration_pos_ = 0;
reader_->Shutdown();
}
void StartImpl() override {
reader_->Start();
ReloadBuffer();
}
void ReloadBuffer() {
buffer_.clear();
buffer_.reserve(buffer_size_);
......@@ -86,9 +97,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(
new ShuffleReader(underlying_reader.Get(),
static_cast<size_t>(Attr<int>("buffer_size"))));
out->Reset(framework::MakeDecoratedReader<ShuffleReader>(
underlying_reader, static_cast<size_t>(Attr<int>("buffer_size"))));
}
};
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
class ThreadedReader : public framework::DecoratedReader {
public:
explicit ThreadedReader(const std::shared_ptr<ReaderBase>& reader)
: DecoratedReader(reader) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
std::lock_guard<std::mutex> lock(mutex_);
reader_->ReadNext(out);
}
void ReInit() override { reader_->ReInit(); }
private:
std::mutex mutex_;
};
class CreateThreadedReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
auto* out = detail::Ref(scope.FindVar(Output("Out")))
.GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(new ThreadedReader(underlying_reader.Get()));
}
};
class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
protected:
void Apply() override {
AddComment(R"DOC(
CreateThreadedReader Operator
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
time.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
environment.
)DOC");
}
};
} // namespace reader
} // namespace operators
} // namespace paddle
namespace reader = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_threaded_reader,
reader::CreateThreadedReaderOp,
reader::CreateThreadedReaderOpMaker);
......@@ -23,24 +23,26 @@ namespace reader {
class MultiFileReader : public framework::ReaderBase {
public:
MultiFileReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num,
MultiFileReader(const std::vector<std::string>& file_names, size_t thread_num,
size_t buffer_size)
: buffer_size_(buffer_size) {
readers_.reserve(file_names.size());
for (const std::string& f_name : file_names) {
readers_.emplace_back(CreateReaderByFileName(f_name, dims));
readers_.emplace_back(CreateReaderByFileName(f_name));
}
prefetchers_.resize(thread_num);
StartNewScheduler();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
~MultiFileReader() { EndScheduler(); }
private:
void ShutdownImpl() override { EndScheduler(); }
void StartImpl() override { StartNewScheduler(); }
void StartNewScheduler();
void EndScheduler();
void ScheduleThreadFunc();
......@@ -55,17 +57,12 @@ class MultiFileReader : public framework::ReaderBase {
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
};
void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) {
void MultiFileReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
if (!buffer_->Receive(out)) {
out->clear();
}
}
void MultiFileReader::ReInit() {
EndScheduler();
StartNewScheduler();
}
void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size();
waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size());
......@@ -120,7 +117,7 @@ void MultiFileReader::ScheduleThreadFunc() {
}
}
}
// If users invoke ReInit() when scheduler is running, it will close the
// If users invoke Shutdown() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends.
for (auto& p : prefetchers_) {
......@@ -138,7 +135,8 @@ void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (ins.empty()) {
reader->ReInit();
reader->Shutdown();
reader->Start();
break;
}
try {
......@@ -180,9 +178,8 @@ class OpenFilesOp : public framework::OperatorBase {
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultiFileReader(file_names,
RestoreShapes(shape_concat, ranks),
thread_num, buffer_size));
out->Reset(
std::make_shared<MultiFileReader>(file_names, thread_num, buffer_size));
}
};
......
......@@ -39,7 +39,7 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
}
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims) {
const std::string& file_name) {
size_t separator_pos = file_name.find_last_of(kFileFormatSeparator);
PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
"File name illegal! A legal file name should be like: "
......@@ -49,7 +49,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
auto itor = FileReaderRegistry().find(filetype);
PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
"No file reader registered for '%s' format.", filetype);
framework::ReaderBase* reader = (itor->second)(file_name, dims);
framework::ReaderBase* reader = (itor->second)(file_name);
return std::unique_ptr<framework::ReaderBase>(reader);
}
......
......@@ -25,22 +25,21 @@ namespace reader {
static constexpr char kFileFormatSeparator[] = ".";
using FileReaderCreator = std::function<framework::ReaderBase*(
const std::string&, const std::vector<framework::DDim>&)>;
using FileReaderCreator =
std::function<framework::ReaderBase*(const std::string&)>;
std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry();
template <typename Reader>
int RegisterFileReader(const std::string& filetype) {
FileReaderRegistry()[filetype] = [](
const std::string& fn, const std::vector<framework::DDim>& dims) {
return new Reader(fn, dims);
FileReaderRegistry()[filetype] = [](const std::string& fn) {
return new Reader(fn);
};
return 0;
}
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims);
const std::string& file_name);
extern std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks);
......
......@@ -296,7 +296,7 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("reset", &framework::ReaderHolder::ReInit);
.def("reset", &framework::ReaderHolder::ResetAll);
using LoDTensorBlockingQueue =
::paddle::operators::reader::LoDTensorBlockingQueue;
......
......@@ -375,9 +375,6 @@ def open_recordio_file(filename,
if pass_num > 1:
main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num)
if for_parallel:
main_prog_var = parallel(reader=main_prog_var)
return monkey_patch_reader_methods(main_prog_var)
......@@ -529,9 +526,6 @@ def open_files(filenames,
main_prog_reader = multi_pass(
reader=main_prog_reader, pass_num=pass_num)
if for_parallel:
main_prog_reader = parallel(reader=main_prog_reader)
return monkey_patch_reader_methods(main_prog_reader)
......@@ -647,11 +641,6 @@ def multi_pass(reader, pass_num):
'create_multi_pass_reader', reader, {'pass_num': int(pass_num)})
def parallel(reader):
return __create_shared_decorated_reader__('create_threaded_reader', reader,
{})
def read_file(reader):
"""
Execute the given reader and get data via it.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册