未验证 提交 56a8b3e3 编写于 作者: Y yaoxuefeng 提交者: GitHub

add dymf accessor support (#42881)

上级 5efc4146
...@@ -120,6 +120,24 @@ void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) { ...@@ -120,6 +120,24 @@ void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
&data_feed_desc_); &data_feed_desc_);
} }
template <typename T>
std::vector<std::string> DatasetImpl<T>::GetSlots() {
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
use_slots_.clear();
for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
const auto& slot = multi_slot_desc.slots(i);
if (slot.type() == "uint64" || slot.type() == "uint32") {
use_slots_.push_back(slot.name());
}
}
std::cout << "dataset use slots: ";
for (auto s : use_slots_) {
std::cout << s << " | ";
}
std::cout << " end " << std::endl;
return use_slots_;
}
template <typename T> template <typename T>
void DatasetImpl<T>::SetChannelNum(int channel_num) { void DatasetImpl<T>::SetChannelNum(int channel_num) {
channel_num_ = channel_num; channel_num_ = channel_num;
......
...@@ -152,13 +152,15 @@ class Dataset { ...@@ -152,13 +152,15 @@ class Dataset {
virtual void DestroyPreLoadReaders() = 0; virtual void DestroyPreLoadReaders() = 0;
// set preload thread num // set preload thread num
virtual void SetPreLoadThreadNum(int thread_num) = 0; virtual void SetPreLoadThreadNum(int thread_num) = 0;
// separate train thread and dataset thread // seperate train thread and dataset thread
virtual void DynamicAdjustChannelNum(int channel_num, virtual void DynamicAdjustChannelNum(int channel_num,
bool discard_remaining_ins = false) = 0; bool discard_remaining_ins = false) = 0;
virtual void DynamicAdjustReadersNum(int thread_num) = 0; virtual void DynamicAdjustReadersNum(int thread_num) = 0;
// set fleet send sleep seconds // set fleet send sleep seconds
virtual void SetFleetSendSleepSeconds(int seconds) = 0; virtual void SetFleetSendSleepSeconds(int seconds) = 0;
virtual std::vector<std::string> GetSlots() = 0;
protected: protected:
virtual int ReceiveFromClient(int msg_type, int client_id, virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) = 0; const std::string& msg) = 0;
...@@ -246,6 +248,7 @@ class DatasetImpl : public Dataset { ...@@ -246,6 +248,7 @@ class DatasetImpl : public Dataset {
bool discard_remaining_ins = false); bool discard_remaining_ins = false);
virtual void DynamicAdjustReadersNum(int thread_num); virtual void DynamicAdjustReadersNum(int thread_num);
virtual void SetFleetSendSleepSeconds(int seconds); virtual void SetFleetSendSleepSeconds(int seconds);
virtual std::vector<std::string> GetSlots();
/* for enable_heterps_ /* for enable_heterps_
virtual void EnableHeterps(bool enable_heterps) { virtual void EnableHeterps(bool enable_heterps) {
enable_heterps_ = enable_heterps; enable_heterps_ = enable_heterps;
...@@ -321,6 +324,7 @@ class DatasetImpl : public Dataset { ...@@ -321,6 +324,7 @@ class DatasetImpl : public Dataset {
int64_t global_index_ = 0; int64_t global_index_ = 0;
std::vector<std::shared_ptr<ThreadPool>> consume_task_pool_; std::vector<std::shared_ptr<ThreadPool>> consume_task_pool_;
std::vector<T> input_records_; // only for paddleboxdatafeed std::vector<T> input_records_; // only for paddleboxdatafeed
std::vector<std::string> use_slots_;
bool enable_heterps_ = false; bool enable_heterps_ = false;
}; };
......
...@@ -69,7 +69,7 @@ void FleetWrapper::InitWorker(const std::string& dist_desc, ...@@ -69,7 +69,7 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
int node_num, int index) { int node_num, int index) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
if (!is_initialized_) { if (!is_initialized_) {
VLOG(3) << "Going to init worker"; VLOG(0) << "Going to init worker";
pslib_ptr_ = std::shared_ptr<paddle::distributed::PSlib>( pslib_ptr_ = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib()); new paddle::distributed::PSlib());
pslib_ptr_->init_worker(dist_desc, pslib_ptr_->init_worker(dist_desc,
...@@ -126,7 +126,7 @@ void FleetWrapper::GatherServers(const std::vector<uint64_t>& host_sign_list, ...@@ -126,7 +126,7 @@ void FleetWrapper::GatherServers(const std::vector<uint64_t>& host_sign_list,
void FleetWrapper::GatherClients(const std::vector<uint64_t>& host_sign_list) { void FleetWrapper::GatherClients(const std::vector<uint64_t>& host_sign_list) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to gather client ips"; VLOG(0) << "Going to gather client ips";
size_t len = host_sign_list.size(); size_t len = host_sign_list.size();
pslib_ptr_->gather_clients(const_cast<uint64_t*>(host_sign_list.data()), len); pslib_ptr_->gather_clients(const_cast<uint64_t*>(host_sign_list.data()), len);
#endif #endif
...@@ -142,7 +142,7 @@ std::vector<uint64_t> FleetWrapper::GetClientsInfo() { ...@@ -142,7 +142,7 @@ std::vector<uint64_t> FleetWrapper::GetClientsInfo() {
void FleetWrapper::CreateClient2ClientConnection() { void FleetWrapper::CreateClient2ClientConnection() {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to create client2client connection"; VLOG(0) << "Going to create client2client connection";
pslib_ptr_->create_client2client_connection(client2client_request_timeout_ms_, pslib_ptr_->create_client2client_connection(client2client_request_timeout_ms_,
client2client_connect_timeout_ms_, client2client_connect_timeout_ms_,
client2client_max_retry_); client2client_max_retry_);
...@@ -1054,7 +1054,8 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync( ...@@ -1054,7 +1054,8 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync(
int slot_offset = 0; int slot_offset = 0;
int grad_dim = 0; int grad_dim = 0;
// don't worry, user do not have to care about all these flags // don't worry, user do not have to care about all these flags
if (accesor == "DownpourCtrAccessor") { if (accesor == "DownpourCtrAccessor" ||
accesor == "DownpourCtrDymfAccessor") {
dump_slot = true; dump_slot = true;
slot_offset = 1; slot_offset = 1;
grad_dim = fea_dim - 2; grad_dim = fea_dim - 2;
......
...@@ -720,6 +720,7 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) { ...@@ -720,6 +720,7 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
if (is_shuffle) { if (is_shuffle) {
dataset_->LocalShuffle(); dataset_->LocalShuffle();
} }
InitSlotInfo();
std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get(); std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get();
gpu_task->Reset(); gpu_task->Reset();
data_ready_channel_->Put(gpu_task); data_ready_channel_->Put(gpu_task);
......
...@@ -339,7 +339,29 @@ class PSGPUWrapper { ...@@ -339,7 +339,29 @@ class PSGPUWrapper {
void SetSlotDimVector(const std::vector<int>& slot_mf_dim_vector) { void SetSlotDimVector(const std::vector<int>& slot_mf_dim_vector) {
slot_mf_dim_vector_ = slot_mf_dim_vector; slot_mf_dim_vector_ = slot_mf_dim_vector;
assert(slot_mf_dim_vector_.size() == slot_vector_.size()); assert(slot_mf_dim_vector_.size() == slot_vector_.size());
for (size_t i = 0; i < slot_mf_dim_vector.size(); i++) { }
void InitSlotInfo() {
if (slot_info_initialized_) {
return;
}
SlotRecordDataset* dataset = dynamic_cast<SlotRecordDataset*>(dataset_);
auto slots_vec = dataset->GetSlots();
slot_offset_vector_.clear();
for (auto& slot : slot_vector_) {
for (size_t i = 0; i < slots_vec.size(); ++i) {
if (std::to_string(slot) == slots_vec[i]) {
slot_offset_vector_.push_back(i);
break;
}
}
}
std::cout << "psgpu wrapper use slots: ";
for (auto s : slot_offset_vector_) {
std::cout << s << " | ";
}
std::cout << " end " << std::endl;
for (size_t i = 0; i < slot_mf_dim_vector_.size(); i++) {
slot_dim_map_[slot_vector_[i]] = slot_mf_dim_vector_[i]; slot_dim_map_[slot_vector_[i]] = slot_mf_dim_vector_[i];
} }
...@@ -368,6 +390,7 @@ class PSGPUWrapper { ...@@ -368,6 +390,7 @@ class PSGPUWrapper {
TYPEALIGN(8, sizeof(FeatureValue) + sizeof(float) * (max_mf_dim_ + 1)); TYPEALIGN(8, sizeof(FeatureValue) + sizeof(float) * (max_mf_dim_ + 1));
grad_type_size_ = grad_type_size_ =
TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float))); TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float)));
slot_info_initialized_ = true;
} }
#endif #endif
...@@ -428,6 +451,7 @@ class PSGPUWrapper { ...@@ -428,6 +451,7 @@ class PSGPUWrapper {
int year_; int year_;
int month_; int month_;
int day_; int day_;
bool slot_info_initialized_ = false;
int use_afs_api_ = 0; int use_afs_api_ = 0;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -103,9 +103,9 @@ class PSLib(Fleet): ...@@ -103,9 +103,9 @@ class PSLib(Fleet):
# prepare for client to client communication # prepare for client to client communication
if self._role_maker.is_worker(): if self._role_maker.is_worker():
info = self._fleet_ptr.get_clients_info() info = self._fleet_ptr.get_clients_info()
print("IIIIFO: {}".format(info)) print("Client Info: {}".format(info))
all_info = self._role_maker._worker_gather(info[0]) all_info = self._role_maker._worker_gather(info[0])
print("ALL info: {}".format(all_info)) print("All Client Info: {}".format(all_info))
self._fleet_ptr.gather_clients(all_info) self._fleet_ptr.gather_clients(all_info)
self._fleet_ptr.set_client2client_config( self._fleet_ptr.set_client2client_config(
self._client2client_request_timeout_ms, self._client2client_request_timeout_ms,
......
...@@ -124,14 +124,15 @@ class DownpourServer(Server): ...@@ -124,14 +124,15 @@ class DownpourServer(Server):
support_accessor_class = [ support_accessor_class = [
'DownpourFeatureValueAccessor', 'DownpourCtrAccessor', 'DownpourFeatureValueAccessor', 'DownpourCtrAccessor',
'DownpourSparseValueAccessor', 'DownpourCtrDoubleAccessor', 'DownpourCtrDymfAccessor', 'DownpourSparseValueAccessor',
'DownpourUnitAccessor', 'DownpourDoubleUnitAccessor' 'DownpourCtrDoubleAccessor', 'DownpourUnitAccessor',
'DownpourDoubleUnitAccessor'
] ]
if strategy.get('sparse_accessor_class') is not None: if strategy.get('sparse_accessor_class') is not None:
accessor_class = strategy.get('sparse_accessor_class') accessor_class = strategy.get('sparse_accessor_class')
if accessor_class not in support_accessor_class: if accessor_class not in support_accessor_class:
raise ValueError( raise ValueError(
"support sparse_accessor_class: ['DownpourFeatureValueAccessor', 'DownpourCtrAccessor', \ "support sparse_accessor_class: ['DownpourFeatureValueAccessor', 'DownpourCtrAccessor', 'DownpourCtrDymfAccessor', \
'DownpourSparseValueAccessor', 'DownpourCtrDoubleAccessor'], \ 'DownpourSparseValueAccessor', 'DownpourCtrDoubleAccessor'], \
but actual %s" % (accessor_class)) but actual %s" % (accessor_class))
else: else:
...@@ -141,6 +142,7 @@ class DownpourServer(Server): ...@@ -141,6 +142,7 @@ class DownpourServer(Server):
if accessor_class == 'DownpourFeatureValueAccessor' \ if accessor_class == 'DownpourFeatureValueAccessor' \
or accessor_class == 'DownpourCtrAccessor' \ or accessor_class == 'DownpourCtrAccessor' \
or accessor_class == 'DownpourCtrDymfAccessor' \
or accessor_class == 'DownpourCtrDoubleAccessor': or accessor_class == 'DownpourCtrDoubleAccessor':
table.accessor.sparse_sgd_param.learning_rate = strategy.get( table.accessor.sparse_sgd_param.learning_rate = strategy.get(
'sparse_learning_rate', 0.05) 'sparse_learning_rate', 0.05)
......
...@@ -339,6 +339,7 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -339,6 +339,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
# set sparse_embedx_dim in the strategy according to accessor and use_cvm config # set sparse_embedx_dim in the strategy according to accessor and use_cvm config
if accessor == "DownpourFeatureValueAccessor" \ if accessor == "DownpourFeatureValueAccessor" \
or accessor == "DownpourCtrAccessor" \ or accessor == "DownpourCtrAccessor" \
or accessor == "DownpourCtrDymfAccessor" \
or accessor == "DownpourDoubleUnitAccessor" \ or accessor == "DownpourDoubleUnitAccessor" \
or accessor == "DownpourUnitAccessor": or accessor == "DownpourUnitAccessor":
if st.get("sparse_embedx_dim") is not None \ if st.get("sparse_embedx_dim") is not None \
...@@ -586,6 +587,7 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -586,6 +587,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
# set sparse_embedx_dim in strategy, # set sparse_embedx_dim in strategy,
# user do not have to set it in config_fleet # user do not have to set it in config_fleet
if accessor == "DownpourFeatureValueAccessor" \ if accessor == "DownpourFeatureValueAccessor" \
or accessor == "DownpourCtrDymfAccessor" \
or accessor == "DownpourCtrAccessor" \ or accessor == "DownpourCtrAccessor" \
or accessor == "DownpourDoubleUnitAccessor" \ or accessor == "DownpourDoubleUnitAccessor" \
or accessor == "DownpourUnitAccessor": or accessor == "DownpourUnitAccessor":
...@@ -873,7 +875,8 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -873,7 +875,8 @@ class DistributedAdam(DistributedOptimizerImplBase):
if server._server.downpour_server_param.downpour_table_param[ if server._server.downpour_server_param.downpour_table_param[
0].accessor.accessor_class in [ 0].accessor.accessor_class in [
"DownpourCtrAccessor", "DownpourCtrDoubleAccessor", "DownpourCtrAccessor", "DownpourCtrDoubleAccessor",
"DownpourUnitAccessor", "DownpourDoubleUnitAccessor" "DownpourUnitAccessor", "DownpourDoubleUnitAccessor",
"DownpourCtrDymfAccessor"
]: ]:
opt_info["dump_slot"] = True opt_info["dump_slot"] = True
elif server._server.downpour_server_param.downpour_table_param[ elif server._server.downpour_server_param.downpour_table_param[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册