未验证 提交 e6b87b31 编写于 作者: H hutuxian 提交者: GitHub

Support AucRunner in PaddleBox (#22884)

* Support AucRunner in PaddleBox
* update some code style
上级 c417f991
...@@ -41,44 +41,44 @@ namespace paddle { ...@@ -41,44 +41,44 @@ namespace paddle {
namespace framework { namespace framework {
void RecordCandidateList::ReSize(size_t length) { void RecordCandidateList::ReSize(size_t length) {
_mutex.lock(); mutex_.lock();
_capacity = length; capacity_ = length;
CHECK(_capacity > 0); // NOLINT CHECK(capacity_ > 0); // NOLINT
_candidate_list.clear(); candidate_list_.clear();
_candidate_list.resize(_capacity); candidate_list_.resize(capacity_);
_full = false; full_ = false;
_cur_size = 0; cur_size_ = 0;
_total_size = 0; total_size_ = 0;
_mutex.unlock(); mutex_.unlock();
} }
void RecordCandidateList::ReInit() { void RecordCandidateList::ReInit() {
_mutex.lock(); mutex_.lock();
_full = false; full_ = false;
_cur_size = 0; cur_size_ = 0;
_total_size = 0; total_size_ = 0;
_mutex.unlock(); mutex_.unlock();
} }
void RecordCandidateList::AddAndGet(const Record& record, void RecordCandidateList::AddAndGet(const Record& record,
RecordCandidate* result) { RecordCandidate* result) {
_mutex.lock(); mutex_.lock();
size_t index = 0; size_t index = 0;
++_total_size; ++total_size_;
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
if (!_full) { if (!full_) {
_candidate_list[_cur_size++] = record; candidate_list_[cur_size_++] = record;
_full = (_cur_size == _capacity); full_ = (cur_size_ == capacity_);
} else { } else {
CHECK(_cur_size == _capacity); CHECK(cur_size_ == capacity_);
index = fleet_ptr->LocalRandomEngine()() % _total_size; index = fleet_ptr->LocalRandomEngine()() % total_size_;
if (index < _capacity) { if (index < capacity_) {
_candidate_list[index] = record; candidate_list_[index] = record;
} }
} }
index = fleet_ptr->LocalRandomEngine()() % _cur_size; index = fleet_ptr->LocalRandomEngine()() % cur_size_;
*result = _candidate_list[index]; *result = candidate_list_[index];
_mutex.unlock(); mutex_.unlock();
} }
void DataFeed::AddFeedVar(Variable* var, const std::string& name) { void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
...@@ -1452,7 +1452,11 @@ void PaddleBoxDataFeed::PutToFeedVec(const std::vector<PvInstance>& pv_vec) { ...@@ -1452,7 +1452,11 @@ void PaddleBoxDataFeed::PutToFeedVec(const std::vector<PvInstance>& pv_vec) {
int PaddleBoxDataFeed::GetCurrentPhase() { int PaddleBoxDataFeed::GetCurrentPhase() {
#ifdef PADDLE_WITH_BOX_PS #ifdef PADDLE_WITH_BOX_PS
auto box_ptr = paddle::framework::BoxWrapper::GetInstance(); auto box_ptr = paddle::framework::BoxWrapper::GetInstance();
return box_ptr->PassFlag(); // join: 1, update: 0 if (box_ptr->Mode() == 1) { // For AucRunner
return 1;
} else {
return box_ptr->Phase();
}
#else #else
LOG(WARNING) << "It should be complied with BOX_PS..."; LOG(WARNING) << "It should be complied with BOX_PS...";
return current_phase_; return current_phase_;
......
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -34,6 +35,7 @@ limitations under the License. */ ...@@ -34,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
...@@ -484,13 +486,25 @@ paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar, ...@@ -484,13 +486,25 @@ paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
struct RecordCandidate { struct RecordCandidate {
std::string ins_id_; std::string ins_id_;
std::unordered_multimap<uint16_t, FeatureKey> feas; std::unordered_multimap<uint16_t, FeatureKey> feas_;
size_t shadow_index_ = -1; // Optimization for Reservoir Sample
RecordCandidate() {}
RecordCandidate(const Record& rec,
const std::unordered_set<uint16_t>& slot_index_to_replace) {
for (const auto& fea : rec.uint64_feasigns_) {
if (slot_index_to_replace.find(fea.slot()) !=
slot_index_to_replace.end()) {
feas_.insert({fea.slot(), fea.sign()});
}
}
}
RecordCandidate& operator=(const Record& rec) { RecordCandidate& operator=(const Record& rec) {
feas.clear(); feas_.clear();
ins_id_ = rec.ins_id_; ins_id_ = rec.ins_id_;
for (auto& fea : rec.uint64_feasigns_) { for (auto& fea : rec.uint64_feasigns_) {
feas.insert({fea.slot(), fea.sign()}); feas_.insert({fea.slot(), fea.sign()});
} }
return *this; return *this;
} }
...@@ -499,22 +513,67 @@ struct RecordCandidate { ...@@ -499,22 +513,67 @@ struct RecordCandidate {
class RecordCandidateList { class RecordCandidateList {
public: public:
RecordCandidateList() = default; RecordCandidateList() = default;
RecordCandidateList(const RecordCandidateList&) = delete; RecordCandidateList(const RecordCandidateList&) {}
RecordCandidateList& operator=(const RecordCandidateList&) = delete;
size_t Size() { return cur_size_; }
void ReSize(size_t length); void ReSize(size_t length);
void ReInit(); void ReInit();
void ReInitPass() {
for (size_t i = 0; i < cur_size_; ++i) {
if (candidate_list_[i].shadow_index_ != i) {
candidate_list_[i].ins_id_ =
candidate_list_[candidate_list_[i].shadow_index_].ins_id_;
candidate_list_[i].feas_.swap(
candidate_list_[candidate_list_[i].shadow_index_].feas_);
candidate_list_[i].shadow_index_ = i;
}
}
candidate_list_.resize(cur_size_);
}
void AddAndGet(const Record& record, RecordCandidate* result); void AddAndGet(const Record& record, RecordCandidate* result);
void AddAndGet(const Record& record, size_t& index_result) { // NOLINT
// std::unique_lock<std::mutex> lock(mutex_);
size_t index = 0;
++total_size_;
auto fleet_ptr = FleetWrapper::GetInstance();
if (!full_) {
candidate_list_.emplace_back(record, slot_index_to_replace_);
candidate_list_.back().shadow_index_ = cur_size_;
++cur_size_;
full_ = (cur_size_ == capacity_);
} else {
index = fleet_ptr->LocalRandomEngine()() % total_size_;
if (index < capacity_) {
candidate_list_.emplace_back(record, slot_index_to_replace_);
candidate_list_[index].shadow_index_ = candidate_list_.size() - 1;
}
}
index = fleet_ptr->LocalRandomEngine()() % cur_size_;
index_result = candidate_list_[index].shadow_index_;
}
const RecordCandidate& Get(size_t index) const {
PADDLE_ENFORCE_LT(
index, candidate_list_.size(),
platform::errors::OutOfRange("Your index [%lu] exceeds the number of "
"elements in candidate_list[%lu].",
index, candidate_list_.size()));
return candidate_list_[index];
}
void SetSlotIndexToReplace(
const std::unordered_set<uint16_t>& slot_index_to_replace) {
slot_index_to_replace_ = slot_index_to_replace;
}
private: private:
size_t _capacity = 0; size_t capacity_ = 0;
std::mutex _mutex; std::mutex mutex_;
bool _full = false; bool full_ = false;
size_t _cur_size = 0; size_t cur_size_ = 0;
size_t _total_size = 0; size_t total_size_ = 0;
std::vector<RecordCandidate> _candidate_list; std::vector<RecordCandidate> candidate_list_;
std::unordered_set<uint16_t> slot_index_to_replace_;
}; };
template <class AR> template <class AR>
......
...@@ -1141,13 +1141,15 @@ void MultiSlotDataset::MergeByInsId() { ...@@ -1141,13 +1141,15 @@ void MultiSlotDataset::MergeByInsId() {
VLOG(3) << "MultiSlotDataset::MergeByInsId end"; VLOG(3) << "MultiSlotDataset::MergeByInsId end";
} }
void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace, void MultiSlotDataset::GetRandomData(
std::vector<Record>* result) { const std::unordered_set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {
int debug_erase_cnt = 0; int debug_erase_cnt = 0;
int debug_push_cnt = 0; int debug_push_cnt = 0;
auto multi_slot_desc = data_feed_desc_.multi_slot_desc(); auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
slots_shuffle_rclist_.ReInit(); slots_shuffle_rclist_.ReInit();
for (const auto& rec : slots_shuffle_original_data_) { const auto& slots_shuffle_original_data = GetSlotsOriginalData();
for (const auto& rec : slots_shuffle_original_data) {
RecordCandidate rand_rec; RecordCandidate rand_rec;
Record new_rec = rec; Record new_rec = rec;
slots_shuffle_rclist_.AddAndGet(rec, &rand_rec); slots_shuffle_rclist_.AddAndGet(rec, &rand_rec);
...@@ -1161,7 +1163,7 @@ void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace, ...@@ -1161,7 +1163,7 @@ void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace,
} }
} }
for (auto slot : slots_to_replace) { for (auto slot : slots_to_replace) {
auto range = rand_rec.feas.equal_range(slot); auto range = rand_rec.feas_.equal_range(slot);
for (auto it = range.first; it != range.second; ++it) { for (auto it = range.first; it != range.second; ++it) {
new_rec.uint64_feasigns_.push_back({it->second, it->first}); new_rec.uint64_feasigns_.push_back({it->second, it->first});
debug_push_cnt += 1; debug_push_cnt += 1;
...@@ -1173,9 +1175,9 @@ void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace, ...@@ -1173,9 +1175,9 @@ void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace,
<< " repush feasign num: " << debug_push_cnt; << " repush feasign num: " << debug_push_cnt;
} }
// slots shuffle to input_channel_ with needed-shuffle slots void MultiSlotDataset::PreprocessChannel(
void MultiSlotDataset::SlotsShuffle( const std::set<std::string>& slots_to_replace,
const std::set<std::string>& slots_to_replace) { std::unordered_set<uint16_t>& index_slots) { // NOLINT
int out_channel_size = 0; int out_channel_size = 0;
if (cur_channel_ == 0) { if (cur_channel_ == 0) {
for (size_t i = 0; i < multi_output_channel_.size(); ++i) { for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
...@@ -1189,20 +1191,14 @@ void MultiSlotDataset::SlotsShuffle( ...@@ -1189,20 +1191,14 @@ void MultiSlotDataset::SlotsShuffle(
VLOG(2) << "DatasetImpl<T>::SlotsShuffle() begin with input channel size: " VLOG(2) << "DatasetImpl<T>::SlotsShuffle() begin with input channel size: "
<< input_channel_->Size() << input_channel_->Size()
<< " output channel size: " << out_channel_size; << " output channel size: " << out_channel_size;
if (!slots_shuffle_fea_eval_) {
VLOG(3) << "DatasetImpl<T>::SlotsShuffle() end,"
"fea eval mode off, need to set on for slots shuffle";
return;
}
if ((!input_channel_ || input_channel_->Size() == 0) && if ((!input_channel_ || input_channel_->Size() == 0) &&
slots_shuffle_original_data_.size() == 0 && out_channel_size == 0) { slots_shuffle_original_data_.size() == 0 && out_channel_size == 0) {
VLOG(3) << "DatasetImpl<T>::SlotsShuffle() end, no data to slots shuffle"; VLOG(3) << "DatasetImpl<T>::SlotsShuffle() end, no data to slots shuffle";
return; return;
} }
platform::Timer timeline;
timeline.Start();
auto multi_slot_desc = data_feed_desc_.multi_slot_desc(); auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
std::set<uint16_t> index_slots;
for (int i = 0; i < multi_slot_desc.slots_size(); ++i) { for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
std::string cur_slot = multi_slot_desc.slots(i).name(); std::string cur_slot = multi_slot_desc.slots(i).name();
if (slots_to_replace.find(cur_slot) != slots_to_replace.end()) { if (slots_to_replace.find(cur_slot) != slots_to_replace.end()) {
...@@ -1287,6 +1283,19 @@ void MultiSlotDataset::SlotsShuffle( ...@@ -1287,6 +1283,19 @@ void MultiSlotDataset::SlotsShuffle(
} }
CHECK(input_channel_->Size() == 0) CHECK(input_channel_->Size() == 0)
<< "input channel should be empty before slots shuffle"; << "input channel should be empty before slots shuffle";
}
// slots shuffle to input_channel_ with needed-shuffle slots
void MultiSlotDataset::SlotsShuffle(
const std::set<std::string>& slots_to_replace) {
PADDLE_ENFORCE_EQ(slots_shuffle_fea_eval_, true,
platform::errors::PreconditionNotMet(
"fea eval mode off, need to set on for slots shuffle"));
platform::Timer timeline;
timeline.Start();
std::unordered_set<uint16_t> index_slots;
PreprocessChannel(slots_to_replace, index_slots);
std::vector<Record> random_data; std::vector<Record> random_data;
random_data.clear(); random_data.clear();
// get slots shuffled random_data // get slots shuffled random_data
......
...@@ -67,6 +67,7 @@ class Dataset { ...@@ -67,6 +67,7 @@ class Dataset {
virtual void SetParseContent(bool parse_content) = 0; virtual void SetParseContent(bool parse_content) = 0;
virtual void SetParseLogKey(bool parse_logkey) = 0; virtual void SetParseLogKey(bool parse_logkey) = 0;
virtual void SetEnablePvMerge(bool enable_pv_merge) = 0; virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
virtual bool EnablePvMerge() = 0;
virtual void SetMergeBySid(bool is_merge) = 0; virtual void SetMergeBySid(bool is_merge) = 0;
// set merge by ins id // set merge by ins id
virtual void SetMergeByInsId(int merge_size) = 0; virtual void SetMergeByInsId(int merge_size) = 0;
...@@ -108,10 +109,7 @@ class Dataset { ...@@ -108,10 +109,7 @@ class Dataset {
virtual void LocalShuffle() = 0; virtual void LocalShuffle() = 0;
// global shuffle data // global shuffle data
virtual void GlobalShuffle(int thread_num = -1) = 0; virtual void GlobalShuffle(int thread_num = -1) = 0;
// for slots shuffle
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) = 0; virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) = 0;
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result) = 0;
// create readers // create readers
virtual void CreateReaders() = 0; virtual void CreateReaders() = 0;
// destroy readers // destroy readers
...@@ -183,6 +181,9 @@ class DatasetImpl : public Dataset { ...@@ -183,6 +181,9 @@ class DatasetImpl : public Dataset {
virtual int GetThreadNum() { return thread_num_; } virtual int GetThreadNum() { return thread_num_; }
virtual int GetTrainerNum() { return trainer_num_; } virtual int GetTrainerNum() { return trainer_num_; }
virtual Channel<T> GetInputChannel() { return input_channel_; } virtual Channel<T> GetInputChannel() { return input_channel_; }
virtual void SetInputChannel(const Channel<T>& input_channel) {
input_channel_ = input_channel;
}
virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; } virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
virtual std::pair<std::string, std::string> GetHdfsConfig() { virtual std::pair<std::string, std::string> GetHdfsConfig() {
return std::make_pair(fs_name_, fs_ugi_); return std::make_pair(fs_name_, fs_ugi_);
...@@ -192,6 +193,7 @@ class DatasetImpl : public Dataset { ...@@ -192,6 +193,7 @@ class DatasetImpl : public Dataset {
return data_feed_desc_; return data_feed_desc_;
} }
virtual int GetChannelNum() { return channel_num_; } virtual int GetChannelNum() { return channel_num_; }
virtual bool EnablePvMerge() { return enable_pv_merge_; }
virtual std::vector<paddle::framework::DataFeed*> GetReaders(); virtual std::vector<paddle::framework::DataFeed*> GetReaders();
virtual void CreateChannel(); virtual void CreateChannel();
virtual void RegisterClientToClientMsgHandler(); virtual void RegisterClientToClientMsgHandler();
...@@ -202,8 +204,9 @@ class DatasetImpl : public Dataset { ...@@ -202,8 +204,9 @@ class DatasetImpl : public Dataset {
virtual void LocalShuffle(); virtual void LocalShuffle();
virtual void GlobalShuffle(int thread_num = -1); virtual void GlobalShuffle(int thread_num = -1);
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) {} virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) {}
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace, virtual const std::vector<T>& GetSlotsOriginalData() {
std::vector<Record>* result) {} return slots_shuffle_original_data_;
}
virtual void CreateReaders(); virtual void CreateReaders();
virtual void DestroyReaders(); virtual void DestroyReaders();
virtual int64_t GetMemoryDataSize(); virtual int64_t GetMemoryDataSize();
...@@ -293,9 +296,13 @@ class MultiSlotDataset : public DatasetImpl<Record> { ...@@ -293,9 +296,13 @@ class MultiSlotDataset : public DatasetImpl<Record> {
} }
std::vector<std::unordered_set<uint64_t>>().swap(local_tables_); std::vector<std::unordered_set<uint64_t>>().swap(local_tables_);
} }
virtual void PreprocessChannel(
const std::set<std::string>& slots_to_replace,
std::unordered_set<uint16_t>& index_slot); // NOLINT
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace); virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace);
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace, virtual void GetRandomData(
std::vector<Record>* result); const std::unordered_set<uint16_t>& slots_to_replace,
std::vector<Record>* result);
virtual ~MultiSlotDataset() {} virtual ~MultiSlotDataset() {}
}; };
......
...@@ -255,6 +255,113 @@ void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place, ...@@ -255,6 +255,113 @@ void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place,
<< " s"; << " s";
VLOG(3) << "End PushSparseGrad"; VLOG(3) << "End PushSparseGrad";
} }
void BoxWrapper::GetRandomReplace(const std::vector<Record>& pass_data) {
VLOG(0) << "Begin GetRandomReplace";
size_t ins_num = pass_data.size();
replace_idx_.resize(ins_num);
for (auto& cand_list : random_ins_pool_list) {
cand_list.ReInitPass();
}
std::vector<std::thread> threads;
for (int tid = 0; tid < auc_runner_thread_num_; ++tid) {
threads.push_back(std::thread([this, &pass_data, tid, ins_num]() {
int start = tid * ins_num / auc_runner_thread_num_;
int end = (tid + 1) * ins_num / auc_runner_thread_num_;
VLOG(3) << "GetRandomReplace begin for thread[" << tid
<< "], and process [" << start << ", " << end
<< "), total ins: " << ins_num;
auto& random_pool = random_ins_pool_list[tid];
for (int i = start; i < end; ++i) {
const auto& ins = pass_data[i];
random_pool.AddAndGet(ins, replace_idx_[i]);
}
}));
}
for (int tid = 0; tid < auc_runner_thread_num_; ++tid) {
threads[tid].join();
}
pass_done_semi_->Put(1);
VLOG(0) << "End GetRandomReplace";
}
void BoxWrapper::GetRandomData(
const std::vector<Record>& pass_data,
const std::unordered_set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {
VLOG(0) << "Begin GetRandomData";
std::vector<std::thread> threads;
for (int tid = 0; tid < auc_runner_thread_num_; ++tid) {
threads.push_back(std::thread([this, &pass_data, tid, &slots_to_replace,
result]() {
int debug_erase_cnt = 0;
int debug_push_cnt = 0;
size_t ins_num = pass_data.size();
int start = tid * ins_num / auc_runner_thread_num_;
int end = (tid + 1) * ins_num / auc_runner_thread_num_;
VLOG(3) << "GetRandomData begin for thread[" << tid << "], and process ["
<< start << ", " << end << "), total ins: " << ins_num;
const auto& random_pool = random_ins_pool_list[tid];
for (int i = start; i < end; ++i) {
const auto& ins = pass_data[i];
const RecordCandidate& rand_rec = random_pool.Get(replace_idx_[i]);
Record new_rec = ins;
for (auto it = new_rec.uint64_feasigns_.begin();
it != new_rec.uint64_feasigns_.end();) {
if (slots_to_replace.find(it->slot()) != slots_to_replace.end()) {
it = new_rec.uint64_feasigns_.erase(it);
debug_erase_cnt += 1;
} else {
++it;
}
}
for (auto slot : slots_to_replace) {
auto range = rand_rec.feas_.equal_range(slot);
for (auto it = range.first; it != range.second; ++it) {
new_rec.uint64_feasigns_.push_back({it->second, it->first});
debug_push_cnt += 1;
}
}
(*result)[i] = std::move(new_rec);
}
VLOG(3) << "thread[" << tid << "]: erase feasign num: " << debug_erase_cnt
<< " repush feasign num: " << debug_push_cnt;
}));
}
for (int tid = 0; tid < auc_runner_thread_num_; ++tid) {
threads[tid].join();
}
VLOG(0) << "End GetRandomData";
}
void BoxWrapper::AddReplaceFeasign(boxps::PSAgentBase* p_agent,
int feed_pass_thread_num) {
VLOG(0) << "Enter AddReplaceFeasign Function";
int semi;
pass_done_semi_->Get(semi);
VLOG(0) << "Last Pass had updated random pool done. Begin AddReplaceFeasign";
std::vector<std::thread> threads;
for (int tid = 0; tid < feed_pass_thread_num; ++tid) {
threads.push_back(std::thread([this, tid, p_agent, feed_pass_thread_num]() {
VLOG(3) << "AddReplaceFeasign begin for thread[" << tid << "]";
for (size_t pool_id = tid; pool_id < random_ins_pool_list.size();
pool_id += feed_pass_thread_num) {
auto& random_pool = random_ins_pool_list[pool_id];
for (size_t i = 0; i < random_pool.Size(); ++i) {
auto& ins_candidate = random_pool.Get(i);
for (const auto& pair : ins_candidate.feas_) {
p_agent->AddKey(pair.second.uint64_feasign_, tid);
}
}
}
}));
}
for (int tid = 0; tid < feed_pass_thread_num; ++tid) {
threads[tid].join();
}
VLOG(0) << "End AddReplaceFeasign";
}
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
#endif #endif
...@@ -31,10 +31,12 @@ limitations under the License. */ ...@@ -31,10 +31,12 @@ limitations under the License. */
#include <map> #include <map>
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <set>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -469,16 +471,16 @@ class BoxWrapper { ...@@ -469,16 +471,16 @@ class BoxWrapper {
public: public:
MetricMsg() {} MetricMsg() {}
MetricMsg(const std::string& label_varname, const std::string& pred_varname, MetricMsg(const std::string& label_varname, const std::string& pred_varname,
int is_join, int bucket_size = 1000000) int metric_phase, int bucket_size = 1000000)
: label_varname_(label_varname), : label_varname_(label_varname),
pred_varname_(pred_varname), pred_varname_(pred_varname),
is_join_(is_join) { metric_phase_(metric_phase) {
calculator = new BasicAucCalculator(); calculator = new BasicAucCalculator();
calculator->init(bucket_size); calculator->init(bucket_size);
} }
virtual ~MetricMsg() {} virtual ~MetricMsg() {}
int IsJoin() const { return is_join_; } int MetricPhase() const { return metric_phase_; }
BasicAucCalculator* GetCalculator() { return calculator; } BasicAucCalculator* GetCalculator() { return calculator; }
virtual void add_data(const Scope* exe_scope) { virtual void add_data(const Scope* exe_scope) {
std::vector<int64_t> label_data; std::vector<int64_t> label_data;
...@@ -514,20 +516,20 @@ class BoxWrapper { ...@@ -514,20 +516,20 @@ class BoxWrapper {
protected: protected:
std::string label_varname_; std::string label_varname_;
std::string pred_varname_; std::string pred_varname_;
int is_join_; int metric_phase_;
BasicAucCalculator* calculator; BasicAucCalculator* calculator;
}; };
class MultiTaskMetricMsg : public MetricMsg { class MultiTaskMetricMsg : public MetricMsg {
public: public:
MultiTaskMetricMsg(const std::string& label_varname, MultiTaskMetricMsg(const std::string& label_varname,
const std::string& pred_varname_list, int is_join, const std::string& pred_varname_list, int metric_phase,
const std::string& cmatch_rank_group, const std::string& cmatch_rank_group,
const std::string& cmatch_rank_varname, const std::string& cmatch_rank_varname,
int bucket_size = 1000000) { int bucket_size = 1000000) {
label_varname_ = label_varname; label_varname_ = label_varname;
cmatch_rank_varname_ = cmatch_rank_varname; cmatch_rank_varname_ = cmatch_rank_varname;
is_join_ = is_join; metric_phase_ = metric_phase;
calculator = new BasicAucCalculator(); calculator = new BasicAucCalculator();
calculator->init(bucket_size); calculator->init(bucket_size);
for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) { for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) {
...@@ -594,14 +596,14 @@ class BoxWrapper { ...@@ -594,14 +596,14 @@ class BoxWrapper {
class CmatchRankMetricMsg : public MetricMsg { class CmatchRankMetricMsg : public MetricMsg {
public: public:
CmatchRankMetricMsg(const std::string& label_varname, CmatchRankMetricMsg(const std::string& label_varname,
const std::string& pred_varname, int is_join, const std::string& pred_varname, int metric_phase,
const std::string& cmatch_rank_group, const std::string& cmatch_rank_group,
const std::string& cmatch_rank_varname, const std::string& cmatch_rank_varname,
int bucket_size = 1000000) { int bucket_size = 1000000) {
label_varname_ = label_varname; label_varname_ = label_varname;
pred_varname_ = pred_varname; pred_varname_ = pred_varname;
cmatch_rank_varname_ = cmatch_rank_varname; cmatch_rank_varname_ = cmatch_rank_varname;
is_join_ = is_join; metric_phase_ = metric_phase;
calculator = new BasicAucCalculator(); calculator = new BasicAucCalculator();
calculator->init(bucket_size); calculator->init(bucket_size);
for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) { for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) {
...@@ -653,12 +655,12 @@ class BoxWrapper { ...@@ -653,12 +655,12 @@ class BoxWrapper {
class MaskMetricMsg : public MetricMsg { class MaskMetricMsg : public MetricMsg {
public: public:
MaskMetricMsg(const std::string& label_varname, MaskMetricMsg(const std::string& label_varname,
const std::string& pred_varname, int is_join, const std::string& pred_varname, int metric_phase,
const std::string& mask_varname, int bucket_size = 1000000) { const std::string& mask_varname, int bucket_size = 1000000) {
label_varname_ = label_varname; label_varname_ = label_varname;
pred_varname_ = pred_varname; pred_varname_ = pred_varname;
mask_varname_ = mask_varname; mask_varname_ = mask_varname;
is_join_ = is_join; metric_phase_ = metric_phase;
calculator = new BasicAucCalculator(); calculator = new BasicAucCalculator();
calculator->init(bucket_size); calculator->init(bucket_size);
} }
...@@ -682,36 +684,59 @@ class BoxWrapper { ...@@ -682,36 +684,59 @@ class BoxWrapper {
protected: protected:
std::string mask_varname_; std::string mask_varname_;
}; };
const std::vector<std::string>& GetMetricNameList() const { const std::vector<std::string> GetMetricNameList(
return metric_name_list_; int metric_phase = -1) const {
VLOG(0) << "Want to Get metric phase: " << metric_phase;
if (metric_phase == -1) {
return metric_name_list_;
} else {
std::vector<std::string> ret;
for (const auto& name : metric_name_list_) {
const auto iter = metric_lists_.find(name);
PADDLE_ENFORCE_NE(
iter, metric_lists_.end(),
platform::errors::InvalidArgument(
"The metric name you provided is not registered."));
if (iter->second->MetricPhase() == metric_phase) {
VLOG(0) << name << "'s phase is " << iter->second->MetricPhase()
<< ", we want";
ret.push_back(name);
} else {
VLOG(0) << name << "'s phase is " << iter->second->MetricPhase()
<< ", not we want";
}
}
return ret;
}
} }
int PassFlag() const { return pass_flag_; } int Phase() const { return phase_; }
void FlipPassFlag() { pass_flag_ = 1 - pass_flag_; } void FlipPhase() { phase_ = (phase_ + 1) % phase_num_; }
std::map<std::string, MetricMsg*>& GetMetricList() { return metric_lists_; } std::map<std::string, MetricMsg*>& GetMetricList() { return metric_lists_; }
void InitMetric(const std::string& method, const std::string& name, void InitMetric(const std::string& method, const std::string& name,
const std::string& label_varname, const std::string& label_varname,
const std::string& pred_varname, const std::string& pred_varname,
const std::string& cmatch_rank_varname, const std::string& cmatch_rank_varname,
const std::string& mask_varname, bool is_join, const std::string& mask_varname, int metric_phase,
const std::string& cmatch_rank_group, const std::string& cmatch_rank_group,
int bucket_size = 1000000) { int bucket_size = 1000000) {
if (method == "AucCalculator") { if (method == "AucCalculator") {
metric_lists_.emplace(name, new MetricMsg(label_varname, pred_varname, metric_lists_.emplace(name, new MetricMsg(label_varname, pred_varname,
is_join ? 1 : 0, bucket_size)); metric_phase, bucket_size));
} else if (method == "MultiTaskAucCalculator") { } else if (method == "MultiTaskAucCalculator") {
metric_lists_.emplace( metric_lists_.emplace(
name, new MultiTaskMetricMsg(label_varname, pred_varname, name, new MultiTaskMetricMsg(label_varname, pred_varname,
is_join ? 1 : 0, cmatch_rank_group, metric_phase, cmatch_rank_group,
cmatch_rank_varname, bucket_size)); cmatch_rank_varname, bucket_size));
} else if (method == "CmatchRankAucCalculator") { } else if (method == "CmatchRankAucCalculator") {
metric_lists_.emplace( metric_lists_.emplace(
name, new CmatchRankMetricMsg(label_varname, pred_varname, name, new CmatchRankMetricMsg(label_varname, pred_varname,
is_join ? 1 : 0, cmatch_rank_group, metric_phase, cmatch_rank_group,
cmatch_rank_varname, bucket_size)); cmatch_rank_varname, bucket_size));
} else if (method == "MaskAucCalculator") { } else if (method == "MaskAucCalculator") {
metric_lists_.emplace( metric_lists_.emplace(
name, new MaskMetricMsg(label_varname, pred_varname, is_join ? 1 : 0, name, new MaskMetricMsg(label_varname, pred_varname, metric_phase,
mask_varname, bucket_size)); mask_varname, bucket_size));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
...@@ -753,7 +778,8 @@ class BoxWrapper { ...@@ -753,7 +778,8 @@ class BoxWrapper {
std::unordered_set<std::string> slot_name_omited_in_feedpass_; std::unordered_set<std::string> slot_name_omited_in_feedpass_;
// Metric Related // Metric Related
int pass_flag_ = 1; // join: 1, update: 0 int phase_ = 1;
int phase_num_ = 2;
std::map<std::string, MetricMsg*> metric_lists_; std::map<std::string, MetricMsg*> metric_lists_;
std::vector<std::string> metric_name_list_; std::vector<std::string> metric_name_list_;
std::vector<int> slot_vector_; std::vector<int> slot_vector_;
...@@ -762,6 +788,57 @@ class BoxWrapper { ...@@ -762,6 +788,57 @@ class BoxWrapper {
public: public:
static AfsManager* afs_manager; static AfsManager* afs_manager;
// Auc Runner
public:
void InitializeAucRunner(std::vector<std::vector<std::string>> slot_eval,
int thread_num, int pool_size,
std::vector<std::string> slot_list) {
mode_ = 1;
phase_num_ = static_cast<int>(slot_eval.size());
phase_ = phase_num_ - 1;
auc_runner_thread_num_ = thread_num;
pass_done_semi_ = paddle::framework::MakeChannel<int>();
pass_done_semi_->Put(1); // Note: At most 1 pipeline in AucRunner
random_ins_pool_list.resize(thread_num);
std::unordered_set<std::string> slot_set;
for (size_t i = 0; i < slot_eval.size(); ++i) {
for (const auto& slot : slot_eval[i]) {
slot_set.insert(slot);
}
}
for (size_t i = 0; i < slot_list.size(); ++i) {
if (slot_set.find(slot_list[i]) != slot_set.end()) {
slot_index_to_replace_.insert(static_cast<int16_t>(i));
}
}
for (int i = 0; i < auc_runner_thread_num_; ++i) {
random_ins_pool_list[i].SetSlotIndexToReplace(slot_index_to_replace_);
}
VLOG(0) << "AucRunner configuration: thread number[" << thread_num
<< "], pool size[" << pool_size << "], runner_group[" << phase_num_
<< "]";
VLOG(0) << "Slots that need to be evaluated:";
for (auto e : slot_index_to_replace_) {
VLOG(0) << e << ": " << slot_list[e];
}
}
void GetRandomReplace(const std::vector<Record>& pass_data);
void AddReplaceFeasign(boxps::PSAgentBase* p_agent, int feed_pass_thread_num);
void GetRandomData(const std::vector<Record>& pass_data,
const std::unordered_set<uint16_t>& slots_to_replace,
std::vector<Record>* result);
int Mode() const { return mode_; }
private:
int mode_ = 0; // 0 means train/test 1 means auc_runner
int auc_runner_thread_num_ = 1;
bool init_done_ = false;
paddle::framework::Channel<int> pass_done_semi_;
std::unordered_set<uint16_t> slot_index_to_replace_;
std::vector<RecordCandidateList> random_ins_pool_list;
std::vector<size_t> replace_idx_;
}; };
#endif #endif
...@@ -810,7 +887,38 @@ class BoxHelper { ...@@ -810,7 +887,38 @@ class BoxHelper {
VLOG(3) << "After PreLoadIntoMemory()"; VLOG(3) << "After PreLoadIntoMemory()";
} }
void WaitFeedPassDone() { feed_data_thread_->join(); } void WaitFeedPassDone() { feed_data_thread_->join(); }
void SlotsShuffle(const std::set<std::string>& slots_to_replace) {
#ifdef PADDLE_WITH_BOX_PS
auto box_ptr = BoxWrapper::GetInstance();
PADDLE_ENFORCE_EQ(box_ptr->Mode(), 1,
platform::errors::PreconditionNotMet(
"Should call InitForAucRunner first."));
box_ptr->FlipPhase();
std::unordered_set<uint16_t> index_slots;
dynamic_cast<MultiSlotDataset*>(dataset_)->PreprocessChannel(
slots_to_replace, index_slots);
const std::vector<Record>& pass_data =
dynamic_cast<MultiSlotDataset*>(dataset_)->GetSlotsOriginalData();
if (!get_random_replace_done_) {
box_ptr->GetRandomReplace(pass_data);
get_random_replace_done_ = true;
}
std::vector<Record> random_data;
random_data.resize(pass_data.size());
box_ptr->GetRandomData(pass_data, index_slots, &random_data);
auto new_input_channel = paddle::framework::MakeChannel<Record>();
new_input_channel->Open();
new_input_channel->Write(std::move(random_data));
new_input_channel->Close();
dynamic_cast<MultiSlotDataset*>(dataset_)->SetInputChannel(
new_input_channel);
if (dataset_->EnablePvMerge()) {
dataset_->PreprocessInstance();
}
#endif
}
#ifdef PADDLE_WITH_BOX_PS #ifdef PADDLE_WITH_BOX_PS
// notify boxps to feed this pass feasigns from SSD to memory // notify boxps to feed this pass feasigns from SSD to memory
static void FeedPassThread(const std::deque<Record>& t, int begin_index, static void FeedPassThread(const std::deque<Record>& t, int begin_index,
...@@ -881,6 +989,10 @@ class BoxHelper { ...@@ -881,6 +989,10 @@ class BoxHelper {
for (size_t i = 0; i < tnum; ++i) { for (size_t i = 0; i < tnum; ++i) {
threads[i].join(); threads[i].join();
} }
if (box_ptr->Mode() == 1) {
box_ptr->AddReplaceFeasign(p_agent, tnum);
}
VLOG(3) << "Begin call EndFeedPass in BoxPS"; VLOG(3) << "Begin call EndFeedPass in BoxPS";
box_ptr->EndFeedPass(p_agent); box_ptr->EndFeedPass(p_agent);
#endif #endif
...@@ -892,6 +1004,7 @@ class BoxHelper { ...@@ -892,6 +1004,7 @@ class BoxHelper {
int year_; int year_;
int month_; int month_;
int day_; int day_;
bool get_random_replace_done_ = false;
}; };
} // end namespace framework } // end namespace framework
......
...@@ -211,7 +211,7 @@ void SectionWorker::TrainFiles() { ...@@ -211,7 +211,7 @@ void SectionWorker::TrainFiles() {
auto& metric_list = box_ptr->GetMetricList(); auto& metric_list = box_ptr->GetMetricList();
for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) { for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) {
auto* metric_msg = iter->second; auto* metric_msg = iter->second;
if (metric_msg->IsJoin() != box_ptr->PassFlag()) { if (box_ptr->Phase() != metric_msg->MetricPhase()) {
continue; continue;
} }
metric_msg->add_data(exe_scope); metric_msg->add_data(exe_scope);
...@@ -367,7 +367,7 @@ void SectionWorker::TrainFilesWithProfiler() { ...@@ -367,7 +367,7 @@ void SectionWorker::TrainFilesWithProfiler() {
auto& metric_list = box_ptr->GetMetricList(); auto& metric_list = box_ptr->GetMetricList();
for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) { for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) {
auto* metric_msg = iter->second; auto* metric_msg = iter->second;
if (metric_msg->IsJoin() != box_ptr->PassFlag()) { if (box_ptr->Phase() != metric_msg->MetricPhase()) {
continue; continue;
} }
metric_msg->add_data(exe_scope); metric_msg->add_data(exe_scope);
......
...@@ -54,6 +54,8 @@ void BindBoxHelper(py::module* m) { ...@@ -54,6 +54,8 @@ void BindBoxHelper(py::module* m) {
.def("preload_into_memory", &framework::BoxHelper::PreLoadIntoMemory, .def("preload_into_memory", &framework::BoxHelper::PreLoadIntoMemory,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("load_into_memory", &framework::BoxHelper::LoadIntoMemory, .def("load_into_memory", &framework::BoxHelper::LoadIntoMemory,
py::call_guard<py::gil_scoped_release>())
.def("slots_shuffle", &framework::BoxHelper::SlotsShuffle,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
} // end BoxHelper } // end BoxHelper
...@@ -76,13 +78,15 @@ void BindBoxWrapper(py::module* m) { ...@@ -76,13 +78,15 @@ void BindBoxWrapper(py::module* m) {
.def("initialize_gpu_and_load_model", .def("initialize_gpu_and_load_model",
&framework::BoxWrapper::InitializeGPUAndLoadModel, &framework::BoxWrapper::InitializeGPUAndLoadModel,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("initialize_auc_runner", &framework::BoxWrapper::InitializeAucRunner,
py::call_guard<py::gil_scoped_release>())
.def("init_metric", &framework::BoxWrapper::InitMetric, .def("init_metric", &framework::BoxWrapper::InitMetric,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("get_metric_msg", &framework::BoxWrapper::GetMetricMsg, .def("get_metric_msg", &framework::BoxWrapper::GetMetricMsg,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("get_metric_name_list", &framework::BoxWrapper::GetMetricNameList, .def("get_metric_name_list", &framework::BoxWrapper::GetMetricNameList,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("flip_pass_flag", &framework::BoxWrapper::FlipPassFlag, .def("flip_phase", &framework::BoxWrapper::FlipPhase,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("init_afs_api", &framework::BoxWrapper::InitAfsAPI, .def("init_afs_api", &framework::BoxWrapper::InitAfsAPI,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
......
...@@ -291,6 +291,8 @@ void BindDataset(py::module *m) { ...@@ -291,6 +291,8 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("set_fleet_send_sleep_seconds", .def("set_fleet_send_sleep_seconds",
&framework::Dataset::SetFleetSendSleepSeconds, &framework::Dataset::SetFleetSendSleepSeconds,
py::call_guard<py::gil_scoped_release>())
.def("enable_pv_merge", &framework::Dataset::EnablePvMerge,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
py::class_<IterableDatasetWrapper>(*m, "IterableDatasetWrapper") py::class_<IterableDatasetWrapper>(*m, "IterableDatasetWrapper")
......
...@@ -1079,3 +1079,24 @@ class BoxPSDataset(InMemoryDataset): ...@@ -1079,3 +1079,24 @@ class BoxPSDataset(InMemoryDataset):
def _dynamic_adjust_after_train(self): def _dynamic_adjust_after_train(self):
pass pass
def slots_shuffle(self, slots):
"""
Slots Shuffle
Slots Shuffle is a shuffle method in slots level, which is usually used
in sparse feature with large scale of instances. To compare the metric, i.e.
auc while doing slots shuffle on one or several slots with baseline to
evaluate the importance level of slots(features).
Args:
slots(list[string]): the set of slots(string) to do slots shuffle.
Examples:
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_merge_by_lineid()
#suppose there is a slot 0
dataset.slots_shuffle(['0'])
"""
slots_set = set(slots)
self.boxps.slots_shuffle(slots_set)
...@@ -172,6 +172,7 @@ class TestBoxPSPreload(unittest.TestCase): ...@@ -172,6 +172,7 @@ class TestBoxPSPreload(unittest.TestCase):
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
datasets[0].load_into_memory() datasets[0].load_into_memory()
datasets[0].begin_pass() datasets[0].begin_pass()
datasets[0].slots_shuffle([])
datasets[1].preload_into_memory() datasets[1].preload_into_memory()
exe.train_from_dataset( exe.train_from_dataset(
program=fluid.default_main_program(), program=fluid.default_main_program(),
......
...@@ -125,6 +125,7 @@ class TestDataset(unittest.TestCase): ...@@ -125,6 +125,7 @@ class TestDataset(unittest.TestCase):
dataset.set_trainer_num(4) dataset.set_trainer_num(4)
dataset.set_hdfs_config("my_fs_name", "my_fs_ugi") dataset.set_hdfs_config("my_fs_name", "my_fs_ugi")
dataset.set_download_cmd("./read_from_afs my_fs_name my_fs_ugi") dataset.set_download_cmd("./read_from_afs my_fs_name my_fs_ugi")
dataset.enable_pv_merge()
thread_num = dataset.get_thread_num() thread_num = dataset.get_thread_num()
self.assertEqual(thread_num, 12) self.assertEqual(thread_num, 12)
...@@ -231,7 +232,7 @@ class TestDataset(unittest.TestCase): ...@@ -231,7 +232,7 @@ class TestDataset(unittest.TestCase):
dataset.set_pipe_command("cat") dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars) dataset.set_use_var(slots_vars)
dataset.load_into_memory() dataset.load_into_memory()
dataset.set_fea_eval(10000, True) dataset.set_fea_eval(1, True)
dataset.slots_shuffle(["slot1"]) dataset.slots_shuffle(["slot1"])
dataset.local_shuffle() dataset.local_shuffle()
dataset.set_generate_unique_feasigns(True, 15) dataset.set_generate_unique_feasigns(True, 15)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册