未验证 提交 56a8b3e3 编写于 作者: Y yaoxuefeng 提交者: GitHub

add dymf accessor support (#42881)

上级 5efc4146
......@@ -120,6 +120,24 @@ void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
&data_feed_desc_);
}
template <typename T>
std::vector<std::string> DatasetImpl<T>::GetSlots() {
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
use_slots_.clear();
for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
const auto& slot = multi_slot_desc.slots(i);
if (slot.type() == "uint64" || slot.type() == "uint32") {
use_slots_.push_back(slot.name());
}
}
std::cout << "dataset use slots: ";
for (auto s : use_slots_) {
std::cout << s << " | ";
}
std::cout << " end " << std::endl;
return use_slots_;
}
template <typename T>
void DatasetImpl<T>::SetChannelNum(int channel_num) {
channel_num_ = channel_num;
......
......@@ -152,13 +152,15 @@ class Dataset {
virtual void DestroyPreLoadReaders() = 0;
// set preload thread num
virtual void SetPreLoadThreadNum(int thread_num) = 0;
// separate train thread and dataset thread
// seperate train thread and dataset thread
virtual void DynamicAdjustChannelNum(int channel_num,
bool discard_remaining_ins = false) = 0;
virtual void DynamicAdjustReadersNum(int thread_num) = 0;
// set fleet send sleep seconds
virtual void SetFleetSendSleepSeconds(int seconds) = 0;
virtual std::vector<std::string> GetSlots() = 0;
protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) = 0;
......@@ -246,6 +248,7 @@ class DatasetImpl : public Dataset {
bool discard_remaining_ins = false);
virtual void DynamicAdjustReadersNum(int thread_num);
virtual void SetFleetSendSleepSeconds(int seconds);
virtual std::vector<std::string> GetSlots();
/* for enable_heterps_
virtual void EnableHeterps(bool enable_heterps) {
enable_heterps_ = enable_heterps;
......@@ -321,6 +324,7 @@ class DatasetImpl : public Dataset {
int64_t global_index_ = 0;
std::vector<std::shared_ptr<ThreadPool>> consume_task_pool_;
std::vector<T> input_records_; // only for paddleboxdatafeed
std::vector<std::string> use_slots_;
bool enable_heterps_ = false;
};
......
......@@ -69,7 +69,7 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
int node_num, int index) {
#ifdef PADDLE_WITH_PSLIB
if (!is_initialized_) {
VLOG(3) << "Going to init worker";
VLOG(0) << "Going to init worker";
pslib_ptr_ = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib());
pslib_ptr_->init_worker(dist_desc,
......@@ -126,7 +126,7 @@ void FleetWrapper::GatherServers(const std::vector<uint64_t>& host_sign_list,
void FleetWrapper::GatherClients(const std::vector<uint64_t>& host_sign_list) {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to gather client ips";
VLOG(0) << "Going to gather client ips";
size_t len = host_sign_list.size();
pslib_ptr_->gather_clients(const_cast<uint64_t*>(host_sign_list.data()), len);
#endif
......@@ -142,7 +142,7 @@ std::vector<uint64_t> FleetWrapper::GetClientsInfo() {
void FleetWrapper::CreateClient2ClientConnection() {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "Going to create client2client connection";
VLOG(0) << "Going to create client2client connection";
pslib_ptr_->create_client2client_connection(client2client_request_timeout_ms_,
client2client_connect_timeout_ms_,
client2client_max_retry_);
......@@ -1054,7 +1054,8 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync(
int slot_offset = 0;
int grad_dim = 0;
// don't worry, user do not have to care about all these flags
if (accesor == "DownpourCtrAccessor") {
if (accesor == "DownpourCtrAccessor" ||
accesor == "DownpourCtrDymfAccessor") {
dump_slot = true;
slot_offset = 1;
grad_dim = fea_dim - 2;
......
......@@ -720,6 +720,7 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
if (is_shuffle) {
dataset_->LocalShuffle();
}
InitSlotInfo();
std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get();
gpu_task->Reset();
data_ready_channel_->Put(gpu_task);
......
......@@ -339,7 +339,29 @@ class PSGPUWrapper {
void SetSlotDimVector(const std::vector<int>& slot_mf_dim_vector) {
slot_mf_dim_vector_ = slot_mf_dim_vector;
assert(slot_mf_dim_vector_.size() == slot_vector_.size());
for (size_t i = 0; i < slot_mf_dim_vector.size(); i++) {
}
void InitSlotInfo() {
if (slot_info_initialized_) {
return;
}
SlotRecordDataset* dataset = dynamic_cast<SlotRecordDataset*>(dataset_);
auto slots_vec = dataset->GetSlots();
slot_offset_vector_.clear();
for (auto& slot : slot_vector_) {
for (size_t i = 0; i < slots_vec.size(); ++i) {
if (std::to_string(slot) == slots_vec[i]) {
slot_offset_vector_.push_back(i);
break;
}
}
}
std::cout << "psgpu wrapper use slots: ";
for (auto s : slot_offset_vector_) {
std::cout << s << " | ";
}
std::cout << " end " << std::endl;
for (size_t i = 0; i < slot_mf_dim_vector_.size(); i++) {
slot_dim_map_[slot_vector_[i]] = slot_mf_dim_vector_[i];
}
......@@ -368,6 +390,7 @@ class PSGPUWrapper {
TYPEALIGN(8, sizeof(FeatureValue) + sizeof(float) * (max_mf_dim_ + 1));
grad_type_size_ =
TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float)));
slot_info_initialized_ = true;
}
#endif
......@@ -428,6 +451,7 @@ class PSGPUWrapper {
int year_;
int month_;
int day_;
bool slot_info_initialized_ = false;
int use_afs_api_ = 0;
#ifdef PADDLE_WITH_CUDA
......
......@@ -103,9 +103,9 @@ class PSLib(Fleet):
# prepare for client to client communication
if self._role_maker.is_worker():
info = self._fleet_ptr.get_clients_info()
print("IIIIFO: {}".format(info))
print("Client Info: {}".format(info))
all_info = self._role_maker._worker_gather(info[0])
print("ALL info: {}".format(all_info))
print("All Client Info: {}".format(all_info))
self._fleet_ptr.gather_clients(all_info)
self._fleet_ptr.set_client2client_config(
self._client2client_request_timeout_ms,
......
......@@ -124,14 +124,15 @@ class DownpourServer(Server):
support_accessor_class = [
'DownpourFeatureValueAccessor', 'DownpourCtrAccessor',
'DownpourSparseValueAccessor', 'DownpourCtrDoubleAccessor',
'DownpourUnitAccessor', 'DownpourDoubleUnitAccessor'
'DownpourCtrDymfAccessor', 'DownpourSparseValueAccessor',
'DownpourCtrDoubleAccessor', 'DownpourUnitAccessor',
'DownpourDoubleUnitAccessor'
]
if strategy.get('sparse_accessor_class') is not None:
accessor_class = strategy.get('sparse_accessor_class')
if accessor_class not in support_accessor_class:
raise ValueError(
"support sparse_accessor_class: ['DownpourFeatureValueAccessor', 'DownpourCtrAccessor', \
"support sparse_accessor_class: ['DownpourFeatureValueAccessor', 'DownpourCtrAccessor', 'DownpourCtrDymfAccessor', \
'DownpourSparseValueAccessor', 'DownpourCtrDoubleAccessor'], \
but actual %s" % (accessor_class))
else:
......@@ -141,6 +142,7 @@ class DownpourServer(Server):
if accessor_class == 'DownpourFeatureValueAccessor' \
or accessor_class == 'DownpourCtrAccessor' \
or accessor_class == 'DownpourCtrDymfAccessor' \
or accessor_class == 'DownpourCtrDoubleAccessor':
table.accessor.sparse_sgd_param.learning_rate = strategy.get(
'sparse_learning_rate', 0.05)
......
......@@ -339,6 +339,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
# set sparse_embedx_dim in the strategy according to accessor and use_cvm config
if accessor == "DownpourFeatureValueAccessor" \
or accessor == "DownpourCtrAccessor" \
or accessor == "DownpourCtrDymfAccessor" \
or accessor == "DownpourDoubleUnitAccessor" \
or accessor == "DownpourUnitAccessor":
if st.get("sparse_embedx_dim") is not None \
......@@ -586,6 +587,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
# set sparse_embedx_dim in strategy,
# user do not have to set it in config_fleet
if accessor == "DownpourFeatureValueAccessor" \
or accessor == "DownpourCtrDymfAccessor" \
or accessor == "DownpourCtrAccessor" \
or accessor == "DownpourDoubleUnitAccessor" \
or accessor == "DownpourUnitAccessor":
......@@ -873,7 +875,8 @@ class DistributedAdam(DistributedOptimizerImplBase):
if server._server.downpour_server_param.downpour_table_param[
0].accessor.accessor_class in [
"DownpourCtrAccessor", "DownpourCtrDoubleAccessor",
"DownpourUnitAccessor", "DownpourDoubleUnitAccessor"
"DownpourUnitAccessor", "DownpourDoubleUnitAccessor",
"DownpourCtrDymfAccessor"
]:
opt_info["dump_slot"] = True
elif server._server.downpour_server_param.downpour_table_param[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册