未验证 提交 e6b87b31 编写于 作者: H hutuxian 提交者: GitHub

Support AucRunner in PaddleBox (#22884)

* Support AucRunner in PaddleBox
* update some code style
上级 c417f991
......@@ -41,44 +41,44 @@ namespace paddle {
namespace framework {
void RecordCandidateList::ReSize(size_t length) {
_mutex.lock();
_capacity = length;
CHECK(_capacity > 0); // NOLINT
_candidate_list.clear();
_candidate_list.resize(_capacity);
_full = false;
_cur_size = 0;
_total_size = 0;
_mutex.unlock();
mutex_.lock();
capacity_ = length;
CHECK(capacity_ > 0); // NOLINT
candidate_list_.clear();
candidate_list_.resize(capacity_);
full_ = false;
cur_size_ = 0;
total_size_ = 0;
mutex_.unlock();
}
void RecordCandidateList::ReInit() {
_mutex.lock();
_full = false;
_cur_size = 0;
_total_size = 0;
_mutex.unlock();
mutex_.lock();
full_ = false;
cur_size_ = 0;
total_size_ = 0;
mutex_.unlock();
}
void RecordCandidateList::AddAndGet(const Record& record,
RecordCandidate* result) {
_mutex.lock();
mutex_.lock();
size_t index = 0;
++_total_size;
++total_size_;
auto fleet_ptr = FleetWrapper::GetInstance();
if (!_full) {
_candidate_list[_cur_size++] = record;
_full = (_cur_size == _capacity);
if (!full_) {
candidate_list_[cur_size_++] = record;
full_ = (cur_size_ == capacity_);
} else {
CHECK(_cur_size == _capacity);
index = fleet_ptr->LocalRandomEngine()() % _total_size;
if (index < _capacity) {
_candidate_list[index] = record;
CHECK(cur_size_ == capacity_);
index = fleet_ptr->LocalRandomEngine()() % total_size_;
if (index < capacity_) {
candidate_list_[index] = record;
}
}
index = fleet_ptr->LocalRandomEngine()() % _cur_size;
*result = _candidate_list[index];
_mutex.unlock();
index = fleet_ptr->LocalRandomEngine()() % cur_size_;
*result = candidate_list_[index];
mutex_.unlock();
}
void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
......@@ -1452,7 +1452,11 @@ void PaddleBoxDataFeed::PutToFeedVec(const std::vector<PvInstance>& pv_vec) {
int PaddleBoxDataFeed::GetCurrentPhase() {
#ifdef PADDLE_WITH_BOX_PS
auto box_ptr = paddle::framework::BoxWrapper::GetInstance();
return box_ptr->PassFlag(); // join: 1, update: 0
if (box_ptr->Mode() == 1) { // For AucRunner
return 1;
} else {
return box_ptr->Phase();
}
#else
LOG(WARNING) << "It should be complied with BOX_PS...";
return current_phase_;
......
......@@ -27,6 +27,7 @@ limitations under the License. */
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
......@@ -34,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable.h"
......@@ -484,13 +486,25 @@ paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
struct RecordCandidate {
std::string ins_id_;
std::unordered_multimap<uint16_t, FeatureKey> feas;
std::unordered_multimap<uint16_t, FeatureKey> feas_;
size_t shadow_index_ = -1; // Optimization for Reservoir Sample
RecordCandidate() {}
RecordCandidate(const Record& rec,
const std::unordered_set<uint16_t>& slot_index_to_replace) {
for (const auto& fea : rec.uint64_feasigns_) {
if (slot_index_to_replace.find(fea.slot()) !=
slot_index_to_replace.end()) {
feas_.insert({fea.slot(), fea.sign()});
}
}
}
RecordCandidate& operator=(const Record& rec) {
feas.clear();
feas_.clear();
ins_id_ = rec.ins_id_;
for (auto& fea : rec.uint64_feasigns_) {
feas.insert({fea.slot(), fea.sign()});
feas_.insert({fea.slot(), fea.sign()});
}
return *this;
}
......@@ -499,22 +513,67 @@ struct RecordCandidate {
class RecordCandidateList {
public:
RecordCandidateList() = default;
RecordCandidateList(const RecordCandidateList&) = delete;
RecordCandidateList& operator=(const RecordCandidateList&) = delete;
RecordCandidateList(const RecordCandidateList&) {}
size_t Size() { return cur_size_; }
void ReSize(size_t length);
void ReInit();
void ReInitPass() {
for (size_t i = 0; i < cur_size_; ++i) {
if (candidate_list_[i].shadow_index_ != i) {
candidate_list_[i].ins_id_ =
candidate_list_[candidate_list_[i].shadow_index_].ins_id_;
candidate_list_[i].feas_.swap(
candidate_list_[candidate_list_[i].shadow_index_].feas_);
candidate_list_[i].shadow_index_ = i;
}
}
candidate_list_.resize(cur_size_);
}
void AddAndGet(const Record& record, RecordCandidate* result);
void AddAndGet(const Record& record, size_t& index_result) { // NOLINT
// std::unique_lock<std::mutex> lock(mutex_);
size_t index = 0;
++total_size_;
auto fleet_ptr = FleetWrapper::GetInstance();
if (!full_) {
candidate_list_.emplace_back(record, slot_index_to_replace_);
candidate_list_.back().shadow_index_ = cur_size_;
++cur_size_;
full_ = (cur_size_ == capacity_);
} else {
index = fleet_ptr->LocalRandomEngine()() % total_size_;
if (index < capacity_) {
candidate_list_.emplace_back(record, slot_index_to_replace_);
candidate_list_[index].shadow_index_ = candidate_list_.size() - 1;
}
}
index = fleet_ptr->LocalRandomEngine()() % cur_size_;
index_result = candidate_list_[index].shadow_index_;
}
const RecordCandidate& Get(size_t index) const {
PADDLE_ENFORCE_LT(
index, candidate_list_.size(),
platform::errors::OutOfRange("Your index [%lu] exceeds the number of "
"elements in candidate_list[%lu].",
index, candidate_list_.size()));
return candidate_list_[index];
}
void SetSlotIndexToReplace(
const std::unordered_set<uint16_t>& slot_index_to_replace) {
slot_index_to_replace_ = slot_index_to_replace;
}
private:
size_t _capacity = 0;
std::mutex _mutex;
bool _full = false;
size_t _cur_size = 0;
size_t _total_size = 0;
std::vector<RecordCandidate> _candidate_list;
size_t capacity_ = 0;
std::mutex mutex_;
bool full_ = false;
size_t cur_size_ = 0;
size_t total_size_ = 0;
std::vector<RecordCandidate> candidate_list_;
std::unordered_set<uint16_t> slot_index_to_replace_;
};
template <class AR>
......
......@@ -1141,13 +1141,15 @@ void MultiSlotDataset::MergeByInsId() {
VLOG(3) << "MultiSlotDataset::MergeByInsId end";
}
void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace,
void MultiSlotDataset::GetRandomData(
const std::unordered_set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {
int debug_erase_cnt = 0;
int debug_push_cnt = 0;
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
slots_shuffle_rclist_.ReInit();
for (const auto& rec : slots_shuffle_original_data_) {
const auto& slots_shuffle_original_data = GetSlotsOriginalData();
for (const auto& rec : slots_shuffle_original_data) {
RecordCandidate rand_rec;
Record new_rec = rec;
slots_shuffle_rclist_.AddAndGet(rec, &rand_rec);
......@@ -1161,7 +1163,7 @@ void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace,
}
}
for (auto slot : slots_to_replace) {
auto range = rand_rec.feas.equal_range(slot);
auto range = rand_rec.feas_.equal_range(slot);
for (auto it = range.first; it != range.second; ++it) {
new_rec.uint64_feasigns_.push_back({it->second, it->first});
debug_push_cnt += 1;
......@@ -1173,9 +1175,9 @@ void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace,
<< " repush feasign num: " << debug_push_cnt;
}
// slots shuffle to input_channel_ with needed-shuffle slots
void MultiSlotDataset::SlotsShuffle(
const std::set<std::string>& slots_to_replace) {
void MultiSlotDataset::PreprocessChannel(
const std::set<std::string>& slots_to_replace,
std::unordered_set<uint16_t>& index_slots) { // NOLINT
int out_channel_size = 0;
if (cur_channel_ == 0) {
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
......@@ -1189,20 +1191,14 @@ void MultiSlotDataset::SlotsShuffle(
VLOG(2) << "DatasetImpl<T>::SlotsShuffle() begin with input channel size: "
<< input_channel_->Size()
<< " output channel size: " << out_channel_size;
if (!slots_shuffle_fea_eval_) {
VLOG(3) << "DatasetImpl<T>::SlotsShuffle() end,"
"fea eval mode off, need to set on for slots shuffle";
return;
}
if ((!input_channel_ || input_channel_->Size() == 0) &&
slots_shuffle_original_data_.size() == 0 && out_channel_size == 0) {
VLOG(3) << "DatasetImpl<T>::SlotsShuffle() end, no data to slots shuffle";
return;
}
platform::Timer timeline;
timeline.Start();
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
std::set<uint16_t> index_slots;
for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
std::string cur_slot = multi_slot_desc.slots(i).name();
if (slots_to_replace.find(cur_slot) != slots_to_replace.end()) {
......@@ -1287,6 +1283,19 @@ void MultiSlotDataset::SlotsShuffle(
}
CHECK(input_channel_->Size() == 0)
<< "input channel should be empty before slots shuffle";
}
// slots shuffle to input_channel_ with needed-shuffle slots
void MultiSlotDataset::SlotsShuffle(
const std::set<std::string>& slots_to_replace) {
PADDLE_ENFORCE_EQ(slots_shuffle_fea_eval_, true,
platform::errors::PreconditionNotMet(
"fea eval mode off, need to set on for slots shuffle"));
platform::Timer timeline;
timeline.Start();
std::unordered_set<uint16_t> index_slots;
PreprocessChannel(slots_to_replace, index_slots);
std::vector<Record> random_data;
random_data.clear();
// get slots shuffled random_data
......
......@@ -67,6 +67,7 @@ class Dataset {
virtual void SetParseContent(bool parse_content) = 0;
virtual void SetParseLogKey(bool parse_logkey) = 0;
virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
virtual bool EnablePvMerge() = 0;
virtual void SetMergeBySid(bool is_merge) = 0;
// set merge by ins id
virtual void SetMergeByInsId(int merge_size) = 0;
......@@ -108,10 +109,7 @@ class Dataset {
virtual void LocalShuffle() = 0;
// global shuffle data
virtual void GlobalShuffle(int thread_num = -1) = 0;
// for slots shuffle
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) = 0;
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result) = 0;
// create readers
virtual void CreateReaders() = 0;
// destroy readers
......@@ -183,6 +181,9 @@ class DatasetImpl : public Dataset {
virtual int GetThreadNum() { return thread_num_; }
virtual int GetTrainerNum() { return trainer_num_; }
virtual Channel<T> GetInputChannel() { return input_channel_; }
virtual void SetInputChannel(const Channel<T>& input_channel) {
input_channel_ = input_channel;
}
virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
virtual std::pair<std::string, std::string> GetHdfsConfig() {
return std::make_pair(fs_name_, fs_ugi_);
......@@ -192,6 +193,7 @@ class DatasetImpl : public Dataset {
return data_feed_desc_;
}
virtual int GetChannelNum() { return channel_num_; }
virtual bool EnablePvMerge() { return enable_pv_merge_; }
virtual std::vector<paddle::framework::DataFeed*> GetReaders();
virtual void CreateChannel();
virtual void RegisterClientToClientMsgHandler();
......@@ -202,8 +204,9 @@ class DatasetImpl : public Dataset {
virtual void LocalShuffle();
virtual void GlobalShuffle(int thread_num = -1);
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) {}
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {}
virtual const std::vector<T>& GetSlotsOriginalData() {
return slots_shuffle_original_data_;
}
virtual void CreateReaders();
virtual void DestroyReaders();
virtual int64_t GetMemoryDataSize();
......@@ -293,8 +296,12 @@ class MultiSlotDataset : public DatasetImpl<Record> {
}
std::vector<std::unordered_set<uint64_t>>().swap(local_tables_);
}
virtual void PreprocessChannel(
const std::set<std::string>& slots_to_replace,
std::unordered_set<uint16_t>& index_slot); // NOLINT
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace);
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
virtual void GetRandomData(
const std::unordered_set<uint16_t>& slots_to_replace,
std::vector<Record>* result);
virtual ~MultiSlotDataset() {}
};
......
......@@ -255,6 +255,113 @@ void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place,
<< " s";
VLOG(3) << "End PushSparseGrad";
}
void BoxWrapper::GetRandomReplace(const std::vector<Record>& pass_data) {
VLOG(0) << "Begin GetRandomReplace";
size_t ins_num = pass_data.size();
replace_idx_.resize(ins_num);
for (auto& cand_list : random_ins_pool_list) {
cand_list.ReInitPass();
}
std::vector<std::thread> threads;
for (int tid = 0; tid < auc_runner_thread_num_; ++tid) {
threads.push_back(std::thread([this, &pass_data, tid, ins_num]() {
int start = tid * ins_num / auc_runner_thread_num_;
int end = (tid + 1) * ins_num / auc_runner_thread_num_;
VLOG(3) << "GetRandomReplace begin for thread[" << tid
<< "], and process [" << start << ", " << end
<< "), total ins: " << ins_num;
auto& random_pool = random_ins_pool_list[tid];
for (int i = start; i < end; ++i) {
const auto& ins = pass_data[i];
random_pool.AddAndGet(ins, replace_idx_[i]);
}
}));
}
for (int tid = 0; tid < auc_runner_thread_num_; ++tid) {
threads[tid].join();
}
pass_done_semi_->Put(1);
VLOG(0) << "End GetRandomReplace";
}
void BoxWrapper::GetRandomData(
const std::vector<Record>& pass_data,
const std::unordered_set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {
VLOG(0) << "Begin GetRandomData";
std::vector<std::thread> threads;
for (int tid = 0; tid < auc_runner_thread_num_; ++tid) {
threads.push_back(std::thread([this, &pass_data, tid, &slots_to_replace,
result]() {
int debug_erase_cnt = 0;
int debug_push_cnt = 0;
size_t ins_num = pass_data.size();
int start = tid * ins_num / auc_runner_thread_num_;
int end = (tid + 1) * ins_num / auc_runner_thread_num_;
VLOG(3) << "GetRandomData begin for thread[" << tid << "], and process ["
<< start << ", " << end << "), total ins: " << ins_num;
const auto& random_pool = random_ins_pool_list[tid];
for (int i = start; i < end; ++i) {
const auto& ins = pass_data[i];
const RecordCandidate& rand_rec = random_pool.Get(replace_idx_[i]);
Record new_rec = ins;
for (auto it = new_rec.uint64_feasigns_.begin();
it != new_rec.uint64_feasigns_.end();) {
if (slots_to_replace.find(it->slot()) != slots_to_replace.end()) {
it = new_rec.uint64_feasigns_.erase(it);
debug_erase_cnt += 1;
} else {
++it;
}
}
for (auto slot : slots_to_replace) {
auto range = rand_rec.feas_.equal_range(slot);
for (auto it = range.first; it != range.second; ++it) {
new_rec.uint64_feasigns_.push_back({it->second, it->first});
debug_push_cnt += 1;
}
}
(*result)[i] = std::move(new_rec);
}
VLOG(3) << "thread[" << tid << "]: erase feasign num: " << debug_erase_cnt
<< " repush feasign num: " << debug_push_cnt;
}));
}
for (int tid = 0; tid < auc_runner_thread_num_; ++tid) {
threads[tid].join();
}
VLOG(0) << "End GetRandomData";
}
void BoxWrapper::AddReplaceFeasign(boxps::PSAgentBase* p_agent,
int feed_pass_thread_num) {
VLOG(0) << "Enter AddReplaceFeasign Function";
int semi;
pass_done_semi_->Get(semi);
VLOG(0) << "Last Pass had updated random pool done. Begin AddReplaceFeasign";
std::vector<std::thread> threads;
for (int tid = 0; tid < feed_pass_thread_num; ++tid) {
threads.push_back(std::thread([this, tid, p_agent, feed_pass_thread_num]() {
VLOG(3) << "AddReplaceFeasign begin for thread[" << tid << "]";
for (size_t pool_id = tid; pool_id < random_ins_pool_list.size();
pool_id += feed_pass_thread_num) {
auto& random_pool = random_ins_pool_list[pool_id];
for (size_t i = 0; i < random_pool.Size(); ++i) {
auto& ins_candidate = random_pool.Get(i);
for (const auto& pair : ins_candidate.feas_) {
p_agent->AddKey(pair.second.uint64_feasign_, tid);
}
}
}
}));
}
for (int tid = 0; tid < feed_pass_thread_num; ++tid) {
threads[tid].join();
}
VLOG(0) << "End AddReplaceFeasign";
}
} // end namespace framework
} // end namespace paddle
#endif
......@@ -31,10 +31,12 @@ limitations under the License. */
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <set>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
......@@ -469,16 +471,16 @@ class BoxWrapper {
public:
MetricMsg() {}
MetricMsg(const std::string& label_varname, const std::string& pred_varname,
int is_join, int bucket_size = 1000000)
int metric_phase, int bucket_size = 1000000)
: label_varname_(label_varname),
pred_varname_(pred_varname),
is_join_(is_join) {
metric_phase_(metric_phase) {
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
}
virtual ~MetricMsg() {}
int IsJoin() const { return is_join_; }
int MetricPhase() const { return metric_phase_; }
BasicAucCalculator* GetCalculator() { return calculator; }
virtual void add_data(const Scope* exe_scope) {
std::vector<int64_t> label_data;
......@@ -514,20 +516,20 @@ class BoxWrapper {
protected:
std::string label_varname_;
std::string pred_varname_;
int is_join_;
int metric_phase_;
BasicAucCalculator* calculator;
};
class MultiTaskMetricMsg : public MetricMsg {
public:
MultiTaskMetricMsg(const std::string& label_varname,
const std::string& pred_varname_list, int is_join,
const std::string& pred_varname_list, int metric_phase,
const std::string& cmatch_rank_group,
const std::string& cmatch_rank_varname,
int bucket_size = 1000000) {
label_varname_ = label_varname;
cmatch_rank_varname_ = cmatch_rank_varname;
is_join_ = is_join;
metric_phase_ = metric_phase;
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) {
......@@ -594,14 +596,14 @@ class BoxWrapper {
class CmatchRankMetricMsg : public MetricMsg {
public:
CmatchRankMetricMsg(const std::string& label_varname,
const std::string& pred_varname, int is_join,
const std::string& pred_varname, int metric_phase,
const std::string& cmatch_rank_group,
const std::string& cmatch_rank_varname,
int bucket_size = 1000000) {
label_varname_ = label_varname;
pred_varname_ = pred_varname;
cmatch_rank_varname_ = cmatch_rank_varname;
is_join_ = is_join;
metric_phase_ = metric_phase;
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) {
......@@ -653,12 +655,12 @@ class BoxWrapper {
class MaskMetricMsg : public MetricMsg {
public:
MaskMetricMsg(const std::string& label_varname,
const std::string& pred_varname, int is_join,
const std::string& pred_varname, int metric_phase,
const std::string& mask_varname, int bucket_size = 1000000) {
label_varname_ = label_varname;
pred_varname_ = pred_varname;
mask_varname_ = mask_varname;
is_join_ = is_join;
metric_phase_ = metric_phase;
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
}
......@@ -682,36 +684,59 @@ class BoxWrapper {
protected:
std::string mask_varname_;
};
const std::vector<std::string>& GetMetricNameList() const {
const std::vector<std::string> GetMetricNameList(
int metric_phase = -1) const {
VLOG(0) << "Want to Get metric phase: " << metric_phase;
if (metric_phase == -1) {
return metric_name_list_;
} else {
std::vector<std::string> ret;
for (const auto& name : metric_name_list_) {
const auto iter = metric_lists_.find(name);
PADDLE_ENFORCE_NE(
iter, metric_lists_.end(),
platform::errors::InvalidArgument(
"The metric name you provided is not registered."));
if (iter->second->MetricPhase() == metric_phase) {
VLOG(0) << name << "'s phase is " << iter->second->MetricPhase()
<< ", we want";
ret.push_back(name);
} else {
VLOG(0) << name << "'s phase is " << iter->second->MetricPhase()
<< ", not we want";
}
int PassFlag() const { return pass_flag_; }
void FlipPassFlag() { pass_flag_ = 1 - pass_flag_; }
}
return ret;
}
}
int Phase() const { return phase_; }
void FlipPhase() { phase_ = (phase_ + 1) % phase_num_; }
std::map<std::string, MetricMsg*>& GetMetricList() { return metric_lists_; }
void InitMetric(const std::string& method, const std::string& name,
const std::string& label_varname,
const std::string& pred_varname,
const std::string& cmatch_rank_varname,
const std::string& mask_varname, bool is_join,
const std::string& mask_varname, int metric_phase,
const std::string& cmatch_rank_group,
int bucket_size = 1000000) {
if (method == "AucCalculator") {
metric_lists_.emplace(name, new MetricMsg(label_varname, pred_varname,
is_join ? 1 : 0, bucket_size));
metric_phase, bucket_size));
} else if (method == "MultiTaskAucCalculator") {
metric_lists_.emplace(
name, new MultiTaskMetricMsg(label_varname, pred_varname,
is_join ? 1 : 0, cmatch_rank_group,
metric_phase, cmatch_rank_group,
cmatch_rank_varname, bucket_size));
} else if (method == "CmatchRankAucCalculator") {
metric_lists_.emplace(
name, new CmatchRankMetricMsg(label_varname, pred_varname,
is_join ? 1 : 0, cmatch_rank_group,
metric_phase, cmatch_rank_group,
cmatch_rank_varname, bucket_size));
} else if (method == "MaskAucCalculator") {
metric_lists_.emplace(
name, new MaskMetricMsg(label_varname, pred_varname, is_join ? 1 : 0,
name, new MaskMetricMsg(label_varname, pred_varname, metric_phase,
mask_varname, bucket_size));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
......@@ -753,7 +778,8 @@ class BoxWrapper {
std::unordered_set<std::string> slot_name_omited_in_feedpass_;
// Metric Related
int pass_flag_ = 1; // join: 1, update: 0
int phase_ = 1;
int phase_num_ = 2;
std::map<std::string, MetricMsg*> metric_lists_;
std::vector<std::string> metric_name_list_;
std::vector<int> slot_vector_;
......@@ -762,6 +788,57 @@ class BoxWrapper {
public:
static AfsManager* afs_manager;
// Auc Runner
public:
void InitializeAucRunner(std::vector<std::vector<std::string>> slot_eval,
int thread_num, int pool_size,
std::vector<std::string> slot_list) {
mode_ = 1;
phase_num_ = static_cast<int>(slot_eval.size());
phase_ = phase_num_ - 1;
auc_runner_thread_num_ = thread_num;
pass_done_semi_ = paddle::framework::MakeChannel<int>();
pass_done_semi_->Put(1); // Note: At most 1 pipeline in AucRunner
random_ins_pool_list.resize(thread_num);
std::unordered_set<std::string> slot_set;
for (size_t i = 0; i < slot_eval.size(); ++i) {
for (const auto& slot : slot_eval[i]) {
slot_set.insert(slot);
}
}
for (size_t i = 0; i < slot_list.size(); ++i) {
if (slot_set.find(slot_list[i]) != slot_set.end()) {
slot_index_to_replace_.insert(static_cast<int16_t>(i));
}
}
for (int i = 0; i < auc_runner_thread_num_; ++i) {
random_ins_pool_list[i].SetSlotIndexToReplace(slot_index_to_replace_);
}
VLOG(0) << "AucRunner configuration: thread number[" << thread_num
<< "], pool size[" << pool_size << "], runner_group[" << phase_num_
<< "]";
VLOG(0) << "Slots that need to be evaluated:";
for (auto e : slot_index_to_replace_) {
VLOG(0) << e << ": " << slot_list[e];
}
}
void GetRandomReplace(const std::vector<Record>& pass_data);
void AddReplaceFeasign(boxps::PSAgentBase* p_agent, int feed_pass_thread_num);
void GetRandomData(const std::vector<Record>& pass_data,
const std::unordered_set<uint16_t>& slots_to_replace,
std::vector<Record>* result);
int Mode() const { return mode_; }
private:
int mode_ = 0; // 0 means train/test 1 means auc_runner
int auc_runner_thread_num_ = 1;
bool init_done_ = false;
paddle::framework::Channel<int> pass_done_semi_;
std::unordered_set<uint16_t> slot_index_to_replace_;
std::vector<RecordCandidateList> random_ins_pool_list;
std::vector<size_t> replace_idx_;
};
#endif
......@@ -810,7 +887,38 @@ class BoxHelper {
VLOG(3) << "After PreLoadIntoMemory()";
}
void WaitFeedPassDone() { feed_data_thread_->join(); }
void SlotsShuffle(const std::set<std::string>& slots_to_replace) {
#ifdef PADDLE_WITH_BOX_PS
auto box_ptr = BoxWrapper::GetInstance();
PADDLE_ENFORCE_EQ(box_ptr->Mode(), 1,
platform::errors::PreconditionNotMet(
"Should call InitForAucRunner first."));
box_ptr->FlipPhase();
std::unordered_set<uint16_t> index_slots;
dynamic_cast<MultiSlotDataset*>(dataset_)->PreprocessChannel(
slots_to_replace, index_slots);
const std::vector<Record>& pass_data =
dynamic_cast<MultiSlotDataset*>(dataset_)->GetSlotsOriginalData();
if (!get_random_replace_done_) {
box_ptr->GetRandomReplace(pass_data);
get_random_replace_done_ = true;
}
std::vector<Record> random_data;
random_data.resize(pass_data.size());
box_ptr->GetRandomData(pass_data, index_slots, &random_data);
auto new_input_channel = paddle::framework::MakeChannel<Record>();
new_input_channel->Open();
new_input_channel->Write(std::move(random_data));
new_input_channel->Close();
dynamic_cast<MultiSlotDataset*>(dataset_)->SetInputChannel(
new_input_channel);
if (dataset_->EnablePvMerge()) {
dataset_->PreprocessInstance();
}
#endif
}
#ifdef PADDLE_WITH_BOX_PS
// notify boxps to feed this pass feasigns from SSD to memory
static void FeedPassThread(const std::deque<Record>& t, int begin_index,
......@@ -881,6 +989,10 @@ class BoxHelper {
for (size_t i = 0; i < tnum; ++i) {
threads[i].join();
}
if (box_ptr->Mode() == 1) {
box_ptr->AddReplaceFeasign(p_agent, tnum);
}
VLOG(3) << "Begin call EndFeedPass in BoxPS";
box_ptr->EndFeedPass(p_agent);
#endif
......@@ -892,6 +1004,7 @@ class BoxHelper {
int year_;
int month_;
int day_;
bool get_random_replace_done_ = false;
};
} // end namespace framework
......
......@@ -211,7 +211,7 @@ void SectionWorker::TrainFiles() {
auto& metric_list = box_ptr->GetMetricList();
for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) {
auto* metric_msg = iter->second;
if (metric_msg->IsJoin() != box_ptr->PassFlag()) {
if (box_ptr->Phase() != metric_msg->MetricPhase()) {
continue;
}
metric_msg->add_data(exe_scope);
......@@ -367,7 +367,7 @@ void SectionWorker::TrainFilesWithProfiler() {
auto& metric_list = box_ptr->GetMetricList();
for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) {
auto* metric_msg = iter->second;
if (metric_msg->IsJoin() != box_ptr->PassFlag()) {
if (box_ptr->Phase() != metric_msg->MetricPhase()) {
continue;
}
metric_msg->add_data(exe_scope);
......
......@@ -54,6 +54,8 @@ void BindBoxHelper(py::module* m) {
.def("preload_into_memory", &framework::BoxHelper::PreLoadIntoMemory,
py::call_guard<py::gil_scoped_release>())
.def("load_into_memory", &framework::BoxHelper::LoadIntoMemory,
py::call_guard<py::gil_scoped_release>())
.def("slots_shuffle", &framework::BoxHelper::SlotsShuffle,
py::call_guard<py::gil_scoped_release>());
} // end BoxHelper
......@@ -76,13 +78,15 @@ void BindBoxWrapper(py::module* m) {
.def("initialize_gpu_and_load_model",
&framework::BoxWrapper::InitializeGPUAndLoadModel,
py::call_guard<py::gil_scoped_release>())
.def("initialize_auc_runner", &framework::BoxWrapper::InitializeAucRunner,
py::call_guard<py::gil_scoped_release>())
.def("init_metric", &framework::BoxWrapper::InitMetric,
py::call_guard<py::gil_scoped_release>())
.def("get_metric_msg", &framework::BoxWrapper::GetMetricMsg,
py::call_guard<py::gil_scoped_release>())
.def("get_metric_name_list", &framework::BoxWrapper::GetMetricNameList,
py::call_guard<py::gil_scoped_release>())
.def("flip_pass_flag", &framework::BoxWrapper::FlipPassFlag,
.def("flip_phase", &framework::BoxWrapper::FlipPhase,
py::call_guard<py::gil_scoped_release>())
.def("init_afs_api", &framework::BoxWrapper::InitAfsAPI,
py::call_guard<py::gil_scoped_release>())
......
......@@ -291,6 +291,8 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def("set_fleet_send_sleep_seconds",
&framework::Dataset::SetFleetSendSleepSeconds,
py::call_guard<py::gil_scoped_release>())
.def("enable_pv_merge", &framework::Dataset::EnablePvMerge,
py::call_guard<py::gil_scoped_release>());
py::class_<IterableDatasetWrapper>(*m, "IterableDatasetWrapper")
......
......@@ -1079,3 +1079,24 @@ class BoxPSDataset(InMemoryDataset):
def _dynamic_adjust_after_train(self):
pass
def slots_shuffle(self, slots):
"""
Slots Shuffle
Slots Shuffle is a shuffle method in slots level, which is usually used
in sparse feature with large scale of instances. To compare the metric, i.e.
auc while doing slots shuffle on one or several slots with baseline to
evaluate the importance level of slots(features).
Args:
slots(list[string]): the set of slots(string) to do slots shuffle.
Examples:
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_merge_by_lineid()
#suppose there is a slot 0
dataset.slots_shuffle(['0'])
"""
slots_set = set(slots)
self.boxps.slots_shuffle(slots_set)
......@@ -172,6 +172,7 @@ class TestBoxPSPreload(unittest.TestCase):
exe.run(fluid.default_startup_program())
datasets[0].load_into_memory()
datasets[0].begin_pass()
datasets[0].slots_shuffle([])
datasets[1].preload_into_memory()
exe.train_from_dataset(
program=fluid.default_main_program(),
......
......@@ -125,6 +125,7 @@ class TestDataset(unittest.TestCase):
dataset.set_trainer_num(4)
dataset.set_hdfs_config("my_fs_name", "my_fs_ugi")
dataset.set_download_cmd("./read_from_afs my_fs_name my_fs_ugi")
dataset.enable_pv_merge()
thread_num = dataset.get_thread_num()
self.assertEqual(thread_num, 12)
......@@ -231,7 +232,7 @@ class TestDataset(unittest.TestCase):
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
dataset.load_into_memory()
dataset.set_fea_eval(10000, True)
dataset.set_fea_eval(1, True)
dataset.slots_shuffle(["slot1"])
dataset.local_shuffle()
dataset.set_generate_unique_feasigns(True, 15)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册