未验证 提交 d1e1d858 编写于 作者: W wawltor 提交者: GitHub

add the graph batch reader for pslib mode (#24178)

Add the pslib graph batch reader mode, add the test case for this change
上级 80355949
......@@ -813,6 +813,7 @@ void MultiSlotInMemoryDataFeed::Init(
visit_.resize(all_slot_num, false);
pipe_command_ = data_feed_desc.pipe_command();
finish_init_ = true;
input_type_ = data_feed_desc.input_type();
}
void MultiSlotInMemoryDataFeed::GetMsgFromLogKey(const std::string& log_key,
......@@ -1065,8 +1066,27 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(int64_t));
}
auto& slot_offset = offset_[i];
LoD data_lod{slot_offset};
feed_vec_[i]->set_lod(data_lod);
if (this->input_type_ == 0) {
LoD data_lod{slot_offset};
feed_vec_[i]->set_lod(data_lod);
} else if (this->input_type_ == 1) {
if (!use_slots_is_dense_[i]) {
std::vector<size_t> tmp_offset;
PADDLE_ENFORCE_EQ(slot_offset.size(), 2,
platform::errors::InvalidArgument(
"In batch reader, the sparse tensor lod size "
"must be 2, but received %d",
slot_offset.size()));
const auto& max_size = slot_offset[1];
tmp_offset.reserve(max_size + 1);
for (unsigned int k = 0; k <= max_size; k++) {
tmp_offset.emplace_back(k);
}
slot_offset = tmp_offset;
LoD data_lod{slot_offset};
feed_vec_[i]->set_lod(data_lod);
}
}
if (use_slots_is_dense_[i]) {
if (inductive_shape_index_[i] != -1) {
use_slots_shape_[i][inductive_shape_index_[i]] =
......
......@@ -232,6 +232,9 @@ class DataFeed {
std::vector<std::string> ins_id_vec_;
std::vector<std::string> ins_content_vec_;
platform::Place place_;
// The input type of pipe reader, 0 for one sample, 1 for one batch
int input_type_;
};
// PrivateQueueDataFeed is the base virtual class for ohther DataFeeds.
......
......@@ -32,4 +32,5 @@ message DataFeedDesc {
optional int32 thread_num = 5;
optional string rank_offset = 6;
optional int32 pv_batch_size = 7 [ default = 32 ];
optional int32 input_type = 8 [ default = 0 ];
}
......@@ -221,6 +221,9 @@ class DatasetBase(object):
self.dataset.set_filelist(filelist)
self.filelist = filelist
def set_input_type(self, input_type):
self.proto_desc.input_type = input_type
def set_use_var(self, var_list):
"""
Set Variables which you will use.
......
......@@ -601,6 +601,63 @@ class TestDataset(unittest.TestCase):
os.remove("./test_queue_dataset_run_a.txt")
os.remove("./test_queue_dataset_run_b.txt")
def test_queue_dataset_run_3(self):
"""
Testcase for QueueDataset from create to run.
Use CUDAPlace
Use float type id
"""
with open("test_queue_dataset_run_a.txt", "w") as f:
data = "2 1 2 2 5 4 2 2 7 2 1 3\n"
data += "2 6 2 2 1 4 2 2 4 2 2 3\n"
data += "2 5 2 2 9 9 2 2 7 2 1 3\n"
data += "2 7 2 2 1 9 2 3 7 2 5 3\n"
f.write(data)
with open("test_queue_dataset_run_b.txt", "w") as f:
data = "2 1 2 2 5 4 2 2 7 2 1 3\n"
data += "2 6 2 2 1 4 2 2 4 2 2 3\n"
data += "2 5 2 2 9 9 2 2 7 2 1 3\n"
data += "2 7 2 2 1 9 2 3 7 2 5 3\n"
f.write(data)
slots = ["slot1", "slot2", "slot3", "slot4"]
slots_vars = []
for slot in slots:
var = fluid.data(
name=slot, shape=[None, 1], dtype="int64", lod_level=1)
slots_vars.append(var)
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_input_type(1)
dataset.set_batch_size(1)
dataset.set_thread(2)
dataset.set_filelist(
["test_queue_dataset_run_a.txt", "test_queue_dataset_run_b.txt"])
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
dataset.load_into_memory()
exe = fluid.Executor(fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
if self.use_data_loader:
data_loader = fluid.io.DataLoader.from_dataset(dataset,
fluid.cpu_places(),
self.drop_last)
for i in range(self.epoch_num):
for data in data_loader():
exe.run(fluid.default_main_program(), feed=data)
else:
for i in range(self.epoch_num):
try:
exe.train_from_dataset(fluid.default_main_program(),
dataset)
except Exception as e:
self.assertTrue(False)
os.remove("./test_queue_dataset_run_a.txt")
os.remove("./test_queue_dataset_run_b.txt")
class TestDatasetWithDataLoader(TestDataset):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册