提交 92a98ca7 编写于 作者: B barrierye

add MultiSlotDataFeed

上级 78c3380b
...@@ -34,221 +34,241 @@ limitations under the License. */ ...@@ -34,221 +34,241 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
DEFINE_bool(is_text_feed, false, "is_text_feed");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::vector<std::string> TextClassDataFeed::s_filelist_;
std::mutex TextClassDataFeed::s_locker_for_pick_file_;
unsigned int TextClassDataFeed::s_current_file_idx_ = 0;
size_t TextClassDataFeed::s_current_finished_file_cnt_ = 0;
unsigned int TextClassDataFeed::s_current_epoch_ = 0;
int TextClassDataFeed::s_current_save_epoch_ = 0;
std::mutex TextClassDataFeed::s_locker_epoch_start_;
std::condition_variable TextClassDataFeed::s_condition_epoch_start_;
bool TextClassDataFeed::s_epoch_start_flag_ = false;
void TextClassDataFeed::Init() { std::vector<std::string> DataFeed::filelist_;
// hard coding for a specific datafeed size_t DataFeed::file_idx_;
feed_vec_.resize(2); std::mutex DataFeed::mutex_for_pick_file_;
// feed_vec_[0].reset(new LoDTensor);
// feed_vec_[1].reset(new LoDTensor); void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
all_slot_ids_ = {0, 1}; if (CheckInit() == false) {return;}
use_slot_ids_ = {0, 1}; for (size_t i = 0; i < use_slots_.size(); ++i) {
use_slot_alias_ = {"words", "label"}; if (name == use_slots_[i]) {
if (use_slots_is_dense_[i]) {
file_content_buffer_host_.reset(new char[200*1024*1024], feed_vec_[i] = MixTensor(var->GetMutable<Tensor>());
[](char *p) {delete[] p;}); } else {
file_content_buffer_ = file_content_buffer_host_.get(); feed_vec_[i] = MixTensor(var->GetMutable<LoDTensor>());
file_content_buffer_ptr_ = file_content_buffer_; }
batch_id_host_.reset(new int[10240*1024],
[](int *p) {delete[] p;}); // max word num in a batch
batch_id_buffer_ = batch_id_host_.get();
label_host_.reset(new int[10240],
[](int *p) {delete[] p;}); // max label in a batch
label_ptr_ = label_host_.get();
field_names_.clear();
}
TextClassDataFeed::TextClassDataFeed() {
Init();
}
// todo: use elegant implemention for this function
bool TextClassDataFeed::ReadBatch() {
paddle::framework::Vector<size_t> offset;
int tlen = 0;
int llen = 0;
int inst_idx = 0;
offset.resize(batch_size_ + 1);
offset[0] = 0;
while (inst_idx < batch_size_) {
int ptr_offset = 0;
if (file_content_buffer_ptr_ - file_content_buffer_ >= file_size_) {
break;
} }
memcpy(reinterpret_cast<char *>(&llen),
file_content_buffer_ptr_ + ptr_offset,
sizeof(int));
ptr_offset += sizeof(int);
memcpy(reinterpret_cast<char *>(batch_id_buffer_ + tlen),
file_content_buffer_ptr_ + ptr_offset,
llen * sizeof(int));
tlen += llen;
offset[inst_idx + 1] = offset[inst_idx] + llen;
ptr_offset += sizeof(int) * llen;
memcpy(reinterpret_cast<char *>(label_ptr_ + inst_idx),
file_content_buffer_ptr_ + ptr_offset,
sizeof(int));
ptr_offset += sizeof(int);
file_content_buffer_ptr_ += ptr_offset;
inst_idx++;
} }
}
if (inst_idx != batch_size_) { bool DataFeed::SetFileList(const std::vector<std::string>& files) {
if (CheckInit() == false) {return false;}
if (files.size() == 0) {
LOG(ERROR) << "error: you have set an empty filelist";
return false; return false;
} }
filelist_.assign(files.begin(), files.end());
file_idx_ = 0;
LoD input_lod{offset}; finish_set_filelist_ = true;
paddle::framework::Vector<size_t> label_offset; return true;
label_offset.resize(batch_size_ + 1); }
for (int i = 0; i <= batch_size_; ++i) {
label_offset[i] = i;
}
LoD label_lod{label_offset}; bool DataFeed::PickOneFile(std::string& filename) {
int64_t* input_ptr = feed_vec_[0]->mutable_data<int64_t>( std::unique_lock<std::mutex> lock(mutex_for_pick_file_);
{static_cast<int64_t>(offset.back()), 1}, if (file_idx_ == filelist_.size()) {
platform::CPUPlace()); return false;
int64_t* label_ptr = feed_vec_[1]->mutable_data<int64_t>({batch_size_, 1},
platform::CPUPlace());
for (unsigned int i = 0; i < offset.back(); ++i) {
input_ptr[i] = static_cast<int64_t>(batch_id_buffer_[i]);
}
for (int i = 0; i < batch_size_; ++i) {
label_ptr[i] = static_cast<int64_t>(label_ptr_[i]);
} }
feed_vec_[0]->set_lod(input_lod); filename = filelist_[file_idx_++];
feed_vec_[1]->set_lod(label_lod);
return true; return true;
} }
TextClassDataFeed::TextClassDataFeed(const TextClassDataFeed& data_feed) { bool DataFeed::CheckInit() {
Init(); if (finish_init_) {return true;}
SetBatchSize(data_feed.batch_size_); LOG(ERROR) << "error: initialization did not succeed";
SetFieldNames(data_feed.field_names_); return false;
} }
void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) { bool DataFeed::CheckSetFileList() {
for (unsigned int i = 0; i < use_slot_alias_.size(); ++i) { if (finish_set_filelist_) {return true;}
if (name == use_slot_alias_[i]) { LOG(ERROR) << "error: set filelist did not succeed";
feed_vec_[i] = feed->GetMutable<LoDTensor>(); return false;
} }
bool DataFeed::CheckStart() {
if (finish_start_) {return true;}
LOG(ERROR) << "error: Datafeed has not started running yet";
return false;
}
template<typename T>
void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
if (!CheckInit()) {return;}
if (queue_size <= 0) {
LOG(ERROR) << "error: illegal queue size: " << queue_size;
return;
} }
queue_size_ = queue_size;
queue_.ReCap(queue_size_);
} }
void TextClassDataFeed::SetFileList(const char* filelist) { template<typename T>
s_filelist_.clear(); bool PrivateQueueDataFeed<T>::Start() {
std::ifstream fin(filelist); if (!(CheckSetFileList())) {return false;}
PADDLE_ENFORCE(fin.good(), read_thread_ = std::thread(&PrivateQueueDataFeed::ReadThread, this);
"Opening file %s fail", read_thread_.detach();
filelist);
finish_start_ = true;
return true;
}
template<typename T>
void PrivateQueueDataFeed<T>::ReadThread(){
std::string filename; std::string filename;
while (fin >> filename) { while (PickOneFile(filename)) {
LOG(ERROR) << "add " << filename.c_str() << " to filelist"; file_.open(filename.c_str()); // is_text_feed
s_filelist_.push_back(filename); if (!file_.is_open()) {
LOG(ERROR) << "error: open file<" << filename << "> fail";
}
T instance;
while (ParseOneInstance(instance)) {
queue_.Send(instance);
}
file_.close();
} }
fin.close(); queue_.Close();
} }
void TextClassDataFeed::SetFieldNames( template<typename T>
const std::vector<std::string>& field_names) { bool PrivateQueueDataFeed<T>::Next(){
field_names_.clear(); if (!CheckStart()) {return false;}
field_names_.insert(field_names_.end(), field_names.begin(), int index = 0;
field_names.end()); T instance;
T ins_vec(use_slots_.size());
while (index < default_batch_size_) {
if (!queue_.Receive(&instance)) {
break;
}
AddInstanceToInsVec(ins_vec, instance, index++);
}
batch_size_ = index;
PutToFeedVec(ins_vec);
return batch_size_ != 0;
} }
bool TextClassDataFeed::SetFile(const char* filename) { void MultiSlotDataFeed::Init(paddle::DataFeedDesc& data_feed_desc) {
// termnum termid termid ... termid label finish_init_ = false;
std::ifstream ifs(filename, std::ios::binary); finish_set_filelist_ = false;
if (ifs.fail()) { finish_start_ = false;
return false; if (!data_feed_desc.has_multi_slot_desc()){
LOG(ERROR) << "error: multi_slot_desc has not been set";
return ;
} }
paddle::MultiSlotDesc multi_slot_desc = data_feed_desc.multi_slot_desc();
ifs.seekg(0, std::ios::end); size_t all_slot_num = multi_slot_desc.slots_size();
int filesize = ifs.tellg(); all_slots_.resize(all_slot_num);
ifs.seekg(0, std::ios::beg); all_slots_type_.resize(all_slot_num);
ifs.read(file_content_buffer_, filesize); use_slots_index_.resize(all_slot_num);
if (filesize < 0 || filesize >= 1024 * 1024 * 1024) { use_slots_.clear();
return false; use_slots_is_dense_.clear();
for (size_t i = 0; i < all_slot_num; ++i) {
auto& slot = multi_slot_desc.slots(i);
all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.use() ? use_slots_.size() : -1;
if (slot.use()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.dense());
}
} }
file_content_buffer_ptr_ = file_content_buffer_; feed_vec_.resize(use_slots_.size());
file_size_ = filesize;
// todo , remove magic number
return true; finish_init_ = true;
} }
void TextClassDataFeed::UpdateEpochNum() { bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>& instance) {
s_current_finished_file_cnt_++; std::string line;
if (getline(file_, line)) {
if (s_current_finished_file_cnt_ >= s_filelist_.size()) { int use_slots_num = use_slots_.size();
s_current_finished_file_cnt_ = 0; instance.resize(use_slots_num);
s_current_epoch_++; //parse line
#if 1 const char* str = line.c_str();
LOG(WARNING) << "UpdateEpochNum: epoch = " << s_current_epoch_; char* endptr = (char*)str;
#endif int pos = 0;
{ for (size_t i = 0; i < use_slots_index_.size(); ++i) {
std::lock_guard<std::mutex> lock(s_locker_epoch_start_); int idx = use_slots_index_[i];
s_epoch_start_flag_ = false; int num = (int)strtol(&str[pos], &endptr, 10);
if (num == 0) {
LOG(ERROR) << "error: the number of ids can not be zero, you need padding it";
exit(-1);
}
if (idx != -1) {
instance[idx].SetType(all_slots_type_[i]);
if (instance[idx].GetType()[0] == 'f') { // float
for (int j = 0; j < num; ++j) {
float feasign = (float)strtof(endptr, &endptr);
instance[idx].AddValue(feasign);
}
} else if (instance[idx].GetType()[0] == 'u'){ // uint64
for (int j = 0; j < num; ++j) {
uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
instance[idx].AddValue(feasign);
}
}
pos = endptr - str;
} else {
for (int j = 0; j <= num; ++j) {
pos = line.find_first_of(' ', pos + 1);
}
}
} }
} else {
return false;
} }
return true;
} }
void TextClassDataFeed::StartOneEpoch() { void MultiSlotDataFeed::AddInstanceToInsVec(std::vector<MultiSlotType>& ins_vec,
std::lock_guard<std::mutex> lock(s_locker_for_pick_file_); std::vector<MultiSlotType>& instance, int index) {
std::random_shuffle(s_filelist_.begin(), s_filelist_.end()); if (index == 0) {
s_current_file_idx_ = 0; for (size_t i = 0; i < instance.size(); ++i) {
LOG(INFO) << "Beginning epoch " << s_current_epoch_; ins_vec[i].SetType(instance[i].GetType());
}
{ }
std::lock_guard<std::mutex> lock(s_locker_epoch_start_); for (size_t i = 0; i < instance.size(); ++i){
s_epoch_start_flag_ = true; ins_vec[i].AddIns(instance[i]);
} }
s_condition_epoch_start_.notify_all();
}
void TextClassDataFeed::WaitNextEpoch() {
std::unique_lock<std::mutex> lock(s_locker_epoch_start_);
s_condition_epoch_start_.wait(lock, []{return s_epoch_start_flag_;});
} }
void MultiSlotDataFeed::PutToFeedVec(std::vector<MultiSlotType>& ins_vec) {
const char* TextClassDataFeed::PickOneFile() { for (size_t i = 0; i < use_slots_.size(); ++i) {
std::string file_to_be_processed; auto& type = ins_vec[i].GetType();
std::lock_guard<std::mutex> lock(s_locker_for_pick_file_); auto& offset = ins_vec[i].GetOffset();
int total_instance = static_cast<int>(offset.back());
// One epoch has run over if (type[0] == 'f') { // float
// Wait for next epoch auto& feasign = ins_vec[i].GetFloatData();
if (s_current_file_idx_ >= s_filelist_.size()) { if (feed_vec_[i].IsDense()) {
LOG(ERROR) << "thread " << thread_id_ int size_in_each_batch = total_instance / batch_size_;
<< ": finish traing for epoch " << s_current_epoch_ + 1; float* tensor_ptr = feed_vec_[i].GetTensor()->
mutable_data<float>({batch_size_, size_in_each_batch}, platform::CPUPlace());
return NULL; memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else {
float* tensor_ptr = feed_vec_[i].GetLoDTensor()->
mutable_data<float>({total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
LoD data_lod{offset};
feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
}
} else if (type[0] == 'u') { // uint64
// no uint64_t type
auto& feasign = ins_vec[i].GetUint64Data();
if (feed_vec_[i].IsDense()) {
int size_in_each_batch = total_instance / batch_size_;
int64_t* tensor_ptr = feed_vec_[i].GetTensor()->
mutable_data<int64_t>({batch_size_, size_in_each_batch}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
} else {
int64_t* tensor_ptr = feed_vec_[i].GetLoDTensor()->
mutable_data<int64_t>({total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(uint64_t));
LoD data_lod{offset};
feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
}
}
} }
file_to_be_processed = s_filelist_[s_current_file_idx_];
s_current_file_idx_++;
return file_to_be_processed.c_str();
} }
} // namespace framework } // namespace framework
......
...@@ -27,136 +27,335 @@ limitations under the License. */ ...@@ -27,136 +27,335 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <condition_variable> // NOLINT #include <condition_variable> // NOLINT
#include <fstream> #include <fstream>
#include <deque>
#include <atomic>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/data_feed.pb.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct Gauc {
int show, click;
uint64_t fea;
std::string lineid;
};
struct Instance { class MixTensor {
std::vector<std::vector<uint64_t>> feed_vec_buffer; public:
std::vector<std::vector<int>> feed_vec_lod; MixTensor(){}
std::vector<float> other_label; MixTensor(LoDTensor* lodtensor) {
std::vector<Gauc> gauc_vec; is_dense_ = false;
lodtensor_ = lodtensor;
}
MixTensor(Tensor* tensor) {
is_dense_ = true;
tensor_ = tensor;
}
bool IsDense() {return is_dense_;}
LoDTensor* GetLoDTensor(){
if (is_dense_) {
LOG(ERROR) << "error: let a dense var return a LoDTensor ptr";
return NULL;
}
return lodtensor_;
}
Tensor* GetTensor(){
if (!is_dense_) {
LOG(ERROR) << "error: let a sparse var return a Tensor ptr";
return NULL;
}
return tensor_;
}
private:
bool is_dense_;
LoDTensor* lodtensor_;
Tensor* tensor_;
}; };
class DataFeed { template<typename T>
class BlockingQueue {
public: public:
DataFeed() : default_batch_size_(1), batch_size_(0), thread_id_(0) {} explicit BlockingQueue(size_t capacity = 32)
virtual ~DataFeed() {} : capacity_(capacity), closed_(false) {
virtual void Init() = 0; size_.store(0);
/*
* This function will be used to check file format.
* Considering that this function may be used alone,
* it does not check anything.
* */
virtual bool CheckFile(const char* filename) = 0;
virtual bool SetFile(const char* filename) = 0;
virtual bool ReadBatch() = 0;
virtual const std::vector<uint16_t>& GetAllSlotIds() {
return all_slot_ids_;
} }
virtual const std::vector<uint16_t>& GetUseSlotIds() { void ReCap(size_t capacity) {
return use_slot_ids_; capacity_ = capacity;
} }
virtual const std::vector<std::string>& GetUseSlotAlias() { bool Send(const T& elem) {
return use_slot_alias_; int c = -1;
} {
std::unique_lock<std::mutex> lock(send_mutex_);
send_cv_.wait(lock, [&] {return size_.load() < capacity_ || closed_;});
if (closed_) {
VLOG(5)
<< "WARNING: Sending an element to a closed reader::BlokcingQueue.";
return false;
}
queue_.push_back(elem);
c = size_.load();
size_.fetch_add(1);
}
if (c + 1 < capacity_) {
send_cv_.notify_one();
}
virtual void AddFeedVar(Variable* var, if (c == 0) {
const std::string& name) = 0; std::unique_lock<std::mutex> lock(receive_mutex_);
virtual void BindScope(Scope* scope) = 0; receive_cv_.notify_one();
virtual void SetBatchSize(int batch) { default_batch_size_ = batch; } }
virtual int GetBatchSize() { return batch_size_; } return true;
virtual void SetBufferSize(int buffer_size) {} }
virtual unsigned int GetCurrentEpoch() = 0;
virtual const char *PickOneFile() = 0;
virtual void UpdateEpochNum() = 0;
virtual void StartOneEpoch() = 0;
virtual void WaitNextEpoch() = 0;
std::vector<LoDTensor*>& GetFeedVec() { bool Receive(T* elem) {
return feed_vec_; int c = -1;
{
std::unique_lock<std::mutex> lock(receive_mutex_);
receive_cv_.wait(lock, [&] {return size_.load() != 0 || closed_;});
if (size_.load() != 0) {
*elem = queue_.front();
queue_.pop_front();
c = size_.load();
size_.fetch_sub(1);
} else {
return false;
}
}
if (c > 1) {
receive_cv_.notify_one();
}
if (c == capacity_) {
std::unique_lock<std::mutex> lock(send_mutex_);
send_cv_.notify_one();
}
return true;
} }
virtual std::vector<LoDTensor*>& GetFeedVec(const Instance& ins) { void Close() {
LOG(ERROR) << "use defalut get_feed_vec"; std::lock_guard<std::mutex> lock1(send_mutex_);
return feed_vec_; std::lock_guard<std::mutex> lock2(receive_mutex_);
closed_ = true;
send_cv_.notify_all();
receive_cv_.notify_all();
} }
int GetThreadId() {return thread_id_;} private:
void SetThreadId(int thread_id) {thread_id_ = thread_id;} size_t capacity_;
std::atomic_size_t size_;
bool closed_;
std::deque<T> queue_;
mutable std::mutex send_mutex_;
mutable std::mutex receive_mutex_;
mutable std::condition_variable send_cv_;
mutable std::condition_variable receive_cv_;
};
class DataFeed {
public:
DataFeed() {}
virtual ~DataFeed() {}
virtual void Init(paddle::DataFeedDesc& data_feed_desc) = 0;
// for some datafeeds may not be able to implement this interface
virtual bool CheckFile(const char* filename) {
LOG(ERROR) << "error: The function CheckFile is not implemented";
return false;
}
virtual bool SetFileList(const std::vector<std::string>& files);
virtual bool Start() = 0;
virtual bool Next() = 0;
virtual void SetBatchSize(int batch) { default_batch_size_ = batch; }
virtual int GetBatchSize() { return batch_size_; }
// for subclass with queue
virtual void SetQueueSize(int queue_size) {
LOG(ERROR) << "error: The function SetQueueSize is not implemented";
}
// for subclass with buffer
virtual void SetBufferSize(int buffer_size) {
LOG(ERROR) << "error: The function SetBufferSize is not implemented";
}
virtual const std::vector<std::string>& GetAllSlots() {return all_slots_;}
virtual const std::vector<std::string>& GetUseSlots() {return use_slots_;}
std::vector<MixTensor>& GetFeedVec() {return feed_vec_;}
virtual void AddFeedVar(Variable* var, const std::string& name);
protected: protected:
std::vector<uint16_t> all_slot_ids_; // Check if it is executed in this order:
std::vector<uint16_t> use_slot_ids_; // Init -> SetFileList/BindingMemory -> Start -> Next
std::vector<std::string> use_slot_alias_; virtual bool CheckInit();
std::vector<LoDTensor*> feed_vec_; virtual bool CheckSetFileList();
virtual bool CheckStart();
virtual bool PickOneFile(std::string& filename);
static std::vector<std::string> filelist_;
static size_t file_idx_;
static std::mutex mutex_for_pick_file_;
std::vector<std::string> use_slots_;
std::vector<bool> use_slots_is_dense_;
std::vector<std::string> all_slots_;
std::vector<std::string> all_slots_type_;
std::vector<int> use_slots_index_; // -1: not used; >=0: the index of use_slots_
std::vector<MixTensor> feed_vec_;
int default_batch_size_; int default_batch_size_;
int batch_size_; int batch_size_;
int thread_id_;
bool finish_init_;
bool finish_set_filelist_;
bool finish_binding_memory_;
bool finish_start_;
}; };
class TextClassDataFeed : public DataFeed { template<typename T>
class PrivateQueueDataFeed : public DataFeed {
public: public:
TextClassDataFeed(); PrivateQueueDataFeed() {}
TextClassDataFeed(const TextClassDataFeed& data_feed); virtual ~PrivateQueueDataFeed() {}
virtual void Init(paddle::DataFeedDesc& data_feed_desc) = 0;
virtual bool Start();
virtual bool Next(); // no buffer
virtual void SetQueueSize(int queue_size);
protected:
virtual void ReadThread();
virtual bool ParseOneInstance(T& instance) = 0;
virtual void AddInstanceToInsVec(T& vec_ins, T& instance, int index) = 0;
virtual void PutToFeedVec(T& ins_vec) = 0;
std::thread read_thread_; // the thread for read files
/* using ifstream one line and one line parse is faster
* than using fread one buffer and one buffer parse.
* for 601M JingPai data:
* ifstream one line and one line parse: 6034 ms
* fread one buffer and one buffer parse: 7097 ms */
std::ifstream file_;
size_t queue_size_;
// The elements in the queue are one piece of data,
// with multiple fields in each piece of data
BlockingQueue<T> queue_;
};
class MultiSlotType {
public:
MultiSlotType() {
float_feasign_.clear();
uint64_feasign_.clear();
offset_.resize(1);
offset_[0] = 0;
}
~MultiSlotType() {}
void SetType(std::string& type) {
if (!CheckType(type)) {return;}
type_ = type;
}
std::vector<size_t>& GetOffset() {
return offset_;
}
void AddValue(float v) {
if (!CheckFloat()) {return;}
float_feasign_.push_back(v);
}
void AddValue(uint64_t v) {
if (!CheckUint64()) {return;}
uint64_feasign_.push_back(v);
}
void AddIns(MultiSlotType& ins) {
if (ins.GetType()[0] == 'f') { //float
if (!CheckFloat()) {return;}
auto& vec = ins.GetFloatData();
offset_.push_back(offset_.back() + vec.size());
float_feasign_.insert(float_feasign_.end(), vec.begin(), vec.end());
} else if (ins.GetType()[0] == 'u') { //uint64
if (!CheckUint64()) {return;}
auto& vec = ins.GetUint64Data();
offset_.push_back(offset_.back() + vec.size());
uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end());
}
}
std::vector<float>& GetFloatData() {
return float_feasign_;
}
std::vector<uint64_t>& GetUint64Data() {
return uint64_feasign_;
}
std::string& GetType() {
return type_;
}
private:
bool CheckType(std::string& type) {
if (type != "uint64" && type != "float") {
// check in here
LOG(ERROR) << "error: here is no this type";
return false;
}
return true;
}
bool CheckFloat() {
if (type_[0] != 'f') { //float
LOG(ERROR) << "error: add " << type_ << " value to float slot";
return false;
}
return true;
}
bool CheckUint64() {
if (type_[0] != 'u') { //uint64
LOG(ERROR) << "error: add " << type_ << " value to uint64 slot";
return false;
}
return true;
}
std::vector<float> float_feasign_;
std::vector<uint64_t> uint64_feasign_;
std::string type_;
std::vector<size_t> offset_;
};
class MultiSlotDataFeed : public PrivateQueueDataFeed<std::vector<MultiSlotType>> {
public:
MultiSlotDataFeed() {}
virtual ~MultiSlotDataFeed() {}
virtual void Init(paddle::DataFeedDesc& data_feed_desc);
//TODO: virtual bool CheckFile();
protected:
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>& vec_ins,
std::vector<MultiSlotType>& instance, int index);
virtual bool ParseOneInstance(std::vector<MultiSlotType>& instance);
virtual void PutToFeedVec(std::vector<MultiSlotType>& ins_vec);
};
//TODO: to be deleted
class TextClassDataFeed : public DataFeed {
public: public:
virtual ~TextClassDataFeed() {} virtual ~TextClassDataFeed() {}
virtual void Init(); virtual void Init(paddle::DataFeedDesc& data_feed_desc) {}
virtual bool ReadBatch(); virtual bool Start() {return false;}; //TODO
virtual void AddFeedVar(Variable* feed, const std::string& name); virtual bool Next() {return false;}; //TODO
virtual bool ReadBatch() {return false;}
virtual void AddFeedVar(Variable* feed, const std::string& name) {}
virtual void BindScope(Scope* scope) {} virtual void BindScope(Scope* scope) {}
virtual bool SetFile(const char* filename); virtual bool SetFile(const char* filename) {return false;}
virtual bool CheckFile(const char* filename) { virtual bool CheckFile(const char* filename) {
// TODO(xxx) // TODO(xxx)
return false; return false;
} }
void SetBatchSize(int batch) {batch_size_ = batch;}
unsigned int GetCurrentEpoch() {return s_current_epoch_;}
void UpdateEpochNum();
void StartOneEpoch();
void WaitNextEpoch();
public:
void SetFieldNames(const std::vector<std::string>& field_names);
public:
static void SetFileList(const char* filelist);
private: void SetBatchSize(int batch) {batch_size_ = batch;}
const char* PickOneFile();
private: private:
int ReadWholeFile(const std::string& filename, char* buffer) {return -1;}
char* file_content_buffer_; char* file_content_buffer_;
char* file_content_buffer_ptr_; char* file_content_buffer_ptr_;
int* batch_id_buffer_; int* batch_id_buffer_;
int* label_ptr_; int* label_ptr_;
int file_size_; int file_size_;
std::vector<std::string> field_names_; std::vector<std::string> names_;
std::shared_ptr<char> file_content_buffer_host_; std::shared_ptr<char> file_content_buffer_host_;
std::shared_ptr<int> batch_id_host_; std::shared_ptr<int> batch_id_host_;
std::shared_ptr<int> label_host_; std::shared_ptr<int> label_host_;
static std::vector<std::string> s_filelist_;
static std::mutex s_locker_for_pick_file_;
static unsigned int s_current_file_idx_;
static size_t s_current_finished_file_cnt_;
static unsigned int s_current_epoch_;
static int s_current_save_epoch_;
static std::mutex s_locker_epoch_start_;
static std::condition_variable s_condition_epoch_start_;
static bool s_epoch_start_flag_;
}; };
} // namespace framework } // namespace framework
......
...@@ -17,5 +17,16 @@ package paddle; ...@@ -17,5 +17,16 @@ package paddle;
message DataFeedDesc { message DataFeedDesc {
optional string name = 1; optional string name = 1;
optional int32 batch = 2 [default = 32]; optional int32 batch = 2 [default = 32];
optional MultiSlotDesc multi_slot_desc = 3;
} }
message MultiSlotDesc {
repeated Slot slots = 1;
}
message Slot {
required string name = 1;
required string type = 2;
optional bool dense = 3 [default = false];
optional bool use = 4 [default = true];
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册