diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 26c7a012f34f18295134912b399f0b92f3df7b7c..555acf70b625f82613967c8cb3ff287e601eaca4 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -34,221 +34,241 @@ limitations under the License. */ #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/framework/data_feed.h" -DEFINE_bool(is_text_feed, false, "is_text_feed"); namespace paddle { namespace framework { -std::vector TextClassDataFeed::s_filelist_; -std::mutex TextClassDataFeed::s_locker_for_pick_file_; -unsigned int TextClassDataFeed::s_current_file_idx_ = 0; -size_t TextClassDataFeed::s_current_finished_file_cnt_ = 0; -unsigned int TextClassDataFeed::s_current_epoch_ = 0; -int TextClassDataFeed::s_current_save_epoch_ = 0; -std::mutex TextClassDataFeed::s_locker_epoch_start_; -std::condition_variable TextClassDataFeed::s_condition_epoch_start_; -bool TextClassDataFeed::s_epoch_start_flag_ = false; -void TextClassDataFeed::Init() { - // hard coding for a specific datafeed - feed_vec_.resize(2); - // feed_vec_[0].reset(new LoDTensor); - // feed_vec_[1].reset(new LoDTensor); - all_slot_ids_ = {0, 1}; - use_slot_ids_ = {0, 1}; - use_slot_alias_ = {"words", "label"}; - - file_content_buffer_host_.reset(new char[200*1024*1024], - [](char *p) {delete[] p;}); - file_content_buffer_ = file_content_buffer_host_.get(); - file_content_buffer_ptr_ = file_content_buffer_; - - batch_id_host_.reset(new int[10240*1024], - [](int *p) {delete[] p;}); // max word num in a batch - batch_id_buffer_ = batch_id_host_.get(); - - label_host_.reset(new int[10240], - [](int *p) {delete[] p;}); // max label in a batch - label_ptr_ = label_host_.get(); - - field_names_.clear(); -} - -TextClassDataFeed::TextClassDataFeed() { - Init(); -} - - // todo: use elegant implemention for this function -bool TextClassDataFeed::ReadBatch() { - paddle::framework::Vector offset; - int tlen = 0; - int llen = 0; - int inst_idx = 0; - offset.resize(batch_size_ + 1); - offset[0] = 0; - - while (inst_idx < batch_size_) { - int ptr_offset = 0; - if (file_content_buffer_ptr_ - file_content_buffer_ >= file_size_) { - break; +std::vector DataFeed::filelist_; +size_t DataFeed::file_idx_; +std::mutex DataFeed::mutex_for_pick_file_; + +void DataFeed::AddFeedVar(Variable* var, const std::string& name) { + if (CheckInit() == false) {return;} + for (size_t i = 0; i < use_slots_.size(); ++i) { + if (name == use_slots_[i]) { + if (use_slots_is_dense_[i]) { + feed_vec_[i] = MixTensor(var->GetMutable()); + } else { + feed_vec_[i] = MixTensor(var->GetMutable()); + } } - - memcpy(reinterpret_cast(&llen), - file_content_buffer_ptr_ + ptr_offset, - sizeof(int)); - ptr_offset += sizeof(int); - - memcpy(reinterpret_cast(batch_id_buffer_ + tlen), - file_content_buffer_ptr_ + ptr_offset, - llen * sizeof(int)); - tlen += llen; - - offset[inst_idx + 1] = offset[inst_idx] + llen; - ptr_offset += sizeof(int) * llen; - - memcpy(reinterpret_cast(label_ptr_ + inst_idx), - file_content_buffer_ptr_ + ptr_offset, - sizeof(int)); - ptr_offset += sizeof(int); - - file_content_buffer_ptr_ += ptr_offset; - inst_idx++; } +} - if (inst_idx != batch_size_) { +bool DataFeed::SetFileList(const std::vector& files) { + if (CheckInit() == false) {return false;} + if (files.size() == 0) { + LOG(ERROR) << "error: you have set an empty filelist"; return false; } + filelist_.assign(files.begin(), files.end()); + file_idx_ = 0; - LoD input_lod{offset}; - paddle::framework::Vector label_offset; - label_offset.resize(batch_size_ + 1); - for (int i = 0; i <= batch_size_; ++i) { - label_offset[i] = i; - } + finish_set_filelist_ = true; + return true; +} - LoD label_lod{label_offset}; - int64_t* input_ptr = feed_vec_[0]->mutable_data( - {static_cast(offset.back()), 1}, - platform::CPUPlace()); - int64_t* label_ptr = feed_vec_[1]->mutable_data({batch_size_, 1}, - platform::CPUPlace()); - for (unsigned int i = 0; i < offset.back(); ++i) { - input_ptr[i] = static_cast(batch_id_buffer_[i]); - } - for (int i = 0; i < batch_size_; ++i) { - label_ptr[i] = static_cast(label_ptr_[i]); +bool DataFeed::PickOneFile(std::string& filename) { + std::unique_lock lock(mutex_for_pick_file_); + if (file_idx_ == filelist_.size()) { + return false; } - feed_vec_[0]->set_lod(input_lod); - feed_vec_[1]->set_lod(label_lod); + filename = filelist_[file_idx_++]; return true; } -TextClassDataFeed::TextClassDataFeed(const TextClassDataFeed& data_feed) { - Init(); - SetBatchSize(data_feed.batch_size_); - SetFieldNames(data_feed.field_names_); +bool DataFeed::CheckInit() { + if (finish_init_) {return true;} + LOG(ERROR) << "error: initialization did not succeed"; + return false; } -void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) { - for (unsigned int i = 0; i < use_slot_alias_.size(); ++i) { - if (name == use_slot_alias_[i]) { - feed_vec_[i] = feed->GetMutable(); - } +bool DataFeed::CheckSetFileList() { + if (finish_set_filelist_) {return true;} + LOG(ERROR) << "error: set filelist did not succeed"; + return false; +} + +bool DataFeed::CheckStart() { + if (finish_start_) {return true;} + LOG(ERROR) << "error: Datafeed has not started running yet"; + return false; +} + +template +void PrivateQueueDataFeed::SetQueueSize(int queue_size) { + if (!CheckInit()) {return;} + if (queue_size <= 0) { + LOG(ERROR) << "error: illegal queue size: " << queue_size; + return; } + queue_size_ = queue_size; + queue_.ReCap(queue_size_); } -void TextClassDataFeed::SetFileList(const char* filelist) { - s_filelist_.clear(); - std::ifstream fin(filelist); - PADDLE_ENFORCE(fin.good(), - "Opening file %s fail", - filelist); +template +bool PrivateQueueDataFeed::Start() { + if (!(CheckSetFileList())) {return false;} + read_thread_ = std::thread(&PrivateQueueDataFeed::ReadThread, this); + read_thread_.detach(); + + finish_start_ = true; + return true; +} + +template +void PrivateQueueDataFeed::ReadThread(){ std::string filename; - while (fin >> filename) { - LOG(ERROR) << "add " << filename.c_str() << " to filelist"; - s_filelist_.push_back(filename); + while (PickOneFile(filename)) { + file_.open(filename.c_str()); // is_text_feed + if (!file_.is_open()) { + LOG(ERROR) << "error: open file<" << filename << "> fail"; + } + T instance; + while (ParseOneInstance(instance)) { + queue_.Send(instance); + } + file_.close(); } - fin.close(); + queue_.Close(); } -void TextClassDataFeed::SetFieldNames( - const std::vector& field_names) { - field_names_.clear(); - field_names_.insert(field_names_.end(), field_names.begin(), - field_names.end()); +template +bool PrivateQueueDataFeed::Next(){ + if (!CheckStart()) {return false;} + int index = 0; + T instance; + T ins_vec(use_slots_.size()); + while (index < default_batch_size_) { + if (!queue_.Receive(&instance)) { + break; + } + AddInstanceToInsVec(ins_vec, instance, index++); + } + batch_size_ = index; + PutToFeedVec(ins_vec); + return batch_size_ != 0; } -bool TextClassDataFeed::SetFile(const char* filename) { - // termnum termid termid ... termid label - std::ifstream ifs(filename, std::ios::binary); - if (ifs.fail()) { - return false; +void MultiSlotDataFeed::Init(paddle::DataFeedDesc& data_feed_desc) { + finish_init_ = false; + finish_set_filelist_ = false; + finish_start_ = false; + if (!data_feed_desc.has_multi_slot_desc()){ + LOG(ERROR) << "error: multi_slot_desc has not been set"; + return ; } - - ifs.seekg(0, std::ios::end); - int filesize = ifs.tellg(); - ifs.seekg(0, std::ios::beg); - ifs.read(file_content_buffer_, filesize); - if (filesize < 0 || filesize >= 1024 * 1024 * 1024) { - return false; + paddle::MultiSlotDesc multi_slot_desc = data_feed_desc.multi_slot_desc(); + size_t all_slot_num = multi_slot_desc.slots_size(); + all_slots_.resize(all_slot_num); + all_slots_type_.resize(all_slot_num); + use_slots_index_.resize(all_slot_num); + use_slots_.clear(); + use_slots_is_dense_.clear(); + for (size_t i = 0; i < all_slot_num; ++i) { + auto& slot = multi_slot_desc.slots(i); + all_slots_[i] = slot.name(); + all_slots_type_[i] = slot.type(); + use_slots_index_[i] = slot.use() ? use_slots_.size() : -1; + if (slot.use()) { + use_slots_.push_back(all_slots_[i]); + use_slots_is_dense_.push_back(slot.dense()); + } } - file_content_buffer_ptr_ = file_content_buffer_; - file_size_ = filesize; - // todo , remove magic number + feed_vec_.resize(use_slots_.size()); - return true; + finish_init_ = true; } -void TextClassDataFeed::UpdateEpochNum() { - s_current_finished_file_cnt_++; - - if (s_current_finished_file_cnt_ >= s_filelist_.size()) { - s_current_finished_file_cnt_ = 0; - s_current_epoch_++; -#if 1 - LOG(WARNING) << "UpdateEpochNum: epoch = " << s_current_epoch_; -#endif - { - std::lock_guard lock(s_locker_epoch_start_); - s_epoch_start_flag_ = false; +bool MultiSlotDataFeed::ParseOneInstance(std::vector& instance) { + std::string line; + if (getline(file_, line)) { + int use_slots_num = use_slots_.size(); + instance.resize(use_slots_num); + //parse line + const char* str = line.c_str(); + char* endptr = (char*)str; + int pos = 0; + for (size_t i = 0; i < use_slots_index_.size(); ++i) { + int idx = use_slots_index_[i]; + int num = (int)strtol(&str[pos], &endptr, 10); + if (num == 0) { + LOG(ERROR) << "error: the number of ids can not be zero, you need padding it"; + exit(-1); + } + if (idx != -1) { + instance[idx].SetType(all_slots_type_[i]); + if (instance[idx].GetType()[0] == 'f') { // float + for (int j = 0; j < num; ++j) { + float feasign = (float)strtof(endptr, &endptr); + instance[idx].AddValue(feasign); + } + } else if (instance[idx].GetType()[0] == 'u'){ // uint64 + for (int j = 0; j < num; ++j) { + uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10); + instance[idx].AddValue(feasign); + } + } + pos = endptr - str; + } else { + for (int j = 0; j <= num; ++j) { + pos = line.find_first_of(' ', pos + 1); + } + } } + } else { + return false; } + return true; } -void TextClassDataFeed::StartOneEpoch() { - std::lock_guard lock(s_locker_for_pick_file_); - std::random_shuffle(s_filelist_.begin(), s_filelist_.end()); - s_current_file_idx_ = 0; - LOG(INFO) << "Beginning epoch " << s_current_epoch_; - - { - std::lock_guard lock(s_locker_epoch_start_); - s_epoch_start_flag_ = true; +void MultiSlotDataFeed::AddInstanceToInsVec(std::vector& ins_vec, + std::vector& instance, int index) { + if (index == 0) { + for (size_t i = 0; i < instance.size(); ++i) { + ins_vec[i].SetType(instance[i].GetType()); + } + } + for (size_t i = 0; i < instance.size(); ++i){ + ins_vec[i].AddIns(instance[i]); } - s_condition_epoch_start_.notify_all(); -} - -void TextClassDataFeed::WaitNextEpoch() { - std::unique_lock lock(s_locker_epoch_start_); - s_condition_epoch_start_.wait(lock, []{return s_epoch_start_flag_;}); } - -const char* TextClassDataFeed::PickOneFile() { - std::string file_to_be_processed; - std::lock_guard lock(s_locker_for_pick_file_); - - // One epoch has run over - // Wait for next epoch - if (s_current_file_idx_ >= s_filelist_.size()) { - LOG(ERROR) << "thread " << thread_id_ - << ": finish traing for epoch " << s_current_epoch_ + 1; - - return NULL; +void MultiSlotDataFeed::PutToFeedVec(std::vector& ins_vec) { + for (size_t i = 0; i < use_slots_.size(); ++i) { + auto& type = ins_vec[i].GetType(); + auto& offset = ins_vec[i].GetOffset(); + int total_instance = static_cast(offset.back()); + if (type[0] == 'f') { // float + auto& feasign = ins_vec[i].GetFloatData(); + if (feed_vec_[i].IsDense()) { + int size_in_each_batch = total_instance / batch_size_; + float* tensor_ptr = feed_vec_[i].GetTensor()-> + mutable_data({batch_size_, size_in_each_batch}, platform::CPUPlace()); + memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float)); + } else { + float* tensor_ptr = feed_vec_[i].GetLoDTensor()-> + mutable_data({total_instance, 1}, platform::CPUPlace()); + memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float)); + LoD data_lod{offset}; + feed_vec_[i].GetLoDTensor()->set_lod(data_lod); + } + } else if (type[0] == 'u') { // uint64 + // no uint64_t type + auto& feasign = ins_vec[i].GetUint64Data(); + if (feed_vec_[i].IsDense()) { + int size_in_each_batch = total_instance / batch_size_; + int64_t* tensor_ptr = feed_vec_[i].GetTensor()-> + mutable_data({batch_size_, size_in_each_batch}, platform::CPUPlace()); + memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t)); + } else { + int64_t* tensor_ptr = feed_vec_[i].GetLoDTensor()-> + mutable_data({total_instance, 1}, platform::CPUPlace()); + memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(uint64_t)); + LoD data_lod{offset}; + feed_vec_[i].GetLoDTensor()->set_lod(data_lod); + } + } } - - file_to_be_processed = s_filelist_[s_current_file_idx_]; - - s_current_file_idx_++; - return file_to_be_processed.c_str(); } } // namespace framework diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index f5660357788eecc73f85a3cfba20b1b67d6da6c4..f67ead48bc70a8526c9eff237303b36bbd5d172c 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -27,136 +27,335 @@ limitations under the License. */ #include #include // NOLINT #include +#include +#include #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/data_feed.pb.h" namespace paddle { namespace framework { -struct Gauc { - int show, click; - uint64_t fea; - std::string lineid; -}; -struct Instance { - std::vector> feed_vec_buffer; - std::vector> feed_vec_lod; - std::vector other_label; - std::vector gauc_vec; +class MixTensor { + public: + MixTensor(){} + MixTensor(LoDTensor* lodtensor) { + is_dense_ = false; + lodtensor_ = lodtensor; + } + MixTensor(Tensor* tensor) { + is_dense_ = true; + tensor_ = tensor; + } + bool IsDense() {return is_dense_;} + LoDTensor* GetLoDTensor(){ + if (is_dense_) { + LOG(ERROR) << "error: let a dense var return a LoDTensor ptr"; + return NULL; + } + return lodtensor_; + } + Tensor* GetTensor(){ + if (!is_dense_) { + LOG(ERROR) << "error: let a sparse var return a Tensor ptr"; + return NULL; + } + return tensor_; + } + private: + bool is_dense_; + LoDTensor* lodtensor_; + Tensor* tensor_; }; -class DataFeed { +template +class BlockingQueue { public: - DataFeed() : default_batch_size_(1), batch_size_(0), thread_id_(0) {} - virtual ~DataFeed() {} - virtual void Init() = 0; - /* - * This function will be used to check file format. - * Considering that this function may be used alone, - * it does not check anything. - * */ - virtual bool CheckFile(const char* filename) = 0; - virtual bool SetFile(const char* filename) = 0; - virtual bool ReadBatch() = 0; - virtual const std::vector& GetAllSlotIds() { - return all_slot_ids_; + explicit BlockingQueue(size_t capacity = 32) + : capacity_(capacity), closed_(false) { + size_.store(0); } - - virtual const std::vector& GetUseSlotIds() { - return use_slot_ids_; + + void ReCap(size_t capacity) { + capacity_ = capacity; } - virtual const std::vector& GetUseSlotAlias() { - return use_slot_alias_; - } + bool Send(const T& elem) { + int c = -1; + { + std::unique_lock lock(send_mutex_); + send_cv_.wait(lock, [&] {return size_.load() < capacity_ || closed_;}); + if (closed_) { + VLOG(5) + << "WARNING: Sending an element to a closed reader::BlokcingQueue."; + return false; + } + queue_.push_back(elem); + c = size_.load(); + size_.fetch_add(1); + } + if (c + 1 < capacity_) { + send_cv_.notify_one(); + } - virtual void AddFeedVar(Variable* var, - const std::string& name) = 0; - virtual void BindScope(Scope* scope) = 0; - virtual void SetBatchSize(int batch) { default_batch_size_ = batch; } - virtual int GetBatchSize() { return batch_size_; } - virtual void SetBufferSize(int buffer_size) {} - virtual unsigned int GetCurrentEpoch() = 0; - virtual const char *PickOneFile() = 0; - virtual void UpdateEpochNum() = 0; - virtual void StartOneEpoch() = 0; - virtual void WaitNextEpoch() = 0; + if (c == 0) { + std::unique_lock lock(receive_mutex_); + receive_cv_.notify_one(); + } + return true; + } - std::vector& GetFeedVec() { - return feed_vec_; + bool Receive(T* elem) { + int c = -1; + { + std::unique_lock lock(receive_mutex_); + receive_cv_.wait(lock, [&] {return size_.load() != 0 || closed_;}); + if (size_.load() != 0) { + *elem = queue_.front(); + queue_.pop_front(); + c = size_.load(); + size_.fetch_sub(1); + } else { + return false; + } + } + if (c > 1) { + receive_cv_.notify_one(); + } + if (c == capacity_) { + std::unique_lock lock(send_mutex_); + send_cv_.notify_one(); + } + return true; } - virtual std::vector& GetFeedVec(const Instance& ins) { - LOG(ERROR) << "use defalut get_feed_vec"; - return feed_vec_; + void Close() { + std::lock_guard lock1(send_mutex_); + std::lock_guard lock2(receive_mutex_); + closed_ = true; + send_cv_.notify_all(); + receive_cv_.notify_all(); } - int GetThreadId() {return thread_id_;} - void SetThreadId(int thread_id) {thread_id_ = thread_id;} + private: + size_t capacity_; + std::atomic_size_t size_; + bool closed_; + std::deque queue_; + + mutable std::mutex send_mutex_; + mutable std::mutex receive_mutex_; + mutable std::condition_variable send_cv_; + mutable std::condition_variable receive_cv_; +}; +class DataFeed { + public: + DataFeed() {} + virtual ~DataFeed() {} + virtual void Init(paddle::DataFeedDesc& data_feed_desc) = 0; + // for some datafeeds may not be able to implement this interface + virtual bool CheckFile(const char* filename) { + LOG(ERROR) << "error: The function CheckFile is not implemented"; + return false; + } + virtual bool SetFileList(const std::vector& files); + virtual bool Start() = 0; + virtual bool Next() = 0; + virtual void SetBatchSize(int batch) { default_batch_size_ = batch; } + virtual int GetBatchSize() { return batch_size_; } + // for subclass with queue + virtual void SetQueueSize(int queue_size) { + LOG(ERROR) << "error: The function SetQueueSize is not implemented"; + } + // for subclass with buffer + virtual void SetBufferSize(int buffer_size) { + LOG(ERROR) << "error: The function SetBufferSize is not implemented"; + } + virtual const std::vector& GetAllSlots() {return all_slots_;} + virtual const std::vector& GetUseSlots() {return use_slots_;} + std::vector& GetFeedVec() {return feed_vec_;} + virtual void AddFeedVar(Variable* var, const std::string& name); protected: - std::vector all_slot_ids_; - std::vector use_slot_ids_; - std::vector use_slot_alias_; - std::vector feed_vec_; + // Check if it is executed in this order: + // Init -> SetFileList/BindingMemory -> Start -> Next + virtual bool CheckInit(); + virtual bool CheckSetFileList(); + virtual bool CheckStart(); + virtual bool PickOneFile(std::string& filename); + + static std::vector filelist_; + static size_t file_idx_; + static std::mutex mutex_for_pick_file_; + + std::vector use_slots_; + std::vector use_slots_is_dense_; + + std::vector all_slots_; + std::vector all_slots_type_; + std::vector use_slots_index_; // -1: not used; >=0: the index of use_slots_ + + std::vector feed_vec_; + int default_batch_size_; int batch_size_; - int thread_id_; + + bool finish_init_; + bool finish_set_filelist_; + bool finish_binding_memory_; + bool finish_start_; }; -class TextClassDataFeed : public DataFeed { +template +class PrivateQueueDataFeed : public DataFeed { public: - TextClassDataFeed(); - TextClassDataFeed(const TextClassDataFeed& data_feed); + PrivateQueueDataFeed() {} + virtual ~PrivateQueueDataFeed() {} + virtual void Init(paddle::DataFeedDesc& data_feed_desc) = 0; + virtual bool Start(); + virtual bool Next(); // no buffer + virtual void SetQueueSize(int queue_size); + + protected: + virtual void ReadThread(); + virtual bool ParseOneInstance(T& instance) = 0; + virtual void AddInstanceToInsVec(T& vec_ins, T& instance, int index) = 0; + virtual void PutToFeedVec(T& ins_vec) = 0; + std::thread read_thread_; // the thread for read files + /* using ifstream one line and one line parse is faster + * than using fread one buffer and one buffer parse. + * for 601M JingPai data: + * ifstream one line and one line parse: 6034 ms + * fread one buffer and one buffer parse: 7097 ms */ + std::ifstream file_; + size_t queue_size_; + // The elements in the queue are one piece of data, + // with multiple fields in each piece of data + BlockingQueue queue_; +}; + +class MultiSlotType { + public: + MultiSlotType() { + float_feasign_.clear(); + uint64_feasign_.clear(); + offset_.resize(1); + offset_[0] = 0; + } + ~MultiSlotType() {} + void SetType(std::string& type) { + if (!CheckType(type)) {return;} + type_ = type; + } + std::vector& GetOffset() { + return offset_; + } + void AddValue(float v) { + if (!CheckFloat()) {return;} + float_feasign_.push_back(v); + } + void AddValue(uint64_t v) { + if (!CheckUint64()) {return;} + uint64_feasign_.push_back(v); + } + void AddIns(MultiSlotType& ins) { + if (ins.GetType()[0] == 'f') { //float + if (!CheckFloat()) {return;} + auto& vec = ins.GetFloatData(); + offset_.push_back(offset_.back() + vec.size()); + float_feasign_.insert(float_feasign_.end(), vec.begin(), vec.end()); + } else if (ins.GetType()[0] == 'u') { //uint64 + if (!CheckUint64()) {return;} + auto& vec = ins.GetUint64Data(); + offset_.push_back(offset_.back() + vec.size()); + uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end()); + } + } + std::vector& GetFloatData() { + return float_feasign_; + } + std::vector& GetUint64Data() { + return uint64_feasign_; + } + std::string& GetType() { + return type_; + } + private: + bool CheckType(std::string& type) { + if (type != "uint64" && type != "float") { + // check in here + LOG(ERROR) << "error: here is no this type"; + return false; + } + return true; + } + bool CheckFloat() { + if (type_[0] != 'f') { //float + LOG(ERROR) << "error: add " << type_ << " value to float slot"; + return false; + } + return true; + } + bool CheckUint64() { + if (type_[0] != 'u') { //uint64 + LOG(ERROR) << "error: add " << type_ << " value to uint64 slot"; + return false; + } + return true; + } + std::vector float_feasign_; + std::vector uint64_feasign_; + std::string type_; + std::vector offset_; +}; + +class MultiSlotDataFeed : public PrivateQueueDataFeed> { + public: + MultiSlotDataFeed() {} + virtual ~MultiSlotDataFeed() {} + virtual void Init(paddle::DataFeedDesc& data_feed_desc); + //TODO: virtual bool CheckFile(); + protected: + virtual void AddInstanceToInsVec(std::vector& vec_ins, + std::vector& instance, int index); + virtual bool ParseOneInstance(std::vector& instance); + virtual void PutToFeedVec(std::vector& ins_vec); +}; + + +//TODO: to be deleted +class TextClassDataFeed : public DataFeed { public: virtual ~TextClassDataFeed() {} - virtual void Init(); - virtual bool ReadBatch(); - virtual void AddFeedVar(Variable* feed, const std::string& name); + virtual void Init(paddle::DataFeedDesc& data_feed_desc) {} + virtual bool Start() {return false;}; //TODO + virtual bool Next() {return false;}; //TODO + virtual bool ReadBatch() {return false;} + virtual void AddFeedVar(Variable* feed, const std::string& name) {} virtual void BindScope(Scope* scope) {} - virtual bool SetFile(const char* filename); + virtual bool SetFile(const char* filename) {return false;} + virtual bool CheckFile(const char* filename) { // TODO(xxx) return false; } - void SetBatchSize(int batch) {batch_size_ = batch;} - unsigned int GetCurrentEpoch() {return s_current_epoch_;} - void UpdateEpochNum(); - void StartOneEpoch(); - void WaitNextEpoch(); - - public: - void SetFieldNames(const std::vector& field_names); - - public: - static void SetFileList(const char* filelist); - private: - const char* PickOneFile(); + void SetBatchSize(int batch) {batch_size_ = batch;} private: + int ReadWholeFile(const std::string& filename, char* buffer) {return -1;} char* file_content_buffer_; char* file_content_buffer_ptr_; int* batch_id_buffer_; int* label_ptr_; int file_size_; - std::vector field_names_; + std::vector names_; std::shared_ptr file_content_buffer_host_; std::shared_ptr batch_id_host_; std::shared_ptr label_host_; - - static std::vector s_filelist_; - static std::mutex s_locker_for_pick_file_; - static unsigned int s_current_file_idx_; - static size_t s_current_finished_file_cnt_; - static unsigned int s_current_epoch_; - static int s_current_save_epoch_; - static std::mutex s_locker_epoch_start_; - static std::condition_variable s_condition_epoch_start_; - static bool s_epoch_start_flag_; }; } // namespace framework diff --git a/paddle/fluid/framework/data_feed.proto b/paddle/fluid/framework/data_feed.proto index 284627e3525b4f636de8f67555afab81d3ab6053..319ae27607471b4d6ce83fb2322fbe650d474fdc 100644 --- a/paddle/fluid/framework/data_feed.proto +++ b/paddle/fluid/framework/data_feed.proto @@ -17,5 +17,16 @@ package paddle; message DataFeedDesc { optional string name = 1; optional int32 batch = 2 [default = 32]; + optional MultiSlotDesc multi_slot_desc = 3; } +message MultiSlotDesc { + repeated Slot slots = 1; +} + +message Slot { + required string name = 1; + required string type = 2; + optional bool dense = 3 [default = false]; + optional bool use = 4 [default = true]; +}