/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #ifndef PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_ #define PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_ #include #include #include #include #include // NOLINT #include #include #include // NOLINT #include #include #include // NOLINT #include #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "proto/FeedDataParameter.pb.h" namespace paddle { namespace framework { struct Gauc { int show, click; uint64_t fea; std::string lineid; }; struct Instance { std::vector> feed_vec_buffer; std::vector> feed_vec_lod; std::vector other_label; std::vector gauc_vec; }; class DataFeed { public: DataFeed() {} virtual ~DataFeed() {} virtual void Init(const datafeed::DataFeedParameter& feed_param) = 0; /* * This function will be used to check file format. * Considering that this function may be used alone, * it does not check anything. * */ virtual bool CheckFile(const char* filename) = 0; virtual bool SetFile(const char* filename) = 0; virtual bool ReadBatch() = 0; virtual const std::vector& GetAllSlotIds() { return all_slot_ids_; } virtual const std::vector& GetUseSlotIds() { return use_slot_ids_; } virtual const std::vector& GetUseSlotAlias() { return use_slot_alias_; } virtual void AddFeedVar(Variable* var, const std::string& name) = 0; virtual void BindScope(Scope* scope) = 0; virtual void SetBatchSize(int batch) { default_batch_size_ = batch; } virtual int GetBatchSize() { return batch_size_; } virtual void SetBufferSize(int buffer_size) {} std::vector& GetFeedVec() { return feed_vec_; } virtual std::vector& GetFeedVec(const Instance& ins) { LOG(ERROR) << "use defalut get_feed_vec"; return feed_vec_; } protected: std::vector all_slot_ids_; std::vector use_slot_ids_; std::vector use_slot_alias_; std::vector feed_vec_; int default_batch_size_; int batch_size_; }; class TextClassDataFeed : public DataFeed { public: virtual ~TextClassDataFeed() {} virtual void Init(const datafeed::DataFeedParameter& feed_param); virtual bool ReadBatch(); virtual void AddFeedVar(Variable* feed, const std::string& name); virtual void BindScope(Scope* scope) {} virtual bool SetFile(const char* filename); virtual bool CheckFile(const char* filename) { // TODO(xxx) return false; } void SetBatchSize(int batch) {batch_size_ = batch;} private: int ReadWholeFile(const std::string& filename, char* buffer); char* file_content_buffer_; char* file_content_buffer_ptr_; int* batch_id_buffer_; int* label_ptr_; int file_size_; std::vector names_; std::shared_ptr file_content_buffer_host_; std::shared_ptr batch_id_host_; std::shared_ptr label_host_; }; } // namespace framework } // namespace paddle #endif // PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_ /* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */