提交 92a98ca7 编写于 作者: B barrierye

add MultiSlotDataFeed

上级 78c3380b
...@@ -34,221 +34,241 @@ limitations under the License. */ ...@@ -34,221 +34,241 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
DEFINE_bool(is_text_feed, false, "is_text_feed");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::vector<std::string> TextClassDataFeed::s_filelist_;
std::mutex TextClassDataFeed::s_locker_for_pick_file_;
unsigned int TextClassDataFeed::s_current_file_idx_ = 0;
size_t TextClassDataFeed::s_current_finished_file_cnt_ = 0;
unsigned int TextClassDataFeed::s_current_epoch_ = 0;
int TextClassDataFeed::s_current_save_epoch_ = 0;
std::mutex TextClassDataFeed::s_locker_epoch_start_;
std::condition_variable TextClassDataFeed::s_condition_epoch_start_;
bool TextClassDataFeed::s_epoch_start_flag_ = false;
void TextClassDataFeed::Init() { std::vector<std::string> DataFeed::filelist_;
// hard coding for a specific datafeed size_t DataFeed::file_idx_;
feed_vec_.resize(2); std::mutex DataFeed::mutex_for_pick_file_;
// feed_vec_[0].reset(new LoDTensor);
// feed_vec_[1].reset(new LoDTensor); void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
all_slot_ids_ = {0, 1}; if (CheckInit() == false) {return;}
use_slot_ids_ = {0, 1}; for (size_t i = 0; i < use_slots_.size(); ++i) {
use_slot_alias_ = {"words", "label"}; if (name == use_slots_[i]) {
if (use_slots_is_dense_[i]) {
feed_vec_[i] = MixTensor(var->GetMutable<Tensor>());
} else {
feed_vec_[i] = MixTensor(var->GetMutable<LoDTensor>());
}
}
}
}
file_content_buffer_host_.reset(new char[200*1024*1024], bool DataFeed::SetFileList(const std::vector<std::string>& files) {
[](char *p) {delete[] p;}); if (CheckInit() == false) {return false;}
file_content_buffer_ = file_content_buffer_host_.get(); if (files.size() == 0) {
file_content_buffer_ptr_ = file_content_buffer_; LOG(ERROR) << "error: you have set an empty filelist";
return false;
}
filelist_.assign(files.begin(), files.end());
file_idx_ = 0;
batch_id_host_.reset(new int[10240*1024], finish_set_filelist_ = true;
[](int *p) {delete[] p;}); // max word num in a batch return true;
batch_id_buffer_ = batch_id_host_.get(); }
label_host_.reset(new int[10240], bool DataFeed::PickOneFile(std::string& filename) {
[](int *p) {delete[] p;}); // max label in a batch std::unique_lock<std::mutex> lock(mutex_for_pick_file_);
label_ptr_ = label_host_.get(); if (file_idx_ == filelist_.size()) {
return false;
}
filename = filelist_[file_idx_++];
return true;
}
field_names_.clear(); bool DataFeed::CheckInit() {
if (finish_init_) {return true;}
LOG(ERROR) << "error: initialization did not succeed";
return false;
} }
TextClassDataFeed::TextClassDataFeed() { bool DataFeed::CheckSetFileList() {
Init(); if (finish_set_filelist_) {return true;}
LOG(ERROR) << "error: set filelist did not succeed";
return false;
} }
// todo: use elegant implemention for this function bool DataFeed::CheckStart() {
bool TextClassDataFeed::ReadBatch() { if (finish_start_) {return true;}
paddle::framework::Vector<size_t> offset; LOG(ERROR) << "error: Datafeed has not started running yet";
int tlen = 0; return false;
int llen = 0; }
int inst_idx = 0;
offset.resize(batch_size_ + 1);
offset[0] = 0;
while (inst_idx < batch_size_) { template<typename T>
int ptr_offset = 0; void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
if (file_content_buffer_ptr_ - file_content_buffer_ >= file_size_) { if (!CheckInit()) {return;}
break; if (queue_size <= 0) {
LOG(ERROR) << "error: illegal queue size: " << queue_size;
return;
} }
queue_size_ = queue_size;
queue_.ReCap(queue_size_);
}
memcpy(reinterpret_cast<char *>(&llen), template<typename T>
file_content_buffer_ptr_ + ptr_offset, bool PrivateQueueDataFeed<T>::Start() {
sizeof(int)); if (!(CheckSetFileList())) {return false;}
ptr_offset += sizeof(int); read_thread_ = std::thread(&PrivateQueueDataFeed::ReadThread, this);
read_thread_.detach();
memcpy(reinterpret_cast<char *>(batch_id_buffer_ + tlen),
file_content_buffer_ptr_ + ptr_offset,
llen * sizeof(int));
tlen += llen;
offset[inst_idx + 1] = offset[inst_idx] + llen;
ptr_offset += sizeof(int) * llen;
memcpy(reinterpret_cast<char *>(label_ptr_ + inst_idx), finish_start_ = true;
file_content_buffer_ptr_ + ptr_offset, return true;
sizeof(int)); }
ptr_offset += sizeof(int);
file_content_buffer_ptr_ += ptr_offset; template<typename T>
inst_idx++; void PrivateQueueDataFeed<T>::ReadThread(){
std::string filename;
while (PickOneFile(filename)) {
file_.open(filename.c_str()); // is_text_feed
if (!file_.is_open()) {
LOG(ERROR) << "error: open file<" << filename << "> fail";
} }
T instance;
if (inst_idx != batch_size_) { while (ParseOneInstance(instance)) {
return false; queue_.Send(instance);
} }
file_.close();
LoD input_lod{offset};
paddle::framework::Vector<size_t> label_offset;
label_offset.resize(batch_size_ + 1);
for (int i = 0; i <= batch_size_; ++i) {
label_offset[i] = i;
} }
queue_.Close();
}
LoD label_lod{label_offset}; template<typename T>
int64_t* input_ptr = feed_vec_[0]->mutable_data<int64_t>( bool PrivateQueueDataFeed<T>::Next(){
{static_cast<int64_t>(offset.back()), 1}, if (!CheckStart()) {return false;}
platform::CPUPlace()); int index = 0;
int64_t* label_ptr = feed_vec_[1]->mutable_data<int64_t>({batch_size_, 1}, T instance;
platform::CPUPlace()); T ins_vec(use_slots_.size());
for (unsigned int i = 0; i < offset.back(); ++i) { while (index < default_batch_size_) {
input_ptr[i] = static_cast<int64_t>(batch_id_buffer_[i]); if (!queue_.Receive(&instance)) {
break;
} }
for (int i = 0; i < batch_size_; ++i) { AddInstanceToInsVec(ins_vec, instance, index++);
label_ptr[i] = static_cast<int64_t>(label_ptr_[i]);
} }
feed_vec_[0]->set_lod(input_lod); batch_size_ = index;
feed_vec_[1]->set_lod(label_lod); PutToFeedVec(ins_vec);
return true; return batch_size_ != 0;
} }
TextClassDataFeed::TextClassDataFeed(const TextClassDataFeed& data_feed) { void MultiSlotDataFeed::Init(paddle::DataFeedDesc& data_feed_desc) {
Init(); finish_init_ = false;
SetBatchSize(data_feed.batch_size_); finish_set_filelist_ = false;
SetFieldNames(data_feed.field_names_); finish_start_ = false;
} if (!data_feed_desc.has_multi_slot_desc()){
LOG(ERROR) << "error: multi_slot_desc has not been set";
void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) { return ;
for (unsigned int i = 0; i < use_slot_alias_.size(); ++i) {
if (name == use_slot_alias_[i]) {
feed_vec_[i] = feed->GetMutable<LoDTensor>();
} }
paddle::MultiSlotDesc multi_slot_desc = data_feed_desc.multi_slot_desc();
size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num);
use_slots_index_.resize(all_slot_num);
use_slots_.clear();
use_slots_is_dense_.clear();
for (size_t i = 0; i < all_slot_num; ++i) {
auto& slot = multi_slot_desc.slots(i);
all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.use() ? use_slots_.size() : -1;
if (slot.use()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.dense());
} }
}
void TextClassDataFeed::SetFileList(const char* filelist) {
s_filelist_.clear();
std::ifstream fin(filelist);
PADDLE_ENFORCE(fin.good(),
"Opening file %s fail",
filelist);
std::string filename;
while (fin >> filename) {
LOG(ERROR) << "add " << filename.c_str() << " to filelist";
s_filelist_.push_back(filename);
} }
fin.close(); feed_vec_.resize(use_slots_.size());
}
void TextClassDataFeed::SetFieldNames( finish_init_ = true;
const std::vector<std::string>& field_names) {
field_names_.clear();
field_names_.insert(field_names_.end(), field_names.begin(),
field_names.end());
} }
bool TextClassDataFeed::SetFile(const char* filename) { bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>& instance) {
// termnum termid termid ... termid label std::string line;
std::ifstream ifs(filename, std::ios::binary); if (getline(file_, line)) {
if (ifs.fail()) { int use_slots_num = use_slots_.size();
return false; instance.resize(use_slots_num);
//parse line
const char* str = line.c_str();
char* endptr = (char*)str;
int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = (int)strtol(&str[pos], &endptr, 10);
if (num == 0) {
LOG(ERROR) << "error: the number of ids can not be zero, you need padding it";
exit(-1);
} }
if (idx != -1) {
ifs.seekg(0, std::ios::end); instance[idx].SetType(all_slots_type_[i]);
int filesize = ifs.tellg(); if (instance[idx].GetType()[0] == 'f') { // float
ifs.seekg(0, std::ios::beg); for (int j = 0; j < num; ++j) {
ifs.read(file_content_buffer_, filesize); float feasign = (float)strtof(endptr, &endptr);
if (filesize < 0 || filesize >= 1024 * 1024 * 1024) { instance[idx].AddValue(feasign);
}
} else if (instance[idx].GetType()[0] == 'u'){ // uint64
for (int j = 0; j < num; ++j) {
uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
instance[idx].AddValue(feasign);
}
}
pos = endptr - str;
} else {
for (int j = 0; j <= num; ++j) {
pos = line.find_first_of(' ', pos + 1);
}
}
}
} else {
return false; return false;
} }
file_content_buffer_ptr_ = file_content_buffer_;
file_size_ = filesize;
// todo , remove magic number
return true; return true;
} }
void TextClassDataFeed::UpdateEpochNum() { void MultiSlotDataFeed::AddInstanceToInsVec(std::vector<MultiSlotType>& ins_vec,
s_current_finished_file_cnt_++; std::vector<MultiSlotType>& instance, int index) {
if (index == 0) {
if (s_current_finished_file_cnt_ >= s_filelist_.size()) { for (size_t i = 0; i < instance.size(); ++i) {
s_current_finished_file_cnt_ = 0; ins_vec[i].SetType(instance[i].GetType());
s_current_epoch_++;
#if 1
LOG(WARNING) << "UpdateEpochNum: epoch = " << s_current_epoch_;
#endif
{
std::lock_guard<std::mutex> lock(s_locker_epoch_start_);
s_epoch_start_flag_ = false;
} }
} }
} for (size_t i = 0; i < instance.size(); ++i){
ins_vec[i].AddIns(instance[i]);
void TextClassDataFeed::StartOneEpoch() {
std::lock_guard<std::mutex> lock(s_locker_for_pick_file_);
std::random_shuffle(s_filelist_.begin(), s_filelist_.end());
s_current_file_idx_ = 0;
LOG(INFO) << "Beginning epoch " << s_current_epoch_;
{
std::lock_guard<std::mutex> lock(s_locker_epoch_start_);
s_epoch_start_flag_ = true;
} }
s_condition_epoch_start_.notify_all();
} }
void MultiSlotDataFeed::PutToFeedVec(std::vector<MultiSlotType>& ins_vec) {
void TextClassDataFeed::WaitNextEpoch() { for (size_t i = 0; i < use_slots_.size(); ++i) {
std::unique_lock<std::mutex> lock(s_locker_epoch_start_); auto& type = ins_vec[i].GetType();
s_condition_epoch_start_.wait(lock, []{return s_epoch_start_flag_;}); auto& offset = ins_vec[i].GetOffset();
} int total_instance = static_cast<int>(offset.back());
if (type[0] == 'f') { // float
const char* TextClassDataFeed::PickOneFile() { auto& feasign = ins_vec[i].GetFloatData();
std::string file_to_be_processed; if (feed_vec_[i].IsDense()) {
std::lock_guard<std::mutex> lock(s_locker_for_pick_file_); int size_in_each_batch = total_instance / batch_size_;
float* tensor_ptr = feed_vec_[i].GetTensor()->
// One epoch has run over mutable_data<float>({batch_size_, size_in_each_batch}, platform::CPUPlace());
// Wait for next epoch memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
if (s_current_file_idx_ >= s_filelist_.size()) { } else {
LOG(ERROR) << "thread " << thread_id_ float* tensor_ptr = feed_vec_[i].GetLoDTensor()->
<< ": finish traing for epoch " << s_current_epoch_ + 1; mutable_data<float>({total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
return NULL; LoD data_lod{offset};
feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
}
} else if (type[0] == 'u') { // uint64
// no uint64_t type
auto& feasign = ins_vec[i].GetUint64Data();
if (feed_vec_[i].IsDense()) {
int size_in_each_batch = total_instance / batch_size_;
int64_t* tensor_ptr = feed_vec_[i].GetTensor()->
mutable_data<int64_t>({batch_size_, size_in_each_batch}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
} else {
int64_t* tensor_ptr = feed_vec_[i].GetLoDTensor()->
mutable_data<int64_t>({total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(uint64_t));
LoD data_lod{offset};
feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
}
}
} }
file_to_be_processed = s_filelist_[s_current_file_idx_];
s_current_file_idx_++;
return file_to_be_processed.c_str();
} }
} // namespace framework } // namespace framework
......
...@@ -27,136 +27,335 @@ limitations under the License. */ ...@@ -27,136 +27,335 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <condition_variable> // NOLINT #include <condition_variable> // NOLINT
#include <fstream> #include <fstream>
#include <deque>
#include <atomic>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/data_feed.pb.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct Gauc {
int show, click;
uint64_t fea;
std::string lineid;
};
struct Instance { class MixTensor {
std::vector<std::vector<uint64_t>> feed_vec_buffer; public:
std::vector<std::vector<int>> feed_vec_lod; MixTensor(){}
std::vector<float> other_label; MixTensor(LoDTensor* lodtensor) {
std::vector<Gauc> gauc_vec; is_dense_ = false;
lodtensor_ = lodtensor;
}
MixTensor(Tensor* tensor) {
is_dense_ = true;
tensor_ = tensor;
}
bool IsDense() {return is_dense_;}
LoDTensor* GetLoDTensor(){
if (is_dense_) {
LOG(ERROR) << "error: let a dense var return a LoDTensor ptr";
return NULL;
}
return lodtensor_;
}
Tensor* GetTensor(){
if (!is_dense_) {
LOG(ERROR) << "error: let a sparse var return a Tensor ptr";
return NULL;
}
return tensor_;
}
private:
bool is_dense_;
LoDTensor* lodtensor_;
Tensor* tensor_;
}; };
class DataFeed { template<typename T>
class BlockingQueue {
public: public:
DataFeed() : default_batch_size_(1), batch_size_(0), thread_id_(0) {} explicit BlockingQueue(size_t capacity = 32)
virtual ~DataFeed() {} : capacity_(capacity), closed_(false) {
virtual void Init() = 0; size_.store(0);
/*
* This function will be used to check file format.
* Considering that this function may be used alone,
* it does not check anything.
* */
virtual bool CheckFile(const char* filename) = 0;
virtual bool SetFile(const char* filename) = 0;
virtual bool ReadBatch() = 0;
virtual const std::vector<uint16_t>& GetAllSlotIds() {
return all_slot_ids_;
} }
virtual const std::vector<uint16_t>& GetUseSlotIds() { void ReCap(size_t capacity) {
return use_slot_ids_; capacity_ = capacity;
} }
virtual const std::vector<std::string>& GetUseSlotAlias() { bool Send(const T& elem) {
return use_slot_alias_; int c = -1;
{
std::unique_lock<std::mutex> lock(send_mutex_);
send_cv_.wait(lock, [&] {return size_.load() < capacity_ || closed_;});
if (closed_) {
VLOG(5)
<< "WARNING: Sending an element to a closed reader::BlokcingQueue.";
return false;
}
queue_.push_back(elem);
c = size_.load();
size_.fetch_add(1);
}
if (c + 1 < capacity_) {
send_cv_.notify_one();
} }
virtual void AddFeedVar(Variable* var, if (c == 0) {
const std::string& name) = 0; std::unique_lock<std::mutex> lock(receive_mutex_);
virtual void BindScope(Scope* scope) = 0; receive_cv_.notify_one();
virtual void SetBatchSize(int batch) { default_batch_size_ = batch; } }
virtual int GetBatchSize() { return batch_size_; } return true;
virtual void SetBufferSize(int buffer_size) {} }
virtual unsigned int GetCurrentEpoch() = 0;
virtual const char *PickOneFile() = 0;
virtual void UpdateEpochNum() = 0;
virtual void StartOneEpoch() = 0;
virtual void WaitNextEpoch() = 0;
std::vector<LoDTensor*>& GetFeedVec() { bool Receive(T* elem) {
return feed_vec_; int c = -1;
{
std::unique_lock<std::mutex> lock(receive_mutex_);
receive_cv_.wait(lock, [&] {return size_.load() != 0 || closed_;});
if (size_.load() != 0) {
*elem = queue_.front();
queue_.pop_front();
c = size_.load();
size_.fetch_sub(1);
} else {
return false;
}
}
if (c > 1) {
receive_cv_.notify_one();
}
if (c == capacity_) {
std::unique_lock<std::mutex> lock(send_mutex_);
send_cv_.notify_one();
}
return true;
} }
virtual std::vector<LoDTensor*>& GetFeedVec(const Instance& ins) { void Close() {
LOG(ERROR) << "use defalut get_feed_vec"; std::lock_guard<std::mutex> lock1(send_mutex_);
return feed_vec_; std::lock_guard<std::mutex> lock2(receive_mutex_);
closed_ = true;
send_cv_.notify_all();
receive_cv_.notify_all();
} }
int GetThreadId() {return thread_id_;} private:
void SetThreadId(int thread_id) {thread_id_ = thread_id;} size_t capacity_;
std::atomic_size_t size_;
bool closed_;
std::deque<T> queue_;
mutable std::mutex send_mutex_;
mutable std::mutex receive_mutex_;
mutable std::condition_variable send_cv_;
mutable std::condition_variable receive_cv_;
};
class DataFeed {
public:
DataFeed() {}
virtual ~DataFeed() {}
virtual void Init(paddle::DataFeedDesc& data_feed_desc) = 0;
// for some datafeeds may not be able to implement this interface
virtual bool CheckFile(const char* filename) {
LOG(ERROR) << "error: The function CheckFile is not implemented";
return false;
}
virtual bool SetFileList(const std::vector<std::string>& files);
virtual bool Start() = 0;
virtual bool Next() = 0;
virtual void SetBatchSize(int batch) { default_batch_size_ = batch; }
virtual int GetBatchSize() { return batch_size_; }
// for subclass with queue
virtual void SetQueueSize(int queue_size) {
LOG(ERROR) << "error: The function SetQueueSize is not implemented";
}
// for subclass with buffer
virtual void SetBufferSize(int buffer_size) {
LOG(ERROR) << "error: The function SetBufferSize is not implemented";
}
virtual const std::vector<std::string>& GetAllSlots() {return all_slots_;}
virtual const std::vector<std::string>& GetUseSlots() {return use_slots_;}
std::vector<MixTensor>& GetFeedVec() {return feed_vec_;}
virtual void AddFeedVar(Variable* var, const std::string& name);
protected: protected:
std::vector<uint16_t> all_slot_ids_; // Check if it is executed in this order:
std::vector<uint16_t> use_slot_ids_; // Init -> SetFileList/BindingMemory -> Start -> Next
std::vector<std::string> use_slot_alias_; virtual bool CheckInit();
std::vector<LoDTensor*> feed_vec_; virtual bool CheckSetFileList();
virtual bool CheckStart();
virtual bool PickOneFile(std::string& filename);
static std::vector<std::string> filelist_;
static size_t file_idx_;
static std::mutex mutex_for_pick_file_;
std::vector<std::string> use_slots_;
std::vector<bool> use_slots_is_dense_;
std::vector<std::string> all_slots_;
std::vector<std::string> all_slots_type_;
std::vector<int> use_slots_index_; // -1: not used; >=0: the index of use_slots_
std::vector<MixTensor> feed_vec_;
int default_batch_size_; int default_batch_size_;
int batch_size_; int batch_size_;
int thread_id_;
bool finish_init_;
bool finish_set_filelist_;
bool finish_binding_memory_;
bool finish_start_;
}; };
class TextClassDataFeed : public DataFeed { template<typename T>
class PrivateQueueDataFeed : public DataFeed {
public: public:
TextClassDataFeed(); PrivateQueueDataFeed() {}
TextClassDataFeed(const TextClassDataFeed& data_feed); virtual ~PrivateQueueDataFeed() {}
virtual void Init(paddle::DataFeedDesc& data_feed_desc) = 0;
virtual bool Start();
virtual bool Next(); // no buffer
virtual void SetQueueSize(int queue_size);
protected:
virtual void ReadThread();
virtual bool ParseOneInstance(T& instance) = 0;
virtual void AddInstanceToInsVec(T& vec_ins, T& instance, int index) = 0;
virtual void PutToFeedVec(T& ins_vec) = 0;
std::thread read_thread_; // the thread for read files
/* using ifstream one line and one line parse is faster
* than using fread one buffer and one buffer parse.
* for 601M JingPai data:
* ifstream one line and one line parse: 6034 ms
* fread one buffer and one buffer parse: 7097 ms */
std::ifstream file_;
size_t queue_size_;
// The elements in the queue are one piece of data,
// with multiple fields in each piece of data
BlockingQueue<T> queue_;
};
class MultiSlotType {
public: public:
virtual ~TextClassDataFeed() {} MultiSlotType() {
virtual void Init(); float_feasign_.clear();
virtual bool ReadBatch(); uint64_feasign_.clear();
virtual void AddFeedVar(Variable* feed, const std::string& name); offset_.resize(1);
virtual void BindScope(Scope* scope) {} offset_[0] = 0;
virtual bool SetFile(const char* filename); }
virtual bool CheckFile(const char* filename) { ~MultiSlotType() {}
// TODO(xxx) void SetType(std::string& type) {
if (!CheckType(type)) {return;}
type_ = type;
}
std::vector<size_t>& GetOffset() {
return offset_;
}
void AddValue(float v) {
if (!CheckFloat()) {return;}
float_feasign_.push_back(v);
}
void AddValue(uint64_t v) {
if (!CheckUint64()) {return;}
uint64_feasign_.push_back(v);
}
void AddIns(MultiSlotType& ins) {
if (ins.GetType()[0] == 'f') { //float
if (!CheckFloat()) {return;}
auto& vec = ins.GetFloatData();
offset_.push_back(offset_.back() + vec.size());
float_feasign_.insert(float_feasign_.end(), vec.begin(), vec.end());
} else if (ins.GetType()[0] == 'u') { //uint64
if (!CheckUint64()) {return;}
auto& vec = ins.GetUint64Data();
offset_.push_back(offset_.back() + vec.size());
uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end());
}
}
std::vector<float>& GetFloatData() {
return float_feasign_;
}
std::vector<uint64_t>& GetUint64Data() {
return uint64_feasign_;
}
std::string& GetType() {
return type_;
}
private:
bool CheckType(std::string& type) {
if (type != "uint64" && type != "float") {
// check in here
LOG(ERROR) << "error: here is no this type";
return false; return false;
} }
void SetBatchSize(int batch) {batch_size_ = batch;} return true;
unsigned int GetCurrentEpoch() {return s_current_epoch_;} }
void UpdateEpochNum(); bool CheckFloat() {
void StartOneEpoch(); if (type_[0] != 'f') { //float
void WaitNextEpoch(); LOG(ERROR) << "error: add " << type_ << " value to float slot";
return false;
}
return true;
}
bool CheckUint64() {
if (type_[0] != 'u') { //uint64
LOG(ERROR) << "error: add " << type_ << " value to uint64 slot";
return false;
}
return true;
}
std::vector<float> float_feasign_;
std::vector<uint64_t> uint64_feasign_;
std::string type_;
std::vector<size_t> offset_;
};
class MultiSlotDataFeed : public PrivateQueueDataFeed<std::vector<MultiSlotType>> {
public: public:
void SetFieldNames(const std::vector<std::string>& field_names); MultiSlotDataFeed() {}
virtual ~MultiSlotDataFeed() {}
virtual void Init(paddle::DataFeedDesc& data_feed_desc);
//TODO: virtual bool CheckFile();
protected:
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>& vec_ins,
std::vector<MultiSlotType>& instance, int index);
virtual bool ParseOneInstance(std::vector<MultiSlotType>& instance);
virtual void PutToFeedVec(std::vector<MultiSlotType>& ins_vec);
};
//TODO: to be deleted
class TextClassDataFeed : public DataFeed {
public: public:
static void SetFileList(const char* filelist); virtual ~TextClassDataFeed() {}
virtual void Init(paddle::DataFeedDesc& data_feed_desc) {}
virtual bool Start() {return false;}; //TODO
virtual bool Next() {return false;}; //TODO
virtual bool ReadBatch() {return false;}
virtual void AddFeedVar(Variable* feed, const std::string& name) {}
virtual void BindScope(Scope* scope) {}
virtual bool SetFile(const char* filename) {return false;}
private: virtual bool CheckFile(const char* filename) {
const char* PickOneFile(); // TODO(xxx)
return false;
}
void SetBatchSize(int batch) {batch_size_ = batch;}
private: private:
int ReadWholeFile(const std::string& filename, char* buffer) {return -1;}
char* file_content_buffer_; char* file_content_buffer_;
char* file_content_buffer_ptr_; char* file_content_buffer_ptr_;
int* batch_id_buffer_; int* batch_id_buffer_;
int* label_ptr_; int* label_ptr_;
int file_size_; int file_size_;
std::vector<std::string> field_names_; std::vector<std::string> names_;
std::shared_ptr<char> file_content_buffer_host_; std::shared_ptr<char> file_content_buffer_host_;
std::shared_ptr<int> batch_id_host_; std::shared_ptr<int> batch_id_host_;
std::shared_ptr<int> label_host_; std::shared_ptr<int> label_host_;
static std::vector<std::string> s_filelist_;
static std::mutex s_locker_for_pick_file_;
static unsigned int s_current_file_idx_;
static size_t s_current_finished_file_cnt_;
static unsigned int s_current_epoch_;
static int s_current_save_epoch_;
static std::mutex s_locker_epoch_start_;
static std::condition_variable s_condition_epoch_start_;
static bool s_epoch_start_flag_;
}; };
} // namespace framework } // namespace framework
......
...@@ -17,5 +17,16 @@ package paddle; ...@@ -17,5 +17,16 @@ package paddle;
message DataFeedDesc { message DataFeedDesc {
optional string name = 1; optional string name = 1;
optional int32 batch = 2 [default = 32]; optional int32 batch = 2 [default = 32];
optional MultiSlotDesc multi_slot_desc = 3;
} }
message MultiSlotDesc {
repeated Slot slots = 1;
}
message Slot {
required string name = 1;
required string type = 2;
optional bool dense = 3 [default = false];
optional bool use = 4 [default = true];
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册