提交 92a98ca7 编写于 作者: B barrierye

add MultiSlotDataFeed

上级 78c3380b
......@@ -34,221 +34,241 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/framework/data_feed.h"
DEFINE_bool(is_text_feed, false, "is_text_feed");
namespace paddle {
namespace framework {
std::vector<std::string> TextClassDataFeed::s_filelist_;
std::mutex TextClassDataFeed::s_locker_for_pick_file_;
unsigned int TextClassDataFeed::s_current_file_idx_ = 0;
size_t TextClassDataFeed::s_current_finished_file_cnt_ = 0;
unsigned int TextClassDataFeed::s_current_epoch_ = 0;
int TextClassDataFeed::s_current_save_epoch_ = 0;
std::mutex TextClassDataFeed::s_locker_epoch_start_;
std::condition_variable TextClassDataFeed::s_condition_epoch_start_;
bool TextClassDataFeed::s_epoch_start_flag_ = false;
void TextClassDataFeed::Init() {
// hard coding for a specific datafeed
feed_vec_.resize(2);
// feed_vec_[0].reset(new LoDTensor);
// feed_vec_[1].reset(new LoDTensor);
all_slot_ids_ = {0, 1};
use_slot_ids_ = {0, 1};
use_slot_alias_ = {"words", "label"};
file_content_buffer_host_.reset(new char[200*1024*1024],
[](char *p) {delete[] p;});
file_content_buffer_ = file_content_buffer_host_.get();
file_content_buffer_ptr_ = file_content_buffer_;
batch_id_host_.reset(new int[10240*1024],
[](int *p) {delete[] p;}); // max word num in a batch
batch_id_buffer_ = batch_id_host_.get();
label_host_.reset(new int[10240],
[](int *p) {delete[] p;}); // max label in a batch
label_ptr_ = label_host_.get();
field_names_.clear();
}
TextClassDataFeed::TextClassDataFeed() {
Init();
}
// todo: use elegant implemention for this function
bool TextClassDataFeed::ReadBatch() {
paddle::framework::Vector<size_t> offset;
int tlen = 0;
int llen = 0;
int inst_idx = 0;
offset.resize(batch_size_ + 1);
offset[0] = 0;
while (inst_idx < batch_size_) {
int ptr_offset = 0;
if (file_content_buffer_ptr_ - file_content_buffer_ >= file_size_) {
break;
std::vector<std::string> DataFeed::filelist_;
size_t DataFeed::file_idx_;
std::mutex DataFeed::mutex_for_pick_file_;
void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
if (CheckInit() == false) {return;}
for (size_t i = 0; i < use_slots_.size(); ++i) {
if (name == use_slots_[i]) {
if (use_slots_is_dense_[i]) {
feed_vec_[i] = MixTensor(var->GetMutable<Tensor>());
} else {
feed_vec_[i] = MixTensor(var->GetMutable<LoDTensor>());
}
}
memcpy(reinterpret_cast<char *>(&llen),
file_content_buffer_ptr_ + ptr_offset,
sizeof(int));
ptr_offset += sizeof(int);
memcpy(reinterpret_cast<char *>(batch_id_buffer_ + tlen),
file_content_buffer_ptr_ + ptr_offset,
llen * sizeof(int));
tlen += llen;
offset[inst_idx + 1] = offset[inst_idx] + llen;
ptr_offset += sizeof(int) * llen;
memcpy(reinterpret_cast<char *>(label_ptr_ + inst_idx),
file_content_buffer_ptr_ + ptr_offset,
sizeof(int));
ptr_offset += sizeof(int);
file_content_buffer_ptr_ += ptr_offset;
inst_idx++;
}
}
if (inst_idx != batch_size_) {
bool DataFeed::SetFileList(const std::vector<std::string>& files) {
if (CheckInit() == false) {return false;}
if (files.size() == 0) {
LOG(ERROR) << "error: you have set an empty filelist";
return false;
}
filelist_.assign(files.begin(), files.end());
file_idx_ = 0;
LoD input_lod{offset};
paddle::framework::Vector<size_t> label_offset;
label_offset.resize(batch_size_ + 1);
for (int i = 0; i <= batch_size_; ++i) {
label_offset[i] = i;
}
finish_set_filelist_ = true;
return true;
}
LoD label_lod{label_offset};
int64_t* input_ptr = feed_vec_[0]->mutable_data<int64_t>(
{static_cast<int64_t>(offset.back()), 1},
platform::CPUPlace());
int64_t* label_ptr = feed_vec_[1]->mutable_data<int64_t>({batch_size_, 1},
platform::CPUPlace());
for (unsigned int i = 0; i < offset.back(); ++i) {
input_ptr[i] = static_cast<int64_t>(batch_id_buffer_[i]);
}
for (int i = 0; i < batch_size_; ++i) {
label_ptr[i] = static_cast<int64_t>(label_ptr_[i]);
bool DataFeed::PickOneFile(std::string& filename) {
std::unique_lock<std::mutex> lock(mutex_for_pick_file_);
if (file_idx_ == filelist_.size()) {
return false;
}
feed_vec_[0]->set_lod(input_lod);
feed_vec_[1]->set_lod(label_lod);
filename = filelist_[file_idx_++];
return true;
}
TextClassDataFeed::TextClassDataFeed(const TextClassDataFeed& data_feed) {
Init();
SetBatchSize(data_feed.batch_size_);
SetFieldNames(data_feed.field_names_);
bool DataFeed::CheckInit() {
if (finish_init_) {return true;}
LOG(ERROR) << "error: initialization did not succeed";
return false;
}
void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
for (unsigned int i = 0; i < use_slot_alias_.size(); ++i) {
if (name == use_slot_alias_[i]) {
feed_vec_[i] = feed->GetMutable<LoDTensor>();
}
bool DataFeed::CheckSetFileList() {
if (finish_set_filelist_) {return true;}
LOG(ERROR) << "error: set filelist did not succeed";
return false;
}
bool DataFeed::CheckStart() {
if (finish_start_) {return true;}
LOG(ERROR) << "error: Datafeed has not started running yet";
return false;
}
template<typename T>
void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
if (!CheckInit()) {return;}
if (queue_size <= 0) {
LOG(ERROR) << "error: illegal queue size: " << queue_size;
return;
}
queue_size_ = queue_size;
queue_.ReCap(queue_size_);
}
void TextClassDataFeed::SetFileList(const char* filelist) {
s_filelist_.clear();
std::ifstream fin(filelist);
PADDLE_ENFORCE(fin.good(),
"Opening file %s fail",
filelist);
template<typename T>
bool PrivateQueueDataFeed<T>::Start() {
if (!(CheckSetFileList())) {return false;}
read_thread_ = std::thread(&PrivateQueueDataFeed::ReadThread, this);
read_thread_.detach();
finish_start_ = true;
return true;
}
template<typename T>
void PrivateQueueDataFeed<T>::ReadThread(){
std::string filename;
while (fin >> filename) {
LOG(ERROR) << "add " << filename.c_str() << " to filelist";
s_filelist_.push_back(filename);
while (PickOneFile(filename)) {
file_.open(filename.c_str()); // is_text_feed
if (!file_.is_open()) {
LOG(ERROR) << "error: open file<" << filename << "> fail";
}
T instance;
while (ParseOneInstance(instance)) {
queue_.Send(instance);
}
file_.close();
}
fin.close();
queue_.Close();
}
void TextClassDataFeed::SetFieldNames(
const std::vector<std::string>& field_names) {
field_names_.clear();
field_names_.insert(field_names_.end(), field_names.begin(),
field_names.end());
template<typename T>
bool PrivateQueueDataFeed<T>::Next(){
if (!CheckStart()) {return false;}
int index = 0;
T instance;
T ins_vec(use_slots_.size());
while (index < default_batch_size_) {
if (!queue_.Receive(&instance)) {
break;
}
AddInstanceToInsVec(ins_vec, instance, index++);
}
batch_size_ = index;
PutToFeedVec(ins_vec);
return batch_size_ != 0;
}
bool TextClassDataFeed::SetFile(const char* filename) {
// termnum termid termid ... termid label
std::ifstream ifs(filename, std::ios::binary);
if (ifs.fail()) {
return false;
void MultiSlotDataFeed::Init(paddle::DataFeedDesc& data_feed_desc) {
finish_init_ = false;
finish_set_filelist_ = false;
finish_start_ = false;
if (!data_feed_desc.has_multi_slot_desc()){
LOG(ERROR) << "error: multi_slot_desc has not been set";
return ;
}
ifs.seekg(0, std::ios::end);
int filesize = ifs.tellg();
ifs.seekg(0, std::ios::beg);
ifs.read(file_content_buffer_, filesize);
if (filesize < 0 || filesize >= 1024 * 1024 * 1024) {
return false;
paddle::MultiSlotDesc multi_slot_desc = data_feed_desc.multi_slot_desc();
size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num);
use_slots_index_.resize(all_slot_num);
use_slots_.clear();
use_slots_is_dense_.clear();
for (size_t i = 0; i < all_slot_num; ++i) {
auto& slot = multi_slot_desc.slots(i);
all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.use() ? use_slots_.size() : -1;
if (slot.use()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.dense());
}
}
file_content_buffer_ptr_ = file_content_buffer_;
file_size_ = filesize;
// todo , remove magic number
feed_vec_.resize(use_slots_.size());
return true;
finish_init_ = true;
}
void TextClassDataFeed::UpdateEpochNum() {
s_current_finished_file_cnt_++;
if (s_current_finished_file_cnt_ >= s_filelist_.size()) {
s_current_finished_file_cnt_ = 0;
s_current_epoch_++;
#if 1
LOG(WARNING) << "UpdateEpochNum: epoch = " << s_current_epoch_;
#endif
{
std::lock_guard<std::mutex> lock(s_locker_epoch_start_);
s_epoch_start_flag_ = false;
bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>& instance) {
std::string line;
if (getline(file_, line)) {
int use_slots_num = use_slots_.size();
instance.resize(use_slots_num);
//parse line
const char* str = line.c_str();
char* endptr = (char*)str;
int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = (int)strtol(&str[pos], &endptr, 10);
if (num == 0) {
LOG(ERROR) << "error: the number of ids can not be zero, you need padding it";
exit(-1);
}
if (idx != -1) {
instance[idx].SetType(all_slots_type_[i]);
if (instance[idx].GetType()[0] == 'f') { // float
for (int j = 0; j < num; ++j) {
float feasign = (float)strtof(endptr, &endptr);
instance[idx].AddValue(feasign);
}
} else if (instance[idx].GetType()[0] == 'u'){ // uint64
for (int j = 0; j < num; ++j) {
uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
instance[idx].AddValue(feasign);
}
}
pos = endptr - str;
} else {
for (int j = 0; j <= num; ++j) {
pos = line.find_first_of(' ', pos + 1);
}
}
}
} else {
return false;
}
return true;
}
void TextClassDataFeed::StartOneEpoch() {
std::lock_guard<std::mutex> lock(s_locker_for_pick_file_);
std::random_shuffle(s_filelist_.begin(), s_filelist_.end());
s_current_file_idx_ = 0;
LOG(INFO) << "Beginning epoch " << s_current_epoch_;
{
std::lock_guard<std::mutex> lock(s_locker_epoch_start_);
s_epoch_start_flag_ = true;
void MultiSlotDataFeed::AddInstanceToInsVec(std::vector<MultiSlotType>& ins_vec,
std::vector<MultiSlotType>& instance, int index) {
if (index == 0) {
for (size_t i = 0; i < instance.size(); ++i) {
ins_vec[i].SetType(instance[i].GetType());
}
}
for (size_t i = 0; i < instance.size(); ++i){
ins_vec[i].AddIns(instance[i]);
}
s_condition_epoch_start_.notify_all();
}
void TextClassDataFeed::WaitNextEpoch() {
std::unique_lock<std::mutex> lock(s_locker_epoch_start_);
s_condition_epoch_start_.wait(lock, []{return s_epoch_start_flag_;});
}
const char* TextClassDataFeed::PickOneFile() {
std::string file_to_be_processed;
std::lock_guard<std::mutex> lock(s_locker_for_pick_file_);
// One epoch has run over
// Wait for next epoch
if (s_current_file_idx_ >= s_filelist_.size()) {
LOG(ERROR) << "thread " << thread_id_
<< ": finish traing for epoch " << s_current_epoch_ + 1;
return NULL;
void MultiSlotDataFeed::PutToFeedVec(std::vector<MultiSlotType>& ins_vec) {
for (size_t i = 0; i < use_slots_.size(); ++i) {
auto& type = ins_vec[i].GetType();
auto& offset = ins_vec[i].GetOffset();
int total_instance = static_cast<int>(offset.back());
if (type[0] == 'f') { // float
auto& feasign = ins_vec[i].GetFloatData();
if (feed_vec_[i].IsDense()) {
int size_in_each_batch = total_instance / batch_size_;
float* tensor_ptr = feed_vec_[i].GetTensor()->
mutable_data<float>({batch_size_, size_in_each_batch}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else {
float* tensor_ptr = feed_vec_[i].GetLoDTensor()->
mutable_data<float>({total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
LoD data_lod{offset};
feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
}
} else if (type[0] == 'u') { // uint64
// no uint64_t type
auto& feasign = ins_vec[i].GetUint64Data();
if (feed_vec_[i].IsDense()) {
int size_in_each_batch = total_instance / batch_size_;
int64_t* tensor_ptr = feed_vec_[i].GetTensor()->
mutable_data<int64_t>({batch_size_, size_in_each_batch}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
} else {
int64_t* tensor_ptr = feed_vec_[i].GetLoDTensor()->
mutable_data<int64_t>({total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(uint64_t));
LoD data_lod{offset};
feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
}
}
}
file_to_be_processed = s_filelist_[s_current_file_idx_];
s_current_file_idx_++;
return file_to_be_processed.c_str();
}
} // namespace framework
......
......@@ -27,136 +27,335 @@ limitations under the License. */
#include <unordered_set>
#include <condition_variable> // NOLINT
#include <fstream>
#include <deque>
#include <atomic>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/data_feed.pb.h"
namespace paddle {
namespace framework {
struct Gauc {
int show, click;
uint64_t fea;
std::string lineid;
};
struct Instance {
std::vector<std::vector<uint64_t>> feed_vec_buffer;
std::vector<std::vector<int>> feed_vec_lod;
std::vector<float> other_label;
std::vector<Gauc> gauc_vec;
class MixTensor {
public:
MixTensor(){}
MixTensor(LoDTensor* lodtensor) {
is_dense_ = false;
lodtensor_ = lodtensor;
}
MixTensor(Tensor* tensor) {
is_dense_ = true;
tensor_ = tensor;
}
bool IsDense() {return is_dense_;}
LoDTensor* GetLoDTensor(){
if (is_dense_) {
LOG(ERROR) << "error: let a dense var return a LoDTensor ptr";
return NULL;
}
return lodtensor_;
}
Tensor* GetTensor(){
if (!is_dense_) {
LOG(ERROR) << "error: let a sparse var return a Tensor ptr";
return NULL;
}
return tensor_;
}
private:
bool is_dense_;
LoDTensor* lodtensor_;
Tensor* tensor_;
};
class DataFeed {
template<typename T>
class BlockingQueue {
public:
DataFeed() : default_batch_size_(1), batch_size_(0), thread_id_(0) {}
virtual ~DataFeed() {}
virtual void Init() = 0;
/*
* This function will be used to check file format.
* Considering that this function may be used alone,
* it does not check anything.
* */
virtual bool CheckFile(const char* filename) = 0;
virtual bool SetFile(const char* filename) = 0;
virtual bool ReadBatch() = 0;
virtual const std::vector<uint16_t>& GetAllSlotIds() {
return all_slot_ids_;
explicit BlockingQueue(size_t capacity = 32)
: capacity_(capacity), closed_(false) {
size_.store(0);
}
virtual const std::vector<uint16_t>& GetUseSlotIds() {
return use_slot_ids_;
void ReCap(size_t capacity) {
capacity_ = capacity;
}
virtual const std::vector<std::string>& GetUseSlotAlias() {
return use_slot_alias_;
}
bool Send(const T& elem) {
int c = -1;
{
std::unique_lock<std::mutex> lock(send_mutex_);
send_cv_.wait(lock, [&] {return size_.load() < capacity_ || closed_;});
if (closed_) {
VLOG(5)
<< "WARNING: Sending an element to a closed reader::BlokcingQueue.";
return false;
}
queue_.push_back(elem);
c = size_.load();
size_.fetch_add(1);
}
if (c + 1 < capacity_) {
send_cv_.notify_one();
}
virtual void AddFeedVar(Variable* var,
const std::string& name) = 0;
virtual void BindScope(Scope* scope) = 0;
virtual void SetBatchSize(int batch) { default_batch_size_ = batch; }
virtual int GetBatchSize() { return batch_size_; }
virtual void SetBufferSize(int buffer_size) {}
virtual unsigned int GetCurrentEpoch() = 0;
virtual const char *PickOneFile() = 0;
virtual void UpdateEpochNum() = 0;
virtual void StartOneEpoch() = 0;
virtual void WaitNextEpoch() = 0;
if (c == 0) {
std::unique_lock<std::mutex> lock(receive_mutex_);
receive_cv_.notify_one();
}
return true;
}
std::vector<LoDTensor*>& GetFeedVec() {
return feed_vec_;
bool Receive(T* elem) {
int c = -1;
{
std::unique_lock<std::mutex> lock(receive_mutex_);
receive_cv_.wait(lock, [&] {return size_.load() != 0 || closed_;});
if (size_.load() != 0) {
*elem = queue_.front();
queue_.pop_front();
c = size_.load();
size_.fetch_sub(1);
} else {
return false;
}
}
if (c > 1) {
receive_cv_.notify_one();
}
if (c == capacity_) {
std::unique_lock<std::mutex> lock(send_mutex_);
send_cv_.notify_one();
}
return true;
}
virtual std::vector<LoDTensor*>& GetFeedVec(const Instance& ins) {
LOG(ERROR) << "use defalut get_feed_vec";
return feed_vec_;
void Close() {
std::lock_guard<std::mutex> lock1(send_mutex_);
std::lock_guard<std::mutex> lock2(receive_mutex_);
closed_ = true;
send_cv_.notify_all();
receive_cv_.notify_all();
}
int GetThreadId() {return thread_id_;}
void SetThreadId(int thread_id) {thread_id_ = thread_id;}
private:
size_t capacity_;
std::atomic_size_t size_;
bool closed_;
std::deque<T> queue_;
mutable std::mutex send_mutex_;
mutable std::mutex receive_mutex_;
mutable std::condition_variable send_cv_;
mutable std::condition_variable receive_cv_;
};
class DataFeed {
public:
DataFeed() {}
virtual ~DataFeed() {}
virtual void Init(paddle::DataFeedDesc& data_feed_desc) = 0;
// for some datafeeds may not be able to implement this interface
virtual bool CheckFile(const char* filename) {
LOG(ERROR) << "error: The function CheckFile is not implemented";
return false;
}
virtual bool SetFileList(const std::vector<std::string>& files);
virtual bool Start() = 0;
virtual bool Next() = 0;
virtual void SetBatchSize(int batch) { default_batch_size_ = batch; }
virtual int GetBatchSize() { return batch_size_; }
// for subclass with queue
virtual void SetQueueSize(int queue_size) {
LOG(ERROR) << "error: The function SetQueueSize is not implemented";
}
// for subclass with buffer
virtual void SetBufferSize(int buffer_size) {
LOG(ERROR) << "error: The function SetBufferSize is not implemented";
}
virtual const std::vector<std::string>& GetAllSlots() {return all_slots_;}
virtual const std::vector<std::string>& GetUseSlots() {return use_slots_;}
std::vector<MixTensor>& GetFeedVec() {return feed_vec_;}
virtual void AddFeedVar(Variable* var, const std::string& name);
protected:
std::vector<uint16_t> all_slot_ids_;
std::vector<uint16_t> use_slot_ids_;
std::vector<std::string> use_slot_alias_;
std::vector<LoDTensor*> feed_vec_;
// Check if it is executed in this order:
// Init -> SetFileList/BindingMemory -> Start -> Next
virtual bool CheckInit();
virtual bool CheckSetFileList();
virtual bool CheckStart();
virtual bool PickOneFile(std::string& filename);
static std::vector<std::string> filelist_;
static size_t file_idx_;
static std::mutex mutex_for_pick_file_;
std::vector<std::string> use_slots_;
std::vector<bool> use_slots_is_dense_;
std::vector<std::string> all_slots_;
std::vector<std::string> all_slots_type_;
std::vector<int> use_slots_index_; // -1: not used; >=0: the index of use_slots_
std::vector<MixTensor> feed_vec_;
int default_batch_size_;
int batch_size_;
int thread_id_;
bool finish_init_;
bool finish_set_filelist_;
bool finish_binding_memory_;
bool finish_start_;
};
class TextClassDataFeed : public DataFeed {
template<typename T>
class PrivateQueueDataFeed : public DataFeed {
public:
TextClassDataFeed();
TextClassDataFeed(const TextClassDataFeed& data_feed);
PrivateQueueDataFeed() {}
virtual ~PrivateQueueDataFeed() {}
virtual void Init(paddle::DataFeedDesc& data_feed_desc) = 0;
virtual bool Start();
virtual bool Next(); // no buffer
virtual void SetQueueSize(int queue_size);
protected:
virtual void ReadThread();
virtual bool ParseOneInstance(T& instance) = 0;
virtual void AddInstanceToInsVec(T& vec_ins, T& instance, int index) = 0;
virtual void PutToFeedVec(T& ins_vec) = 0;
std::thread read_thread_; // the thread for read files
/* using ifstream one line and one line parse is faster
* than using fread one buffer and one buffer parse.
* for 601M JingPai data:
* ifstream one line and one line parse: 6034 ms
* fread one buffer and one buffer parse: 7097 ms */
std::ifstream file_;
size_t queue_size_;
// The elements in the queue are one piece of data,
// with multiple fields in each piece of data
BlockingQueue<T> queue_;
};
class MultiSlotType {
public:
MultiSlotType() {
float_feasign_.clear();
uint64_feasign_.clear();
offset_.resize(1);
offset_[0] = 0;
}
~MultiSlotType() {}
void SetType(std::string& type) {
if (!CheckType(type)) {return;}
type_ = type;
}
std::vector<size_t>& GetOffset() {
return offset_;
}
void AddValue(float v) {
if (!CheckFloat()) {return;}
float_feasign_.push_back(v);
}
void AddValue(uint64_t v) {
if (!CheckUint64()) {return;}
uint64_feasign_.push_back(v);
}
void AddIns(MultiSlotType& ins) {
if (ins.GetType()[0] == 'f') { //float
if (!CheckFloat()) {return;}
auto& vec = ins.GetFloatData();
offset_.push_back(offset_.back() + vec.size());
float_feasign_.insert(float_feasign_.end(), vec.begin(), vec.end());
} else if (ins.GetType()[0] == 'u') { //uint64
if (!CheckUint64()) {return;}
auto& vec = ins.GetUint64Data();
offset_.push_back(offset_.back() + vec.size());
uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end());
}
}
std::vector<float>& GetFloatData() {
return float_feasign_;
}
std::vector<uint64_t>& GetUint64Data() {
return uint64_feasign_;
}
std::string& GetType() {
return type_;
}
private:
bool CheckType(std::string& type) {
if (type != "uint64" && type != "float") {
// check in here
LOG(ERROR) << "error: here is no this type";
return false;
}
return true;
}
bool CheckFloat() {
if (type_[0] != 'f') { //float
LOG(ERROR) << "error: add " << type_ << " value to float slot";
return false;
}
return true;
}
bool CheckUint64() {
if (type_[0] != 'u') { //uint64
LOG(ERROR) << "error: add " << type_ << " value to uint64 slot";
return false;
}
return true;
}
std::vector<float> float_feasign_;
std::vector<uint64_t> uint64_feasign_;
std::string type_;
std::vector<size_t> offset_;
};
class MultiSlotDataFeed : public PrivateQueueDataFeed<std::vector<MultiSlotType>> {
public:
MultiSlotDataFeed() {}
virtual ~MultiSlotDataFeed() {}
virtual void Init(paddle::DataFeedDesc& data_feed_desc);
//TODO: virtual bool CheckFile();
protected:
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>& vec_ins,
std::vector<MultiSlotType>& instance, int index);
virtual bool ParseOneInstance(std::vector<MultiSlotType>& instance);
virtual void PutToFeedVec(std::vector<MultiSlotType>& ins_vec);
};
//TODO: to be deleted
class TextClassDataFeed : public DataFeed {
public:
virtual ~TextClassDataFeed() {}
virtual void Init();
virtual bool ReadBatch();
virtual void AddFeedVar(Variable* feed, const std::string& name);
virtual void Init(paddle::DataFeedDesc& data_feed_desc) {}
virtual bool Start() {return false;}; //TODO
virtual bool Next() {return false;}; //TODO
virtual bool ReadBatch() {return false;}
virtual void AddFeedVar(Variable* feed, const std::string& name) {}
virtual void BindScope(Scope* scope) {}
virtual bool SetFile(const char* filename);
virtual bool SetFile(const char* filename) {return false;}
virtual bool CheckFile(const char* filename) {
// TODO(xxx)
return false;
}
void SetBatchSize(int batch) {batch_size_ = batch;}
unsigned int GetCurrentEpoch() {return s_current_epoch_;}
void UpdateEpochNum();
void StartOneEpoch();
void WaitNextEpoch();
public:
void SetFieldNames(const std::vector<std::string>& field_names);
public:
static void SetFileList(const char* filelist);
private:
const char* PickOneFile();
void SetBatchSize(int batch) {batch_size_ = batch;}
private:
int ReadWholeFile(const std::string& filename, char* buffer) {return -1;}
char* file_content_buffer_;
char* file_content_buffer_ptr_;
int* batch_id_buffer_;
int* label_ptr_;
int file_size_;
std::vector<std::string> field_names_;
std::vector<std::string> names_;
std::shared_ptr<char> file_content_buffer_host_;
std::shared_ptr<int> batch_id_host_;
std::shared_ptr<int> label_host_;
static std::vector<std::string> s_filelist_;
static std::mutex s_locker_for_pick_file_;
static unsigned int s_current_file_idx_;
static size_t s_current_finished_file_cnt_;
static unsigned int s_current_epoch_;
static int s_current_save_epoch_;
static std::mutex s_locker_epoch_start_;
static std::condition_variable s_condition_epoch_start_;
static bool s_epoch_start_flag_;
};
} // namespace framework
......
......@@ -17,5 +17,16 @@ package paddle;
message DataFeedDesc {
optional string name = 1;
optional int32 batch = 2 [default = 32];
optional MultiSlotDesc multi_slot_desc = 3;
}
message MultiSlotDesc {
repeated Slot slots = 1;
}
message Slot {
required string name = 1;
required string type = 2;
optional bool dense = 3 [default = false];
optional bool use = 4 [default = true];
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册