提交 5528f599 编写于 作者: F fengjiayi

Split ReInit() to Shutdown() and Start()

上级 de9a411f
...@@ -16,8 +16,29 @@ ...@@ -16,8 +16,29 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void ReaderBase::ReadNext(std::vector<LoDTensor> *out) {
if (status_ != ReaderStatus::kRunning) {
PADDLE_THROW("The reader is not at the status of 'running'.");
}
ReadNextImpl(out);
}
void ReaderBase::Shutdown() {
if (status_ != ReaderStatus::kStopped) {
ShutdownImpl();
status_ = ReaderStatus::kStopped;
}
}
void ReaderBase::Start() {
if (status_ != ReaderStatus::kRunning) {
StartImpl();
status_ = ReaderStatus::kRunning;
}
}
ReaderBase::~ReaderBase() {} ReaderBase::~ReaderBase() {}
void FileReader::ReadNext(std::vector<LoDTensor> *out) { ReadNextImpl(out); }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -24,13 +24,26 @@ ...@@ -24,13 +24,26 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
enum ReaderStatus { kRunning, kStopped };
class ReaderBase { class ReaderBase {
public: public:
virtual void ReadNext(std::vector<LoDTensor>* out) = 0; void ReadNext(std::vector<LoDTensor>* out);
void Shutdown();
virtual void ReInit() = 0; void Start();
virtual ~ReaderBase(); virtual ~ReaderBase();
protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
virtual void ShutdownImpl() = 0;
virtual void StartImpl() = 0;
std::atomic<ReaderStatus> status_{kStopped};
}; };
class DecoratedReader : public ReaderBase { class DecoratedReader : public ReaderBase {
...@@ -40,9 +53,11 @@ class DecoratedReader : public ReaderBase { ...@@ -40,9 +53,11 @@ class DecoratedReader : public ReaderBase {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
} }
void ReInit() override { reader_->ReInit(); }
protected: protected:
void ShutdownImpl() override { reader_->Shutdown(); }
void StartImpl() override { reader_->Start(); }
std::shared_ptr<ReaderBase> reader_; std::shared_ptr<ReaderBase> reader_;
}; };
...@@ -50,10 +65,10 @@ class FileReader : public ReaderBase { ...@@ -50,10 +65,10 @@ class FileReader : public ReaderBase {
public: public:
FileReader() : ReaderBase() {} FileReader() : ReaderBase() {}
void ReadNext(std::vector<LoDTensor>* out) override;
protected: protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0; void ShutdownImpl() override {}
void StartImpl() override {}
}; };
// The ReaderHolder is used as reader' unified wrapper, // The ReaderHolder is used as reader' unified wrapper,
...@@ -68,9 +83,19 @@ class ReaderHolder { ...@@ -68,9 +83,19 @@ class ReaderHolder {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReadNext(out); reader_->ReadNext(out);
} }
void ReInit() {
void ResetAll() {
// TODO(fengjiayi): The interface of reseting all.
}
void Shutdown() {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->Shutdown();
}
void Start() {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReInit(); reader_->Start();
} }
private: private:
......
...@@ -23,9 +23,10 @@ class BatchReader : public framework::DecoratedReader { ...@@ -23,9 +23,10 @@ class BatchReader : public framework::DecoratedReader {
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size) BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size)
: DecoratedReader(reader), batch_size_(batch_size) { : DecoratedReader(reader), batch_size_(batch_size) {
buffer_.reserve(batch_size_); buffer_.reserve(batch_size_);
Start();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
private: private:
int batch_size_; int batch_size_;
...@@ -66,7 +67,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -66,7 +67,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
} }
}; };
void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) { void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
buffer_.clear(); buffer_.clear();
buffer_.reserve(batch_size_); buffer_.reserve(batch_size_);
for (int i = 0; i < batch_size_; ++i) { for (int i = 0; i < batch_size_; ++i) {
......
...@@ -31,9 +31,11 @@ class CustomReader : public framework::DecoratedReader { ...@@ -31,9 +31,11 @@ class CustomReader : public framework::DecoratedReader {
sub_block_id_(sub_block.ID()), sub_block_id_(sub_block.ID()),
exe_(framework::Executor(platform::CPUPlace())), exe_(framework::Executor(platform::CPUPlace())),
source_var_names_(source_var_names), source_var_names_(source_var_names),
sink_var_names_(sink_var_names) {} sink_var_names_(sink_var_names) {
Start();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
private: private:
const framework::ProgramDesc program_; const framework::ProgramDesc program_;
...@@ -143,7 +145,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference { ...@@ -143,7 +145,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
} }
}; };
void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) { void CustomReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
out->clear(); out->clear();
std::vector<framework::LoDTensor> underlying_outs; std::vector<framework::LoDTensor> underlying_outs;
reader_->ReadNext(&underlying_outs); reader_->ReadNext(&underlying_outs);
......
...@@ -47,15 +47,24 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -47,15 +47,24 @@ class DoubleBufferReader : public framework::DecoratedReader {
} }
} }
#endif #endif
StartPrefetcher(); Start();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;
~DoubleBufferReader() { EndPrefetcher(); } ~DoubleBufferReader() { Shutdown(); }
private: private:
void ShutdownImpl() override {
EndPrefetcher();
reader_->Shutdown();
}
void StartImpl() override {
reader_->Start();
StartPrefetcher();
}
void StartPrefetcher() { void StartPrefetcher() {
channel_ = new reader::BlockingQueue<size_t>(kChannelSize); channel_ = new reader::BlockingQueue<size_t>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
...@@ -136,7 +145,7 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -136,7 +145,7 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
} }
}; };
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { void DoubleBufferReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
size_t cached_tensor_id; size_t cached_tensor_id;
if (channel_->Receive(&cached_tensor_id)) { if (channel_->Receive(&cached_tensor_id)) {
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
...@@ -150,12 +159,6 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { ...@@ -150,12 +159,6 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
} }
} }
void DoubleBufferReader::ReInit() {
EndPrefetcher();
reader_->ReInit();
StartPrefetcher();
}
void DoubleBufferReader::PrefetchThreadFunc() { void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts."; VLOG(5) << "A new prefetch thread starts.";
size_t cached_tensor_id = 0; size_t cached_tensor_id = 0;
......
...@@ -22,25 +22,28 @@ namespace reader { ...@@ -22,25 +22,28 @@ namespace reader {
class MultiPassReader : public framework::DecoratedReader { class MultiPassReader : public framework::DecoratedReader {
public: public:
MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num) MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num)
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {} : DecoratedReader(reader), pass_num_(pass_num) {
Start();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
reader_->ReadNext(out); reader_->ReadNext(out);
if (out->empty()) { if (out->empty()) {
++pass_count_; ++pass_count_;
if (pass_count_ < pass_num_) { if (pass_count_ < pass_num_) {
reader_->ReInit(); reader_->Shutdown();
reader_->Start();
reader_->ReadNext(out); reader_->ReadNext(out);
} }
} }
} }
void ReInit() override { private:
void StartImpl() override {
pass_count_ = 0; pass_count_ = 0;
reader_->ReInit(); reader_->Start();
} }
private:
int pass_num_; int pass_num_;
mutable int pass_count_; mutable int pass_count_;
}; };
......
...@@ -33,9 +33,13 @@ class PyReader : public framework::FileReader { ...@@ -33,9 +33,13 @@ class PyReader : public framework::FileReader {
if (!success) out->clear(); if (!success) out->clear();
} }
void ReInit() override {}
private: private:
void ShutdownImpl() override { /* TODO */
}
void StartImpl() override { /* TODO */
}
std::shared_ptr<LoDTensorBlockingQueue> queue_; std::shared_ptr<LoDTensorBlockingQueue> queue_;
}; };
......
...@@ -30,6 +30,7 @@ class RandomDataGenerator : public framework::FileReader { ...@@ -30,6 +30,7 @@ class RandomDataGenerator : public framework::FileReader {
unsigned int seed = std::random_device()(); unsigned int seed = std::random_device()();
engine_.seed(seed); engine_.seed(seed);
dist_ = std::uniform_real_distribution<float>(low_, high_); dist_ = std::uniform_real_distribution<float>(low_, high_);
Start();
} }
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
...@@ -51,8 +52,6 @@ class RandomDataGenerator : public framework::FileReader { ...@@ -51,8 +52,6 @@ class RandomDataGenerator : public framework::FileReader {
} }
} }
void ReInit() override { return; }
private: private:
float low_; float low_;
float high_; float high_;
......
...@@ -30,10 +30,9 @@ class RecordIOFileReader : public framework::FileReader { ...@@ -30,10 +30,9 @@ class RecordIOFileReader : public framework::FileReader {
mutex_.reset(new std::mutex()); mutex_.reset(new std::mutex());
} }
LOG(INFO) << "Creating file reader" << filename; LOG(INFO) << "Creating file reader" << filename;
Start();
} }
void ReInit() override { scanner_.Reset(); }
protected: protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
if (ThreadSafe) { if (ThreadSafe) {
...@@ -44,6 +43,8 @@ class RecordIOFileReader : public framework::FileReader { ...@@ -44,6 +43,8 @@ class RecordIOFileReader : public framework::FileReader {
} }
} }
void ShutdownImpl() override { scanner_.Reset(); }
private: private:
std::unique_ptr<std::mutex> mutex_; std::unique_ptr<std::mutex> mutex_;
recordio::Scanner scanner_; recordio::Scanner scanner_;
......
...@@ -31,10 +31,10 @@ class ShuffleReader : public framework::DecoratedReader { ...@@ -31,10 +31,10 @@ class ShuffleReader : public framework::DecoratedReader {
std::random_device device; std::random_device device;
seed_ = device(); seed_ = device();
} }
ReloadBuffer(); Start();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
out->clear(); out->clear();
if (iteration_pos_ >= buffer_.size()) { if (iteration_pos_ >= buffer_.size()) {
VLOG(10) << "Resetting shuffle buffer"; VLOG(10) << "Resetting shuffle buffer";
...@@ -47,6 +47,17 @@ class ShuffleReader : public framework::DecoratedReader { ...@@ -47,6 +47,17 @@ class ShuffleReader : public framework::DecoratedReader {
} }
private: private:
void ShutdownImpl() override {
buffer_.clear();
iteration_pos_ = 0;
reader_->Shutdown();
}
void StartImpl() override {
reader_->Start();
ReloadBuffer();
}
void ReloadBuffer() { void ReloadBuffer() {
buffer_.clear(); buffer_.clear();
buffer_.reserve(buffer_size_); buffer_.reserve(buffer_size_);
......
...@@ -22,16 +22,26 @@ namespace reader { ...@@ -22,16 +22,26 @@ namespace reader {
class ThreadedReader : public framework::DecoratedReader { class ThreadedReader : public framework::DecoratedReader {
public: public:
explicit ThreadedReader(const std::shared_ptr<ReaderBase>& reader) explicit ThreadedReader(const std::shared_ptr<ReaderBase>& reader)
: DecoratedReader(reader) {} : DecoratedReader(reader) {
Start();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
reader_->ReadNext(out); reader_->ReadNext(out);
} }
void ReInit() override { reader_->ReInit(); }
private: private:
void ShutdownImpl() override {
std::lock_guard<std::mutex> lock(mutex_);
reader_->Shutdown();
}
void StartImpl() override {
std::lock_guard<std::mutex> lock(mutex_);
reader_->Start();
}
std::mutex mutex_; std::mutex mutex_;
}; };
...@@ -62,9 +72,6 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -62,9 +72,6 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
This operator creates a threaded reader. A threaded reader's This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same 'ReadNext()' can be invoked by several threads at the same
time. time.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
environment.
)DOC"); )DOC");
} }
}; };
......
...@@ -31,17 +31,16 @@ class MultiFileReader : public framework::ReaderBase { ...@@ -31,17 +31,16 @@ class MultiFileReader : public framework::ReaderBase {
readers_.emplace_back(CreateReaderByFileName(f_name)); readers_.emplace_back(CreateReaderByFileName(f_name));
} }
prefetchers_.resize(thread_num); prefetchers_.resize(thread_num);
StartNewScheduler(); Start();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;
~MultiFileReader() { EndScheduler(); } ~MultiFileReader() { Shutdown(); }
private: private:
void StartNewScheduler(); void StartImpl() override;
void EndScheduler(); void ShutdownImpl() override;
void ScheduleThreadFunc(); void ScheduleThreadFunc();
void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx); void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx);
...@@ -54,18 +53,13 @@ class MultiFileReader : public framework::ReaderBase { ...@@ -54,18 +53,13 @@ class MultiFileReader : public framework::ReaderBase {
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_; reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
}; };
void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) { void MultiFileReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
if (!buffer_->Receive(out)) { if (!buffer_->Receive(out)) {
out->clear(); out->clear();
} }
} }
void MultiFileReader::ReInit() { void MultiFileReader::StartImpl() {
EndScheduler();
StartNewScheduler();
}
void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size(); size_t thread_num = prefetchers_.size();
waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size()); waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size());
available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num); available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num);
...@@ -83,7 +77,7 @@ void MultiFileReader::StartNewScheduler() { ...@@ -83,7 +77,7 @@ void MultiFileReader::StartNewScheduler() {
scheduler_ = std::thread([this] { ScheduleThreadFunc(); }); scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
} }
void MultiFileReader::EndScheduler() { void MultiFileReader::ShutdownImpl() {
available_thread_idx_->Close(); available_thread_idx_->Close();
buffer_->Close(); buffer_->Close();
waiting_reader_idx_->Close(); waiting_reader_idx_->Close();
...@@ -119,7 +113,7 @@ void MultiFileReader::ScheduleThreadFunc() { ...@@ -119,7 +113,7 @@ void MultiFileReader::ScheduleThreadFunc() {
} }
} }
} }
// If users invoke ReInit() when scheduler is running, it will close the // If users invoke Shutdown() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler // 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends. // to release their resource. So a check is needed before scheduler ends.
for (auto& p : prefetchers_) { for (auto& p : prefetchers_) {
...@@ -137,7 +131,8 @@ void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) { ...@@ -137,7 +131,8 @@ void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
std::vector<framework::LoDTensor> ins; std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins); reader->ReadNext(&ins);
if (ins.empty()) { if (ins.empty()) {
reader->ReInit(); reader->Shutdown();
reader->Start();
break; break;
} }
try { try {
......
...@@ -296,7 +296,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -296,7 +296,7 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference); py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "") py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("reset", &framework::ReaderHolder::ReInit); .def("reset", &framework::ReaderHolder::ResetAll);
using LoDTensorBlockingQueue = using LoDTensorBlockingQueue =
::paddle::operators::reader::LoDTensorBlockingQueue; ::paddle::operators::reader::LoDTensorBlockingQueue;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册