未验证 提交 3db61dc0 编写于 作者: X xujiaqi01 提交者: GitHub

cherry-pick1.6 simplify master+patch,remove ins when size != merge_size or has...

cherry-pick1.6 simplify master+patch,remove ins when size != merge_size or has conflict slot  (#20941)

* simplify master+patch,remove ins when size != merge_size or has conflict slot
* test=develop
上级 5c3656bb
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <algorithm> #include <algorithm>
#include <random> #include <random>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
...@@ -45,9 +46,7 @@ DatasetImpl<T>::DatasetImpl() { ...@@ -45,9 +46,7 @@ DatasetImpl<T>::DatasetImpl() {
fleet_send_batch_size_ = 1024; fleet_send_batch_size_ = 1024;
fleet_send_sleep_seconds_ = 0; fleet_send_sleep_seconds_ = 0;
merge_by_insid_ = false; merge_by_insid_ = false;
erase_duplicate_feas_ = true; merge_size_ = 2;
keep_unmerged_ins_ = true;
min_merge_size_ = 2;
parse_ins_id_ = false; parse_ins_id_ = false;
parse_content_ = false; parse_content_ = false;
preload_thread_num_ = 0; preload_thread_num_ = 0;
...@@ -118,15 +117,10 @@ void DatasetImpl<T>::SetParseContent(bool parse_content) { ...@@ -118,15 +117,10 @@ void DatasetImpl<T>::SetParseContent(bool parse_content) {
} }
template <typename T> template <typename T>
void DatasetImpl<T>::SetMergeByInsId( void DatasetImpl<T>::SetMergeByInsId(int merge_size) {
const std::vector<std::string>& merge_slot_list, bool erase_duplicate_feas,
int min_merge_size, bool keep_unmerged_ins) {
merge_by_insid_ = true; merge_by_insid_ = true;
parse_ins_id_ = true; parse_ins_id_ = true;
merge_slots_list_ = merge_slot_list; merge_size_ = merge_size;
erase_duplicate_feas_ = erase_duplicate_feas;
min_merge_size_ = min_merge_size;
keep_unmerged_ins_ = keep_unmerged_ins;
} }
template <typename T> template <typename T>
...@@ -643,22 +637,11 @@ void MultiSlotDataset::MergeByInsId() { ...@@ -643,22 +637,11 @@ void MultiSlotDataset::MergeByInsId() {
return; return;
} }
auto multi_slot_desc = data_feed_desc_.multi_slot_desc(); auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
std::unordered_map<int, bool> merge_slots;
std::vector<std::string> use_slots; std::vector<std::string> use_slots;
std::vector<bool> use_slots_is_dense;
for (size_t i = 0; i < multi_slot_desc.slots_size(); ++i) { for (size_t i = 0; i < multi_slot_desc.slots_size(); ++i) {
const auto& slot = multi_slot_desc.slots(i); const auto& slot = multi_slot_desc.slots(i);
if (slot.is_used()) { if (slot.is_used()) {
use_slots.push_back(slot.name()); use_slots.push_back(slot.name());
use_slots_is_dense.push_back(slot.is_dense());
}
}
for (size_t i = 0; i < use_slots.size(); ++i) {
// currently, we don't merge dense slots
if (std::find(merge_slots_list_.begin(), merge_slots_list_.end(),
use_slots[i]) != merge_slots_list_.end() &&
!use_slots_is_dense[i]) {
merge_slots[i] = true;
} }
} }
CHECK(multi_output_channel_.size() != 0); // NOLINT CHECK(multi_output_channel_.size() != 0); // NOLINT
...@@ -682,134 +665,82 @@ void MultiSlotDataset::MergeByInsId() { ...@@ -682,134 +665,82 @@ void MultiSlotDataset::MergeByInsId() {
return a.ins_id_ < b.ins_id_; return a.ins_id_ < b.ins_id_;
}); });
auto sort_cmp_uint64 = [&merge_slots](const FeatureItem& a,
const FeatureItem& b) {
auto& a_sign = a.sign().uint64_feasign_;
auto& b_sign = b.sign().uint64_feasign_;
return a_sign < b_sign || (a_sign == b_sign && a.slot() < b.slot());
};
auto sort_cmp_float = [&merge_slots](const FeatureItem& a,
const FeatureItem& b) {
auto& a_sign = a.sign().float_feasign_;
auto& b_sign = b.sign().float_feasign_;
return a_sign < b_sign || (a_sign == b_sign && a.slot() < b.slot());
};
auto unique_eq_uint64 = [&merge_slots](const FeatureItem& a,
const FeatureItem& b) {
if (a.slot() == b.slot() &&
merge_slots.find(a.slot()) == merge_slots.end()) {
return true;
}
auto& a_sign = a.sign().uint64_feasign_;
auto& b_sign = b.sign().uint64_feasign_;
return a_sign == b_sign && a.slot() == b.slot();
};
auto unique_eq_float = [&merge_slots](const FeatureItem& a,
const FeatureItem& b) {
if (a.slot() == b.slot() &&
merge_slots.find(a.slot()) == merge_slots.end()) {
return true;
}
auto& a_sign = a.sign().float_feasign_;
auto& b_sign = b.sign().float_feasign_;
return a_sign == b_sign && a.slot() == b.slot();
};
std::vector<Record> results; std::vector<Record> results;
uint64_t drop_ins_num = 0;
std::unordered_set<uint16_t> all_int64;
std::unordered_set<uint16_t> all_float;
std::unordered_set<uint16_t> local_uint64;
std::unordered_set<uint16_t> local_float;
VLOG(3) << "recs.size() " << recs.size(); VLOG(3) << "recs.size() " << recs.size();
for (size_t i = 0; i < recs.size();) { for (size_t i = 0; i < recs.size();) {
size_t j = i + 1; size_t j = i + 1;
while (j < recs.size() && recs[j].ins_id_ == recs[i].ins_id_) { while (j < recs.size() && recs[j].ins_id_ == recs[i].ins_id_) {
j++; j++;
} }
if (j - i < min_merge_size_) { if (merge_size_ > 0 && j - i != merge_size_) {
if (keep_unmerged_ins_) { drop_ins_num += j - i;
for (size_t k = i; k < j; ++k) { LOG(WARNING) << "drop ins " << recs[i].ins_id_ << " size=" << j - i
results.push_back(std::move(recs[k])); << ", because merge_size=" << merge_size_;
}
}
i = j; i = j;
continue; continue;
} }
std::vector<FeatureItem> merge_uint64_feasigns; all_int64.clear();
std::vector<FeatureItem> merge_float_feasigns; all_float.clear();
Record rec = std::move(recs[i]); bool has_conflict_slot = false;
uint16_t conflict_slot = 0;
Record rec;
rec.ins_id_ = recs[i].ins_id_;
rec.content_ = recs[i].content_;
for (size_t k = i + 1; k < j; k++) { for (size_t k = i; k < j; k++) {
local_uint64.clear();
local_float.clear();
for (auto& feature : recs[k].uint64_feasigns_) { for (auto& feature : recs[k].uint64_feasigns_) {
if (merge_slots.find(feature.slot()) != merge_slots.end()) { uint16_t slot = feature.slot();
merge_uint64_feasigns.push_back(std::move(feature)); if (all_int64.find(slot) != all_int64.end()) {
has_conflict_slot = true;
conflict_slot = slot;
break;
} }
local_uint64.insert(slot);
rec.uint64_feasigns_.push_back(std::move(feature));
}
if (has_conflict_slot) {
break;
} }
all_int64.insert(local_uint64.begin(), local_uint64.end());
for (auto& feature : recs[k].float_feasigns_) { for (auto& feature : recs[k].float_feasigns_) {
if (merge_slots.find(feature.slot()) != merge_slots.end()) { uint16_t slot = feature.slot();
merge_float_feasigns.push_back(std::move(feature)); if (all_float.find(slot) != all_float.end()) {
has_conflict_slot = true;
conflict_slot = slot;
break;
} }
local_float.insert(slot);
rec.float_feasigns_.push_back(std::move(feature));
}
if (has_conflict_slot) {
break;
} }
recs[k] = Record(); all_float.insert(local_float.begin(), local_float.end());
} }
i = j;
if (!erase_duplicate_feas_) { if (has_conflict_slot) {
rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(), LOG(WARNING) << "drop ins " << recs[i].ins_id_ << " size=" << j - i
merge_uint64_feasigns.begin(), << ", because conflict_slot=" << use_slots[conflict_slot];
merge_uint64_feasigns.end()); drop_ins_num += j - i;
rec.float_feasigns_.insert(rec.float_feasigns_.end(),
merge_float_feasigns.begin(),
merge_float_feasigns.end());
} else { } else {
std::vector<FeatureItem> not_merge_uint64_feasigns; results.push_back(std::move(rec));
std::vector<FeatureItem> not_merge_float_feasigns;
for (auto& feature : rec.uint64_feasigns_) {
if (merge_slots.find(feature.slot()) != merge_slots.end()) {
merge_uint64_feasigns.push_back(std::move(feature));
} else {
not_merge_uint64_feasigns.push_back(std::move(feature));
}
}
for (auto& feature : rec.float_feasigns_) {
if (merge_slots.find(feature.slot()) != merge_slots.end()) {
merge_float_feasigns.push_back(std::move(feature));
} else {
not_merge_float_feasigns.push_back(std::move(feature));
}
}
rec.uint64_feasigns_.clear();
rec.float_feasigns_.clear();
// erase duplicate uint64 feasigns
std::sort(merge_uint64_feasigns.begin(), merge_uint64_feasigns.end(),
sort_cmp_uint64);
merge_uint64_feasigns.erase(
std::unique(merge_uint64_feasigns.begin(),
merge_uint64_feasigns.end(), unique_eq_uint64),
merge_uint64_feasigns.end());
rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(),
merge_uint64_feasigns.begin(),
merge_uint64_feasigns.end());
rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(),
not_merge_uint64_feasigns.begin(),
not_merge_uint64_feasigns.end());
// erase duplicate float feasigns
std::sort(merge_float_feasigns.begin(), merge_float_feasigns.end(),
sort_cmp_float);
merge_float_feasigns.erase(
std::unique(merge_float_feasigns.begin(), merge_float_feasigns.end(),
unique_eq_float),
merge_float_feasigns.end());
rec.float_feasigns_.insert(rec.float_feasigns_.end(),
merge_float_feasigns.begin(),
merge_float_feasigns.end());
rec.float_feasigns_.insert(rec.float_feasigns_.end(),
not_merge_float_feasigns.begin(),
not_merge_float_feasigns.end());
} }
results.push_back(rec); i = j;
} }
std::vector<Record>().swap(recs);
VLOG(3) << "results size " << results.size(); VLOG(3) << "results size " << results.size();
LOG(WARNING) << "total drop ins num: " << drop_ins_num;
results.shrink_to_fit(); results.shrink_to_fit();
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
......
...@@ -62,9 +62,7 @@ class Dataset { ...@@ -62,9 +62,7 @@ class Dataset {
virtual void SetParseInsId(bool parse_ins_id) = 0; virtual void SetParseInsId(bool parse_ins_id) = 0;
virtual void SetParseContent(bool parse_content) = 0; virtual void SetParseContent(bool parse_content) = 0;
// set merge by ins id // set merge by ins id
virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list, virtual void SetMergeByInsId(int merge_size) = 0;
bool erase_duplicate_feas, int min_merge_size,
bool keep_unmerged_ins) = 0;
// set fea eval mode // set fea eval mode
virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0; virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0;
// get file list // get file list
...@@ -149,9 +147,7 @@ class DatasetImpl : public Dataset { ...@@ -149,9 +147,7 @@ class DatasetImpl : public Dataset {
virtual void SetChannelNum(int channel_num); virtual void SetChannelNum(int channel_num);
virtual void SetParseInsId(bool parse_ins_id); virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseContent(bool parse_content); virtual void SetParseContent(bool parse_content);
virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list, virtual void SetMergeByInsId(int merge_size);
bool erase_duplicate_feas, int min_merge_size,
bool keep_unmerged_ins);
virtual void SetFeaEval(bool fea_eval, int record_candidate_size); virtual void SetFeaEval(bool fea_eval, int record_candidate_size);
virtual const std::vector<std::string>& GetFileList() { return filelist_; } virtual const std::vector<std::string>& GetFileList() { return filelist_; }
...@@ -219,10 +215,7 @@ class DatasetImpl : public Dataset { ...@@ -219,10 +215,7 @@ class DatasetImpl : public Dataset {
bool merge_by_insid_; bool merge_by_insid_;
bool parse_ins_id_; bool parse_ins_id_;
bool parse_content_; bool parse_content_;
bool erase_duplicate_feas_; int merge_size_;
bool keep_unmerged_ins_;
int min_merge_size_;
std::vector<std::string> merge_slots_list_;
bool slots_shuffle_fea_eval_ = false; bool slots_shuffle_fea_eval_ = false;
int preload_thread_num_; int preload_thread_num_;
std::mutex global_index_mutex_; std::mutex global_index_mutex_;
......
...@@ -408,26 +408,13 @@ class InMemoryDataset(DatasetBase): ...@@ -408,26 +408,13 @@ class InMemoryDataset(DatasetBase):
""" """
self.fleet_send_sleep_seconds = fleet_send_sleep_seconds self.fleet_send_sleep_seconds = fleet_send_sleep_seconds
def set_merge_by_lineid(self, def set_merge_by_lineid(self, merge_size=2):
var_list,
erase_duplicate_feas=True,
min_merge_size=2,
keep_unmerged_ins=True):
""" """
Set merge by line id, instances of same line id will be merged after Set merge by line id, instances of same line id will be merged after
shuffle, you should parse line id in data generator. shuffle, you should parse line id in data generator.
Args: Args:
var_list(list): slots that can be merge. each element in var_list merge_size(int): ins size to merge. default is 2.
is Variable. some slots such as show and click, we
usually don't merge them for same line id, so user
should specify which slot can be merged.
erase_duplicate_feas(bool): whether erase duplicate feasigns when
merge. default is True.
min_merge_size(int): minimal size to merge. default is 2.
keep_unmerged_ins(bool): whether to keep unmerged ins, such as
ins with unique id or the num of ins with
same id is less than min_merge_size.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -437,10 +424,9 @@ class InMemoryDataset(DatasetBase): ...@@ -437,10 +424,9 @@ class InMemoryDataset(DatasetBase):
dataset.set_merge_by_lineid() dataset.set_merge_by_lineid()
""" """
var_name_list = [i.name for i in var_list] self.dataset.set_merge_by_lineid(merge_size)
self.dataset.set_merge_by_lineid(var_name_list, erase_duplicate_feas,
min_merge_size, keep_unmerged_ins)
self.merge_by_lineid = True self.merge_by_lineid = True
self.parse_ins_id = True
def load_into_memory(self): def load_into_memory(self):
""" """
......
...@@ -272,7 +272,8 @@ class TestDataset(unittest.TestCase): ...@@ -272,7 +272,8 @@ class TestDataset(unittest.TestCase):
except Exception as e: except Exception as e:
self.assertTrue(False) self.assertTrue(False)
dataset.set_merge_by_lineid(slots_vars) dataset.set_merge_by_lineid(2)
dataset.set_parse_ins_id(False)
dataset.set_fleet_send_sleep_seconds(2) dataset.set_fleet_send_sleep_seconds(2)
dataset.preload_into_memory() dataset.preload_into_memory()
dataset.wait_preload_done() dataset.wait_preload_done()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册