未验证 提交 f4041572 编写于 作者: X xujiaqi01 提交者: GitHub

fix master patch when slot is dense (#21580)

* fix master patch when slot is dense
* test=develop
上级 c05706fe
......@@ -646,10 +646,12 @@ void MultiSlotDataset::MergeByInsId() {
}
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
std::vector<std::string> use_slots;
std::vector<bool> use_slots_is_dense;
for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
const auto& slot = multi_slot_desc.slots(i);
if (slot.is_used()) {
use_slots.push_back(slot.name());
use_slots_is_dense.push_back(slot.is_dense());
}
}
CHECK(multi_output_channel_.size() != 0); // NOLINT
......@@ -679,6 +681,11 @@ void MultiSlotDataset::MergeByInsId() {
std::unordered_set<uint16_t> all_float;
std::unordered_set<uint16_t> local_uint64;
std::unordered_set<uint16_t> local_float;
std::unordered_map<uint16_t, std::vector<FeatureItem>> all_dense_uint64;
std::unordered_map<uint16_t, std::vector<FeatureItem>> all_dense_float;
std::unordered_map<uint16_t, std::vector<FeatureItem>> local_dense_uint64;
std::unordered_map<uint16_t, std::vector<FeatureItem>> local_dense_float;
std::unordered_map<uint16_t, bool> dense_empty;
VLOG(3) << "recs.size() " << recs.size();
for (size_t i = 0; i < recs.size();) {
......@@ -696,6 +703,8 @@ void MultiSlotDataset::MergeByInsId() {
all_int64.clear();
all_float.clear();
all_dense_uint64.clear();
all_dense_float.clear();
bool has_conflict_slot = false;
uint16_t conflict_slot = 0;
......@@ -703,12 +712,61 @@ void MultiSlotDataset::MergeByInsId() {
rec.ins_id_ = recs[i].ins_id_;
rec.content_ = recs[i].content_;
for (size_t k = i; k < j; k++) {
dense_empty.clear();
local_dense_uint64.clear();
local_dense_float.clear();
for (auto& feature : recs[k].uint64_feasigns_) {
uint16_t slot = feature.slot();
if (!use_slots_is_dense[slot]) {
continue;
}
local_dense_uint64[slot].push_back(feature);
if (feature.sign().uint64_feasign_ != 0) {
dense_empty[slot] = false;
} else if (dense_empty.find(slot) == dense_empty.end() &&
all_dense_uint64.find(slot) == all_dense_uint64.end()) {
dense_empty[slot] = true;
}
}
for (auto& feature : recs[k].float_feasigns_) {
uint16_t slot = feature.slot();
if (!use_slots_is_dense[slot]) {
continue;
}
local_dense_float[slot].push_back(feature);
if (fabs(feature.sign().float_feasign_) >= 1e-6) {
dense_empty[slot] = false;
} else if (dense_empty.find(slot) == dense_empty.end() &&
all_dense_float.find(slot) == all_dense_float.end()) {
dense_empty[slot] = true;
}
}
for (auto& p : dense_empty) {
if (local_dense_uint64.find(p.first) != local_dense_uint64.end()) {
all_dense_uint64[p.first] = std::move(local_dense_uint64[p.first]);
} else if (local_dense_float.find(p.first) != local_dense_float.end()) {
all_dense_float[p.first] = std::move(local_dense_float[p.first]);
}
}
}
for (auto& f : all_dense_uint64) {
rec.uint64_feasigns_.insert(rec.uint64_feasigns_.end(), f.second.begin(),
f.second.end());
}
for (auto& f : all_dense_float) {
rec.float_feasigns_.insert(rec.float_feasigns_.end(), f.second.begin(),
f.second.end());
}
for (size_t k = i; k < j; k++) {
local_uint64.clear();
local_float.clear();
for (auto& feature : recs[k].uint64_feasigns_) {
uint16_t slot = feature.slot();
if (all_int64.find(slot) != all_int64.end()) {
if (use_slots_is_dense[slot]) {
continue;
} else if (all_int64.find(slot) != all_int64.end()) {
has_conflict_slot = true;
conflict_slot = slot;
break;
......@@ -723,7 +781,9 @@ void MultiSlotDataset::MergeByInsId() {
for (auto& feature : recs[k].float_feasigns_) {
uint16_t slot = feature.slot();
if (all_float.find(slot) != all_float.end()) {
if (use_slots_is_dense[slot]) {
continue;
} else if (all_float.find(slot) != all_float.end()) {
has_conflict_slot = true;
conflict_slot = slot;
break;
......
......@@ -264,6 +264,72 @@ class TestDataset(unittest.TestCase):
os.remove("./test_in_memory_dataset_masterpatch_a.txt")
os.remove("./test_in_memory_dataset_masterpatch_b.txt")
def test_in_memory_dataset_masterpatch1(self):
"""
Testcase for InMemoryDataset from create to run.
"""
with open("test_in_memory_dataset_masterpatch1_a.txt", "w") as f:
data = "1 id1 1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 id1 1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 id2 1 1 1 1 1 0 1 0\n"
data += "1 id3 1 0 1 0 1 1 1 1\n"
data += "1 id3 1 1 1 1 1 0 1 0\n"
data += "1 id4 1 0 1 0 1 1 1 1\n"
data += "1 id4 1 0 1 0 1 1 1 1\n"
data += "1 id5 1 1 1 1 1 0 1 0\n"
data += "1 id5 1 1 1 1 1 0 1 0\n"
f.write(data)
with open("test_in_memory_dataset_masterpatch1_b.txt", "w") as f:
data = "1 id6 1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 id6 1 1 2 3 4 4 6 6 6 6 1 5\n"
data += "1 id6 1 6 2 3 5 4 7 7 7 7 1 6\n"
data += "1 id6 1 7 2 3 6 4 8 8 8 8 1 7\n"
f.write(data)
slots_vars = []
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
var1 = fluid.layers.data(
name="slot1", shape=[1], dtype="int64", lod_level=0)
var2 = fluid.layers.data(
name="slot2", shape=[1], dtype="int64", lod_level=0)
var3 = fluid.layers.data(
name="slot3", shape=[1], dtype="float32", lod_level=0)
var4 = fluid.layers.data(
name="slot4", shape=[1], dtype="float32", lod_level=0)
slots_vars = [var1, var2, var3, var4]
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_batch_size(32)
dataset.set_thread(1)
dataset.set_parse_ins_id(True)
dataset.set_filelist([
"test_in_memory_dataset_masterpatch1_a.txt",
"test_in_memory_dataset_masterpatch1_b.txt"
])
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
dataset.load_into_memory()
dataset.local_shuffle()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
for i in range(2):
try:
exe.train_from_dataset(train_program, dataset)
except ImportError as e:
pass
except Exception as e:
self.assertTrue(False)
dataset.set_merge_by_lineid(2)
dataset.dataset.merge_by_lineid()
os.remove("./test_in_memory_dataset_masterpatch1_a.txt")
os.remove("./test_in_memory_dataset_masterpatch1_b.txt")
def test_in_memory_dataset_run_2(self):
"""
Testcase for InMemoryDataset from create to run.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册