提交 8a7b6118 编写于 作者: B barrierye

update datafeed and async_executor to run bow_net demo

上级 664be756
...@@ -137,6 +137,18 @@ void AsyncExecutor::SetModelPrefix(const std::string& model_prefix) { ...@@ -137,6 +137,18 @@ void AsyncExecutor::SetModelPrefix(const std::string& model_prefix) {
model_prefix_ = model_prefix; model_prefix_ = model_prefix;
} }
void PrepareReaders(std::vector<std::shared_ptr<DataFeed> >& readers,
const int thread_num, DataFeedDesc& data_feed_desc,
const std::vector<std::string>& filelist) {
readers.resize(thread_num);
for (size_t i = 0; i < readers.size(); ++i) {
readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
readers[i]->Init(data_feed_desc); // set batch size here
//readers[i]->SetQueueSize(32); // default is 32
}
readers[0]->SetFileList(filelist);
}
std::vector<float> AsyncExecutor::RunFromFile( std::vector<float> AsyncExecutor::RunFromFile(
const ProgramDesc& main_program, const ProgramDesc& main_program,
const std::string& data_feed_desc_str, const std::string& data_feed_desc_str,
...@@ -159,11 +171,8 @@ std::vector<float> AsyncExecutor::RunFromFile( ...@@ -159,11 +171,8 @@ std::vector<float> AsyncExecutor::RunFromFile(
*/ */
// todo: should be factory method for creating datafeed // todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed> > readers; std::vector<std::shared_ptr<DataFeed> > readers;
readers.resize(thread_num); PrepareReaders(readers, thread_num, data_feed_desc, filelist);
for (unsigned int i = 0; i < readers.size(); ++i) {
readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
}
std::vector<std::shared_ptr<ExecutorThreadWorker> > workers; std::vector<std::shared_ptr<ExecutorThreadWorker> > workers;
workers.resize(thread_num); workers.resize(thread_num);
for (auto& worker : workers) { for (auto& worker : workers) {
......
...@@ -135,11 +135,11 @@ void PrivateQueueDataFeed<T>::ReadThread(){ ...@@ -135,11 +135,11 @@ void PrivateQueueDataFeed<T>::ReadThread(){
} }
template<typename T> template<typename T>
bool PrivateQueueDataFeed<T>::Next(){ int PrivateQueueDataFeed<T>::Next(){
CheckStart(); CheckStart();
int index = 0; int index = 0;
T instance; T instance;
T ins_vec(use_slots_.size()); T ins_vec;
while (index < default_batch_size_) { while (index < default_batch_size_) {
if (!queue_.Receive(&instance)) { if (!queue_.Receive(&instance)) {
break; break;
...@@ -147,8 +147,10 @@ bool PrivateQueueDataFeed<T>::Next(){ ...@@ -147,8 +147,10 @@ bool PrivateQueueDataFeed<T>::Next(){
AddInstanceToInsVec(ins_vec, instance, index++); AddInstanceToInsVec(ins_vec, instance, index++);
} }
batch_size_ = index; batch_size_ = index;
PutToFeedVec(ins_vec); if (batch_size_ != 0) {
return batch_size_ != 0; PutToFeedVec(ins_vec);
}
return batch_size_;
} }
void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) { void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) {
...@@ -161,6 +163,7 @@ void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) { ...@@ -161,6 +163,7 @@ void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) {
exit(-1); exit(-1);
} }
paddle::framework::MultiSlotDesc multi_slot_desc = data_feed_desc.multi_slot_desc(); paddle::framework::MultiSlotDesc multi_slot_desc = data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch());
size_t all_slot_num = multi_slot_desc.slots_size(); size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num); all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num); all_slots_type_.resize(all_slot_num);
...@@ -178,7 +181,7 @@ void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) { ...@@ -178,7 +181,7 @@ void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) {
} }
} }
feed_vec_.resize(use_slots_.size()); feed_vec_.resize(use_slots_.size());
finish_init_ = true; finish_init_ = true;
} }
...@@ -205,7 +208,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>& instance) { ...@@ -205,7 +208,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>& instance) {
exit(-1); exit(-1);
} }
if (idx != -1) { if (idx != -1) {
instance[idx].SetType(all_slots_type_[i]); instance[idx].Init(all_slots_type_[i]);
if (instance[idx].GetType()[0] == 'f') { // float if (instance[idx].GetType()[0] == 'f') { // float
for (int j = 0; j < num; ++j) { for (int j = 0; j < num; ++j) {
float feasign = (float)strtof(endptr, &endptr); float feasign = (float)strtof(endptr, &endptr);
...@@ -233,8 +236,10 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>& instance) { ...@@ -233,8 +236,10 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>& instance) {
void MultiSlotDataFeed::AddInstanceToInsVec(std::vector<MultiSlotType>& ins_vec, void MultiSlotDataFeed::AddInstanceToInsVec(std::vector<MultiSlotType>& ins_vec,
std::vector<MultiSlotType>& instance, int index) { std::vector<MultiSlotType>& instance, int index) {
if (index == 0) { if (index == 0) {
ins_vec.resize(instance.size());
for (size_t i = 0; i < instance.size(); ++i) { for (size_t i = 0; i < instance.size(); ++i) {
ins_vec[i].SetType(instance[i].GetType()); ins_vec[i].Init(instance[i].GetType());
ins_vec[i].InitOffset();
} }
} }
for (size_t i = 0; i < instance.size(); ++i){ for (size_t i = 0; i < instance.size(); ++i){
......
...@@ -163,7 +163,7 @@ class DataFeed { ...@@ -163,7 +163,7 @@ class DataFeed {
} }
virtual bool SetFileList(const std::vector<std::string>& files); virtual bool SetFileList(const std::vector<std::string>& files);
virtual bool Start() = 0; virtual bool Start() = 0;
virtual bool Next() = 0; virtual int Next() = 0;
virtual void SetBatchSize(int batch) { default_batch_size_ = batch; } virtual void SetBatchSize(int batch) { default_batch_size_ = batch; }
virtual int GetBatchSize() { return batch_size_; } virtual int GetBatchSize() { return batch_size_; }
// for subclass with queue // for subclass with queue
...@@ -217,7 +217,7 @@ class PrivateQueueDataFeed : public DataFeed { ...@@ -217,7 +217,7 @@ class PrivateQueueDataFeed : public DataFeed {
virtual ~PrivateQueueDataFeed() {} virtual ~PrivateQueueDataFeed() {}
virtual void Init(paddle::framework::DataFeedDesc& data_feed_desc) = 0; virtual void Init(paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool Start(); virtual bool Start();
virtual bool Next(); // no buffer virtual int Next(); // no buffer
virtual void SetQueueSize(int queue_size); virtual void SetQueueSize(int queue_size);
protected: protected:
...@@ -234,24 +234,28 @@ class PrivateQueueDataFeed : public DataFeed { ...@@ -234,24 +234,28 @@ class PrivateQueueDataFeed : public DataFeed {
* fread one buffer and one buffer parse: 7097 ms */ * fread one buffer and one buffer parse: 7097 ms */
std::ifstream file_; std::ifstream file_;
size_t queue_size_; size_t queue_size_;
// The elements in the queue are one piece of data,
// with multiple fields in each piece of data
BlockingQueue<T> queue_; BlockingQueue<T> queue_;
}; };
class MultiSlotType { class MultiSlotType {
public: public:
MultiSlotType() { MultiSlotType() {}
float_feasign_.clear();
uint64_feasign_.clear();
offset_.resize(1);
offset_[0] = 0;
}
~MultiSlotType() {} ~MultiSlotType() {}
void SetType(std::string& type) { void Init(std::string& type) {
CheckType(type); CheckType(type);
if (type_[0] == 'f') {
float_feasign_.clear();
} else if (type_[0] == 'u') {
uint64_feasign_.clear();
}
type_ = type; type_ = type;
} }
void InitOffset() {
offset_.resize(1);
// LoDTensor' lod is counted from 0, the size of lod
// is one size larger than the size of data.
offset_[0] = 0;
}
std::vector<size_t>& GetOffset() { std::vector<size_t>& GetOffset() {
return offset_; return offset_;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册