提交 b4f0e579 编写于 作者: F fengjiayi

fix errors

上级 6fc6cc2f
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <atomic>
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -43,7 +44,7 @@ class ReaderBase { ...@@ -43,7 +44,7 @@ class ReaderBase {
virtual void StartImpl() = 0; virtual void StartImpl() = 0;
std::atomic<ReaderStatus> status_{kStopped}; std::atomic<ReaderStatus> status_{kRunning};
}; };
class DecoratedReader : public ReaderBase { class DecoratedReader : public ReaderBase {
......
...@@ -26,7 +26,6 @@ class BatchReader : public framework::DecoratedReader { ...@@ -26,7 +26,6 @@ class BatchReader : public framework::DecoratedReader {
batch_size_(batch_size), batch_size_(batch_size),
discard_leftover_(discard_leftover) { discard_leftover_(discard_leftover) {
buffer_.reserve(batch_size_); buffer_.reserve(batch_size_);
Start();
} }
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override; void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
......
...@@ -31,9 +31,7 @@ class CustomReader : public framework::DecoratedReader { ...@@ -31,9 +31,7 @@ class CustomReader : public framework::DecoratedReader {
sub_block_id_(sub_block.ID()), sub_block_id_(sub_block.ID()),
exe_(framework::Executor(platform::CPUPlace())), exe_(framework::Executor(platform::CPUPlace())),
source_var_names_(source_var_names), source_var_names_(source_var_names),
sink_var_names_(sink_var_names) { sink_var_names_(sink_var_names) {}
Start();
}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override; void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
......
...@@ -47,12 +47,12 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -47,12 +47,12 @@ class DoubleBufferReader : public framework::DecoratedReader {
} }
} }
#endif #endif
Start(); StartPrefetcher();
} }
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override; void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
~DoubleBufferReader() { Shutdown(); } ~DoubleBufferReader() { EndPrefetcher(); }
private: private:
void ShutdownImpl() override { void ShutdownImpl() override {
......
...@@ -22,9 +22,7 @@ namespace reader { ...@@ -22,9 +22,7 @@ namespace reader {
class MultiPassReader : public framework::DecoratedReader { class MultiPassReader : public framework::DecoratedReader {
public: public:
MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num) MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num)
: DecoratedReader(reader), pass_num_(pass_num) { : DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
Start();
}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
reader_->ReadNext(out); reader_->ReadNext(out);
......
...@@ -30,7 +30,6 @@ class RandomDataGenerator : public framework::FileReader { ...@@ -30,7 +30,6 @@ class RandomDataGenerator : public framework::FileReader {
unsigned int seed = std::random_device()(); unsigned int seed = std::random_device()();
engine_.seed(seed); engine_.seed(seed);
dist_ = std::uniform_real_distribution<float>(low_, high_); dist_ = std::uniform_real_distribution<float>(low_, high_);
Start();
} }
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
......
...@@ -30,7 +30,6 @@ class RecordIOFileReader : public framework::FileReader { ...@@ -30,7 +30,6 @@ class RecordIOFileReader : public framework::FileReader {
mutex_.reset(new std::mutex()); mutex_.reset(new std::mutex());
} }
LOG(INFO) << "Creating file reader" << filename; LOG(INFO) << "Creating file reader" << filename;
Start();
} }
protected: protected:
......
...@@ -31,7 +31,7 @@ class ShuffleReader : public framework::DecoratedReader { ...@@ -31,7 +31,7 @@ class ShuffleReader : public framework::DecoratedReader {
std::random_device device; std::random_device device;
seed_ = device(); seed_ = device();
} }
Start(); ReloadBuffer();
} }
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
......
...@@ -22,9 +22,7 @@ namespace reader { ...@@ -22,9 +22,7 @@ namespace reader {
class ThreadedReader : public framework::DecoratedReader { class ThreadedReader : public framework::DecoratedReader {
public: public:
explicit ThreadedReader(const std::shared_ptr<ReaderBase>& reader) explicit ThreadedReader(const std::shared_ptr<ReaderBase>& reader)
: DecoratedReader(reader) { : DecoratedReader(reader) {}
Start();
}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
......
...@@ -31,16 +31,20 @@ class MultiFileReader : public framework::ReaderBase { ...@@ -31,16 +31,20 @@ class MultiFileReader : public framework::ReaderBase {
readers_.emplace_back(CreateReaderByFileName(f_name)); readers_.emplace_back(CreateReaderByFileName(f_name));
} }
prefetchers_.resize(thread_num); prefetchers_.resize(thread_num);
Start(); StartNewScheduler();
} }
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override; void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
~MultiFileReader() { Shutdown(); } ~MultiFileReader() { EndScheduler(); }
private: private:
void StartImpl() override; void ShutdownImpl() override { EndScheduler(); }
void ShutdownImpl() override;
void StartImpl() override { StartNewScheduler(); }
void StartNewScheduler();
void EndScheduler();
void ScheduleThreadFunc(); void ScheduleThreadFunc();
void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx); void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx);
...@@ -59,7 +63,7 @@ void MultiFileReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) { ...@@ -59,7 +63,7 @@ void MultiFileReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
} }
} }
void MultiFileReader::StartImpl() { void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size(); size_t thread_num = prefetchers_.size();
waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size()); waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size());
available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num); available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num);
...@@ -77,7 +81,7 @@ void MultiFileReader::StartImpl() { ...@@ -77,7 +81,7 @@ void MultiFileReader::StartImpl() {
scheduler_ = std::thread([this] { ScheduleThreadFunc(); }); scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
} }
void MultiFileReader::ShutdownImpl() { void MultiFileReader::EndScheduler() {
available_thread_idx_->Close(); available_thread_idx_->Close();
buffer_->Close(); buffer_->Close();
waiting_reader_idx_->Close(); waiting_reader_idx_->Close();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册