提交 b4f0e579 编写于 作者: F fengjiayi

fix errors

上级 6fc6cc2f
......@@ -14,6 +14,7 @@
#pragma once
#include <atomic>
#include <memory>
#include <vector>
......@@ -43,7 +44,7 @@ class ReaderBase {
virtual void StartImpl() = 0;
std::atomic<ReaderStatus> status_{kStopped};
std::atomic<ReaderStatus> status_{kRunning};
};
class DecoratedReader : public ReaderBase {
......
......@@ -26,7 +26,6 @@ class BatchReader : public framework::DecoratedReader {
batch_size_(batch_size),
discard_leftover_(discard_leftover) {
buffer_.reserve(batch_size_);
Start();
}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
......
......@@ -31,9 +31,7 @@ class CustomReader : public framework::DecoratedReader {
sub_block_id_(sub_block.ID()),
exe_(framework::Executor(platform::CPUPlace())),
source_var_names_(source_var_names),
sink_var_names_(sink_var_names) {
Start();
}
sink_var_names_(sink_var_names) {}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
......
......@@ -47,12 +47,12 @@ class DoubleBufferReader : public framework::DecoratedReader {
}
}
#endif
Start();
StartPrefetcher();
}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
~DoubleBufferReader() { Shutdown(); }
~DoubleBufferReader() { EndPrefetcher(); }
private:
void ShutdownImpl() override {
......
......@@ -22,9 +22,7 @@ namespace reader {
class MultiPassReader : public framework::DecoratedReader {
public:
MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num)
: DecoratedReader(reader), pass_num_(pass_num) {
Start();
}
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
reader_->ReadNext(out);
......
......@@ -30,7 +30,6 @@ class RandomDataGenerator : public framework::FileReader {
unsigned int seed = std::random_device()();
engine_.seed(seed);
dist_ = std::uniform_real_distribution<float>(low_, high_);
Start();
}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
......
......@@ -30,7 +30,6 @@ class RecordIOFileReader : public framework::FileReader {
mutex_.reset(new std::mutex());
}
LOG(INFO) << "Creating file reader" << filename;
Start();
}
protected:
......
......@@ -31,7 +31,7 @@ class ShuffleReader : public framework::DecoratedReader {
std::random_device device;
seed_ = device();
}
Start();
ReloadBuffer();
}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
......
......@@ -22,9 +22,7 @@ namespace reader {
class ThreadedReader : public framework::DecoratedReader {
public:
explicit ThreadedReader(const std::shared_ptr<ReaderBase>& reader)
: DecoratedReader(reader) {
Start();
}
: DecoratedReader(reader) {}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
std::lock_guard<std::mutex> lock(mutex_);
......
......@@ -31,16 +31,20 @@ class MultiFileReader : public framework::ReaderBase {
readers_.emplace_back(CreateReaderByFileName(f_name));
}
prefetchers_.resize(thread_num);
Start();
StartNewScheduler();
}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
~MultiFileReader() { Shutdown(); }
~MultiFileReader() { EndScheduler(); }
private:
void StartImpl() override;
void ShutdownImpl() override;
void ShutdownImpl() override { EndScheduler(); }
void StartImpl() override { StartNewScheduler(); }
void StartNewScheduler();
void EndScheduler();
void ScheduleThreadFunc();
void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx);
......@@ -59,7 +63,7 @@ void MultiFileReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
}
}
void MultiFileReader::StartImpl() {
void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size();
waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size());
available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num);
......@@ -77,7 +81,7 @@ void MultiFileReader::StartImpl() {
scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
}
void MultiFileReader::ShutdownImpl() {
void MultiFileReader::EndScheduler() {
available_thread_idx_->Close();
buffer_->Close();
waiting_reader_idx_->Close();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册