未验证 提交 0a3dbe8a 编写于 作者: Y yaoxuefeng 提交者: GitHub

add slotrecord datafeed (#36099)

上级 c12176e8
......@@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/fluid/platform/timer.h"
USE_INT_STAT(STAT_total_feasign_num_in_mem);
DECLARE_bool(enable_ins_parser_file);
namespace paddle {
namespace framework {
......@@ -1929,5 +1930,646 @@ void PaddleBoxDataFeed::PutToFeedVec(const std::vector<Record*>& ins_vec) {
#endif
}
template class InMemoryDataFeed<SlotRecord>;
void SlotRecordInMemoryDataFeed::Init(const DataFeedDesc& data_feed_desc) {
finish_init_ = false;
finish_set_filelist_ = false;
finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(),
platform::errors::PreconditionNotMet(
"Multi_slot_desc has not been set in data_feed_desc"));
paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size());
size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num);
all_slots_info_.resize(all_slot_num);
used_slots_info_.resize(all_slot_num);
use_slot_size_ = 0;
use_slots_.clear();
float_total_dims_size_ = 0;
float_total_dims_without_inductives_.clear();
for (size_t i = 0; i < all_slot_num; ++i) {
const auto& slot = multi_slot_desc.slots(i);
all_slots_[i] = slot.name();
AllSlotInfo& all_slot = all_slots_info_[i];
all_slot.slot = slot.name();
all_slot.type = slot.type();
all_slot.used_idx = slot.is_used() ? use_slot_size_ : -1;
all_slot.slot_value_idx = -1;
if (slot.is_used()) {
UsedSlotInfo& info = used_slots_info_[use_slot_size_];
info.idx = i;
info.slot = slot.name();
info.type = slot.type();
info.dense = slot.is_dense();
info.total_dims_without_inductive = 1;
info.inductive_shape_index = -1;
// record float value and uint64_t value pos
if (info.type[0] == 'u') {
info.slot_value_idx = uint64_use_slot_size_;
all_slot.slot_value_idx = uint64_use_slot_size_;
++uint64_use_slot_size_;
} else if (info.type[0] == 'f') {
info.slot_value_idx = float_use_slot_size_;
all_slot.slot_value_idx = float_use_slot_size_;
++float_use_slot_size_;
}
use_slots_.push_back(slot.name());
if (slot.is_dense()) {
for (int j = 0; j < slot.shape_size(); ++j) {
if (slot.shape(j) > 0) {
info.total_dims_without_inductive *= slot.shape(j);
}
if (slot.shape(j) == -1) {
info.inductive_shape_index = j;
}
}
}
if (info.type[0] == 'f') {
float_total_dims_without_inductives_.push_back(
info.total_dims_without_inductive);
float_total_dims_size_ += info.total_dims_without_inductive;
}
info.local_shape.clear();
for (int j = 0; j < slot.shape_size(); ++j) {
info.local_shape.push_back(slot.shape(j));
}
++use_slot_size_;
}
}
used_slots_info_.resize(use_slot_size_);
feed_vec_.resize(used_slots_info_.size());
const int kEstimatedFeasignNumPerSlot = 5; // Magic Number
for (size_t i = 0; i < all_slot_num; i++) {
batch_float_feasigns_.push_back(std::vector<float>());
batch_uint64_feasigns_.push_back(std::vector<uint64_t>());
batch_float_feasigns_[i].reserve(default_batch_size_ *
kEstimatedFeasignNumPerSlot);
batch_uint64_feasigns_[i].reserve(default_batch_size_ *
kEstimatedFeasignNumPerSlot);
offset_.push_back(std::vector<size_t>());
offset_[i].reserve(default_batch_size_ +
1); // Each lod info will prepend a zero
}
visit_.resize(all_slot_num, false);
pipe_command_ = data_feed_desc.pipe_command();
finish_init_ = true;
input_type_ = data_feed_desc.input_type();
size_t pos = pipe_command_.find(".so");
if (pos != std::string::npos) {
pos = pipe_command_.rfind('|');
if (pos == std::string::npos) {
so_parser_name_ = pipe_command_;
pipe_command_.clear();
} else {
so_parser_name_ = pipe_command_.substr(pos + 1);
pipe_command_ = pipe_command_.substr(0, pos);
}
so_parser_name_ = paddle::string::erase_spaces(so_parser_name_);
} else {
so_parser_name_.clear();
}
}
void SlotRecordInMemoryDataFeed::LoadIntoMemory() {
VLOG(3) << "SlotRecord LoadIntoMemory() begin, thread_id=" << thread_id_;
if (!so_parser_name_.empty()) {
LoadIntoMemoryByLib();
} else {
LoadIntoMemoryByCommand();
}
}
void SlotRecordInMemoryDataFeed::LoadIntoMemoryByLib(void) {
if (true) {
// user defined file format analysis
LoadIntoMemoryByFile();
} else {
LoadIntoMemoryByLine();
}
}
void SlotRecordInMemoryDataFeed::LoadIntoMemoryByFile(void) {
#ifdef _LINUX
paddle::framework::CustomParser* parser =
global_dlmanager_pool().Load(so_parser_name_, all_slots_info_);
CHECK(parser != nullptr);
// get slotrecord object
auto pull_record_func = [this](std::vector<SlotRecord>& record_vec,
int max_fetch_num, int offset) {
if (offset > 0) {
input_channel_->WriteMove(offset, &record_vec[0]);
if (max_fetch_num > 0) {
SlotRecordPool().get(&record_vec[0], offset);
} else { // free all
max_fetch_num = static_cast<int>(record_vec.size());
if (max_fetch_num > offset) {
SlotRecordPool().put(&record_vec[offset], (max_fetch_num - offset));
}
}
} else if (max_fetch_num > 0) {
SlotRecordPool().get(&record_vec, max_fetch_num);
} else {
SlotRecordPool().put(&record_vec);
}
};
std::string filename;
while (this->PickOneFile(&filename)) {
VLOG(3) << "PickOneFile, filename=" << filename
<< ", thread_id=" << thread_id_;
platform::Timer timeline;
timeline.Start();
int lines = 0;
bool is_ok = true;
do {
int err_no = 0;
this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_);
CHECK(this->fp_ != nullptr);
__fsetlocking(&*(this->fp_), FSETLOCKING_BYCALLER);
is_ok = parser->ParseFileInstance(
[this](char* buf, int len) {
return fread(buf, sizeof(char), len, this->fp_.get());
},
pull_record_func, lines);
if (!is_ok) {
LOG(WARNING) << "parser error, filename=" << filename
<< ", lines=" << lines;
}
} while (!is_ok);
timeline.Pause();
VLOG(3) << "LoadIntoMemoryByLib() read all file, file=" << filename
<< ", cost time=" << timeline.ElapsedSec()
<< " seconds, thread_id=" << thread_id_ << ", lines=" << lines;
}
#endif
}
void SlotRecordInMemoryDataFeed::LoadIntoMemoryByLine(void) {
#ifdef _LINUX
paddle::framework::CustomParser* parser =
global_dlmanager_pool().Load(so_parser_name_, all_slots_info_);
std::string filename;
BufferedLineFileReader line_reader;
line_reader.set_sample_rate(sample_rate_);
BufferedLineFileReader::LineFunc line_func = nullptr;
while (this->PickOneFile(&filename)) {
VLOG(3) << "PickOneFile, filename=" << filename
<< ", thread_id=" << thread_id_;
std::vector<SlotRecord> record_vec;
platform::Timer timeline;
timeline.Start();
int offset = 0;
int old_offset = 0;
SlotRecordPool().get(&record_vec, OBJPOOL_BLOCK_SIZE);
// get slotrecord object function
auto record_func = [this, &offset, &record_vec, &old_offset](
std::vector<SlotRecord>& vec, int num) {
vec.resize(num);
if (offset + num > OBJPOOL_BLOCK_SIZE) {
input_channel_->WriteMove(offset, &record_vec[0]);
SlotRecordPool().get(&record_vec[0], offset);
record_vec.resize(OBJPOOL_BLOCK_SIZE);
offset = 0;
old_offset = 0;
}
for (int i = 0; i < num; ++i) {
auto& ins = record_vec[offset + i];
ins->reset();
vec[i] = ins;
}
offset = offset + num;
};
line_func = [this, &parser, &record_vec, &offset, &filename, &record_func,
&old_offset](const std::string& line) {
old_offset = offset;
if (!parser->ParseOneInstance(line, record_func)) {
offset = old_offset;
LOG(WARNING) << "read file:[" << filename << "] item error, line:["
<< line << "]";
return false;
}
if (offset >= OBJPOOL_BLOCK_SIZE) {
input_channel_->Write(std::move(record_vec));
record_vec.clear();
SlotRecordPool().get(&record_vec, OBJPOOL_BLOCK_SIZE);
offset = 0;
}
return true;
};
int lines = 0;
do {
int err_no = 0;
this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_);
CHECK(this->fp_ != nullptr);
__fsetlocking(&*(this->fp_), FSETLOCKING_BYCALLER);
lines = line_reader.read_file(this->fp_.get(), line_func, lines);
} while (line_reader.is_error());
if (offset > 0) {
input_channel_->WriteMove(offset, &record_vec[0]);
if (offset < OBJPOOL_BLOCK_SIZE) {
SlotRecordPool().put(&record_vec[offset],
(OBJPOOL_BLOCK_SIZE - offset));
}
} else {
SlotRecordPool().put(&record_vec);
}
record_vec.clear();
record_vec.shrink_to_fit();
timeline.Pause();
VLOG(3) << "LoadIntoMemoryByLib() read all lines, file=" << filename
<< ", cost time=" << timeline.ElapsedSec()
<< " seconds, thread_id=" << thread_id_ << ", lines=" << lines
<< ", sample lines=" << line_reader.get_sample_line()
<< ", filesize=" << line_reader.file_size() / 1024.0 / 1024.0
<< "MB";
}
VLOG(3) << "LoadIntoMemoryByLib() end, thread_id=" << thread_id_
<< ", total size: " << line_reader.file_size();
#endif
}
void SlotRecordInMemoryDataFeed::LoadIntoMemoryByCommand(void) {
#ifdef _LINUX
std::string filename;
BufferedLineFileReader line_reader;
line_reader.set_sample_rate(sample_rate_);
while (this->PickOneFile(&filename)) {
VLOG(3) << "PickOneFile, filename=" << filename
<< ", thread_id=" << thread_id_;
int lines = 0;
std::vector<SlotRecord> record_vec;
platform::Timer timeline;
timeline.Start();
SlotRecordPool().get(&record_vec, OBJPOOL_BLOCK_SIZE);
int offset = 0;
do {
int err_no = 0;
this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_);
CHECK(this->fp_ != nullptr);
__fsetlocking(&*(this->fp_), FSETLOCKING_BYCALLER);
lines = line_reader.read_file(
this->fp_.get(),
[this, &record_vec, &offset, &filename](const std::string& line) {
if (ParseOneInstance(line, &record_vec[offset])) {
++offset;
} else {
LOG(WARNING) << "read file:[" << filename
<< "] item error, line:[" << line << "]";
return false;
}
if (offset >= OBJPOOL_BLOCK_SIZE) {
input_channel_->Write(std::move(record_vec));
record_vec.clear();
SlotRecordPool().get(&record_vec, OBJPOOL_BLOCK_SIZE);
offset = 0;
}
return true;
},
lines);
} while (line_reader.is_error());
if (offset > 0) {
input_channel_->WriteMove(offset, &record_vec[0]);
if (offset < OBJPOOL_BLOCK_SIZE) {
SlotRecordPool().put(&record_vec[offset],
(OBJPOOL_BLOCK_SIZE - offset));
}
} else {
SlotRecordPool().put(&record_vec);
}
record_vec.clear();
record_vec.shrink_to_fit();
timeline.Pause();
VLOG(3) << "LoadIntoMemory() read all lines, file=" << filename
<< ", lines=" << lines
<< ", sample lines=" << line_reader.get_sample_line()
<< ", cost time=" << timeline.ElapsedSec()
<< " seconds, thread_id=" << thread_id_;
}
VLOG(3) << "LoadIntoMemory() end, thread_id=" << thread_id_
<< ", total size: " << line_reader.file_size();
#endif
}
static void parser_log_key(const std::string& log_key, uint64_t* search_id,
uint32_t* cmatch, uint32_t* rank) {
std::string searchid_str = log_key.substr(16, 16);
*search_id = static_cast<uint64_t>(strtoull(searchid_str.c_str(), NULL, 16));
std::string cmatch_str = log_key.substr(11, 3);
*cmatch = static_cast<uint32_t>(strtoul(cmatch_str.c_str(), NULL, 16));
std::string rank_str = log_key.substr(14, 2);
*rank = static_cast<uint32_t>(strtoul(rank_str.c_str(), NULL, 16));
}
bool SlotRecordInMemoryDataFeed::ParseOneInstance(const std::string& line,
SlotRecord* ins) {
SlotRecord& rec = (*ins);
// parse line
const char* str = line.c_str();
char* endptr = const_cast<char*>(str);
int pos = 0;
thread_local std::vector<std::vector<float>> slot_float_feasigns;
thread_local std::vector<std::vector<uint64_t>> slot_uint64_feasigns;
slot_float_feasigns.resize(float_use_slot_size_);
slot_uint64_feasigns.resize(uint64_use_slot_size_);
if (parse_ins_id_) {
int num = strtol(&str[pos], &endptr, 10);
CHECK(num == 1); // NOLINT
pos = endptr - str + 1;
size_t len = 0;
while (str[pos + len] != ' ') {
++len;
}
rec->ins_id_ = std::string(str + pos, len);
pos += len + 1;
}
if (parse_logkey_) {
int num = strtol(&str[pos], &endptr, 10);
CHECK(num == 1); // NOLINT
pos = endptr - str + 1;
size_t len = 0;
while (str[pos + len] != ' ') {
++len;
}
// parse_logkey
std::string log_key = std::string(str + pos, len);
uint64_t search_id;
uint32_t cmatch;
uint32_t rank;
parser_log_key(log_key, &search_id, &cmatch, &rank);
rec->ins_id_ = log_key;
rec->search_id = search_id;
rec->cmatch = cmatch;
rec->rank = rank;
pos += len + 1;
}
int float_total_slot_num = 0;
int uint64_total_slot_num = 0;
for (size_t i = 0; i < all_slots_info_.size(); ++i) {
auto& info = all_slots_info_[i];
int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE(num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s",
str);
if (info.used_idx != -1) {
if (info.type[0] == 'f') { // float
auto& slot_fea = slot_float_feasigns[info.slot_value_idx];
slot_fea.clear();
for (int j = 0; j < num; ++j) {
float feasign = strtof(endptr, &endptr);
if (fabs(feasign) < 1e-6 && !used_slots_info_[info.used_idx].dense) {
continue;
}
slot_fea.push_back(feasign);
++float_total_slot_num;
}
} else if (info.type[0] == 'u') { // uint64
auto& slot_fea = slot_uint64_feasigns[info.slot_value_idx];
slot_fea.clear();
for (int j = 0; j < num; ++j) {
uint64_t feasign =
static_cast<uint64_t>(strtoull(endptr, &endptr, 10));
if (feasign == 0 && !used_slots_info_[info.used_idx].dense) {
continue;
}
slot_fea.push_back(feasign);
++uint64_total_slot_num;
}
}
pos = endptr - str;
} else {
for (int j = 0; j <= num; ++j) {
// pos = line.find_first_of(' ', pos + 1);
while (line[pos + 1] != ' ') {
pos++;
}
}
}
}
rec->slot_float_feasigns_.add_slot_feasigns(slot_float_feasigns,
float_total_slot_num);
rec->slot_uint64_feasigns_.add_slot_feasigns(slot_uint64_feasigns,
uint64_total_slot_num);
return (uint64_total_slot_num > 0);
}
void SlotRecordInMemoryDataFeed::PutToFeedVec(const SlotRecord* ins_vec,
int num) {
for (int j = 0; j < use_slot_size_; ++j) {
auto& feed = feed_vec_[j];
if (feed == nullptr) {
continue;
}
auto& slot_offset = offset_[j];
slot_offset.clear();
slot_offset.reserve(num + 1);
slot_offset.push_back(0);
int total_instance = 0;
auto& info = used_slots_info_[j];
// fill slot value with default value 0
if (info.type[0] == 'f') { // float
auto& batch_fea = batch_float_feasigns_[j];
batch_fea.clear();
for (int i = 0; i < num; ++i) {
auto r = ins_vec[i];
size_t fea_num = 0;
float* slot_values =
r->slot_float_feasigns_.get_values(info.slot_value_idx, &fea_num);
batch_fea.resize(total_instance + fea_num);
memcpy(&batch_fea[total_instance], slot_values,
sizeof(float) * fea_num);
total_instance += fea_num;
slot_offset.push_back(total_instance);
}
float* feasign = batch_fea.data();
float* tensor_ptr =
feed->mutable_data<float>({total_instance, 1}, this->place_);
CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(float));
} else if (info.type[0] == 'u') { // uint64
auto& batch_fea = batch_uint64_feasigns_[j];
batch_fea.clear();
for (int i = 0; i < num; ++i) {
auto r = ins_vec[i];
size_t fea_num = 0;
uint64_t* slot_values =
r->slot_uint64_feasigns_.get_values(info.slot_value_idx, &fea_num);
if (fea_num > 0) {
batch_fea.resize(total_instance + fea_num);
memcpy(&batch_fea[total_instance], slot_values,
sizeof(uint64_t) * fea_num);
total_instance += fea_num;
}
if (fea_num == 0) {
batch_fea.resize(total_instance + fea_num);
batch_fea[total_instance] = 0;
total_instance += 1;
}
slot_offset.push_back(total_instance);
}
// no uint64_t type in paddlepaddle
uint64_t* feasign = batch_fea.data();
int64_t* tensor_ptr =
feed->mutable_data<int64_t>({total_instance, 1}, this->place_);
CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(int64_t));
}
if (info.dense) {
if (info.inductive_shape_index != -1) {
info.local_shape[info.inductive_shape_index] =
total_instance / info.total_dims_without_inductive;
}
feed->Resize(framework::make_ddim(info.local_shape));
} else {
LoD data_lod{slot_offset};
feed_vec_[j]->set_lod(data_lod);
}
}
}
void SlotRecordInMemoryDataFeed::ExpandSlotRecord(SlotRecord* rec) {
SlotRecord& ins = (*rec);
if (ins->slot_float_feasigns_.slot_offsets.empty()) {
return;
}
size_t total_value_size = ins->slot_float_feasigns_.slot_values.size();
if (float_total_dims_size_ == total_value_size) {
return;
}
int float_slot_num =
static_cast<int>(float_total_dims_without_inductives_.size());
CHECK(float_slot_num == float_use_slot_size_);
std::vector<float> old_values;
std::vector<uint32_t> old_offsets;
old_values.swap(ins->slot_float_feasigns_.slot_values);
old_offsets.swap(ins->slot_float_feasigns_.slot_offsets);
ins->slot_float_feasigns_.slot_values.resize(float_total_dims_size_);
ins->slot_float_feasigns_.slot_offsets.assign(float_slot_num + 1, 0);
auto& slot_offsets = ins->slot_float_feasigns_.slot_offsets;
auto& slot_values = ins->slot_float_feasigns_.slot_values;
uint32_t offset = 0;
int num = 0;
uint32_t old_off = 0;
int dim = 0;
for (int i = 0; i < float_slot_num; ++i) {
dim = float_total_dims_without_inductives_[i];
old_off = old_offsets[i];
num = static_cast<int>(old_offsets[i + 1] - old_off);
if (num == 0) {
// fill slot value with default value 0
for (int k = 0; k < dim; ++k) {
slot_values[k + offset] = 0.0;
}
} else {
if (num == dim) {
memcpy(&slot_values[offset], &old_values[old_off], dim * sizeof(float));
} else {
// position fea
// record position index need fix values
int pos_idx = static_cast<int>(old_values[old_off]);
for (int k = 0; k < dim; ++k) {
if (k == pos_idx) {
slot_values[k + offset] = 1.0;
} else {
slot_values[k + offset] = 0.0;
}
}
}
}
slot_offsets[i] = offset;
offset += dim;
}
slot_offsets[float_slot_num] = offset;
CHECK(float_total_dims_size_ == static_cast<size_t>(offset));
}
bool SlotRecordInMemoryDataFeed::Start() {
#ifdef _LINUX
this->CheckSetFileList();
if (input_channel_->Size() != 0) {
std::vector<SlotRecord> data;
input_channel_->Read(data);
}
#endif
if (batch_offsets_.size() > 0) {
VLOG(3) << "batch_size offsets: " << batch_offsets_.size();
enable_heterps_ = true;
this->offset_index_ = 0;
}
this->finish_start_ = true;
return true;
}
int SlotRecordInMemoryDataFeed::Next() {
#ifdef _LINUX
this->CheckStart();
VLOG(3) << "enable heter next: " << offset_index_
<< " batch_offsets: " << batch_offsets_.size();
if (offset_index_ >= batch_offsets_.size()) {
VLOG(3) << "offset_index: " << offset_index_
<< " batch_offsets: " << batch_offsets_.size();
return 0;
}
auto& batch = batch_offsets_[offset_index_++];
this->batch_size_ = batch.second;
VLOG(3) << "batch_size_=" << this->batch_size_
<< ", thread_id=" << thread_id_;
if (this->batch_size_ != 0) {
PutToFeedVec(&records_[batch.first], this->batch_size_);
} else {
VLOG(3) << "finish reading for heterps, batch size zero, thread_id="
<< thread_id_;
}
VLOG(3) << "enable heter next: " << offset_index_
<< " batch_offsets: " << batch_offsets_.size()
<< " baych_size: " << this->batch_size_;
return this->batch_size_;
#else
return 0;
#endif
}
} // namespace framework
} // namespace paddle
......@@ -384,7 +384,7 @@ class CustomParser {
CustomParser() {}
virtual ~CustomParser() {}
virtual void Init(const std::vector<SlotConf>& slots) = 0;
virtual bool Init(const std::vector<AllSlotInfo>& slots) = 0;
virtual bool Init(const std::vector<AllSlotInfo>& slots);
virtual void ParseOneInstance(const char* str, Record* instance) = 0;
virtual bool ParseOneInstance(
const std::string& line,
......@@ -1103,6 +1103,42 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
virtual void PutToFeedVec(const Record* ins_vec, int num);
};
class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
public:
SlotRecordInMemoryDataFeed() {}
virtual ~SlotRecordInMemoryDataFeed() {}
virtual void Init(const DataFeedDesc& data_feed_desc);
virtual void LoadIntoMemory();
void ExpandSlotRecord(SlotRecord* ins);
protected:
virtual bool Start();
virtual int Next();
virtual bool ParseOneInstance(SlotRecord* instance) { return false; }
virtual bool ParseOneInstanceFromPipe(SlotRecord* instance) { return false; }
// virtual void ParseOneInstanceFromSo(const char* str, T* instance,
// CustomParser* parser) {}
virtual void PutToFeedVec(const std::vector<SlotRecord>& ins_vec) {}
virtual void LoadIntoMemoryByCommand(void);
virtual void LoadIntoMemoryByLib(void);
virtual void LoadIntoMemoryByLine(void);
virtual void LoadIntoMemoryByFile(void);
virtual void SetInputChannel(void* channel) {
input_channel_ = static_cast<ChannelObject<SlotRecord>*>(channel);
}
bool ParseOneInstance(const std::string& line, SlotRecord* rec);
virtual void PutToFeedVec(const SlotRecord* ins_vec, int num);
float sample_rate_ = 1.0f;
int use_slot_size_ = 0;
int float_use_slot_size_ = 0;
int uint64_use_slot_size_ = 0;
std::vector<AllSlotInfo> all_slots_info_;
std::vector<UsedSlotInfo> used_slots_info_;
size_t float_total_dims_size_ = 0;
std::vector<int> float_total_dims_without_inductives_;
};
class PaddleBoxDataFeed : public MultiSlotInMemoryDataFeed {
public:
PaddleBoxDataFeed() {}
......
......@@ -58,8 +58,8 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
std::string data_feed_class) {
if (g_data_feed_map.count(data_feed_class) < 1) {
LOG(WARNING) << "Your DataFeed " << data_feed_class
<< "is not supported currently";
LOG(WARNING) << "Supported DataFeed: " << DataFeedTypeList();
<< " is not supported currently";
LOG(WARNING) << " Supported DataFeed: " << DataFeedTypeList();
exit(-1);
}
return g_data_feed_map[data_feed_class]();
......@@ -68,6 +68,7 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed);
REGISTER_DATAFEED_CLASS(PaddleBoxDataFeed);
REGISTER_DATAFEED_CLASS(SlotRecordInMemoryDataFeed);
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32)
REGISTER_DATAFEED_CLASS(MultiSlotFileInstantDataFeed);
#endif
......
......@@ -1609,7 +1609,35 @@ void SlotRecordDataset::DynamicAdjustChannelNum(int channel_num,
void SlotRecordDataset::PrepareTrain() {
#ifdef PADDLE_WITH_GLOO
return;
if (enable_heterps_) {
if (input_records_.size() == 0 && input_channel_ != nullptr &&
input_channel_->Size() != 0) {
input_channel_->ReadAll(input_records_);
VLOG(3) << "read from channel to records with records size: "
<< input_records_.size();
}
VLOG(3) << "input records size: " << input_records_.size();
int64_t total_ins_num = input_records_.size();
std::vector<std::pair<int, int>> offset;
int default_batch_size =
reinterpret_cast<SlotRecordInMemoryDataFeed*>(readers_[0].get())
->GetDefaultBatchSize();
VLOG(3) << "thread_num: " << thread_num_
<< " memory size: " << total_ins_num
<< " default batch_size: " << default_batch_size;
compute_thread_batch_nccl(thread_num_, total_ins_num, default_batch_size,
&offset);
VLOG(3) << "offset size: " << offset.size();
for (int i = 0; i < thread_num_; i++) {
reinterpret_cast<SlotRecordInMemoryDataFeed*>(readers_[i].get())
->SetRecord(&input_records_[0]);
}
for (size_t i = 0; i < offset.size(); i++) {
reinterpret_cast<SlotRecordInMemoryDataFeed*>(
readers_[i % thread_num_].get())
->AddBatchOffset(offset[i]);
}
}
#else
PADDLE_THROW(platform::errors::Unavailable(
"dataset set heterps need compile with GLOO"));
......
......@@ -45,9 +45,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
platform::Timer timeline;
timeline.Start();
int device_num = heter_devices_.size();
MultiSlotDataset* dataset = dynamic_cast<MultiSlotDataset*>(dataset_);
gpu_task->init(thread_keys_shard_num_, device_num);
auto input_channel = dataset->GetInputChannel();
auto& local_keys = gpu_task->feature_keys_;
auto& local_ptr = gpu_task->value_ptr_;
......@@ -68,13 +66,60 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
for (int i = 0; i < thread_keys_thread_num_; i++) {
thread_keys_[i].resize(thread_keys_shard_num_);
}
const std::deque<Record>& vec_data = input_channel->GetData();
size_t total_len = vec_data.size();
size_t len_per_thread = total_len / thread_keys_thread_num_;
int remain = total_len % thread_keys_thread_num_;
size_t total_len = 0;
size_t len_per_thread = 0;
int remain = 0;
size_t begin = 0;
auto gen_func = [this](const std::deque<Record>& total_data, int begin_index,
int end_index, int i) {
std::string data_set_name = std::string(typeid(*dataset_).name());
if (data_set_name.find("SlotRecordDataset") != std::string::npos) {
VLOG(0) << "ps_gpu_wrapper use SlotRecordDataset";
SlotRecordDataset* dataset = dynamic_cast<SlotRecordDataset*>(dataset_);
auto input_channel = dataset->GetInputChannel();
VLOG(0) << "yxf::buildtask::inputslotchannle size: "
<< input_channel->Size();
const std::deque<SlotRecord>& vec_data = input_channel->GetData();
total_len = vec_data.size();
len_per_thread = total_len / thread_keys_thread_num_;
remain = total_len % thread_keys_thread_num_;
VLOG(0) << "total len: " << total_len;
auto gen_func = [this](const std::deque<SlotRecord>& total_data,
int begin_index, int end_index, int i) {
for (auto iter = total_data.begin() + begin_index;
iter != total_data.begin() + end_index; iter++) {
const auto& ins = *iter;
const auto& feasign_v = ins->slot_uint64_feasigns_.slot_values;
for (const auto feasign : feasign_v) {
int shard_id = feasign % thread_keys_shard_num_;
this->thread_keys_[i][shard_id].insert(feasign);
}
}
};
for (int i = 0; i < thread_keys_thread_num_; i++) {
threads.push_back(
std::thread(gen_func, std::ref(vec_data), begin,
begin + len_per_thread + (i < remain ? 1 : 0), i));
begin += len_per_thread + (i < remain ? 1 : 0);
}
for (std::thread& t : threads) {
t.join();
}
timeline.Pause();
VLOG(1) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds.";
} else {
CHECK(data_set_name.find("MultiSlotDataset") != std::string::npos);
VLOG(0) << "ps_gpu_wrapper use MultiSlotDataset";
MultiSlotDataset* dataset = dynamic_cast<MultiSlotDataset*>(dataset_);
auto input_channel = dataset->GetInputChannel();
const std::deque<Record>& vec_data = input_channel->GetData();
total_len = vec_data.size();
len_per_thread = total_len / thread_keys_thread_num_;
remain = total_len % thread_keys_thread_num_;
auto gen_func = [this](const std::deque<Record>& total_data,
int begin_index, int end_index, int i) {
for (auto iter = total_data.begin() + begin_index;
iter != total_data.begin() + end_index; iter++) {
const auto& ins = *iter;
......@@ -87,9 +132,9 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
}
};
for (int i = 0; i < thread_keys_thread_num_; i++) {
threads.push_back(std::thread(gen_func, std::ref(vec_data), begin,
begin + len_per_thread + (i < remain ? 1 : 0),
i));
threads.push_back(
std::thread(gen_func, std::ref(vec_data), begin,
begin + len_per_thread + (i < remain ? 1 : 0), i));
begin += len_per_thread + (i < remain ? 1 : 0);
}
for (std::thread& t : threads) {
......@@ -97,6 +142,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
}
timeline.Pause();
VLOG(1) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds.";
}
timeline.Start();
......
......@@ -688,3 +688,5 @@ DEFINE_bool(enable_slotpool_wait_release, false,
"enable slotrecord obejct wait release, default false");
DEFINE_bool(enable_slotrecord_reset_shrink, false,
"enable slotrecord obejct reset shrink memory, default false");
DEFINE_bool(enable_ins_parser_file, false,
"enable parser ins file , default false");
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册