未验证 提交 d18aabb4 编写于 作者: J jiaqi 提交者: GitHub

support patch data, add load_one_table, fix bug (#18509)

(1)support patch data (merge slots of instances of same line id, modify dense layer which
changes its size)
(2)add fleet load_one_table interface, support load from paddle model and load from pslib model
(3)fix push sparse bug which cause push sparse cost more time(about 10% in my testcase)
(4)when some slots are not in one of your network (join/update, etc.),data feed、collect label info、push/pull sparse will skip these slots, instead of throw error.
(5)add more debug info in TrainFilesWithProfiler
上级 fd3aad6c
...@@ -42,7 +42,11 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) { ...@@ -42,7 +42,11 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
CheckInit(); CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) { for (size_t i = 0; i < use_slots_.size(); ++i) {
if (name == use_slots_[i]) { if (name == use_slots_[i]) {
feed_vec_[i] = var->GetMutable<LoDTensor>(); if (var == nullptr) {
feed_vec_[i] = nullptr;
} else {
feed_vec_[i] = var->GetMutable<LoDTensor>();
}
} }
} }
} }
...@@ -164,6 +168,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() { ...@@ -164,6 +168,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
this->fp_ = nullptr; this->fp_ = nullptr;
this->thread_id_ = 0; this->thread_id_ = 0;
this->thread_num_ = 1; this->thread_num_ = 1;
this->parse_ins_id_ = false;
this->input_channel_ = nullptr; this->input_channel_ = nullptr;
this->output_channel_ = nullptr; this->output_channel_ = nullptr;
this->consume_channel_ = nullptr; this->consume_channel_ = nullptr;
...@@ -247,6 +252,11 @@ void InMemoryDataFeed<T>::SetThreadNum(int thread_num) { ...@@ -247,6 +252,11 @@ void InMemoryDataFeed<T>::SetThreadNum(int thread_num) {
thread_num_ = thread_num; thread_num_ = thread_num;
} }
template <typename T>
void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) {
parse_ins_id_ = parse_ins_id;
}
template <typename T> template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() { void InMemoryDataFeed<T>::LoadIntoMemory() {
#ifdef _LINUX #ifdef _LINUX
...@@ -591,6 +601,9 @@ void MultiSlotDataFeed::PutToFeedVec( ...@@ -591,6 +601,9 @@ void MultiSlotDataFeed::PutToFeedVec(
const std::vector<MultiSlotType>& ins_vec) { const std::vector<MultiSlotType>& ins_vec) {
#ifdef _LINUX #ifdef _LINUX
for (size_t i = 0; i < use_slots_.size(); ++i) { for (size_t i = 0; i < use_slots_.size(); ++i) {
if (feed_vec_[i] == nullptr) {
continue;
}
const auto& type = ins_vec[i].GetType(); const auto& type = ins_vec[i].GetType();
const auto& offset = ins_vec[i].GetOffset(); const auto& offset = ins_vec[i].GetOffset();
int total_instance = static_cast<int>(offset.back()); int total_instance = static_cast<int>(offset.back());
...@@ -684,6 +697,18 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { ...@@ -684,6 +697,18 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
// VLOG(3) << line; // VLOG(3) << line;
char* endptr = const_cast<char*>(str); char* endptr = const_cast<char*>(str);
int pos = 0; int pos = 0;
if (parse_ins_id_) {
int num = strtol(&str[pos], &endptr, 10);
CHECK(num == 1); // NOLINT
pos = endptr - str + 1;
size_t len = 0;
while (str[pos + len] != ' ') {
++len;
}
instance->ins_id_ = std::string(str + pos, len);
pos += len + 1;
VLOG(3) << "ins_id " << instance->ins_id_;
}
for (size_t i = 0; i < use_slots_index_.size(); ++i) { for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i]; int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10); int num = strtol(&str[pos], &endptr, 10);
...@@ -699,7 +724,8 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { ...@@ -699,7 +724,8 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
for (int j = 0; j < num; ++j) { for (int j = 0; j < num; ++j) {
float feasign = strtof(endptr, &endptr); float feasign = strtof(endptr, &endptr);
// if float feasign is equal to zero, ignore it // if float feasign is equal to zero, ignore it
if (fabs(feasign) < 1e-6) { // except when slot is dense
if (fabs(feasign) < 1e-6 && !use_slots_is_dense_[i]) {
continue; continue;
} }
FeatureKey f; FeatureKey f;
...@@ -710,7 +736,8 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { ...@@ -710,7 +736,8 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
for (int j = 0; j < num; ++j) { for (int j = 0; j < num; ++j) {
uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10); uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
// if uint64 feasign is equal to zero, ignore it // if uint64 feasign is equal to zero, ignore it
if (feasign == 0) { // except when slot is dense
if (feasign == 0 && !use_slots_is_dense_[i]) {
continue; continue;
} }
FeatureKey f; FeatureKey f;
...@@ -838,6 +865,9 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec( ...@@ -838,6 +865,9 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
} }
for (size_t i = 0; i < use_slots_.size(); ++i) { for (size_t i = 0; i < use_slots_.size(); ++i) {
if (feed_vec_[i] == nullptr) {
continue;
}
int total_instance = offset[i].back(); int total_instance = offset[i].back();
const auto& type = all_slots_type_[i]; const auto& type = all_slots_type_[i];
if (type[0] == 'f') { // float if (type[0] == 'f') { // float
......
...@@ -102,6 +102,8 @@ class DataFeed { ...@@ -102,6 +102,8 @@ class DataFeed {
virtual void SetThreadId(int thread_id) {} virtual void SetThreadId(int thread_id) {}
// This function will do nothing at default // This function will do nothing at default
virtual void SetThreadNum(int thread_num) {} virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default
virtual void SetParseInsId(bool parse_ins_id) {}
virtual void SetFileListMutex(std::mutex* mutex) { virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex; mutex_for_pick_file_ = mutex;
} }
...@@ -212,6 +214,7 @@ class InMemoryDataFeed : public DataFeed { ...@@ -212,6 +214,7 @@ class InMemoryDataFeed : public DataFeed {
virtual void SetConsumeChannel(void* channel); virtual void SetConsumeChannel(void* channel);
virtual void SetThreadId(int thread_id); virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num); virtual void SetThreadNum(int thread_num);
virtual void SetParseInsId(bool parse_ins_id);
virtual void LoadIntoMemory(); virtual void LoadIntoMemory();
protected: protected:
...@@ -221,6 +224,7 @@ class InMemoryDataFeed : public DataFeed { ...@@ -221,6 +224,7 @@ class InMemoryDataFeed : public DataFeed {
int thread_id_; int thread_id_;
int thread_num_; int thread_num_;
bool parse_ins_id_;
std::ifstream file_; std::ifstream file_;
std::shared_ptr<FILE> fp_; std::shared_ptr<FILE> fp_;
paddle::framework::ChannelObject<T>* input_channel_; paddle::framework::ChannelObject<T>* input_channel_;
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/data_set.h"
#include <algorithm>
#include <random> #include <random>
#include <unordered_map>
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
...@@ -21,6 +23,7 @@ ...@@ -21,6 +23,7 @@
#include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
#include "xxhash.h" // NOLINT
#if defined _WIN32 || defined __APPLE__ #if defined _WIN32 || defined __APPLE__
#else #else
...@@ -41,6 +44,10 @@ DatasetImpl<T>::DatasetImpl() { ...@@ -41,6 +44,10 @@ DatasetImpl<T>::DatasetImpl() {
cur_channel_ = 0; cur_channel_ = 0;
fleet_send_batch_size_ = 80000; fleet_send_batch_size_ = 80000;
fleet_send_sleep_seconds_ = 2; fleet_send_sleep_seconds_ = 2;
merge_by_insid_ = false;
erase_duplicate_feas_ = true;
keep_unmerged_ins_ = true;
min_merge_size_ = 2;
} }
// set filelist, file_idx_ will reset to zero. // set filelist, file_idx_ will reset to zero.
...@@ -96,6 +103,17 @@ void DatasetImpl<T>::SetChannelNum(int channel_num) { ...@@ -96,6 +103,17 @@ void DatasetImpl<T>::SetChannelNum(int channel_num) {
channel_num_ = channel_num; channel_num_ = channel_num;
} }
template <typename T>
void DatasetImpl<T>::SetMergeByInsId(
const std::vector<std::string>& merge_slot_list, bool erase_duplicate_feas,
int min_merge_size, bool keep_unmerged_ins) {
merge_by_insid_ = true;
merge_slots_list_ = merge_slot_list;
erase_duplicate_feas_ = erase_duplicate_feas;
min_merge_size_ = min_merge_size;
keep_unmerged_ins_ = keep_unmerged_ins;
}
template <typename T> template <typename T>
std::vector<paddle::framework::DataFeed*> DatasetImpl<T>::GetReaders() { std::vector<paddle::framework::DataFeed*> DatasetImpl<T>::GetReaders() {
std::vector<paddle::framework::DataFeed*> ret; std::vector<paddle::framework::DataFeed*> ret;
...@@ -266,13 +284,22 @@ void DatasetImpl<T>::GlobalShuffle() { ...@@ -266,13 +284,22 @@ void DatasetImpl<T>::GlobalShuffle() {
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() input_channel_ size " VLOG(3) << "DatasetImpl<T>::GlobalShuffle() input_channel_ size "
<< input_channel_->Size(); << input_channel_->Size();
auto global_shuffle_func = [this]() { auto get_client_id = [this, fleet_ptr](const T& data) -> size_t {
if (!this->merge_by_insid_) {
return fleet_ptr->LocalRandomEngine()() % this->trainer_num_;
} else {
return XXH64(data.ins_id_.data(), data.ins_id_.length(), 0) %
this->trainer_num_;
}
};
auto global_shuffle_func = [this, get_client_id]() {
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
std::vector<T> data; std::vector<T> data;
while (this->input_channel_->Read(data)) { while (this->input_channel_->Read(data)) {
std::vector<paddle::framework::BinaryArchive> ars(this->trainer_num_); std::vector<paddle::framework::BinaryArchive> ars(this->trainer_num_);
for (auto& t : data) { for (auto& t : data) {
auto client_id = fleet_ptr->LocalRandomEngine()() % this->trainer_num_; auto client_id = get_client_id(t);
ars[client_id] << t; ars[client_id] << t;
} }
std::vector<std::future<int32_t>> total_status; std::vector<std::future<int32_t>> total_status;
...@@ -345,6 +372,7 @@ void DatasetImpl<T>::CreateReaders() { ...@@ -345,6 +372,7 @@ void DatasetImpl<T>::CreateReaders() {
readers_[i]->SetFileListMutex(&mutex_for_pick_file_); readers_[i]->SetFileListMutex(&mutex_for_pick_file_);
readers_[i]->SetFileListIndex(&file_idx_); readers_[i]->SetFileListIndex(&file_idx_);
readers_[i]->SetFileList(filelist_); readers_[i]->SetFileList(filelist_);
readers_[i]->SetParseInsId(merge_by_insid_);
if (input_channel_ != nullptr) { if (input_channel_ != nullptr) {
readers_[i]->SetInputChannel(input_channel_.get()); readers_[i]->SetInputChannel(input_channel_.get());
} }
...@@ -366,6 +394,7 @@ void DatasetImpl<T>::CreateReaders() { ...@@ -366,6 +394,7 @@ void DatasetImpl<T>::CreateReaders() {
template <typename T> template <typename T>
void DatasetImpl<T>::DestroyReaders() { void DatasetImpl<T>::DestroyReaders() {
VLOG(3) << "Calling DestroyReaders()"; VLOG(3) << "Calling DestroyReaders()";
VLOG(3) << "readers size1: " << readers_.size();
std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_); std::vector<std::shared_ptr<paddle::framework::DataFeed>>().swap(readers_);
VLOG(3) << "readers size: " << readers_.size(); VLOG(3) << "readers size: " << readers_.size();
file_idx_ = 0; file_idx_ = 0;
...@@ -418,8 +447,206 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id, ...@@ -418,8 +447,206 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
} }
// explicit instantiation // explicit instantiation
template class DatasetImpl<std::vector<MultiSlotType>>;
template class DatasetImpl<Record>; template class DatasetImpl<Record>;
void MultiSlotDataset::MergeByInsId() {
VLOG(3) << "MultiSlotDataset::MergeByInsId begin";
if (!merge_by_insid_) {
VLOG(3) << "merge_by_insid=false, will not MergeByInsId";
return;
}
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
std::unordered_map<int, bool> merge_slots;
std::vector<std::string> use_slots;
std::vector<bool> use_slots_is_dense;
for (size_t i = 0; i < multi_slot_desc.slots_size(); ++i) {
const auto& slot = multi_slot_desc.slots(i);
if (slot.is_used()) {
use_slots.push_back(slot.name());
use_slots_is_dense.push_back(slot.is_dense());
}
}
for (size_t i = 0; i < use_slots.size(); ++i) {
// currently, we don't merge dense slots
if (std::find(merge_slots_list_.begin(), merge_slots_list_.end(),
use_slots[i]) != merge_slots_list_.end() &&
!use_slots_is_dense[i]) {
merge_slots[i] = true;
}
}
CHECK(multi_output_channel_.size() != 0); // NOLINT
auto channel_data = paddle::framework::MakeChannel<Record>();
VLOG(3) << "multi_output_channel_.size() " << multi_output_channel_.size();
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
std::vector<Record> vec_data;
multi_output_channel_[i]->Close();
multi_output_channel_[i]->ReadAll(vec_data);
channel_data->Write(std::move(vec_data));
vec_data.clear();
vec_data.shrink_to_fit();
multi_output_channel_[i]->Clear();
}
channel_data->Close();
std::vector<Record> recs;
recs.reserve(channel_data->Size());
channel_data->ReadAll(recs);
channel_data->Clear();
std::sort(recs.begin(), recs.end(), [](const Record& a, const Record& b) {
return a.ins_id_ < b.ins_id_;
});
auto sort_cmp_uint64 = [&merge_slots](const FeatureItem& a,
const FeatureItem& b) {
auto& a_sign = a.sign().uint64_feasign_;
auto& b_sign = b.sign().uint64_feasign_;
return a_sign < b_sign || (a_sign == b_sign && a.slot() < b.slot());
};
auto sort_cmp_float = [&merge_slots](const FeatureItem& a,
const FeatureItem& b) {
auto& a_sign = a.sign().float_feasign_;
auto& b_sign = b.sign().float_feasign_;
return a_sign < b_sign || (a_sign == b_sign && a.slot() < b.slot());
};
auto unique_eq_uint64 = [&merge_slots](const FeatureItem& a,
const FeatureItem& b) {
if (a.slot() == b.slot() &&
merge_slots.find(a.slot()) == merge_slots.end()) {
return true;
}
auto& a_sign = a.sign().uint64_feasign_;
auto& b_sign = b.sign().uint64_feasign_;
return a_sign == b_sign && a.slot() == b.slot();
};
auto unique_eq_float = [&merge_slots](const FeatureItem& a,
const FeatureItem& b) {
if (a.slot() == b.slot() &&
merge_slots.find(a.slot()) == merge_slots.end()) {
return true;
}
auto& a_sign = a.sign().float_feasign_;
auto& b_sign = b.sign().float_feasign_;
return a_sign == b_sign && a.slot() == b.slot();
};
std::vector<Record> results;
VLOG(3) << "recs.size() " << recs.size();
for (size_t i = 0; i < recs.size();) {
size_t j = i + 1;
while (j < recs.size() && recs[j].ins_id_ == recs[i].ins_id_) {
j++;
}
if (j - i < min_merge_size_) {
if (keep_unmerged_ins_) {
for (size_t k = i; k < j; ++k) {
results.push_back(std::move(recs[k]));
}
}
i = j;
continue;
}
std::vector<FeatureItem> merge_uint64_feasigns;
std::vector<FeatureItem> merge_float_feasigns;
Record rec = std::move(recs[i]);
for (size_t k = i + 1; k < j; k++) {
for (auto& feature : recs[k].uint64_feasigns_) {
if (merge_slots.find(feature.slot()) != merge_slots.end()) {
merge_uint64_feasigns.push_back(std::move(feature));
}
}
for (auto& feature : recs[k].float_feasigns_) {
if (merge_slots.find(feature.slot()) != merge_slots.end()) {
merge_float_feasigns.push_back(std::move(feature));
}
}
recs[k] = Record();
}
i = j;
if (!erase_duplicate_feas_) {
rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(),
merge_uint64_feasigns.begin(),
merge_uint64_feasigns.end());
rec.float_feasigns_.insert(rec.float_feasigns_.end(),
merge_float_feasigns.begin(),
merge_float_feasigns.end());
} else {
std::vector<FeatureItem> not_merge_uint64_feasigns;
std::vector<FeatureItem> not_merge_float_feasigns;
for (auto& feature : rec.uint64_feasigns_) {
if (merge_slots.find(feature.slot()) != merge_slots.end()) {
merge_uint64_feasigns.push_back(std::move(feature));
} else {
not_merge_uint64_feasigns.push_back(std::move(feature));
}
}
for (auto& feature : rec.float_feasigns_) {
if (merge_slots.find(feature.slot()) != merge_slots.end()) {
merge_float_feasigns.push_back(std::move(feature));
} else {
not_merge_float_feasigns.push_back(std::move(feature));
}
}
rec.uint64_feasigns_.clear();
rec.float_feasigns_.clear();
// erase duplicate uint64 feasigns
std::sort(merge_uint64_feasigns.begin(), merge_uint64_feasigns.end(),
sort_cmp_uint64);
merge_uint64_feasigns.erase(
std::unique(merge_uint64_feasigns.begin(),
merge_uint64_feasigns.end(), unique_eq_uint64),
merge_uint64_feasigns.end());
rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(),
merge_uint64_feasigns.begin(),
merge_uint64_feasigns.end());
rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(),
not_merge_uint64_feasigns.begin(),
not_merge_uint64_feasigns.end());
// erase duplicate float feasigns
std::sort(merge_float_feasigns.begin(), merge_float_feasigns.end(),
sort_cmp_float);
merge_float_feasigns.erase(
std::unique(merge_float_feasigns.begin(), merge_float_feasigns.end(),
unique_eq_float),
merge_float_feasigns.end());
rec.float_feasigns_.insert(rec.float_feasigns_.end(),
merge_float_feasigns.begin(),
merge_float_feasigns.end());
rec.float_feasigns_.insert(rec.float_feasigns_.end(),
not_merge_float_feasigns.begin(),
not_merge_float_feasigns.end());
}
results.push_back(rec);
}
VLOG(3) << "results size " << results.size();
results.shrink_to_fit();
auto fleet_ptr = FleetWrapper::GetInstance();
std::shuffle(results.begin(), results.end(), fleet_ptr->LocalRandomEngine());
channel_data->Open();
channel_data->Write(std::move(results));
channel_data->Close();
results.clear();
results.shrink_to_fit();
VLOG(3) << "channel data size " << channel_data->Size();
channel_data->SetBlockSize(channel_data->Size() / channel_num_ + 1);
VLOG(3) << "channel data block size " << channel_data->BlockSize();
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
std::vector<Record> vec_data;
channel_data->Read(vec_data);
multi_output_channel_[i]->Open();
multi_output_channel_[i]->Write(std::move(vec_data));
vec_data.clear();
vec_data.shrink_to_fit();
}
CHECK(channel_data->Size() == 0); // NOLINT
channel_data->Clear();
VLOG(3) << "MultiSlotDataset::MergeByInsId end";
}
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -57,6 +57,10 @@ class Dataset { ...@@ -57,6 +57,10 @@ class Dataset {
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0; virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
// set channel num // set channel num
virtual void SetChannelNum(int channel_num) = 0; virtual void SetChannelNum(int channel_num) = 0;
// set merge by ins id
virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list,
bool erase_duplicate_feas, int min_merge_size,
bool keep_unmerged_ins) = 0;
// get file list // get file list
virtual const std::vector<std::string>& GetFileList() = 0; virtual const std::vector<std::string>& GetFileList() = 0;
// get thread num // get thread num
...@@ -98,6 +102,8 @@ class Dataset { ...@@ -98,6 +102,8 @@ class Dataset {
virtual int64_t GetMemoryDataSize() = 0; virtual int64_t GetMemoryDataSize() = 0;
// get shuffle data size // get shuffle data size
virtual int64_t GetShuffleDataSize() = 0; virtual int64_t GetShuffleDataSize() = 0;
// merge by ins id
virtual void MergeByInsId() = 0;
protected: protected:
virtual int ReceiveFromClient(int msg_type, int client_id, virtual int ReceiveFromClient(int msg_type, int client_id,
...@@ -120,6 +126,9 @@ class DatasetImpl : public Dataset { ...@@ -120,6 +126,9 @@ class DatasetImpl : public Dataset {
const std::string& fs_ugi); const std::string& fs_ugi);
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str); virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
virtual void SetChannelNum(int channel_num); virtual void SetChannelNum(int channel_num);
virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list,
bool erase_duplicate_feas, int min_merge_size,
bool keep_unmerged_ins);
virtual const std::vector<std::string>& GetFileList() { return filelist_; } virtual const std::vector<std::string>& GetFileList() { return filelist_; }
virtual int GetThreadNum() { return thread_num_; } virtual int GetThreadNum() { return thread_num_; }
...@@ -145,6 +154,7 @@ class DatasetImpl : public Dataset { ...@@ -145,6 +154,7 @@ class DatasetImpl : public Dataset {
virtual void DestroyReaders(); virtual void DestroyReaders();
virtual int64_t GetMemoryDataSize(); virtual int64_t GetMemoryDataSize();
virtual int64_t GetShuffleDataSize(); virtual int64_t GetShuffleDataSize();
virtual void MergeByInsId() {}
protected: protected:
virtual int ReceiveFromClient(int msg_type, int client_id, virtual int ReceiveFromClient(int msg_type, int client_id,
...@@ -169,12 +179,18 @@ class DatasetImpl : public Dataset { ...@@ -169,12 +179,18 @@ class DatasetImpl : public Dataset {
int64_t fleet_send_batch_size_; int64_t fleet_send_batch_size_;
int64_t fleet_send_sleep_seconds_; int64_t fleet_send_sleep_seconds_;
std::vector<std::thread> preload_threads_; std::vector<std::thread> preload_threads_;
bool merge_by_insid_;
bool erase_duplicate_feas_;
bool keep_unmerged_ins_;
int min_merge_size_;
std::vector<std::string> merge_slots_list_;
}; };
// use std::vector<MultiSlotType> as data type // use std::vector<MultiSlotType> or Record as data type
class MultiSlotDataset : public DatasetImpl<Record> { class MultiSlotDataset : public DatasetImpl<Record> {
public: public:
MultiSlotDataset() {} MultiSlotDataset() {}
virtual void MergeByInsId();
virtual ~MultiSlotDataset() {} virtual ~MultiSlotDataset() {}
}; };
......
...@@ -89,7 +89,12 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) { ...@@ -89,7 +89,12 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) {
VLOG(3) << "sparse_key_names_[" << i VLOG(3) << "sparse_key_names_[" << i
<< "]: " << sparse_key_names_[table_id][i]; << "]: " << sparse_key_names_[table_id][i];
Variable* fea_var = thread_scope_->FindVar(sparse_key_names_[table_id][i]); Variable* fea_var = thread_scope_->FindVar(sparse_key_names_[table_id][i]);
if (fea_var == nullptr) {
continue;
}
LoDTensor* tensor = fea_var->GetMutable<LoDTensor>(); LoDTensor* tensor = fea_var->GetMutable<LoDTensor>();
CHECK(tensor != nullptr) << "tensor of var "
<< sparse_key_names_[table_id][i] << " is null";
int64_t* ids = tensor->data<int64_t>(); int64_t* ids = tensor->data<int64_t>();
size_t fea_idx = 0; size_t fea_idx = 0;
// tensor->lod()[0].size() == batch_size + 1 // tensor->lod()[0].size() == batch_size + 1
...@@ -128,7 +133,11 @@ void DownpourWorker::FillSparseValue(size_t table_idx) { ...@@ -128,7 +133,11 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
std::string slot_name = sparse_key_names_[table_id][i]; std::string slot_name = sparse_key_names_[table_id][i];
std::string emb_slot_name = sparse_value_names_[table_id][i]; std::string emb_slot_name = sparse_value_names_[table_id][i];
Variable* var = thread_scope_->FindVar(slot_name); Variable* var = thread_scope_->FindVar(slot_name);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
CHECK(tensor != nullptr) << "tensor of var " << slot_name << " is null";
int64_t* ids = tensor->data<int64_t>(); int64_t* ids = tensor->data<int64_t>();
int len = tensor->numel(); int len = tensor->numel();
Variable* var_emb = thread_scope_->FindVar(emb_slot_name); Variable* var_emb = thread_scope_->FindVar(emb_slot_name);
...@@ -198,6 +207,8 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -198,6 +207,8 @@ void DownpourWorker::TrainFilesWithProfiler() {
int cur_batch; int cur_batch;
int batch_cnt = 0; int batch_cnt = 0;
uint64_t total_inst = 0; uint64_t total_inst = 0;
double op_sum_time = 0;
std::unordered_map<std::string, double> op_to_time;
timeline.Start(); timeline.Start();
while ((cur_batch = device_reader_->Next()) > 0) { while ((cur_batch = device_reader_->Next()) > 0) {
timeline.Pause(); timeline.Pause();
...@@ -346,7 +357,27 @@ void DownpourWorker::TrainFilesWithProfiler() { ...@@ -346,7 +357,27 @@ void DownpourWorker::TrainFilesWithProfiler() {
for (size_t i = 0; i < op_total_time.size(); ++i) { for (size_t i = 0; i < op_total_time.size(); ++i) {
fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i, fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
op_name[i].c_str(), op_total_time[i] / batch_cnt); op_name[i].c_str(), op_total_time[i] / batch_cnt);
if (op_to_time.find(op_name[i]) == op_to_time.end()) {
op_to_time[op_name[i]] = 0.0;
}
op_to_time[op_name[i]] += op_total_time[i];
op_sum_time += op_total_time[i];
}
for (auto& i : op_to_time) {
fprintf(stderr, "op [%s] run total time: [%f]ms\n", i.first.c_str(),
i.second / batch_cnt);
} }
fprintf(stderr, "op run total time: %fs\n", op_sum_time / batch_cnt);
fprintf(stderr, "train total time: %fs\n", total_time / batch_cnt);
fprintf(stderr, "pull sparse time: %fs\n",
pull_sparse_time / batch_cnt);
fprintf(stderr, "fill sparse time: %fs\n",
fill_sparse_time / batch_cnt);
fprintf(stderr, "push sparse time: %fs\n",
push_sparse_time / batch_cnt);
fprintf(stderr, "push dense time: %fs\n", push_dense_time / batch_cnt);
fprintf(stderr, "collect label time: %fs\n",
collect_label_time / batch_cnt);
fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt); fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100); fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
fprintf(stderr, "pull sparse time percent: %f\n", fprintf(stderr, "pull sparse time percent: %f\n",
......
...@@ -27,8 +27,10 @@ See the License for the specific language governing permissions and ...@@ -27,8 +27,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include <algorithm>
#include <utility> #include <utility>
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
namespace paddle { namespace paddle {
...@@ -156,7 +158,11 @@ void FleetWrapper::PullSparseVarsSync( ...@@ -156,7 +158,11 @@ void FleetWrapper::PullSparseVarsSync(
fea_keys->reserve(MAX_FEASIGN_NUM); fea_keys->reserve(MAX_FEASIGN_NUM);
for (auto name : var_names) { for (auto name : var_names) {
Variable* var = scope.FindVar(name); Variable* var = scope.FindVar(name);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
CHECK(tensor != nullptr) << "tensor of var " << name << " is null";
int64_t* ids = tensor->data<int64_t>(); int64_t* ids = tensor->data<int64_t>();
int len = tensor->numel(); int len = tensor->numel();
for (auto i = 0u; i < len; ++i) { for (auto i = 0u; i < len; ++i) {
...@@ -291,29 +297,34 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -291,29 +297,34 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
grad_dim = emb_dim - 2; grad_dim = emb_dim - 2;
} }
CHECK_GE(grad_dim, 0); CHECK_GE(grad_dim, 0);
push_values->resize(fea_keys.size() + 1);
for (auto& t : *push_values) {
t.resize(emb_dim + offset);
}
uint64_t fea_idx = 0u; uint64_t fea_idx = 0u;
for (size_t i = 0; i < sparse_key_names.size(); ++i) { for (size_t i = 0; i < sparse_key_names.size(); ++i) {
Variable* g_var = scope.FindVar(sparse_grad_names[i]);
CHECK(g_var != nullptr) << "var[" << sparse_grad_names[i] << "] not found";
LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
if (g_tensor == NULL) {
LOG(ERROR) << "var[" << sparse_key_names[i] << "] not found";
exit(-1);
}
float* g = g_tensor->data<float>();
Variable* var = scope.FindVar(sparse_key_names[i]); Variable* var = scope.FindVar(sparse_key_names[i]);
CHECK(var != nullptr) << "var[" << sparse_key_names[i] << "] not found"; if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
if (tensor == NULL) { if (tensor == nullptr) {
LOG(ERROR) << "var[" << sparse_key_names[i] << "] not found"; LOG(ERROR) << "tensor of var[" << sparse_key_names[i] << "] is null";
exit(-1); exit(-1);
} }
int len = tensor->numel(); int len = tensor->numel();
int64_t* ids = tensor->data<int64_t>(); int64_t* ids = tensor->data<int64_t>();
push_values->resize(fea_keys.size() + 1);
for (auto& t : *push_values) { Variable* g_var = scope.FindVar(sparse_grad_names[i]);
t.resize(emb_dim + offset); CHECK(g_var != nullptr) << "var[" << sparse_grad_names[i] << "] not found";
LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
if (g_tensor == nullptr) {
LOG(ERROR) << "tensor of var[" << sparse_key_names[i] << "] is null";
exit(-1);
} }
float* g = g_tensor->data<float>();
if (scale_sparse_gradient_with_batch_size_ && grad_dim > 0) { if (scale_sparse_gradient_with_batch_size_ && grad_dim > 0) {
int dim = emb_dim + offset; int dim = emb_dim + offset;
Eigen::Map< Eigen::Map<
...@@ -355,6 +366,79 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -355,6 +366,79 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif #endif
} }
void FleetWrapper::LoadFromPaddleModel(Scope& scope, const uint64_t table_id,
std::vector<std::string> var_list,
std::string model_path,
std::string model_proto_file,
bool load_combine) {
// load ProgramDesc from model file
auto read_proto_func = [](const std::string& filename) -> ProgramDesc {
std::string contents;
std::ifstream fin(filename, std::ios::in | std::ios::binary);
fin.seekg(0, std::ios::end);
contents.resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&contents[0], contents.size());
fin.close();
ProgramDesc program_desc(contents);
return program_desc;
};
const ProgramDesc old_program = read_proto_func(model_proto_file);
Scope* old_scope = new Scope();
auto& old_block = old_program.Block(0);
auto place = platform::CPUPlace();
std::vector<std::string> old_param_list;
for (auto& t : var_list) {
VarDesc* old_var_desc = old_block.FindVar(t);
if (old_var_desc == nullptr) {
continue;
}
// init variable in scope
Variable* old_var = old_scope->Var(old_var_desc->Name());
InitializeVariable(old_var, old_var_desc->GetType());
old_param_list.push_back(t);
if (load_combine) {
continue;
}
// load variable from model
paddle::framework::AttributeMap attrs;
attrs.insert({"file_path", model_path + "/" + old_var_desc->Name()});
auto load_op = paddle::framework::OpRegistry::CreateOp(
"load", {}, {{"Out", {old_var_desc->Name()}}}, attrs);
load_op->Run(*old_scope, place);
}
if (load_combine) {
std::sort(old_param_list.begin(), old_param_list.end());
paddle::framework::AttributeMap attrs;
attrs.insert({"file_path", model_path});
auto load_op = paddle::framework::OpRegistry::CreateOp(
"load_combine", {}, {{"Out", old_param_list}}, attrs);
load_op->Run(*old_scope, place);
}
for (auto& t : old_param_list) {
Variable* old_var = old_scope->Var(t);
// old model data, here we assume data type is float
LoDTensor* old_tensor = old_var->GetMutable<LoDTensor>();
float* old_data = old_tensor->data<float>();
// new model data, here we assume data type is float
Variable* var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* data = tensor->data<float>();
// copy from old data to new data
if (old_tensor->numel() > tensor->numel()) {
memcpy(data, old_data, tensor->numel() * sizeof(float));
} else {
memcpy(data, old_data, old_tensor->numel() * sizeof(float));
}
}
delete old_scope;
PushDenseParamSync(scope, table_id, old_param_list);
}
void FleetWrapper::LoadModel(const std::string& path, const int mode) { void FleetWrapper::LoadModel(const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->load(path, std::to_string(mode)); auto ret = pslib_ptr_->_worker_ptr->load(path, std::to_string(mode));
...@@ -368,6 +452,21 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) { ...@@ -368,6 +452,21 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) {
#endif #endif
} }
void FleetWrapper::LoadModelOneTable(const uint64_t table_id,
const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret =
pslib_ptr_->_worker_ptr->load(table_id, path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model of table id: " << table_id
<< ", from path: " << path << " failed";
}
#else
VLOG(0) << "FleetWrapper::LoadModel does nothing when no pslib";
#endif
}
void FleetWrapper::SaveModel(const std::string& path, const int mode) { void FleetWrapper::SaveModel(const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->save(path, std::to_string(mode)); auto ret = pslib_ptr_->_worker_ptr->save(path, std::to_string(mode));
......
...@@ -131,9 +131,18 @@ class FleetWrapper { ...@@ -131,9 +131,18 @@ class FleetWrapper {
// flush all push requests // flush all push requests
void ClientFlush(); void ClientFlush();
// load from paddle model
void LoadFromPaddleModel(Scope& scope, const uint64_t table_id, // NOLINT
std::vector<std::string> var_list,
std::string model_path, std::string model_proto_file,
bool load_combine);
// mode = 0, load all feature // mode = 0, load all feature
// mode = 1, laod delta feature, which means load diff // mode = 1, laod delta feature, which means load diff
void LoadModel(const std::string& path, const int mode); void LoadModel(const std::string& path, const int mode);
// mode = 0, load all feature
// mode = 1, laod delta feature, which means load diff
void LoadModelOneTable(const uint64_t table_id, const std::string& path,
const int mode);
// mode = 0, save all feature // mode = 0, save all feature
// mode = 1, save delta feature, which means save diff // mode = 1, save delta feature, which means save diff
void SaveModel(const std::string& path, const int mode); void SaveModel(const std::string& path, const int mode);
......
...@@ -64,7 +64,7 @@ void HogwildWorker::BindingDataFeedMemory() { ...@@ -64,7 +64,7 @@ void HogwildWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed = const std::vector<std::string>& input_feed =
device_reader_->GetUseSlotAlias(); device_reader_->GetUseSlotAlias();
for (auto name : input_feed) { for (auto name : input_feed) {
device_reader_->AddFeedVar(thread_scope_->Var(name), name); device_reader_->AddFeedVar(thread_scope_->FindVar(name), name);
} }
} }
......
...@@ -99,6 +99,10 @@ void BindDataset(py::module* m) { ...@@ -99,6 +99,10 @@ void BindDataset(py::module* m) {
.def("get_shuffle_data_size", &framework::Dataset::GetShuffleDataSize, .def("get_shuffle_data_size", &framework::Dataset::GetShuffleDataSize,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("set_queue_num", &framework::Dataset::SetChannelNum, .def("set_queue_num", &framework::Dataset::SetChannelNum,
py::call_guard<py::gil_scoped_release>())
.def("set_merge_by_lineid", &framework::Dataset::SetMergeByInsId,
py::call_guard<py::gil_scoped_release>())
.def("merge_by_lineid", &framework::Dataset::MergeByInsId,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
} }
......
...@@ -57,7 +57,10 @@ void BindFleetWrapper(py::module* m) { ...@@ -57,7 +57,10 @@ void BindFleetWrapper(py::module* m) {
&framework::FleetWrapper::CreateClient2ClientConnection) &framework::FleetWrapper::CreateClient2ClientConnection)
.def("shrink_sparse_table", &framework::FleetWrapper::ShrinkSparseTable) .def("shrink_sparse_table", &framework::FleetWrapper::ShrinkSparseTable)
.def("shrink_dense_table", &framework::FleetWrapper::ShrinkDenseTable) .def("shrink_dense_table", &framework::FleetWrapper::ShrinkDenseTable)
.def("client_flush", &framework::FleetWrapper::ClientFlush); .def("client_flush", &framework::FleetWrapper::ClientFlush)
.def("load_from_paddle_model",
&framework::FleetWrapper::LoadFromPaddleModel)
.def("load_model_one_table", &framework::FleetWrapper::LoadModelOneTable);
} // end FleetWrapper } // end FleetWrapper
} // end namespace pybind } // end namespace pybind
} // end namespace paddle } // end namespace paddle
...@@ -237,6 +237,7 @@ class InMemoryDataset(DatasetBase): ...@@ -237,6 +237,7 @@ class InMemoryDataset(DatasetBase):
self.proto_desc.name = "MultiSlotInMemoryDataFeed" self.proto_desc.name = "MultiSlotInMemoryDataFeed"
self.fleet_send_batch_size = 80000 self.fleet_send_batch_size = 80000
self.queue_num = None self.queue_num = None
self.merge_by_lineid = False
def _prepare_to_run(self): def _prepare_to_run(self):
""" """
...@@ -258,7 +259,7 @@ class InMemoryDataset(DatasetBase): ...@@ -258,7 +259,7 @@ class InMemoryDataset(DatasetBase):
Set Dataset output queue num, training threads get data from queues Set Dataset output queue num, training threads get data from queues
Args: Args:
set_queue_num(int): dataset output queue num queue_num(int): dataset output queue num
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -287,6 +288,40 @@ class InMemoryDataset(DatasetBase): ...@@ -287,6 +288,40 @@ class InMemoryDataset(DatasetBase):
""" """
self.fleet_send_batch_size = fleet_send_batch_size self.fleet_send_batch_size = fleet_send_batch_size
def set_merge_by_lineid(self,
var_list,
erase_duplicate_feas=True,
min_merge_size=2,
keep_unmerged_ins=True):
"""
Set merge by line id, instances of same line id will be merged after
shuffle, you should parse line id in data generator.
Args:
var_list(list): slots that can be merge. each element in var_list
is Variable. some slots such as show and click, we
usually don't merge them for same line id, so user
should specify which slot can be merged.
erase_duplicate_feas(bool): whether erase duplicate feasigns when
merge. default is True.
min_merge_size(int): minimal size to merge. default is 2.
keep_unmerged_ins(bool): whether to keep unmerged ins, such as
ins with unique id or the num of ins with
same id is less than min_merge_size.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_merge_by_lineid()
"""
var_name_list = [i.name for i in var_list]
self.dataset.set_merge_by_lineid(var_name_list, erase_duplicate_feas,
min_merge_size, keep_unmerged_ins)
self.merge_by_lineid = True
def load_into_memory(self): def load_into_memory(self):
""" """
Load data into memory Load data into memory
...@@ -386,6 +421,10 @@ class InMemoryDataset(DatasetBase): ...@@ -386,6 +421,10 @@ class InMemoryDataset(DatasetBase):
self.dataset.global_shuffle() self.dataset.global_shuffle()
if fleet is not None: if fleet is not None:
fleet._role_maker._barrier_worker() fleet._role_maker._barrier_worker()
if self.merge_by_lineid:
self.dataset.merge_by_lineid()
if fleet is not None:
fleet._role_maker._barrier_worker()
def release_memory(self): def release_memory(self):
""" """
...@@ -530,6 +569,9 @@ class QueueDataset(DatasetBase): ...@@ -530,6 +569,9 @@ class QueueDataset(DatasetBase):
Global shuffle is not supported in QueueDataset Global shuffle is not supported in QueueDataset
NotImplementedError will be raised NotImplementedError will be raised
Args:
fleet(Fleet): fleet singleton. Default None.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -547,9 +589,12 @@ class QueueDataset(DatasetBase): ...@@ -547,9 +589,12 @@ class QueueDataset(DatasetBase):
class FileInstantDataset(DatasetBase): class FileInstantDataset(DatasetBase):
""" """
FileInstantDataset, it will process data streamly. FileInstantDataset, it will process data streamly.
Example:
import paddle.fluid as fluid Examples:
dataset = fluid.DatasetFactory.create_dataset("FileInstantDataset") .. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory.create_dataset("FileInstantDataset")
""" """
def __init__(self): def __init__(self):
...@@ -561,8 +606,7 @@ class FileInstantDataset(DatasetBase): ...@@ -561,8 +606,7 @@ class FileInstantDataset(DatasetBase):
def local_shuffle(self): def local_shuffle(self):
""" """
Local shuffle Local shuffle, FileInstantDataset does not support local shuffle
FileInstantDataset does not support local shuffle
""" """
raise NotImplementedError( raise NotImplementedError(
"FileInstantDataset does not support local shuffle, " "FileInstantDataset does not support local shuffle, "
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import os
import sys import sys
from optimizer_factory import * from optimizer_factory import *
from google.protobuf import text_format from google.protobuf import text_format
...@@ -191,6 +192,7 @@ class PSLib(Fleet): ...@@ -191,6 +192,7 @@ class PSLib(Fleet):
when using fleet, it will save sparse and dense feature when using fleet, it will save sparse and dense feature
Args: Args:
executor(Executor): fluid executor
dirname(str): save path. It can be hdfs/afs path or local path dirname(str): save path. It can be hdfs/afs path or local path
main_program(Program): fluid program, default None main_program(Program): fluid program, default None
kwargs: use define property, current support following kwargs: use define property, current support following
...@@ -261,6 +263,115 @@ class PSLib(Fleet): ...@@ -261,6 +263,115 @@ class PSLib(Fleet):
decay) decay)
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
def load_one_table(self, table_id, model_path, **kwargs):
"""
load pslib model for one table or load params from paddle model
Args:
table_id(int): load table id
model_path(str): load model path, can be local or hdfs/afs path
kwargs(dict): user defined params, currently support following:
only for load pslib model for one table:
mode(int): load model mode. 0 is for load whole model, 1 is
for load delta model (load diff), default is 0.
only for load params from paddle model:
scope(Scope): Scope object
model_proto_file(str): path of program desc proto binary
file, can be local or hdfs/afs file
load_combine(bool): load from a file or splited param files
default False.
Examples:
.. code-block:: python
# load pslib model for one table
fleet.load_one_table(0, "hdfs:/my_fleet_model/20190714/0/")
fleet.load_one_table(1, "hdfs:/xx/xxx", mode = 0)
# load params from paddle model
fleet.load_one_table(2, "hdfs:/my_paddle_model/",
scope = my_scope,
model_proto_file = "./my_program.bin",
load_combine = False)
# below is how to save proto binary file
with open("my_program.bin", "wb") as fout:
my_program = fluid.default_main_program()
fout.write(my_program.desc.serialize_to_string())
"""
mode = kwargs.get("mode", 0)
scope = kwargs.get("scope", None)
model_proto_file = kwargs.get("model_proto_file", None)
load_combine = kwargs.get("load_combine", False)
if scope is not None and model_proto_file is not None:
self._load_one_table_from_paddle_model(
scope, table_id, model_path, model_proto_file, load_combine)
else:
self._fleet_ptr.load_model_one_table(table_id, model_path, mode)
def _load_one_table_from_paddle_model(self,
scope,
table_id,
model_path,
model_proto_file,
load_combine=False):
"""
load params from paddle model, and push params to pserver
Args:
scope(Scope): Scope object
table_id(int): the id of table to load
model_path(str): path of paddle model, can be local or hdfs/afs file
model_proto_file(str): path of program desc proto binary file,
can be local or hdfs/afs file
load_combine(bool): load from a file or splited param files
"""
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
# get fs config from fleet_desc
fs_name = self._opt_info["fleet_desc"].fs_client_param.uri
fs_ugi = self._opt_info["fleet_desc"].fs_client_param.user + "," + \
self._opt_info["fleet_desc"].fs_client_param.passwd
hadoop_bin = self._opt_info["fleet_desc"].fs_client_param.hadoop_bin
# download model_path if it's hdfs/afs
if model_path.startswith("hdfs:") or model_path.startswith("afs:"):
dest = "./model_for_load_table_%s" % table_id
cmd = hadoop_bin + " fs -D fs.default.name=" + fs_name + \
" -D hadoop.job.ugi=" + fs_ugi + " -get " + model_path + \
" " + dest
ret = os.system(cmd)
if ret != 0:
raise RuntimeError("download model failed")
model_path = dest
# download model_proto_file if it's hdfs/afs
if model_proto_file.startswith("hdfs:") or \
model_proto_file.startswith("afs:"):
dest = "./model_proto_file_for_load_table_%s" % table_id
cmd = hadoop_bin + " fs -D fs.default.name=" + fs_name + \
" -D hadoop.job.ugi=" + fs_ugi + " -get " + \
model_proto_file + " " + dest
ret = os.system(cmd)
if ret != 0:
raise RuntimeError("download model proto file failed")
model_proto_file = dest
for i in self._opt_info["fleet_desc"].trainer_param.dense_table:
if table_id is not None and table_id != i.table_id:
continue
var_list = [var for var in i.dense_variable_name]
skip = False
for var in var_list:
if scope.find_var(var) is None:
skip = True
break
if skip:
continue
self._fleet_ptr.load_from_paddle_model(
scope, table_id, var_list, model_path, model_proto_file,
load_combine)
self._role_maker._barrier_worker()
def _set_opt_info(self, opt_info): def _set_opt_info(self, opt_info):
""" """
this function saves the result from DistributedOptimizer.minimize() this function saves the result from DistributedOptimizer.minimize()
......
...@@ -76,7 +76,8 @@ class TrainerDesc(object): ...@@ -76,7 +76,8 @@ class TrainerDesc(object):
return self.proto_desc.SerializeToString() return self.proto_desc.SerializeToString()
def __str__(self): def __str__(self):
return str(self.proto_desc) from google.protobuf import text_format
return text_format.MessageToString(self.proto_desc)
class MultiTrainer(TrainerDesc): class MultiTrainer(TrainerDesc):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册