From 8ee8133ab841196925a2812b76f18d2812a6701d Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Fri, 16 Nov 2018 20:55:17 +0800 Subject: [PATCH] add some files about datafeed --- paddle/fluid/framework/data_feed.cc.yebaiwei | 411 ++++++++++++++++++ paddle/fluid/framework/data_feed.h.yebaiwei | 368 ++++++++++++++++ .../fluid/framework/data_feed.proto.yebaiwei | 32 ++ 3 files changed, 811 insertions(+) create mode 100644 paddle/fluid/framework/data_feed.cc.yebaiwei create mode 100644 paddle/fluid/framework/data_feed.h.yebaiwei create mode 100644 paddle/fluid/framework/data_feed.proto.yebaiwei diff --git a/paddle/fluid/framework/data_feed.cc.yebaiwei b/paddle/fluid/framework/data_feed.cc.yebaiwei new file mode 100644 index 00000000000..03c5a2dc32d --- /dev/null +++ b/paddle/fluid/framework/data_feed.cc.yebaiwei @@ -0,0 +1,411 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include +#include +#include +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" + +#include "gflags/gflags.h" +#include "paddle/fluid/framework/feed_fetch_method.h" +#include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/lod_rank_table.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/platform/profiler.h" +#include "paddle/fluid/framework/data_feed.h" + +DEFINE_bool(is_text_feed, false, "is_text_feed"); + +namespace paddle { +namespace framework { + +std::vector DataFeed::filelist_; +size_t DataFeed::file_idx_; +std::mutex DataFeed::mutex_for_pick_file_; + +void DataFeed::AddFeedVar(Variable* var, const std::string& name) { + if (CheckInit() == false) {return;} + for (size_t i = 0; i < use_slots_.size(); ++i) { + if (name == use_slots_[i]) { + if (use_slot_is_dense_[i]) { + feed_vec[i]_ = MixTensor(var->GetMutable()); + } else { + feed_vec[i]_ = MixTensor(var->GetMutable()); + } + } + } +} + +bool DataFeed::SetFileList(const std::vector& files) { + if (CheckInit() == false) {return false;} + if (files.size() == 0) { + LOG(ERROR) << "error: you have set an empty filelist"; + return false; + } + filelist_.assign(files.begin(), files.end()); + file_idx_ = 0; + + finish_set_filelist_ = true; + return true; +} + +bool DataFeed::PickOneFile(std::string& filename) { + std::unique_lock lock(mutex_for_pick_file_); + if (file_idx_ == filelist_.size()) { + return false; + } + filename = filelist_[file_idx++]; + return true; +} + +bool DataFeed::CheckInit() { + if (finish_init_) {return true;} + LOG(ERROR) << "error: initialization did not succeed"; + return false; +} + +bool DataFeed::CheckSetFileList() { + if (finish_set_filelist_) {return true;} + LOG(ERROR) << "error: set filelist did not succeed"; + return false; +} + +bool DataFeed::CheckStart() { + if (finish_start_) {return true;} + LOG(ERROR) << "error: Datafeed has not started running yet"; + return false; +} + +template +void PrivateQueueDataFeed::SetQueueSize(int queue_size) { + if (!CheckInit()) {return false;} + if (queue_size <= 0) { + LOG(ERROR) << "error: illegal queue size: " << queue_size; + return; + } + queue_ = BlockingQueue>(queue_size_); +} + +template +bool PrivateQueueDataFeed::Start() { + if (!(CheckSetFileList())) {return false;} + read_thread_ = std::thread(&PrivateQueueDataFeed::ReadThread, this); + read_thread_.detach(); + + finish_start_ = true; +} + +template +void PrivateQueueDataFeed::ReadThread(){ + std::string filename; + while (PickOneFile(filename)) { + if (is_text_fees) { + file_.open(filename.c_str()); + } else { + LOG(ERROR) << "error: binary DataFeed is not implemented"; + } + if (!file_.is_open()) { + LOG(ERROR) << "error: open file<" << filename << "> fail"; + } + std::vector instance; + while (ParseOneInstance(instance)) { + queue_.Send(instance); + } + file_.close(); + } + queue_.Close(); +} + +template +bool PrivateQueueDataFeed::Next(){ + if (!CheckStart()) {return false;} + int index = 0; + std::vector instance; + std::vector ins_vec(use_slots_.size()); + while (index < default_batch_size_) { + if (!queue_.Receive(&instance)) { + break; + } + if (index == 0) { + for (auto& slot : ins_vec) { + ins_vec.SetType(instance.GetType()); + } + } + for (auto& slot : ins_vec) { + ins_vec.AddIns(instance); + } + ++index; + } + batch_size_ = index; + PutToFeedVec(ins_vec); + return batch_size_ != 0; +} + +void MultiSlotDataFeed::Init(paddle::DataFeedDesc& data_feed_desc) { + finish_init_ = false; + finish_set_filelist_ = false; + finish_start_ = false; + if (!data_feed_decs.has_multi_slot_desc()){ + LOG(ERROR) << "error: multi_slot_desc has not been set"; + return ; + } + paddle::MultiSlotDesc multi_slot_desc = data_feed_desc.multi_slot_desc(); + size_t all_slot_num = multi_slot_desc.slots_size(); + all_slots_.resize(all_slot_num); + all_slots_type_.resize(all_slot_num); + use_slots_index_.resize(all_slot_num); + use_slots_.clear(); + use_slots_is_dense_.clear(); + for (size_t i = 0; i < all_slot_num; ++i) { + auto& slot = multi_slot_desc.slots(i); + all_slots_[i] = slot.name(i); + all_slots_type_[i] = slot.type(i); + use_slots_index_[i] = slot.use(i) ? use_slots_.size() : -1; + if (is_used_[i]) { + use_slots_.push_back(all_slots_[i]); + use_slots_is_dense_.push_back(slot.dense(i)): + } + } + feed_vec_.resize(use_slots_.size()); + + finish_init_ = true; +} + +bool MultiSlotDataFeed::ParseOneInstance(std::vector& instance) { + std::string line; + if (getline(fin, line)) { + int use_slots_num = use_slots_.size(); + instance.resize(use_slots_num); + //parse line + int len = line.length(); + const char* str = line.c_str(); + char* endptr = str; + int pos = 0; + for (size_t i = 0; i < use_slots_index_.size(); ++i) { + int idx = use_slots_index_[i]; + int num = (int)strtol(&str[pos], &endptr, 10); + if (num == 0) { + LOG(ERROR) << "error: the number of ids can not be zero, you need padding it"; + exit(-1); + } + if (idx != -1) { + instance[idx].SetType(all_slots_type_[i]); + if (instance[idx].GetType()[0] == 'f') { // float + for (int j = 0; j < num; ++j) { + float feasign = (float)strtof(endptr, &endptr); + instance[idx].AddValue(feasign); + } + } else if (instance[idx].GetType()[0] == 'u'){ // uint64 + for (int j = 0; j < num; ++j) { + uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10); + instance[idx].AddValue(feasign); + } + } + pos = endptr - str; + } else { + for (int j = 0; j <= num; ++j) { + pos = line.find_first_of(' ', pos + 1); + } + } + } + } else { + return false; + } +} + +void MultiSlotDataFeed::PutToFeedVec(std::vector& ins_vec) { + for (size_t i = 0; i < use_slots_.size(); ++i) { + auto& type = ins_vec[i].GetType(); + if (type[0] == 'f') { // float + auto& feasign = ins_vec[i].GetFloatData(); + if (_feed_vec[i].IsDense()) { + float* tensor_ptr = _feed_vec[i].GetTensor()-> + mutable_data({batch_size_, offset.back() / batch_size_}, + platform::CPUPlace(), offset.back() * sizeof(float)); + memcpy(tensor_ptr, &feasign[0], offset.back() * sizeof(float)); + } else { + float* tensor_ptr = _feed_vec[i].GetLoDTensor()-> + mutable_data({offset.back(), 1}, platform::CPUPlace()); + memcpy(tensor_ptr, &feasign[0], offset.back() * sizeof(float)); + auto& offset = ins_vec[i].GetOffset(); + LoD data_lod{offset}; + _feed_vec[i].GetLoDTensor()->set_lod(data_lod); + } + } else if (type[0] == 'u') { // uint64 + auto& feasign = ins_vec[i].GetUint64Data(); + if (_feed_vec[i].IsDense()) { + // no uint64_t type + int64_t* tensor_ptr = _feed_vec[i].GetTensor()-> + mutable_data({batch_size_, offset.back() / batch_size_}, + platform::CPUPlace(), offset.back() * sizeof(uint64_t)); + memcpy(tensor_ptr, &feasign[0], offset.back() * sizeof(uint64_t)); + } else { + int64_t* tensor_ptr = _feed_vec[i].GetLoDTensor()-> + mutable_data({offset.back(), 1}, platform::CPUPlace()); + memcpy(tensor_ptr, &feasign[0], offset.back() * sizeof(uint64_t)); + auto& offset = ins_vec[i].GetOffset(); + LoD data_lod{offset}; + _feed_vec[i].GetLoDTensor()->set_lod(data_lod); + } + } + } +} + + + + + + + + + + + + + + +void TextClassDataFeed::Init() { + // hard coding for a specific datafeed + feed_vec_.resize(2); + // feed_vec_[0].reset(new LoDTensor); + // feed_vec_[1].reset(new LoDTensor); + all_slot_ids_ = {0, 1}; + use_slot_ids_ = {0, 1}; + use_slot_alias_ = {"words", "label"}; + + file_content_buffer_host_.reset(new char[200*1024*1024], + [](char *p) {delete[] p;}); + file_content_buffer_ = file_content_buffer_host_.get(); + file_content_buffer_ptr_ = file_content_buffer_; + + batch_id_host_.reset(new int[10240*1024], + [](int *p) {delete[] p;}); // max word num in a batch + batch_id_buffer_ = batch_id_host_.get(); + + label_host_.reset(new int[10240], + [](int *p) {delete[] p;}); // max label in a batch + label_ptr_ = label_host_.get(); +} + + // todo: use elegant implemention for this function +bool TextClassDataFeed::ReadBatch() { + paddle::framework::Vector offset; + int tlen = 0; + int llen = 0; + int inst_idx = 0; + offset.resize(batch_size_ + 1); + offset[0] = 0; + while (inst_idx < batch_size_) { + int ptr_offset = 0; + if (file_content_buffer_ptr_ - file_content_buffer_ >= file_size_) { + break; + } + + memcpy(reinterpret_cast(&llen), + file_content_buffer_ptr_ + ptr_offset, + sizeof(int)); + ptr_offset += sizeof(int); + + memcpy(reinterpret_cast(batch_id_buffer_ + tlen), + file_content_buffer_ptr_ + ptr_offset, + llen * sizeof(int)); + tlen += llen; + + offset[inst_idx + 1] = offset[inst_idx] + llen; + ptr_offset += sizeof(int) * llen; + + memcpy(reinterpret_cast(label_ptr_ + inst_idx), + file_content_buffer_ptr_ + ptr_offset, + sizeof(int)); + ptr_offset += sizeof(int); + + file_content_buffer_ptr_ += ptr_offset; + inst_idx++; + } + + if (inst_idx != batch_size_) { + return false; + } + + LoD input_lod{offset}; + paddle::framework::Vector label_offset; + label_offset.resize(batch_size_ + 1); + for (int i = 0; i <= batch_size_; ++i) { + label_offset[i] = i; + } + + LoD label_lod{label_offset}; + int64_t* input_ptr = feed_vec_[0]->mutable_data( + {static_cast(offset.back()), 1}, + platform::CPUPlace()); + int64_t* label_ptr = feed_vec_[1]->mutable_data({batch_size_, 1}, + platform::CPUPlace()); + for (unsigned int i = 0; i < offset.back(); ++i) { + input_ptr[i] = static_cast(batch_id_buffer_[i]); + } + for (int i = 0; i < batch_size_; ++i) { + label_ptr[i] = static_cast(label_ptr_[i]); + } + feed_vec_[0]->set_lod(input_lod); + feed_vec_[1]->set_lod(label_lod); + return true; +} + +void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) { + for (unsigned int i = 0; i < use_slot_alias_.size(); ++i) { + if (name == use_slot_alias_[i]) { + feed_vec_[i] = feed->GetMutable(); + } + } +} + +bool TextClassDataFeed::SetFile(const char* filename) { + // termnum termid termid ... termid label + int filesize = ReadWholeFile(filename, file_content_buffer_); + // todo , remove magic number + if (filesize < 0 || filesize >= 1024 * 1024 * 1024) { + return false; + } + file_content_buffer_ptr_ = file_content_buffer_; + file_size_ = filesize; + return true; +} + +int TextClassDataFeed::ReadWholeFile(const std::string& filename, + char* buffer) { + std::ifstream ifs(filename.c_str(), std::ios::binary); + if (ifs.fail()) { + return -1; + } + + ifs.seekg(0, std::ios::end); + int file_size = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + ifs.read(buffer, file_size); + return file_size; +} + +} // namespace framework +} // namespace paddle +/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */ + diff --git a/paddle/fluid/framework/data_feed.h.yebaiwei b/paddle/fluid/framework/data_feed.h.yebaiwei new file mode 100644 index 00000000000..47051360df1 --- /dev/null +++ b/paddle/fluid/framework/data_feed.h.yebaiwei @@ -0,0 +1,368 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_ +#define PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_ + +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include // NOLINT +#include +#include +#include // NOLINT +#include +#include +#include + +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace framework { + +class MixTensor { + public: + MixTensor(LoDTensor* lodtensor) { + is_dense_ = false; + lodtensor_ = lodtensor; + } + MixTensor(Tensor* tensor) { + is_dense_ = true; + tensor_ = tensor; + } + bool IsDense() {return is_dense_;} + LoDTensor* GetLoDTensor(){ + if (is_dense_) { + LOG(ERROR) << "error: let a dense var return a LoDTensor ptr"; + return NULL; + } + return lodtensor_; + } + Tensor* GetTensor(){ + if (!is_dense_) { + LOG(ERROR) << "error: let a sparse var return a Tensor ptr"; + return NULL; + } + return tensor_; + } + private: + bool is_dense_; + LoDTensor* lodtensor_; + Tensor* tensor_; +}; + +template +class BlockingQueue { + public: + BlockingQueue() : capacity_(32) {} + explicit BlockingQueue(size_t capacity) + : capacity_(capacity), closed_(false) { + size_.store(0); + } + + bool Send(const T& elem) { + int c = -1; + { + std::unique_lock lock(send_mutex_); + send_cv_.wait(lock, [&] {return size_.load() < capacity_ || closed_;}); + if (closed_) { + VLOG(5) + << "WARNING: Sending an element to a closed reader::BlokcingQueue."; + return false; + } + queue_.push_back(elem); + c = size_.load(); + size_.fetch_add(1); + } + if (c + 1 < capacity_) { + send_cv_.notify_one(); + } + + if (c == 0) { + std::unique_lock lock(receive_mutex_); + receive_cv_.notify_one(); + } + return true; + } + + bool Receive(T* elem) { + int c = -1; + { + std::unique_lock lock(receive_mutex_); + receive_cv_.wait(lock, [&] {return size_.load() != 0 || closed_;}); + if (size_.load() != 0) { + *elem = queue_.front(); + queue_.pop_front(); + c = size_.load(); + size_.fetch_sub(1); + } else { + return false; + } + } + if (c > 1) { + receive_cv_.notify_one(); + } + if (c == capacity_) { + std::unique_lock lock(send_mutex_); + send_cv_.notify_one(); + } + return true; + } + + void Close() { + { + std::lock_guard lock1(send_mutex_); + std::lock_guard lock2(receive_mutex_); + closed_ = true; + } + send_cv_.notify_all(); + receive_cv_.notify_all(); + } + + bool IsClosed() const { + std::lock_guard lock1(send_mutex_); + std::lock_guard lock2(receive_mutex_); + return closed_; + } + + size_t Cap() const { + return capacity_; + } + + size_t Size() const { + return size_.load(); + } + + private: + size_t capacity_; + std::atomic_size_t size_; + bool closed_; + std::deque queue_; + + mutable std::mutex send_mutex_; + mutable std::mutex receive_mutex_; + mutable std::condition_variable send_cv_; + mutable std::condition_variable receive_cv_; +}; + +class DataFeed { + public: + DataFeed() {} + virtual ~DataFeed() {} + virtual void Init() = 0; + // for some datafeeds may not be able to implement this interface + virtual bool CheckFile(const char* filename) { + LOG(ERROR) << "error: The function CheckFile is not implemented"; + return false; + } + virtual bool SetFileList(const std::vector& files); + virtual bool Start() = 0; + virtual bool Next() = 0; + virtual void SetBatchSize(int batch) { default_batch_size_ = batch; } + virtual int GetBatchSize() { return batch_size_; } + // for subclass with queue + virtual void SetQueueSize(int queue_size) { + LOG(ERROR) << "error: The function SetQueueSize is not implemented"; + } + // for subclass with buffer + virtual void SetBufferSize(int buffer_size) { + LOG(ERROR) << "error: The function SetBufferSize is not implemented"; + } + virtual const std::vector& GetAllSlots() {return all_slots_;} + virtual const std::vector& GetUseSlots() {return use_slots_;} + std::vector& GetFeedVec() {return feed_vec_;} + virtual void AddFeedVar(Variable* var, const std::string& name); + protected: + // Check if it is executed in this order: + // Init -> SetFileList/BindingMemory -> Start -> Next + virtual bool CheckInit(); + virtual bool CheckSetFileList(); + virtual bool CheckStart(); + virtual bool PickOneFile(std::string& filename); + + static std::vector filelist_; + static size_t file_idx_; + static std::mutex mutex_for_pick_file_; + + std::vector use_slots_; + std::vector use_slots_is_dense_; + + std::vector all_slots_; + std::vector all_slots_type_; + std::vector use_slots_index_; // -1: not used; >=0: the index of use_slots_ + + std::vector feed_vec_; + + int default_batch_size_; + int batch_size_; + + bool finish_init_; + bool finish_set_filelist_; + bool finish_binding_memory_; + bool finish_start_; +}; + +template +class PrivateQueueDataFeed : public DataFeed { + public: + PrivateQueueDataFeed() {} + virtual ~PrivateQueueDataFeed() {} + virtual void Init() = 0; + virtual bool Start(); + virtual bool Next(); // no buffer + virtual void SetQueueSize(int queue_size) {queue_size_ = queue_size;} + + protected: + virtual void ReadThread(); + virtual bool ParseOneInstance(std::vector& instance) = 0; + virtual void PutToFeedVec(std::vector& ins_vec) = 0; + + std::thread read_thread_; // the thread for read files + /* using ifstream one line and one line parse is faster + * than using fread one buffer and one buffer parse. + * for 601M JingPai data: + * ifstream one line and one line parse: 6034 ms + * fread one buffer and one buffer parse: 7097 ms */ + std::ifstream file_; + size_t queue_size_; + // The elements in the queue are one piece of data, + // with multiple fields in each piece of data + BlockingQueue> queue_; +}; + +class MultiSlotType { + public: + MultiSlotType() { + float_feasign_.clear(); + uint64_feasign_.clear(); + offset_.resize(1); + offset_[0] = 0; + } + void SetType(std::string& type) { + if (type != "uint64" && type != "float") { + // check in this + LOG(ERROR) << "error: here is no this type"; + exit(0); + } + type_ = type; + } + void AddValue(float v) { + if (!CheckFloat()) {return;} + float_feasign_.push_back(v); + } + void AddValue(uint64_t v) { + if (!CheckUint64()) {return;} + uint64_feasign_.push_back(v); + } + void AddIns(MultiSlotType& ins) { + if (ins.GetType()[0] == 'f') { //float + if (!CheckFloat()) {return;} + auto& vec = ins.GetFloatData(); + offset_.push_back(offset_.back() + vec.size()); + float_feasign_.insert(float_feasign_.end(), vec.begin(), vec.end()); + } else if (ins.GetType()[0] == 'u') { //uint64 + if (!CheckUint64()) {return;} + auto& vec = ins.GetUint64Data(); + offset_.push_back(offset_.back() + vec.size()); + uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end()); + } + } + std::string& GetType() { + return type_; + } + std::vector& GetFloatData() { + return float_feasign_; + } + std::vector& GetUint64Data() { + return uint64_feasign_; + } + std::vector& GetOffset() { + return offset_; + } + private: + bool CheckFloat() { + if (type_[0] != 'f') { //float + LOG(ERROR) << "error: add " << type_ << " value to float slot"; + return false; + } + return true; + } + bool CheckUint64() { + if (type_[0] != 'u') { //uint64 + LOG(ERROR) << "error: add " << type_ << " value to uint64 slot"; + return false; + } + return true; + } + std::string type_; + std::vector float_feasign_; + std::vector uint64_feasign_; + std::vector offset_; +}; + +class MultiSlotDataFeed : public PrivateQueueDataFeed> { + public: + MultiSlotDataFeed() {} + virtual ~MultiSlotDataFeed() {} + virtual void Init(); + //TODO: virtual bool CheckFile(); + protected: + virtual bool ParseOneInstance(std::vector& instance); + virtual void PutToFeedVec(std::vector& ins_vec); +}; + + +//TODO: to be deleted +class TextClassDataFeed : public DataFeed { + public: + virtual ~TextClassDataFeed() {} + virtual void Init(); + virtual bool Start() {return false;}; //TODO + virtual bool Next() {return false;}; //TODO + virtual bool ReadBatch(); + virtual void AddFeedVar(Variable* feed, const std::string& name); + virtual void BindScope(Scope* scope) {} + virtual bool SetFile(const char* filename); + + virtual bool CheckFile(const char* filename) { + // TODO(xxx) + return false; + } + + void SetBatchSize(int batch) {batch_size_ = batch;} + + private: + int ReadWholeFile(const std::string& filename, char* buffer); + char* file_content_buffer_; + char* file_content_buffer_ptr_; + int* batch_id_buffer_; + int* label_ptr_; + int file_size_; + std::vector names_; + std::shared_ptr file_content_buffer_host_; + std::shared_ptr batch_id_host_; + std::shared_ptr label_host_; +}; + +} // namespace framework +} // namespace paddle + +#endif // PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_ +/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */ diff --git a/paddle/fluid/framework/data_feed.proto.yebaiwei b/paddle/fluid/framework/data_feed.proto.yebaiwei new file mode 100644 index 00000000000..4b87f850ec1 --- /dev/null +++ b/paddle/fluid/framework/data_feed.proto.yebaiwei @@ -0,0 +1,32 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +syntax = "proto2"; +package paddle; + +message DataFeedDesc { + optional string name = 1; + optional int32 batch = 2 [default = 32]; + optional MultiSlotDesc multi_slot_desc = 3; +} + +message MultiSlotDesc { + repeated Slot slots = 1; +} + +message Slot { + required string name = 1; + required string type = 2; + optional bool dense = 3 [default = 0]; + optional bool use = 4 [default = 1]; +} -- GitLab