未验证 提交 53480c9c 编写于 作者: Y yaoxuefeng 提交者: GitHub

add slot record support for GpuPS (#36723)

* add slotrecord datafeed (#36099)

* fix multi-node (#36329)
上级 32fe5a49
此差异已折叠。
...@@ -384,7 +384,7 @@ class CustomParser { ...@@ -384,7 +384,7 @@ class CustomParser {
CustomParser() {} CustomParser() {}
virtual ~CustomParser() {} virtual ~CustomParser() {}
virtual void Init(const std::vector<SlotConf>& slots) = 0; virtual void Init(const std::vector<SlotConf>& slots) = 0;
virtual bool Init(const std::vector<AllSlotInfo>& slots) = 0; virtual bool Init(const std::vector<AllSlotInfo>& slots);
virtual void ParseOneInstance(const char* str, Record* instance) = 0; virtual void ParseOneInstance(const char* str, Record* instance) = 0;
virtual bool ParseOneInstance( virtual bool ParseOneInstance(
const std::string& line, const std::string& line,
...@@ -1103,6 +1103,42 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> { ...@@ -1103,6 +1103,42 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
virtual void PutToFeedVec(const Record* ins_vec, int num); virtual void PutToFeedVec(const Record* ins_vec, int num);
}; };
class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
public:
SlotRecordInMemoryDataFeed() {}
virtual ~SlotRecordInMemoryDataFeed() {}
virtual void Init(const DataFeedDesc& data_feed_desc);
virtual void LoadIntoMemory();
void ExpandSlotRecord(SlotRecord* ins);
protected:
virtual bool Start();
virtual int Next();
virtual bool ParseOneInstance(SlotRecord* instance) { return false; }
virtual bool ParseOneInstanceFromPipe(SlotRecord* instance) { return false; }
// virtual void ParseOneInstanceFromSo(const char* str, T* instance,
// CustomParser* parser) {}
virtual void PutToFeedVec(const std::vector<SlotRecord>& ins_vec) {}
virtual void LoadIntoMemoryByCommand(void);
virtual void LoadIntoMemoryByLib(void);
virtual void LoadIntoMemoryByLine(void);
virtual void LoadIntoMemoryByFile(void);
virtual void SetInputChannel(void* channel) {
input_channel_ = static_cast<ChannelObject<SlotRecord>*>(channel);
}
bool ParseOneInstance(const std::string& line, SlotRecord* rec);
virtual void PutToFeedVec(const SlotRecord* ins_vec, int num);
float sample_rate_ = 1.0f;
int use_slot_size_ = 0;
int float_use_slot_size_ = 0;
int uint64_use_slot_size_ = 0;
std::vector<AllSlotInfo> all_slots_info_;
std::vector<UsedSlotInfo> used_slots_info_;
size_t float_total_dims_size_ = 0;
std::vector<int> float_total_dims_without_inductives_;
};
class PaddleBoxDataFeed : public MultiSlotInMemoryDataFeed { class PaddleBoxDataFeed : public MultiSlotInMemoryDataFeed {
public: public:
PaddleBoxDataFeed() {} PaddleBoxDataFeed() {}
......
...@@ -58,8 +58,8 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed( ...@@ -58,8 +58,8 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
std::string data_feed_class) { std::string data_feed_class) {
if (g_data_feed_map.count(data_feed_class) < 1) { if (g_data_feed_map.count(data_feed_class) < 1) {
LOG(WARNING) << "Your DataFeed " << data_feed_class LOG(WARNING) << "Your DataFeed " << data_feed_class
<< "is not supported currently"; << " is not supported currently";
LOG(WARNING) << "Supported DataFeed: " << DataFeedTypeList(); LOG(WARNING) << " Supported DataFeed: " << DataFeedTypeList();
exit(-1); exit(-1);
} }
return g_data_feed_map[data_feed_class](); return g_data_feed_map[data_feed_class]();
...@@ -68,6 +68,7 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed( ...@@ -68,6 +68,7 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed); REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed); REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed);
REGISTER_DATAFEED_CLASS(PaddleBoxDataFeed); REGISTER_DATAFEED_CLASS(PaddleBoxDataFeed);
REGISTER_DATAFEED_CLASS(SlotRecordInMemoryDataFeed);
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32) #if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32)
REGISTER_DATAFEED_CLASS(MultiSlotFileInstantDataFeed); REGISTER_DATAFEED_CLASS(MultiSlotFileInstantDataFeed);
#endif #endif
......
...@@ -1609,7 +1609,35 @@ void SlotRecordDataset::DynamicAdjustChannelNum(int channel_num, ...@@ -1609,7 +1609,35 @@ void SlotRecordDataset::DynamicAdjustChannelNum(int channel_num,
void SlotRecordDataset::PrepareTrain() { void SlotRecordDataset::PrepareTrain() {
#ifdef PADDLE_WITH_GLOO #ifdef PADDLE_WITH_GLOO
return; if (enable_heterps_) {
if (input_records_.size() == 0 && input_channel_ != nullptr &&
input_channel_->Size() != 0) {
input_channel_->ReadAll(input_records_);
VLOG(3) << "read from channel to records with records size: "
<< input_records_.size();
}
VLOG(3) << "input records size: " << input_records_.size();
int64_t total_ins_num = input_records_.size();
std::vector<std::pair<int, int>> offset;
int default_batch_size =
reinterpret_cast<SlotRecordInMemoryDataFeed*>(readers_[0].get())
->GetDefaultBatchSize();
VLOG(3) << "thread_num: " << thread_num_
<< " memory size: " << total_ins_num
<< " default batch_size: " << default_batch_size;
compute_thread_batch_nccl(thread_num_, total_ins_num, default_batch_size,
&offset);
VLOG(3) << "offset size: " << offset.size();
for (int i = 0; i < thread_num_; i++) {
reinterpret_cast<SlotRecordInMemoryDataFeed*>(readers_[i].get())
->SetRecord(&input_records_[0]);
}
for (size_t i = 0; i < offset.size(); i++) {
reinterpret_cast<SlotRecordInMemoryDataFeed*>(
readers_[i % thread_num_].get())
->AddBatchOffset(offset[i]);
}
}
#else #else
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"dataset set heterps need compile with GLOO")); "dataset set heterps need compile with GLOO"));
......
...@@ -45,9 +45,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) { ...@@ -45,9 +45,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
platform::Timer timeline; platform::Timer timeline;
timeline.Start(); timeline.Start();
int device_num = heter_devices_.size(); int device_num = heter_devices_.size();
MultiSlotDataset* dataset = dynamic_cast<MultiSlotDataset*>(dataset_);
gpu_task->init(thread_keys_shard_num_, device_num); gpu_task->init(thread_keys_shard_num_, device_num);
auto input_channel = dataset->GetInputChannel();
auto& local_keys = gpu_task->feature_keys_; auto& local_keys = gpu_task->feature_keys_;
auto& local_ptr = gpu_task->value_ptr_; auto& local_ptr = gpu_task->value_ptr_;
...@@ -68,35 +66,83 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) { ...@@ -68,35 +66,83 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
for (int i = 0; i < thread_keys_thread_num_; i++) { for (int i = 0; i < thread_keys_thread_num_; i++) {
thread_keys_[i].resize(thread_keys_shard_num_); thread_keys_[i].resize(thread_keys_shard_num_);
} }
const std::deque<Record>& vec_data = input_channel->GetData();
size_t total_len = vec_data.size(); size_t total_len = 0;
size_t len_per_thread = total_len / thread_keys_thread_num_; size_t len_per_thread = 0;
int remain = total_len % thread_keys_thread_num_; int remain = 0;
size_t begin = 0; size_t begin = 0;
auto gen_func = [this](const std::deque<Record>& total_data, int begin_index,
int end_index, int i) { std::string data_set_name = std::string(typeid(*dataset_).name());
for (auto iter = total_data.begin() + begin_index;
iter != total_data.begin() + end_index; iter++) { if (data_set_name.find("SlotRecordDataset") != std::string::npos) {
const auto& ins = *iter; VLOG(0) << "ps_gpu_wrapper use SlotRecordDataset";
const auto& feasign_v = ins.uint64_feasigns_; SlotRecordDataset* dataset = dynamic_cast<SlotRecordDataset*>(dataset_);
for (const auto feasign : feasign_v) { auto input_channel = dataset->GetInputChannel();
uint64_t cur_key = feasign.sign().uint64_feasign_; VLOG(0) << "yxf::buildtask::inputslotchannle size: "
int shard_id = cur_key % thread_keys_shard_num_; << input_channel->Size();
this->thread_keys_[i][shard_id].insert(cur_key); const std::deque<SlotRecord>& vec_data = input_channel->GetData();
total_len = vec_data.size();
len_per_thread = total_len / thread_keys_thread_num_;
remain = total_len % thread_keys_thread_num_;
VLOG(0) << "total len: " << total_len;
auto gen_func = [this](const std::deque<SlotRecord>& total_data,
int begin_index, int end_index, int i) {
for (auto iter = total_data.begin() + begin_index;
iter != total_data.begin() + end_index; iter++) {
const auto& ins = *iter;
const auto& feasign_v = ins->slot_uint64_feasigns_.slot_values;
for (const auto feasign : feasign_v) {
int shard_id = feasign % thread_keys_shard_num_;
this->thread_keys_[i][shard_id].insert(feasign);
}
} }
};
for (int i = 0; i < thread_keys_thread_num_; i++) {
threads.push_back(
std::thread(gen_func, std::ref(vec_data), begin,
begin + len_per_thread + (i < remain ? 1 : 0), i));
begin += len_per_thread + (i < remain ? 1 : 0);
} }
}; for (std::thread& t : threads) {
for (int i = 0; i < thread_keys_thread_num_; i++) { t.join();
threads.push_back(std::thread(gen_func, std::ref(vec_data), begin, }
begin + len_per_thread + (i < remain ? 1 : 0), timeline.Pause();
i)); VLOG(1) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds.";
begin += len_per_thread + (i < remain ? 1 : 0); } else {
} CHECK(data_set_name.find("MultiSlotDataset") != std::string::npos);
for (std::thread& t : threads) { VLOG(0) << "ps_gpu_wrapper use MultiSlotDataset";
t.join(); MultiSlotDataset* dataset = dynamic_cast<MultiSlotDataset*>(dataset_);
auto input_channel = dataset->GetInputChannel();
const std::deque<Record>& vec_data = input_channel->GetData();
total_len = vec_data.size();
len_per_thread = total_len / thread_keys_thread_num_;
remain = total_len % thread_keys_thread_num_;
auto gen_func = [this](const std::deque<Record>& total_data,
int begin_index, int end_index, int i) {
for (auto iter = total_data.begin() + begin_index;
iter != total_data.begin() + end_index; iter++) {
const auto& ins = *iter;
const auto& feasign_v = ins.uint64_feasigns_;
for (const auto feasign : feasign_v) {
uint64_t cur_key = feasign.sign().uint64_feasign_;
int shard_id = cur_key % thread_keys_shard_num_;
this->thread_keys_[i][shard_id].insert(cur_key);
}
}
};
for (int i = 0; i < thread_keys_thread_num_; i++) {
threads.push_back(
std::thread(gen_func, std::ref(vec_data), begin,
begin + len_per_thread + (i < remain ? 1 : 0), i));
begin += len_per_thread + (i < remain ? 1 : 0);
}
for (std::thread& t : threads) {
t.join();
}
timeline.Pause();
VLOG(1) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds.";
} }
timeline.Pause();
VLOG(1) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds.";
timeline.Start(); timeline.Start();
......
...@@ -117,6 +117,15 @@ class PSGPUWrapper { ...@@ -117,6 +117,15 @@ class PSGPUWrapper {
resource_ = std::make_shared<HeterPsResource>(dev_ids); resource_ = std::make_shared<HeterPsResource>(dev_ids);
resource_->enable_p2p(); resource_->enable_p2p();
keys_tensor.resize(resource_->total_gpu()); keys_tensor.resize(resource_->total_gpu());
#ifdef PADDLE_WITH_GLOO
auto gloo = paddle::framework::GlooWrapper::GetInstance();
if (gloo->Size() > 1) {
multi_node_ = 1;
}
#else
PADDLE_THROW(
platform::errors::Unavailable("heter ps need compile with GLOO"));
#endif
if (multi_node_) { if (multi_node_) {
int dev_size = dev_ids.size(); int dev_size = dev_ids.size();
// init inner comm // init inner comm
...@@ -127,7 +136,6 @@ class PSGPUWrapper { ...@@ -127,7 +136,6 @@ class PSGPUWrapper {
// init inter comm // init inter comm
#ifdef PADDLE_WITH_GLOO #ifdef PADDLE_WITH_GLOO
inter_comms_.resize(dev_size); inter_comms_.resize(dev_size);
auto gloo = paddle::framework::GlooWrapper::GetInstance();
if (gloo->Rank() == 0) { if (gloo->Rank() == 0) {
for (int i = 0; i < dev_size; ++i) { for (int i = 0; i < dev_size; ++i) {
platform::dynload::ncclGetUniqueId(&inter_ncclids_[i]); platform::dynload::ncclGetUniqueId(&inter_ncclids_[i]);
......
...@@ -148,7 +148,7 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer( ...@@ -148,7 +148,7 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer(
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"dev ids = [%d], it should greater than 0.", dev_ids.size())); "dev ids = [%d], it should greater than 0.", dev_ids.size()));
const int kDevices = dev_ids.size(); const int kDevices = dev_ids.size();
VLOG(3) << "Begin CreateNCCLCommMultiTrainer. device number: " << kDevices VLOG(1) << "Begin CreateNCCLCommMultiTrainer. device number: " << kDevices
<< ", ntrainers: " << ntrainers << ", train_id: " << train_id << ", ntrainers: " << ntrainers << ", train_id: " << train_id
<< ", rind_id: " << ring_id; << ", rind_id: " << ring_id;
ncclComm_t comms[kDevices]; ncclComm_t comms[kDevices];
...@@ -162,10 +162,10 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer( ...@@ -162,10 +162,10 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer(
#endif #endif
platform::dynload::ncclCommInitRank(comms + i, kDevices * ntrainers, platform::dynload::ncclCommInitRank(comms + i, kDevices * ntrainers,
*nccl_id, train_id * kDevices + i); *nccl_id, train_id * kDevices + i);
VLOG(3) << "ncclCommInitRank: " << i; VLOG(1) << "ncclCommInitRank: " << i;
} }
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclGroupEnd()); PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclGroupEnd());
VLOG(3) << "nccl group end seccessss"; VLOG(1) << "nccl group end seccessss";
} }
PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0, PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -174,7 +174,7 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer( ...@@ -174,7 +174,7 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer(
for (int i = 0; i < kDevices; ++i) { for (int i = 0; i < kDevices; ++i) {
AssignNCCLComm(comms[i], kDevices * ntrainers, train_id * kDevices + i, AssignNCCLComm(comms[i], kDevices * ntrainers, train_id * kDevices + i,
dev_ids[i], ring_id); dev_ids[i], ring_id);
VLOG(3) << "nccl communicator of train_id " << train_id * kDevices + i VLOG(1) << "nccl communicator of train_id " << train_id * kDevices + i
<< " in ring " << ring_id << " has been created on device " << " in ring " << ring_id << " has been created on device "
<< dev_ids[i]; << dev_ids[i];
} }
......
...@@ -680,4 +680,6 @@ DEFINE_int32(slotpool_thread_num, 1, "SlotRecordDataset slot pool thread num"); ...@@ -680,4 +680,6 @@ DEFINE_int32(slotpool_thread_num, 1, "SlotRecordDataset slot pool thread num");
DEFINE_bool(enable_slotpool_wait_release, false, DEFINE_bool(enable_slotpool_wait_release, false,
"enable slotrecord obejct wait release, default false"); "enable slotrecord obejct wait release, default false");
DEFINE_bool(enable_slotrecord_reset_shrink, false, DEFINE_bool(enable_slotrecord_reset_shrink, false,
"enable slotrecord obejct reset shrink memory, default false"); "enable slotrecord obejct reset shrink memory, default false");
\ No newline at end of file DEFINE_bool(enable_ins_parser_file, false,
"enable parser ins file , default false");
...@@ -396,6 +396,8 @@ class InMemoryDataset(DatasetBase): ...@@ -396,6 +396,8 @@ class InMemoryDataset(DatasetBase):
Set data_feed_desc Set data_feed_desc
""" """
self.proto_desc.name = data_feed_type self.proto_desc.name = data_feed_type
if (self.proto_desc.name == "SlotRecordInMemoryDataFeed"):
self.dataset = core.Dataset("SlotRecordDataset")
@deprecated( @deprecated(
since="2.0.0", since="2.0.0",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册