From 3db61dc068f41c4f4fbc6e0e724e0c46ab403b98 Mon Sep 17 00:00:00 2001 From: xujiaqi01 <173596896@qq.com> Date: Fri, 1 Nov 2019 16:24:10 +0800 Subject: [PATCH] =?UTF-8?q?cherry-pick1.6=20simplify=20master+patch?= =?UTF-8?q?=EF=BC=8Cremove=20ins=20when=20size=20!=3D=20merge=5Fsize=20or?= =?UTF-8?q?=20has=20conflict=20slot=20=20(#20941)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * simplify master+patch,remove ins when size != merge_size or has conflict slot * test=develop --- paddle/fluid/framework/data_set.cc | 181 ++++++------------ paddle/fluid/framework/data_set.h | 13 +- python/paddle/fluid/dataset.py | 22 +-- .../fluid/tests/unittests/test_dataset.py | 3 +- 4 files changed, 65 insertions(+), 154 deletions(-) diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 471db585cef..3e1f494a749 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" @@ -45,9 +46,7 @@ DatasetImpl::DatasetImpl() { fleet_send_batch_size_ = 1024; fleet_send_sleep_seconds_ = 0; merge_by_insid_ = false; - erase_duplicate_feas_ = true; - keep_unmerged_ins_ = true; - min_merge_size_ = 2; + merge_size_ = 2; parse_ins_id_ = false; parse_content_ = false; preload_thread_num_ = 0; @@ -118,15 +117,10 @@ void DatasetImpl::SetParseContent(bool parse_content) { } template -void DatasetImpl::SetMergeByInsId( - const std::vector& merge_slot_list, bool erase_duplicate_feas, - int min_merge_size, bool keep_unmerged_ins) { +void DatasetImpl::SetMergeByInsId(int merge_size) { merge_by_insid_ = true; parse_ins_id_ = true; - merge_slots_list_ = merge_slot_list; - erase_duplicate_feas_ = erase_duplicate_feas; - min_merge_size_ = min_merge_size; - keep_unmerged_ins_ = keep_unmerged_ins; + merge_size_ = merge_size; } template @@ -643,22 +637,11 @@ void MultiSlotDataset::MergeByInsId() { return; } auto multi_slot_desc = data_feed_desc_.multi_slot_desc(); - std::unordered_map merge_slots; std::vector use_slots; - std::vector use_slots_is_dense; for (size_t i = 0; i < multi_slot_desc.slots_size(); ++i) { const auto& slot = multi_slot_desc.slots(i); if (slot.is_used()) { use_slots.push_back(slot.name()); - use_slots_is_dense.push_back(slot.is_dense()); - } - } - for (size_t i = 0; i < use_slots.size(); ++i) { - // currently, we don't merge dense slots - if (std::find(merge_slots_list_.begin(), merge_slots_list_.end(), - use_slots[i]) != merge_slots_list_.end() && - !use_slots_is_dense[i]) { - merge_slots[i] = true; } } CHECK(multi_output_channel_.size() != 0); // NOLINT @@ -682,134 +665,82 @@ void MultiSlotDataset::MergeByInsId() { return a.ins_id_ < b.ins_id_; }); - auto sort_cmp_uint64 = [&merge_slots](const FeatureItem& a, - const FeatureItem& b) { - auto& a_sign = a.sign().uint64_feasign_; - auto& b_sign = b.sign().uint64_feasign_; - return a_sign < b_sign || (a_sign == b_sign && a.slot() < b.slot()); - }; - auto sort_cmp_float = [&merge_slots](const FeatureItem& a, - const FeatureItem& b) { - auto& a_sign = a.sign().float_feasign_; - auto& b_sign = b.sign().float_feasign_; - return a_sign < b_sign || (a_sign == b_sign && a.slot() < b.slot()); - }; - auto unique_eq_uint64 = [&merge_slots](const FeatureItem& a, - const FeatureItem& b) { - if (a.slot() == b.slot() && - merge_slots.find(a.slot()) == merge_slots.end()) { - return true; - } - auto& a_sign = a.sign().uint64_feasign_; - auto& b_sign = b.sign().uint64_feasign_; - return a_sign == b_sign && a.slot() == b.slot(); - }; - auto unique_eq_float = [&merge_slots](const FeatureItem& a, - const FeatureItem& b) { - if (a.slot() == b.slot() && - merge_slots.find(a.slot()) == merge_slots.end()) { - return true; - } - auto& a_sign = a.sign().float_feasign_; - auto& b_sign = b.sign().float_feasign_; - return a_sign == b_sign && a.slot() == b.slot(); - }; - std::vector results; + uint64_t drop_ins_num = 0; + std::unordered_set all_int64; + std::unordered_set all_float; + std::unordered_set local_uint64; + std::unordered_set local_float; + VLOG(3) << "recs.size() " << recs.size(); for (size_t i = 0; i < recs.size();) { size_t j = i + 1; while (j < recs.size() && recs[j].ins_id_ == recs[i].ins_id_) { j++; } - if (j - i < min_merge_size_) { - if (keep_unmerged_ins_) { - for (size_t k = i; k < j; ++k) { - results.push_back(std::move(recs[k])); - } - } + if (merge_size_ > 0 && j - i != merge_size_) { + drop_ins_num += j - i; + LOG(WARNING) << "drop ins " << recs[i].ins_id_ << " size=" << j - i + << ", because merge_size=" << merge_size_; i = j; continue; } - std::vector merge_uint64_feasigns; - std::vector merge_float_feasigns; - Record rec = std::move(recs[i]); + all_int64.clear(); + all_float.clear(); + bool has_conflict_slot = false; + uint16_t conflict_slot = 0; + + Record rec; + rec.ins_id_ = recs[i].ins_id_; + rec.content_ = recs[i].content_; - for (size_t k = i + 1; k < j; k++) { + for (size_t k = i; k < j; k++) { + local_uint64.clear(); + local_float.clear(); for (auto& feature : recs[k].uint64_feasigns_) { - if (merge_slots.find(feature.slot()) != merge_slots.end()) { - merge_uint64_feasigns.push_back(std::move(feature)); + uint16_t slot = feature.slot(); + if (all_int64.find(slot) != all_int64.end()) { + has_conflict_slot = true; + conflict_slot = slot; + break; } + local_uint64.insert(slot); + rec.uint64_feasigns_.push_back(std::move(feature)); + } + if (has_conflict_slot) { + break; } + all_int64.insert(local_uint64.begin(), local_uint64.end()); + for (auto& feature : recs[k].float_feasigns_) { - if (merge_slots.find(feature.slot()) != merge_slots.end()) { - merge_float_feasigns.push_back(std::move(feature)); + uint16_t slot = feature.slot(); + if (all_float.find(slot) != all_float.end()) { + has_conflict_slot = true; + conflict_slot = slot; + break; } + local_float.insert(slot); + rec.float_feasigns_.push_back(std::move(feature)); + } + if (has_conflict_slot) { + break; } - recs[k] = Record(); + all_float.insert(local_float.begin(), local_float.end()); } - i = j; - if (!erase_duplicate_feas_) { - rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(), - merge_uint64_feasigns.begin(), - merge_uint64_feasigns.end()); - rec.float_feasigns_.insert(rec.float_feasigns_.end(), - merge_float_feasigns.begin(), - merge_float_feasigns.end()); + if (has_conflict_slot) { + LOG(WARNING) << "drop ins " << recs[i].ins_id_ << " size=" << j - i + << ", because conflict_slot=" << use_slots[conflict_slot]; + drop_ins_num += j - i; } else { - std::vector not_merge_uint64_feasigns; - std::vector not_merge_float_feasigns; - - for (auto& feature : rec.uint64_feasigns_) { - if (merge_slots.find(feature.slot()) != merge_slots.end()) { - merge_uint64_feasigns.push_back(std::move(feature)); - } else { - not_merge_uint64_feasigns.push_back(std::move(feature)); - } - } - for (auto& feature : rec.float_feasigns_) { - if (merge_slots.find(feature.slot()) != merge_slots.end()) { - merge_float_feasigns.push_back(std::move(feature)); - } else { - not_merge_float_feasigns.push_back(std::move(feature)); - } - } - rec.uint64_feasigns_.clear(); - rec.float_feasigns_.clear(); - - // erase duplicate uint64 feasigns - std::sort(merge_uint64_feasigns.begin(), merge_uint64_feasigns.end(), - sort_cmp_uint64); - merge_uint64_feasigns.erase( - std::unique(merge_uint64_feasigns.begin(), - merge_uint64_feasigns.end(), unique_eq_uint64), - merge_uint64_feasigns.end()); - rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(), - merge_uint64_feasigns.begin(), - merge_uint64_feasigns.end()); - rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(), - not_merge_uint64_feasigns.begin(), - not_merge_uint64_feasigns.end()); - - // erase duplicate float feasigns - std::sort(merge_float_feasigns.begin(), merge_float_feasigns.end(), - sort_cmp_float); - merge_float_feasigns.erase( - std::unique(merge_float_feasigns.begin(), merge_float_feasigns.end(), - unique_eq_float), - merge_float_feasigns.end()); - rec.float_feasigns_.insert(rec.float_feasigns_.end(), - merge_float_feasigns.begin(), - merge_float_feasigns.end()); - rec.float_feasigns_.insert(rec.float_feasigns_.end(), - not_merge_float_feasigns.begin(), - not_merge_float_feasigns.end()); + results.push_back(std::move(rec)); } - results.push_back(rec); + i = j; } + std::vector().swap(recs); VLOG(3) << "results size " << results.size(); + LOG(WARNING) << "total drop ins num: " << drop_ins_num; results.shrink_to_fit(); auto fleet_ptr = FleetWrapper::GetInstance(); diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index bcf344d23a4..7c8fa461550 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -62,9 +62,7 @@ class Dataset { virtual void SetParseInsId(bool parse_ins_id) = 0; virtual void SetParseContent(bool parse_content) = 0; // set merge by ins id - virtual void SetMergeByInsId(const std::vector& merge_slot_list, - bool erase_duplicate_feas, int min_merge_size, - bool keep_unmerged_ins) = 0; + virtual void SetMergeByInsId(int merge_size) = 0; // set fea eval mode virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0; // get file list @@ -149,9 +147,7 @@ class DatasetImpl : public Dataset { virtual void SetChannelNum(int channel_num); virtual void SetParseInsId(bool parse_ins_id); virtual void SetParseContent(bool parse_content); - virtual void SetMergeByInsId(const std::vector& merge_slot_list, - bool erase_duplicate_feas, int min_merge_size, - bool keep_unmerged_ins); + virtual void SetMergeByInsId(int merge_size); virtual void SetFeaEval(bool fea_eval, int record_candidate_size); virtual const std::vector& GetFileList() { return filelist_; } @@ -219,10 +215,7 @@ class DatasetImpl : public Dataset { bool merge_by_insid_; bool parse_ins_id_; bool parse_content_; - bool erase_duplicate_feas_; - bool keep_unmerged_ins_; - int min_merge_size_; - std::vector merge_slots_list_; + int merge_size_; bool slots_shuffle_fea_eval_ = false; int preload_thread_num_; std::mutex global_index_mutex_; diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 1ae2d056e85..1bf1c8b8b37 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -408,26 +408,13 @@ class InMemoryDataset(DatasetBase): """ self.fleet_send_sleep_seconds = fleet_send_sleep_seconds - def set_merge_by_lineid(self, - var_list, - erase_duplicate_feas=True, - min_merge_size=2, - keep_unmerged_ins=True): + def set_merge_by_lineid(self, merge_size=2): """ Set merge by line id, instances of same line id will be merged after shuffle, you should parse line id in data generator. Args: - var_list(list): slots that can be merge. each element in var_list - is Variable. some slots such as show and click, we - usually don't merge them for same line id, so user - should specify which slot can be merged. - erase_duplicate_feas(bool): whether erase duplicate feasigns when - merge. default is True. - min_merge_size(int): minimal size to merge. default is 2. - keep_unmerged_ins(bool): whether to keep unmerged ins, such as - ins with unique id or the num of ins with - same id is less than min_merge_size. + merge_size(int): ins size to merge. default is 2. Examples: .. code-block:: python @@ -437,10 +424,9 @@ class InMemoryDataset(DatasetBase): dataset.set_merge_by_lineid() """ - var_name_list = [i.name for i in var_list] - self.dataset.set_merge_by_lineid(var_name_list, erase_duplicate_feas, - min_merge_size, keep_unmerged_ins) + self.dataset.set_merge_by_lineid(merge_size) self.merge_by_lineid = True + self.parse_ins_id = True def load_into_memory(self): """ diff --git a/python/paddle/fluid/tests/unittests/test_dataset.py b/python/paddle/fluid/tests/unittests/test_dataset.py index 737e16b3a1a..a9b46273bcb 100644 --- a/python/paddle/fluid/tests/unittests/test_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataset.py @@ -272,7 +272,8 @@ class TestDataset(unittest.TestCase): except Exception as e: self.assertTrue(False) - dataset.set_merge_by_lineid(slots_vars) + dataset.set_merge_by_lineid(2) + dataset.set_parse_ins_id(False) dataset.set_fleet_send_sleep_seconds(2) dataset.preload_into_memory() dataset.wait_preload_done() -- GitLab