提交 5528f599 编写于 作者: F fengjiayi

Split ReInit() to Shutdown() and Start()

上级 de9a411f
......@@ -16,8 +16,29 @@
namespace paddle {
namespace framework {
void ReaderBase::ReadNext(std::vector<LoDTensor> *out) {
if (status_ != ReaderStatus::kRunning) {
PADDLE_THROW("The reader is not at the status of 'running'.");
}
ReadNextImpl(out);
}
void ReaderBase::Shutdown() {
if (status_ != ReaderStatus::kStopped) {
ShutdownImpl();
status_ = ReaderStatus::kStopped;
}
}
void ReaderBase::Start() {
if (status_ != ReaderStatus::kRunning) {
StartImpl();
status_ = ReaderStatus::kRunning;
}
}
ReaderBase::~ReaderBase() {}
void FileReader::ReadNext(std::vector<LoDTensor> *out) { ReadNextImpl(out); }
} // namespace framework
} // namespace paddle
......@@ -24,13 +24,26 @@
namespace paddle {
namespace framework {
enum ReaderStatus { kRunning, kStopped };
class ReaderBase {
public:
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
void ReadNext(std::vector<LoDTensor>* out);
void Shutdown();
virtual void ReInit() = 0;
void Start();
virtual ~ReaderBase();
protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
virtual void ShutdownImpl() = 0;
virtual void StartImpl() = 0;
std::atomic<ReaderStatus> status_{kStopped};
};
class DecoratedReader : public ReaderBase {
......@@ -40,9 +53,11 @@ class DecoratedReader : public ReaderBase {
PADDLE_ENFORCE_NOT_NULL(reader_);
}
void ReInit() override { reader_->ReInit(); }
protected:
void ShutdownImpl() override { reader_->Shutdown(); }
void StartImpl() override { reader_->Start(); }
std::shared_ptr<ReaderBase> reader_;
};
......@@ -50,10 +65,10 @@ class FileReader : public ReaderBase {
public:
FileReader() : ReaderBase() {}
void ReadNext(std::vector<LoDTensor>* out) override;
protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
void ShutdownImpl() override {}
void StartImpl() override {}
};
// The ReaderHolder is used as reader' unified wrapper,
......@@ -68,9 +83,19 @@ class ReaderHolder {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReadNext(out);
}
void ReInit() {
void ResetAll() {
// TODO(fengjiayi): The interface of reseting all.
}
void Shutdown() {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->Shutdown();
}
void Start() {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReInit();
reader_->Start();
}
private:
......
......@@ -23,9 +23,10 @@ class BatchReader : public framework::DecoratedReader {
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size)
: DecoratedReader(reader), batch_size_(batch_size) {
buffer_.reserve(batch_size_);
Start();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
private:
int batch_size_;
......@@ -66,7 +67,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
}
};
void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
buffer_.clear();
buffer_.reserve(batch_size_);
for (int i = 0; i < batch_size_; ++i) {
......
......@@ -31,9 +31,11 @@ class CustomReader : public framework::DecoratedReader {
sub_block_id_(sub_block.ID()),
exe_(framework::Executor(platform::CPUPlace())),
source_var_names_(source_var_names),
sink_var_names_(sink_var_names) {}
sink_var_names_(sink_var_names) {
Start();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
private:
const framework::ProgramDesc program_;
......@@ -143,7 +145,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
}
};
void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
void CustomReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
out->clear();
std::vector<framework::LoDTensor> underlying_outs;
reader_->ReadNext(&underlying_outs);
......
......@@ -47,15 +47,24 @@ class DoubleBufferReader : public framework::DecoratedReader {
}
}
#endif
StartPrefetcher();
Start();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
~DoubleBufferReader() { EndPrefetcher(); }
~DoubleBufferReader() { Shutdown(); }
private:
void ShutdownImpl() override {
EndPrefetcher();
reader_->Shutdown();
}
void StartImpl() override {
reader_->Start();
StartPrefetcher();
}
void StartPrefetcher() {
channel_ = new reader::BlockingQueue<size_t>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
......@@ -136,7 +145,7 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}
};
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
void DoubleBufferReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
size_t cached_tensor_id;
if (channel_->Receive(&cached_tensor_id)) {
if (platform::is_gpu_place(place_)) {
......@@ -150,12 +159,6 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
}
}
void DoubleBufferReader::ReInit() {
EndPrefetcher();
reader_->ReInit();
StartPrefetcher();
}
void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts.";
size_t cached_tensor_id = 0;
......
......@@ -22,25 +22,28 @@ namespace reader {
class MultiPassReader : public framework::DecoratedReader {
public:
MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num)
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
: DecoratedReader(reader), pass_num_(pass_num) {
Start();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
reader_->ReadNext(out);
if (out->empty()) {
++pass_count_;
if (pass_count_ < pass_num_) {
reader_->ReInit();
reader_->Shutdown();
reader_->Start();
reader_->ReadNext(out);
}
}
}
void ReInit() override {
private:
void StartImpl() override {
pass_count_ = 0;
reader_->ReInit();
reader_->Start();
}
private:
int pass_num_;
mutable int pass_count_;
};
......
......@@ -33,9 +33,13 @@ class PyReader : public framework::FileReader {
if (!success) out->clear();
}
void ReInit() override {}
private:
void ShutdownImpl() override { /* TODO */
}
void StartImpl() override { /* TODO */
}
std::shared_ptr<LoDTensorBlockingQueue> queue_;
};
......
......@@ -30,6 +30,7 @@ class RandomDataGenerator : public framework::FileReader {
unsigned int seed = std::random_device()();
engine_.seed(seed);
dist_ = std::uniform_real_distribution<float>(low_, high_);
Start();
}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
......@@ -51,8 +52,6 @@ class RandomDataGenerator : public framework::FileReader {
}
}
void ReInit() override { return; }
private:
float low_;
float high_;
......
......@@ -30,10 +30,9 @@ class RecordIOFileReader : public framework::FileReader {
mutex_.reset(new std::mutex());
}
LOG(INFO) << "Creating file reader" << filename;
Start();
}
void ReInit() override { scanner_.Reset(); }
protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
if (ThreadSafe) {
......@@ -44,6 +43,8 @@ class RecordIOFileReader : public framework::FileReader {
}
}
void ShutdownImpl() override { scanner_.Reset(); }
private:
std::unique_ptr<std::mutex> mutex_;
recordio::Scanner scanner_;
......
......@@ -31,10 +31,10 @@ class ShuffleReader : public framework::DecoratedReader {
std::random_device device;
seed_ = device();
}
ReloadBuffer();
Start();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
out->clear();
if (iteration_pos_ >= buffer_.size()) {
VLOG(10) << "Resetting shuffle buffer";
......@@ -47,6 +47,17 @@ class ShuffleReader : public framework::DecoratedReader {
}
private:
void ShutdownImpl() override {
buffer_.clear();
iteration_pos_ = 0;
reader_->Shutdown();
}
void StartImpl() override {
reader_->Start();
ReloadBuffer();
}
void ReloadBuffer() {
buffer_.clear();
buffer_.reserve(buffer_size_);
......
......@@ -22,16 +22,26 @@ namespace reader {
class ThreadedReader : public framework::DecoratedReader {
public:
explicit ThreadedReader(const std::shared_ptr<ReaderBase>& reader)
: DecoratedReader(reader) {}
: DecoratedReader(reader) {
Start();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
std::lock_guard<std::mutex> lock(mutex_);
reader_->ReadNext(out);
}
void ReInit() override { reader_->ReInit(); }
private:
void ShutdownImpl() override {
std::lock_guard<std::mutex> lock(mutex_);
reader_->Shutdown();
}
void StartImpl() override {
std::lock_guard<std::mutex> lock(mutex_);
reader_->Start();
}
std::mutex mutex_;
};
......@@ -62,9 +72,6 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
time.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
environment.
)DOC");
}
};
......
......@@ -31,17 +31,16 @@ class MultiFileReader : public framework::ReaderBase {
readers_.emplace_back(CreateReaderByFileName(f_name));
}
prefetchers_.resize(thread_num);
StartNewScheduler();
Start();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
~MultiFileReader() { EndScheduler(); }
~MultiFileReader() { Shutdown(); }
private:
void StartNewScheduler();
void EndScheduler();
void StartImpl() override;
void ShutdownImpl() override;
void ScheduleThreadFunc();
void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx);
......@@ -54,18 +53,13 @@ class MultiFileReader : public framework::ReaderBase {
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
};
void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) {
void MultiFileReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
if (!buffer_->Receive(out)) {
out->clear();
}
}
void MultiFileReader::ReInit() {
EndScheduler();
StartNewScheduler();
}
void MultiFileReader::StartNewScheduler() {
void MultiFileReader::StartImpl() {
size_t thread_num = prefetchers_.size();
waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size());
available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num);
......@@ -83,7 +77,7 @@ void MultiFileReader::StartNewScheduler() {
scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
}
void MultiFileReader::EndScheduler() {
void MultiFileReader::ShutdownImpl() {
available_thread_idx_->Close();
buffer_->Close();
waiting_reader_idx_->Close();
......@@ -119,7 +113,7 @@ void MultiFileReader::ScheduleThreadFunc() {
}
}
}
// If users invoke ReInit() when scheduler is running, it will close the
// If users invoke Shutdown() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends.
for (auto& p : prefetchers_) {
......@@ -137,7 +131,8 @@ void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (ins.empty()) {
reader->ReInit();
reader->Shutdown();
reader->Start();
break;
}
try {
......
......@@ -296,7 +296,7 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("reset", &framework::ReaderHolder::ReInit);
.def("reset", &framework::ReaderHolder::ResetAll);
using LoDTensorBlockingQueue =
::paddle::operators::reader::LoDTensorBlockingQueue;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册