diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc old mode 100644 new mode 100755 index 639b546ff394c3394f6c2d27396a208fde573eea..31335fa7231e28fb8872da754f33607c0304848c --- a/paddle/fluid/framework/async_executor.cc +++ b/paddle/fluid/framework/async_executor.cc @@ -137,6 +137,18 @@ void AsyncExecutor::SetModelPrefix(const std::string& model_prefix) { model_prefix_ = model_prefix; } +void PrepareReaders(std::vector >& readers, + const int thread_num, DataFeedDesc& data_feed_desc, + const std::vector& filelist) { + readers.resize(thread_num); + for (size_t i = 0; i < readers.size(); ++i) { + readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name()); + readers[i]->Init(data_feed_desc); // set batch size here + //readers[i]->SetQueueSize(32); // default is 32 + } + readers[0]->SetFileList(filelist); +} + std::vector AsyncExecutor::RunFromFile( const ProgramDesc& main_program, const std::string& data_feed_desc_str, @@ -159,11 +171,8 @@ std::vector AsyncExecutor::RunFromFile( */ // todo: should be factory method for creating datafeed std::vector > readers; - readers.resize(thread_num); - for (unsigned int i = 0; i < readers.size(); ++i) { - readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name()); - } - + PrepareReaders(readers, thread_num, data_feed_desc, filelist); + std::vector > workers; workers.resize(thread_num); for (auto& worker : workers) { diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index e4255101dca5a8e56538a2f5fde37c87a288dfba..556591acdd2092d189edf790ef663e66ada17bac 100755 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -135,11 +135,11 @@ void PrivateQueueDataFeed::ReadThread(){ } template -bool PrivateQueueDataFeed::Next(){ +int PrivateQueueDataFeed::Next(){ CheckStart(); int index = 0; T instance; - T ins_vec(use_slots_.size()); + T ins_vec; while (index < default_batch_size_) { if (!queue_.Receive(&instance)) { break; @@ -147,8 +147,10 @@ bool PrivateQueueDataFeed::Next(){ AddInstanceToInsVec(ins_vec, instance, index++); } batch_size_ = index; - PutToFeedVec(ins_vec); - return batch_size_ != 0; + if (batch_size_ != 0) { + PutToFeedVec(ins_vec); + } + return batch_size_; } void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) { @@ -161,6 +163,7 @@ void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) { exit(-1); } paddle::framework::MultiSlotDesc multi_slot_desc = data_feed_desc.multi_slot_desc(); + SetBatchSize(data_feed_desc.batch()); size_t all_slot_num = multi_slot_desc.slots_size(); all_slots_.resize(all_slot_num); all_slots_type_.resize(all_slot_num); @@ -178,7 +181,7 @@ void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) { } } feed_vec_.resize(use_slots_.size()); - + finish_init_ = true; } @@ -205,7 +208,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector& instance) { exit(-1); } if (idx != -1) { - instance[idx].SetType(all_slots_type_[i]); + instance[idx].Init(all_slots_type_[i]); if (instance[idx].GetType()[0] == 'f') { // float for (int j = 0; j < num; ++j) { float feasign = (float)strtof(endptr, &endptr); @@ -233,8 +236,10 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector& instance) { void MultiSlotDataFeed::AddInstanceToInsVec(std::vector& ins_vec, std::vector& instance, int index) { if (index == 0) { + ins_vec.resize(instance.size()); for (size_t i = 0; i < instance.size(); ++i) { - ins_vec[i].SetType(instance[i].GetType()); + ins_vec[i].Init(instance[i].GetType()); + ins_vec[i].InitOffset(); } } for (size_t i = 0; i < instance.size(); ++i){ diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 388925b5af4297d982ae0569005a2b3a8be6a7d2..41cc9659273d31e871afa11c82f594032b74cdf5 100755 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -163,7 +163,7 @@ class DataFeed { } virtual bool SetFileList(const std::vector& files); virtual bool Start() = 0; - virtual bool Next() = 0; + virtual int Next() = 0; virtual void SetBatchSize(int batch) { default_batch_size_ = batch; } virtual int GetBatchSize() { return batch_size_; } // for subclass with queue @@ -217,7 +217,7 @@ class PrivateQueueDataFeed : public DataFeed { virtual ~PrivateQueueDataFeed() {} virtual void Init(paddle::framework::DataFeedDesc& data_feed_desc) = 0; virtual bool Start(); - virtual bool Next(); // no buffer + virtual int Next(); // no buffer virtual void SetQueueSize(int queue_size); protected: @@ -234,24 +234,28 @@ class PrivateQueueDataFeed : public DataFeed { * fread one buffer and one buffer parse: 7097 ms */ std::ifstream file_; size_t queue_size_; - // The elements in the queue are one piece of data, - // with multiple fields in each piece of data BlockingQueue queue_; }; class MultiSlotType { public: - MultiSlotType() { - float_feasign_.clear(); - uint64_feasign_.clear(); - offset_.resize(1); - offset_[0] = 0; - } + MultiSlotType() {} ~MultiSlotType() {} - void SetType(std::string& type) { + void Init(std::string& type) { CheckType(type); + if (type_[0] == 'f') { + float_feasign_.clear(); + } else if (type_[0] == 'u') { + uint64_feasign_.clear(); + } type_ = type; } + void InitOffset() { + offset_.resize(1); + // LoDTensor' lod is counted from 0, the size of lod + // is one size larger than the size of data. + offset_[0] = 0; + } std::vector& GetOffset() { return offset_; }