未验证 提交 0a3dbe8a 编写于 作者: Y yaoxuefeng 提交者: GitHub

add slotrecord datafeed (#36099)

上级 c12176e8
此差异已折叠。
......@@ -384,7 +384,7 @@ class CustomParser {
CustomParser() {}
virtual ~CustomParser() {}
virtual void Init(const std::vector<SlotConf>& slots) = 0;
virtual bool Init(const std::vector<AllSlotInfo>& slots) = 0;
virtual bool Init(const std::vector<AllSlotInfo>& slots);
virtual void ParseOneInstance(const char* str, Record* instance) = 0;
virtual bool ParseOneInstance(
const std::string& line,
......@@ -1103,6 +1103,42 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
virtual void PutToFeedVec(const Record* ins_vec, int num);
};
class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
public:
SlotRecordInMemoryDataFeed() {}
virtual ~SlotRecordInMemoryDataFeed() {}
virtual void Init(const DataFeedDesc& data_feed_desc);
virtual void LoadIntoMemory();
void ExpandSlotRecord(SlotRecord* ins);
protected:
virtual bool Start();
virtual int Next();
virtual bool ParseOneInstance(SlotRecord* instance) { return false; }
virtual bool ParseOneInstanceFromPipe(SlotRecord* instance) { return false; }
// virtual void ParseOneInstanceFromSo(const char* str, T* instance,
// CustomParser* parser) {}
virtual void PutToFeedVec(const std::vector<SlotRecord>& ins_vec) {}
virtual void LoadIntoMemoryByCommand(void);
virtual void LoadIntoMemoryByLib(void);
virtual void LoadIntoMemoryByLine(void);
virtual void LoadIntoMemoryByFile(void);
virtual void SetInputChannel(void* channel) {
input_channel_ = static_cast<ChannelObject<SlotRecord>*>(channel);
}
bool ParseOneInstance(const std::string& line, SlotRecord* rec);
virtual void PutToFeedVec(const SlotRecord* ins_vec, int num);
float sample_rate_ = 1.0f;
int use_slot_size_ = 0;
int float_use_slot_size_ = 0;
int uint64_use_slot_size_ = 0;
std::vector<AllSlotInfo> all_slots_info_;
std::vector<UsedSlotInfo> used_slots_info_;
size_t float_total_dims_size_ = 0;
std::vector<int> float_total_dims_without_inductives_;
};
class PaddleBoxDataFeed : public MultiSlotInMemoryDataFeed {
public:
PaddleBoxDataFeed() {}
......
......@@ -58,8 +58,8 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
std::string data_feed_class) {
if (g_data_feed_map.count(data_feed_class) < 1) {
LOG(WARNING) << "Your DataFeed " << data_feed_class
<< "is not supported currently";
LOG(WARNING) << "Supported DataFeed: " << DataFeedTypeList();
<< " is not supported currently";
LOG(WARNING) << " Supported DataFeed: " << DataFeedTypeList();
exit(-1);
}
return g_data_feed_map[data_feed_class]();
......@@ -68,6 +68,7 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed);
REGISTER_DATAFEED_CLASS(PaddleBoxDataFeed);
REGISTER_DATAFEED_CLASS(SlotRecordInMemoryDataFeed);
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32)
REGISTER_DATAFEED_CLASS(MultiSlotFileInstantDataFeed);
#endif
......
......@@ -1609,7 +1609,35 @@ void SlotRecordDataset::DynamicAdjustChannelNum(int channel_num,
void SlotRecordDataset::PrepareTrain() {
#ifdef PADDLE_WITH_GLOO
return;
if (enable_heterps_) {
if (input_records_.size() == 0 && input_channel_ != nullptr &&
input_channel_->Size() != 0) {
input_channel_->ReadAll(input_records_);
VLOG(3) << "read from channel to records with records size: "
<< input_records_.size();
}
VLOG(3) << "input records size: " << input_records_.size();
int64_t total_ins_num = input_records_.size();
std::vector<std::pair<int, int>> offset;
int default_batch_size =
reinterpret_cast<SlotRecordInMemoryDataFeed*>(readers_[0].get())
->GetDefaultBatchSize();
VLOG(3) << "thread_num: " << thread_num_
<< " memory size: " << total_ins_num
<< " default batch_size: " << default_batch_size;
compute_thread_batch_nccl(thread_num_, total_ins_num, default_batch_size,
&offset);
VLOG(3) << "offset size: " << offset.size();
for (int i = 0; i < thread_num_; i++) {
reinterpret_cast<SlotRecordInMemoryDataFeed*>(readers_[i].get())
->SetRecord(&input_records_[0]);
}
for (size_t i = 0; i < offset.size(); i++) {
reinterpret_cast<SlotRecordInMemoryDataFeed*>(
readers_[i % thread_num_].get())
->AddBatchOffset(offset[i]);
}
}
#else
PADDLE_THROW(platform::errors::Unavailable(
"dataset set heterps need compile with GLOO"));
......
......@@ -45,9 +45,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
platform::Timer timeline;
timeline.Start();
int device_num = heter_devices_.size();
MultiSlotDataset* dataset = dynamic_cast<MultiSlotDataset*>(dataset_);
gpu_task->init(thread_keys_shard_num_, device_num);
auto input_channel = dataset->GetInputChannel();
auto& local_keys = gpu_task->feature_keys_;
auto& local_ptr = gpu_task->value_ptr_;
......@@ -68,13 +66,60 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
for (int i = 0; i < thread_keys_thread_num_; i++) {
thread_keys_[i].resize(thread_keys_shard_num_);
}
const std::deque<Record>& vec_data = input_channel->GetData();
size_t total_len = vec_data.size();
size_t len_per_thread = total_len / thread_keys_thread_num_;
int remain = total_len % thread_keys_thread_num_;
size_t total_len = 0;
size_t len_per_thread = 0;
int remain = 0;
size_t begin = 0;
auto gen_func = [this](const std::deque<Record>& total_data, int begin_index,
int end_index, int i) {
std::string data_set_name = std::string(typeid(*dataset_).name());
if (data_set_name.find("SlotRecordDataset") != std::string::npos) {
VLOG(0) << "ps_gpu_wrapper use SlotRecordDataset";
SlotRecordDataset* dataset = dynamic_cast<SlotRecordDataset*>(dataset_);
auto input_channel = dataset->GetInputChannel();
VLOG(0) << "yxf::buildtask::inputslotchannle size: "
<< input_channel->Size();
const std::deque<SlotRecord>& vec_data = input_channel->GetData();
total_len = vec_data.size();
len_per_thread = total_len / thread_keys_thread_num_;
remain = total_len % thread_keys_thread_num_;
VLOG(0) << "total len: " << total_len;
auto gen_func = [this](const std::deque<SlotRecord>& total_data,
int begin_index, int end_index, int i) {
for (auto iter = total_data.begin() + begin_index;
iter != total_data.begin() + end_index; iter++) {
const auto& ins = *iter;
const auto& feasign_v = ins->slot_uint64_feasigns_.slot_values;
for (const auto feasign : feasign_v) {
int shard_id = feasign % thread_keys_shard_num_;
this->thread_keys_[i][shard_id].insert(feasign);
}
}
};
for (int i = 0; i < thread_keys_thread_num_; i++) {
threads.push_back(
std::thread(gen_func, std::ref(vec_data), begin,
begin + len_per_thread + (i < remain ? 1 : 0), i));
begin += len_per_thread + (i < remain ? 1 : 0);
}
for (std::thread& t : threads) {
t.join();
}
timeline.Pause();
VLOG(1) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds.";
} else {
CHECK(data_set_name.find("MultiSlotDataset") != std::string::npos);
VLOG(0) << "ps_gpu_wrapper use MultiSlotDataset";
MultiSlotDataset* dataset = dynamic_cast<MultiSlotDataset*>(dataset_);
auto input_channel = dataset->GetInputChannel();
const std::deque<Record>& vec_data = input_channel->GetData();
total_len = vec_data.size();
len_per_thread = total_len / thread_keys_thread_num_;
remain = total_len % thread_keys_thread_num_;
auto gen_func = [this](const std::deque<Record>& total_data,
int begin_index, int end_index, int i) {
for (auto iter = total_data.begin() + begin_index;
iter != total_data.begin() + end_index; iter++) {
const auto& ins = *iter;
......@@ -87,9 +132,9 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
}
};
for (int i = 0; i < thread_keys_thread_num_; i++) {
threads.push_back(std::thread(gen_func, std::ref(vec_data), begin,
begin + len_per_thread + (i < remain ? 1 : 0),
i));
threads.push_back(
std::thread(gen_func, std::ref(vec_data), begin,
begin + len_per_thread + (i < remain ? 1 : 0), i));
begin += len_per_thread + (i < remain ? 1 : 0);
}
for (std::thread& t : threads) {
......@@ -97,6 +142,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
}
timeline.Pause();
VLOG(1) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds.";
}
timeline.Start();
......
......@@ -688,3 +688,5 @@ DEFINE_bool(enable_slotpool_wait_release, false,
"enable slotrecord obejct wait release, default false");
DEFINE_bool(enable_slotrecord_reset_shrink, false,
"enable slotrecord obejct reset shrink memory, default false");
DEFINE_bool(enable_ins_parser_file, false,
"enable parser ins file , default false");
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册