未验证 提交 3db61dc0 编写于 作者: X xujiaqi01 提交者: GitHub

cherry-pick1.6 simplify master+patch,remove ins when size != merge_size or has...

cherry-pick1.6 simplify master+patch,remove ins when size != merge_size or has conflict slot  (#20941)

* simplify master+patch,remove ins when size != merge_size or has conflict slot
* test=develop
上级 5c3656bb
......@@ -16,6 +16,7 @@
#include <algorithm>
#include <random>
#include <unordered_map>
#include <unordered_set>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
......@@ -45,9 +46,7 @@ DatasetImpl<T>::DatasetImpl() {
fleet_send_batch_size_ = 1024;
fleet_send_sleep_seconds_ = 0;
merge_by_insid_ = false;
erase_duplicate_feas_ = true;
keep_unmerged_ins_ = true;
min_merge_size_ = 2;
merge_size_ = 2;
parse_ins_id_ = false;
parse_content_ = false;
preload_thread_num_ = 0;
......@@ -118,15 +117,10 @@ void DatasetImpl<T>::SetParseContent(bool parse_content) {
}
template <typename T>
void DatasetImpl<T>::SetMergeByInsId(
const std::vector<std::string>& merge_slot_list, bool erase_duplicate_feas,
int min_merge_size, bool keep_unmerged_ins) {
void DatasetImpl<T>::SetMergeByInsId(int merge_size) {
merge_by_insid_ = true;
parse_ins_id_ = true;
merge_slots_list_ = merge_slot_list;
erase_duplicate_feas_ = erase_duplicate_feas;
min_merge_size_ = min_merge_size;
keep_unmerged_ins_ = keep_unmerged_ins;
merge_size_ = merge_size;
}
template <typename T>
......@@ -643,22 +637,11 @@ void MultiSlotDataset::MergeByInsId() {
return;
}
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
std::unordered_map<int, bool> merge_slots;
std::vector<std::string> use_slots;
std::vector<bool> use_slots_is_dense;
for (size_t i = 0; i < multi_slot_desc.slots_size(); ++i) {
const auto& slot = multi_slot_desc.slots(i);
if (slot.is_used()) {
use_slots.push_back(slot.name());
use_slots_is_dense.push_back(slot.is_dense());
}
}
for (size_t i = 0; i < use_slots.size(); ++i) {
// currently, we don't merge dense slots
if (std::find(merge_slots_list_.begin(), merge_slots_list_.end(),
use_slots[i]) != merge_slots_list_.end() &&
!use_slots_is_dense[i]) {
merge_slots[i] = true;
}
}
CHECK(multi_output_channel_.size() != 0); // NOLINT
......@@ -682,134 +665,82 @@ void MultiSlotDataset::MergeByInsId() {
return a.ins_id_ < b.ins_id_;
});
auto sort_cmp_uint64 = [&merge_slots](const FeatureItem& a,
const FeatureItem& b) {
auto& a_sign = a.sign().uint64_feasign_;
auto& b_sign = b.sign().uint64_feasign_;
return a_sign < b_sign || (a_sign == b_sign && a.slot() < b.slot());
};
auto sort_cmp_float = [&merge_slots](const FeatureItem& a,
const FeatureItem& b) {
auto& a_sign = a.sign().float_feasign_;
auto& b_sign = b.sign().float_feasign_;
return a_sign < b_sign || (a_sign == b_sign && a.slot() < b.slot());
};
auto unique_eq_uint64 = [&merge_slots](const FeatureItem& a,
const FeatureItem& b) {
if (a.slot() == b.slot() &&
merge_slots.find(a.slot()) == merge_slots.end()) {
return true;
}
auto& a_sign = a.sign().uint64_feasign_;
auto& b_sign = b.sign().uint64_feasign_;
return a_sign == b_sign && a.slot() == b.slot();
};
auto unique_eq_float = [&merge_slots](const FeatureItem& a,
const FeatureItem& b) {
if (a.slot() == b.slot() &&
merge_slots.find(a.slot()) == merge_slots.end()) {
return true;
}
auto& a_sign = a.sign().float_feasign_;
auto& b_sign = b.sign().float_feasign_;
return a_sign == b_sign && a.slot() == b.slot();
};
std::vector<Record> results;
uint64_t drop_ins_num = 0;
std::unordered_set<uint16_t> all_int64;
std::unordered_set<uint16_t> all_float;
std::unordered_set<uint16_t> local_uint64;
std::unordered_set<uint16_t> local_float;
VLOG(3) << "recs.size() " << recs.size();
for (size_t i = 0; i < recs.size();) {
size_t j = i + 1;
while (j < recs.size() && recs[j].ins_id_ == recs[i].ins_id_) {
j++;
}
if (j - i < min_merge_size_) {
if (keep_unmerged_ins_) {
for (size_t k = i; k < j; ++k) {
results.push_back(std::move(recs[k]));
}
}
if (merge_size_ > 0 && j - i != merge_size_) {
drop_ins_num += j - i;
LOG(WARNING) << "drop ins " << recs[i].ins_id_ << " size=" << j - i
<< ", because merge_size=" << merge_size_;
i = j;
continue;
}
std::vector<FeatureItem> merge_uint64_feasigns;
std::vector<FeatureItem> merge_float_feasigns;
Record rec = std::move(recs[i]);
all_int64.clear();
all_float.clear();
bool has_conflict_slot = false;
uint16_t conflict_slot = 0;
Record rec;
rec.ins_id_ = recs[i].ins_id_;
rec.content_ = recs[i].content_;
for (size_t k = i + 1; k < j; k++) {
for (size_t k = i; k < j; k++) {
local_uint64.clear();
local_float.clear();
for (auto& feature : recs[k].uint64_feasigns_) {
if (merge_slots.find(feature.slot()) != merge_slots.end()) {
merge_uint64_feasigns.push_back(std::move(feature));
uint16_t slot = feature.slot();
if (all_int64.find(slot) != all_int64.end()) {
has_conflict_slot = true;
conflict_slot = slot;
break;
}
local_uint64.insert(slot);
rec.uint64_feasigns_.push_back(std::move(feature));
}
if (has_conflict_slot) {
break;
}
all_int64.insert(local_uint64.begin(), local_uint64.end());
for (auto& feature : recs[k].float_feasigns_) {
if (merge_slots.find(feature.slot()) != merge_slots.end()) {
merge_float_feasigns.push_back(std::move(feature));
uint16_t slot = feature.slot();
if (all_float.find(slot) != all_float.end()) {
has_conflict_slot = true;
conflict_slot = slot;
break;
}
local_float.insert(slot);
rec.float_feasigns_.push_back(std::move(feature));
}
if (has_conflict_slot) {
break;
}
recs[k] = Record();
all_float.insert(local_float.begin(), local_float.end());
}
i = j;
if (!erase_duplicate_feas_) {
rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(),
merge_uint64_feasigns.begin(),
merge_uint64_feasigns.end());
rec.float_feasigns_.insert(rec.float_feasigns_.end(),
merge_float_feasigns.begin(),
merge_float_feasigns.end());
if (has_conflict_slot) {
LOG(WARNING) << "drop ins " << recs[i].ins_id_ << " size=" << j - i
<< ", because conflict_slot=" << use_slots[conflict_slot];
drop_ins_num += j - i;
} else {
std::vector<FeatureItem> not_merge_uint64_feasigns;
std::vector<FeatureItem> not_merge_float_feasigns;
for (auto& feature : rec.uint64_feasigns_) {
if (merge_slots.find(feature.slot()) != merge_slots.end()) {
merge_uint64_feasigns.push_back(std::move(feature));
} else {
not_merge_uint64_feasigns.push_back(std::move(feature));
}
}
for (auto& feature : rec.float_feasigns_) {
if (merge_slots.find(feature.slot()) != merge_slots.end()) {
merge_float_feasigns.push_back(std::move(feature));
} else {
not_merge_float_feasigns.push_back(std::move(feature));
}
}
rec.uint64_feasigns_.clear();
rec.float_feasigns_.clear();
// erase duplicate uint64 feasigns
std::sort(merge_uint64_feasigns.begin(), merge_uint64_feasigns.end(),
sort_cmp_uint64);
merge_uint64_feasigns.erase(
std::unique(merge_uint64_feasigns.begin(),
merge_uint64_feasigns.end(), unique_eq_uint64),
merge_uint64_feasigns.end());
rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(),
merge_uint64_feasigns.begin(),
merge_uint64_feasigns.end());
rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(),
not_merge_uint64_feasigns.begin(),
not_merge_uint64_feasigns.end());
// erase duplicate float feasigns
std::sort(merge_float_feasigns.begin(), merge_float_feasigns.end(),
sort_cmp_float);
merge_float_feasigns.erase(
std::unique(merge_float_feasigns.begin(), merge_float_feasigns.end(),
unique_eq_float),
merge_float_feasigns.end());
rec.float_feasigns_.insert(rec.float_feasigns_.end(),
merge_float_feasigns.begin(),
merge_float_feasigns.end());
rec.float_feasigns_.insert(rec.float_feasigns_.end(),
not_merge_float_feasigns.begin(),
not_merge_float_feasigns.end());
results.push_back(std::move(rec));
}
results.push_back(rec);
i = j;
}
std::vector<Record>().swap(recs);
VLOG(3) << "results size " << results.size();
LOG(WARNING) << "total drop ins num: " << drop_ins_num;
results.shrink_to_fit();
auto fleet_ptr = FleetWrapper::GetInstance();
......
......@@ -62,9 +62,7 @@ class Dataset {
virtual void SetParseInsId(bool parse_ins_id) = 0;
virtual void SetParseContent(bool parse_content) = 0;
// set merge by ins id
virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list,
bool erase_duplicate_feas, int min_merge_size,
bool keep_unmerged_ins) = 0;
virtual void SetMergeByInsId(int merge_size) = 0;
// set fea eval mode
virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0;
// get file list
......@@ -149,9 +147,7 @@ class DatasetImpl : public Dataset {
virtual void SetChannelNum(int channel_num);
virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseContent(bool parse_content);
virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list,
bool erase_duplicate_feas, int min_merge_size,
bool keep_unmerged_ins);
virtual void SetMergeByInsId(int merge_size);
virtual void SetFeaEval(bool fea_eval, int record_candidate_size);
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
......@@ -219,10 +215,7 @@ class DatasetImpl : public Dataset {
bool merge_by_insid_;
bool parse_ins_id_;
bool parse_content_;
bool erase_duplicate_feas_;
bool keep_unmerged_ins_;
int min_merge_size_;
std::vector<std::string> merge_slots_list_;
int merge_size_;
bool slots_shuffle_fea_eval_ = false;
int preload_thread_num_;
std::mutex global_index_mutex_;
......
......@@ -408,26 +408,13 @@ class InMemoryDataset(DatasetBase):
"""
self.fleet_send_sleep_seconds = fleet_send_sleep_seconds
def set_merge_by_lineid(self,
var_list,
erase_duplicate_feas=True,
min_merge_size=2,
keep_unmerged_ins=True):
def set_merge_by_lineid(self, merge_size=2):
"""
Set merge by line id, instances of same line id will be merged after
shuffle, you should parse line id in data generator.
Args:
var_list(list): slots that can be merge. each element in var_list
is Variable. some slots such as show and click, we
usually don't merge them for same line id, so user
should specify which slot can be merged.
erase_duplicate_feas(bool): whether erase duplicate feasigns when
merge. default is True.
min_merge_size(int): minimal size to merge. default is 2.
keep_unmerged_ins(bool): whether to keep unmerged ins, such as
ins with unique id or the num of ins with
same id is less than min_merge_size.
merge_size(int): ins size to merge. default is 2.
Examples:
.. code-block:: python
......@@ -437,10 +424,9 @@ class InMemoryDataset(DatasetBase):
dataset.set_merge_by_lineid()
"""
var_name_list = [i.name for i in var_list]
self.dataset.set_merge_by_lineid(var_name_list, erase_duplicate_feas,
min_merge_size, keep_unmerged_ins)
self.dataset.set_merge_by_lineid(merge_size)
self.merge_by_lineid = True
self.parse_ins_id = True
def load_into_memory(self):
"""
......
......@@ -272,7 +272,8 @@ class TestDataset(unittest.TestCase):
except Exception as e:
self.assertTrue(False)
dataset.set_merge_by_lineid(slots_vars)
dataset.set_merge_by_lineid(2)
dataset.set_parse_ins_id(False)
dataset.set_fleet_send_sleep_seconds(2)
dataset.preload_into_memory()
dataset.wait_preload_done()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册