提交 8a7b6118 编写于 作者: B barrierye

update datafeed and async_executor to run bow_net demo

上级 664be756
......@@ -137,6 +137,18 @@ void AsyncExecutor::SetModelPrefix(const std::string& model_prefix) {
model_prefix_ = model_prefix;
}
void PrepareReaders(std::vector<std::shared_ptr<DataFeed> >& readers,
const int thread_num, DataFeedDesc& data_feed_desc,
const std::vector<std::string>& filelist) {
readers.resize(thread_num);
for (size_t i = 0; i < readers.size(); ++i) {
readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
readers[i]->Init(data_feed_desc); // set batch size here
//readers[i]->SetQueueSize(32); // default is 32
}
readers[0]->SetFileList(filelist);
}
std::vector<float> AsyncExecutor::RunFromFile(
const ProgramDesc& main_program,
const std::string& data_feed_desc_str,
......@@ -159,11 +171,8 @@ std::vector<float> AsyncExecutor::RunFromFile(
*/
// todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed> > readers;
readers.resize(thread_num);
for (unsigned int i = 0; i < readers.size(); ++i) {
readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
}
PrepareReaders(readers, thread_num, data_feed_desc, filelist);
std::vector<std::shared_ptr<ExecutorThreadWorker> > workers;
workers.resize(thread_num);
for (auto& worker : workers) {
......
......@@ -135,11 +135,11 @@ void PrivateQueueDataFeed<T>::ReadThread(){
}
template<typename T>
bool PrivateQueueDataFeed<T>::Next(){
int PrivateQueueDataFeed<T>::Next(){
CheckStart();
int index = 0;
T instance;
T ins_vec(use_slots_.size());
T ins_vec;
while (index < default_batch_size_) {
if (!queue_.Receive(&instance)) {
break;
......@@ -147,8 +147,10 @@ bool PrivateQueueDataFeed<T>::Next(){
AddInstanceToInsVec(ins_vec, instance, index++);
}
batch_size_ = index;
PutToFeedVec(ins_vec);
return batch_size_ != 0;
if (batch_size_ != 0) {
PutToFeedVec(ins_vec);
}
return batch_size_;
}
void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) {
......@@ -161,6 +163,7 @@ void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) {
exit(-1);
}
paddle::framework::MultiSlotDesc multi_slot_desc = data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch());
size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num);
......@@ -178,7 +181,7 @@ void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) {
}
}
feed_vec_.resize(use_slots_.size());
finish_init_ = true;
}
......@@ -205,7 +208,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>& instance) {
exit(-1);
}
if (idx != -1) {
instance[idx].SetType(all_slots_type_[i]);
instance[idx].Init(all_slots_type_[i]);
if (instance[idx].GetType()[0] == 'f') { // float
for (int j = 0; j < num; ++j) {
float feasign = (float)strtof(endptr, &endptr);
......@@ -233,8 +236,10 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>& instance) {
void MultiSlotDataFeed::AddInstanceToInsVec(std::vector<MultiSlotType>& ins_vec,
std::vector<MultiSlotType>& instance, int index) {
if (index == 0) {
ins_vec.resize(instance.size());
for (size_t i = 0; i < instance.size(); ++i) {
ins_vec[i].SetType(instance[i].GetType());
ins_vec[i].Init(instance[i].GetType());
ins_vec[i].InitOffset();
}
}
for (size_t i = 0; i < instance.size(); ++i){
......
......@@ -163,7 +163,7 @@ class DataFeed {
}
virtual bool SetFileList(const std::vector<std::string>& files);
virtual bool Start() = 0;
virtual bool Next() = 0;
virtual int Next() = 0;
virtual void SetBatchSize(int batch) { default_batch_size_ = batch; }
virtual int GetBatchSize() { return batch_size_; }
// for subclass with queue
......@@ -217,7 +217,7 @@ class PrivateQueueDataFeed : public DataFeed {
virtual ~PrivateQueueDataFeed() {}
virtual void Init(paddle::framework::DataFeedDesc& data_feed_desc) = 0;
virtual bool Start();
virtual bool Next(); // no buffer
virtual int Next(); // no buffer
virtual void SetQueueSize(int queue_size);
protected:
......@@ -234,24 +234,28 @@ class PrivateQueueDataFeed : public DataFeed {
* fread one buffer and one buffer parse: 7097 ms */
std::ifstream file_;
size_t queue_size_;
// The elements in the queue are one piece of data,
// with multiple fields in each piece of data
BlockingQueue<T> queue_;
};
class MultiSlotType {
public:
MultiSlotType() {
float_feasign_.clear();
uint64_feasign_.clear();
offset_.resize(1);
offset_[0] = 0;
}
MultiSlotType() {}
~MultiSlotType() {}
void SetType(std::string& type) {
void Init(std::string& type) {
CheckType(type);
if (type_[0] == 'f') {
float_feasign_.clear();
} else if (type_[0] == 'u') {
uint64_feasign_.clear();
}
type_ = type;
}
void InitOffset() {
offset_.resize(1);
// LoDTensor' lod is counted from 0, the size of lod
// is one size larger than the size of data.
offset_[0] = 0;
}
std::vector<size_t>& GetOffset() {
return offset_;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册