diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index afcae6cdcdcf5a3fcec5891f9303f21dd7a58b68..670df672f885a31adcf0ed05d3b02e9068353ccf 100755 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -69,6 +69,14 @@ bool DataFeed::SetFileList(const std::vector& files) { return true; } +void DataFeed::SetBatchSize(int batch_size) { + if (batch_size <= 0) { + LOG(ERROR) << "error: illegal batch size: " << batch_size; + exit(-1); + } + default_batch_size_ = batch_size; +} + bool DataFeed::PickOneFile(std::string& filename) { std::unique_lock lock(mutex_for_pick_file_); if (file_idx_ == filelist_.size()) { diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index f40c16f5c33bd0ac5ee335d14d4e5b6d494d1b67..db3b1e1af66877c20e994af68a71e34d45a1db80 100755 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -164,7 +164,7 @@ class DataFeed { virtual bool SetFileList(const std::vector& files); virtual bool Start() = 0; virtual int Next() = 0; - virtual void SetBatchSize(int batch) { default_batch_size_ = batch; } + virtual void SetBatchSize(int batch); virtual int GetBatchSize() { return batch_size_; } // for subclass with queue virtual void SetQueueSize(int queue_size) {