提交 6fc6cc2f 编写于 作者: F fengjiayi

Some updates on readers

1. Shrink DoubleBufferReader's buffer size to 3.
2. Add BatchReader an option to discard leftover instances.
3. Fix a MultiPassReader bug on pass count.
上级 5528f599
...@@ -20,8 +20,11 @@ namespace reader { ...@@ -20,8 +20,11 @@ namespace reader {
class BatchReader : public framework::DecoratedReader { class BatchReader : public framework::DecoratedReader {
public: public:
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size) BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size,
: DecoratedReader(reader), batch_size_(batch_size) { bool discard_leftover)
: DecoratedReader(reader),
batch_size_(batch_size),
discard_leftover_(discard_leftover) {
buffer_.reserve(batch_size_); buffer_.reserve(batch_size_);
Start(); Start();
} }
...@@ -30,6 +33,7 @@ class BatchReader : public framework::DecoratedReader { ...@@ -30,6 +33,7 @@ class BatchReader : public framework::DecoratedReader {
private: private:
int batch_size_; int batch_size_;
bool discard_leftover_;
std::vector<std::vector<framework::LoDTensor>> buffer_; std::vector<std::vector<framework::LoDTensor>> buffer_;
}; };
...@@ -47,8 +51,8 @@ class CreateBatchReaderOp : public framework::OperatorBase { ...@@ -47,8 +51,8 @@ class CreateBatchReaderOp : public framework::OperatorBase {
} }
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
out->Reset( out->Reset(new BatchReader(underlying_reader.Get(), Attr<int>("batch_size"),
new BatchReader(underlying_reader.Get(), Attr<int>("batch_size"))); Attr<bool>("discard_leftover")));
} }
}; };
...@@ -58,6 +62,10 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -58,6 +62,10 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
AddAttr<int>("batch_size", AddAttr<int>("batch_size",
"How many instances the batch reader yields each time.") "How many instances the batch reader yields each time.")
.GreaterThan(0); .GreaterThan(0);
AddAttr<bool>("discard_leftover",
"If true, the leftover instances that are not enough for a "
"new batch will be discarded.")
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
CreateBatchReader Operator CreateBatchReader Operator
...@@ -78,6 +86,9 @@ void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) { ...@@ -78,6 +86,9 @@ void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
break; break;
} }
} }
if (discard_leftover_ && buffer_.size() < batch_size_) {
buffer_.clear();
}
// Concat instances // Concat instances
out->clear(); out->clear();
if (buffer_.empty()) { if (buffer_.empty()) {
......
...@@ -23,13 +23,13 @@ namespace reader { ...@@ -23,13 +23,13 @@ namespace reader {
// 'Double buffer' means we shall maintain two batches of input data at the same // 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2. // time. So the kCacheSize shoul be at least 2.
static constexpr size_t kCacheSize = 5; static constexpr size_t kCacheSize = 3;
// There will be two bacthes out of the channel during training: // There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel // 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by // 2. the one just be received from the channel, which is also being used by
// subsequent operators. // subsequent operators.
// So the channel size should be kChacheSize - 2 // So the channel size should be kChacheSize - 2
static constexpr size_t kChannelSize = 3; // kCacheSize - 2 static constexpr size_t kChannelSize = 1; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader { class DoubleBufferReader : public framework::DecoratedReader {
public: public:
......
...@@ -28,13 +28,11 @@ class MultiPassReader : public framework::DecoratedReader { ...@@ -28,13 +28,11 @@ class MultiPassReader : public framework::DecoratedReader {
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
reader_->ReadNext(out); reader_->ReadNext(out);
if (out->empty()) { if (out->empty() && pass_count_ < pass_num_ - 1) {
++pass_count_;
if (pass_count_ < pass_num_) {
reader_->Shutdown(); reader_->Shutdown();
reader_->Start(); reader_->Start();
reader_->ReadNext(out); reader_->ReadNext(out);
} ++pass_count_;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册