diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index ebe7f6caca567b8f7bb32eb90fbda4cd0c3a3564..514bc8ba9ebf1f1a394588378f2ef66c05013599 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -646,10 +646,12 @@ void MultiSlotDataset::MergeByInsId() { } auto multi_slot_desc = data_feed_desc_.multi_slot_desc(); std::vector use_slots; + std::vector use_slots_is_dense; for (int i = 0; i < multi_slot_desc.slots_size(); ++i) { const auto& slot = multi_slot_desc.slots(i); if (slot.is_used()) { use_slots.push_back(slot.name()); + use_slots_is_dense.push_back(slot.is_dense()); } } CHECK(multi_output_channel_.size() != 0); // NOLINT @@ -679,6 +681,11 @@ void MultiSlotDataset::MergeByInsId() { std::unordered_set all_float; std::unordered_set local_uint64; std::unordered_set local_float; + std::unordered_map> all_dense_uint64; + std::unordered_map> all_dense_float; + std::unordered_map> local_dense_uint64; + std::unordered_map> local_dense_float; + std::unordered_map dense_empty; VLOG(3) << "recs.size() " << recs.size(); for (size_t i = 0; i < recs.size();) { @@ -696,6 +703,8 @@ void MultiSlotDataset::MergeByInsId() { all_int64.clear(); all_float.clear(); + all_dense_uint64.clear(); + all_dense_float.clear(); bool has_conflict_slot = false; uint16_t conflict_slot = 0; @@ -703,12 +712,61 @@ void MultiSlotDataset::MergeByInsId() { rec.ins_id_ = recs[i].ins_id_; rec.content_ = recs[i].content_; + for (size_t k = i; k < j; k++) { + dense_empty.clear(); + local_dense_uint64.clear(); + local_dense_float.clear(); + for (auto& feature : recs[k].uint64_feasigns_) { + uint16_t slot = feature.slot(); + if (!use_slots_is_dense[slot]) { + continue; + } + local_dense_uint64[slot].push_back(feature); + if (feature.sign().uint64_feasign_ != 0) { + dense_empty[slot] = false; + } else if (dense_empty.find(slot) == dense_empty.end() && + all_dense_uint64.find(slot) == all_dense_uint64.end()) { + dense_empty[slot] = true; + } + } + for (auto& feature : recs[k].float_feasigns_) { + uint16_t slot = feature.slot(); + if (!use_slots_is_dense[slot]) { + continue; + } + local_dense_float[slot].push_back(feature); + if (fabs(feature.sign().float_feasign_) >= 1e-6) { + dense_empty[slot] = false; + } else if (dense_empty.find(slot) == dense_empty.end() && + all_dense_float.find(slot) == all_dense_float.end()) { + dense_empty[slot] = true; + } + } + for (auto& p : dense_empty) { + if (local_dense_uint64.find(p.first) != local_dense_uint64.end()) { + all_dense_uint64[p.first] = std::move(local_dense_uint64[p.first]); + } else if (local_dense_float.find(p.first) != local_dense_float.end()) { + all_dense_float[p.first] = std::move(local_dense_float[p.first]); + } + } + } + for (auto& f : all_dense_uint64) { + rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(), f.second.begin(), + f.second.end()); + } + for (auto& f : all_dense_float) { + rec.float_feasigns_.insert(rec.float_feasigns_.end(), f.second.begin(), + f.second.end()); + } + for (size_t k = i; k < j; k++) { local_uint64.clear(); local_float.clear(); for (auto& feature : recs[k].uint64_feasigns_) { uint16_t slot = feature.slot(); - if (all_int64.find(slot) != all_int64.end()) { + if (use_slots_is_dense[slot]) { + continue; + } else if (all_int64.find(slot) != all_int64.end()) { has_conflict_slot = true; conflict_slot = slot; break; @@ -723,7 +781,9 @@ void MultiSlotDataset::MergeByInsId() { for (auto& feature : recs[k].float_feasigns_) { uint16_t slot = feature.slot(); - if (all_float.find(slot) != all_float.end()) { + if (use_slots_is_dense[slot]) { + continue; + } else if (all_float.find(slot) != all_float.end()) { has_conflict_slot = true; conflict_slot = slot; break; diff --git a/python/paddle/fluid/tests/unittests/test_dataset.py b/python/paddle/fluid/tests/unittests/test_dataset.py index 6dc28b3e783558964c77687306896c7562e093eb..d2b7e508a589fb23407dc598d21a9678cbbdee18 100644 --- a/python/paddle/fluid/tests/unittests/test_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataset.py @@ -264,6 +264,72 @@ class TestDataset(unittest.TestCase): os.remove("./test_in_memory_dataset_masterpatch_a.txt") os.remove("./test_in_memory_dataset_masterpatch_b.txt") + def test_in_memory_dataset_masterpatch1(self): + """ + Testcase for InMemoryDataset from create to run. + """ + with open("test_in_memory_dataset_masterpatch1_a.txt", "w") as f: + data = "1 id1 1 1 2 3 3 4 5 5 5 5 1 1\n" + data += "1 id1 1 2 2 3 4 4 6 6 6 6 1 2\n" + data += "1 id2 1 1 1 1 1 0 1 0\n" + data += "1 id3 1 0 1 0 1 1 1 1\n" + data += "1 id3 1 1 1 1 1 0 1 0\n" + data += "1 id4 1 0 1 0 1 1 1 1\n" + data += "1 id4 1 0 1 0 1 1 1 1\n" + data += "1 id5 1 1 1 1 1 0 1 0\n" + data += "1 id5 1 1 1 1 1 0 1 0\n" + f.write(data) + with open("test_in_memory_dataset_masterpatch1_b.txt", "w") as f: + data = "1 id6 1 4 2 3 3 4 5 5 5 5 1 4\n" + data += "1 id6 1 1 2 3 4 4 6 6 6 6 1 5\n" + data += "1 id6 1 6 2 3 5 4 7 7 7 7 1 6\n" + data += "1 id6 1 7 2 3 6 4 8 8 8 8 1 7\n" + f.write(data) + + slots_vars = [] + train_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + var1 = fluid.layers.data( + name="slot1", shape=[1], dtype="int64", lod_level=0) + var2 = fluid.layers.data( + name="slot2", shape=[1], dtype="int64", lod_level=0) + var3 = fluid.layers.data( + name="slot3", shape=[1], dtype="float32", lod_level=0) + var4 = fluid.layers.data( + name="slot4", shape=[1], dtype="float32", lod_level=0) + slots_vars = [var1, var2, var3, var4] + + dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset") + dataset.set_batch_size(32) + dataset.set_thread(1) + dataset.set_parse_ins_id(True) + dataset.set_filelist([ + "test_in_memory_dataset_masterpatch1_a.txt", + "test_in_memory_dataset_masterpatch1_b.txt" + ]) + dataset.set_pipe_command("cat") + dataset.set_use_var(slots_vars) + dataset.load_into_memory() + dataset.local_shuffle() + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(startup_program) + + for i in range(2): + try: + exe.train_from_dataset(train_program, dataset) + except ImportError as e: + pass + except Exception as e: + self.assertTrue(False) + + dataset.set_merge_by_lineid(2) + dataset.dataset.merge_by_lineid() + + os.remove("./test_in_memory_dataset_masterpatch1_a.txt") + os.remove("./test_in_memory_dataset_masterpatch1_b.txt") + def test_in_memory_dataset_run_2(self): """ Testcase for InMemoryDataset from create to run.