From 0d698a38824b57ebcf63dba64bacb511f6d67402 Mon Sep 17 00:00:00 2001 From: barrierye Date: Wed, 21 Nov 2018 19:36:35 +0800 Subject: [PATCH] add batch_size check --- paddle/fluid/framework/data_feed.cc | 8 ++++++++ paddle/fluid/framework/data_feed.h | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index afcae6cdcdc..670df672f88 100755 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -69,6 +69,14 @@ bool DataFeed::SetFileList(const std::vector& files) { return true; } +void DataFeed::SetBatchSize(int batch_size) { + if (batch_size <= 0) { + LOG(ERROR) << "error: illegal batch size: " << batch_size; + exit(-1); + } + default_batch_size_ = batch_size; +} + bool DataFeed::PickOneFile(std::string& filename) { std::unique_lock lock(mutex_for_pick_file_); if (file_idx_ == filelist_.size()) { diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index f40c16f5c33..db3b1e1af66 100755 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -164,7 +164,7 @@ class DataFeed { virtual bool SetFileList(const std::vector& files); virtual bool Start() = 0; virtual int Next() = 0; - virtual void SetBatchSize(int batch) { default_batch_size_ = batch; } + virtual void SetBatchSize(int batch); virtual int GetBatchSize() { return batch_size_; } // for subclass with queue virtual void SetQueueSize(int queue_size) { -- GitLab