提交 d37b9797 编写于 作者: Q Qiao Longfei

update test

上级 4051fb36
......@@ -55,6 +55,38 @@ static void generatedata(const std::vector<std::string>& data,
PADDLE_ENFORCE(out.good(), "save file %s failed!", file_name);
}
static inline void check_all_data(
const std::vector<std::string>& ctr_data,
const std::vector<std::string>& slots, const std::vector<DDim>& label_dims,
const std::vector<int64_t>& label_value,
const std::vector<std::tuple<LoD, std::vector<int64_t>>>& data_slot_6002,
const std::vector<std::tuple<LoD, std::vector<int64_t>>>& data_slot_6003,
size_t batch_num, size_t batch_size,
std::shared_ptr<LoDTensorBlockingQueue> queue, CTRReader* reader) {
std::vector<LoDTensor> out;
for (size_t i = 0; i < batch_num; ++i) {
reader->ReadNext(&out);
ASSERT_EQ(out.size(), slots.size() + 1);
auto& label_tensor = out.back();
ASSERT_EQ(label_tensor.dims(), label_dims[i]);
for (size_t j = 0; j < batch_size && i * batch_num + j < ctr_data.size();
++j) {
auto& label = label_tensor.data<int64_t>()[j];
ASSERT_TRUE(label == 0 || label == 1);
ASSERT_EQ(label, label_value[i * batch_size + j]);
}
auto& tensor_6002 = out[0];
ASSERT_EQ(std::get<0>(data_slot_6002[i]), tensor_6002.lod());
ASSERT_EQ(std::memcmp(std::get<1>(data_slot_6002[i]).data(),
tensor_6002.data<int64_t>(),
tensor_6002.dims()[1] * sizeof(int64_t)),
0);
}
reader->ReadNext(&out);
ASSERT_EQ(out.size(), 0);
ASSERT_EQ(queue->Size(), 0);
}
TEST(CTR_READER, read_data) {
const std::vector<std::string> ctr_data = {
"aaaa 1 0 0:6002 1:6003 2:6004 3:6005 4:6006 -1\n",
......@@ -103,35 +135,15 @@ TEST(CTR_READER, read_data) {
CTRReader reader(queue, batch_size, thread_num, slots, file_list);
reader.Start();
size_t batch_num =
std::ceil(static_cast<float>(ctr_data.size()) / batch_size) * thread_num;
check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002,
data_slot_6003, batch_num, batch_size, queue, &reader);
std::vector<LoDTensor> out;
for (size_t i = 0; i < batch_num; ++i) {
reader.ReadNext(&out);
ASSERT_EQ(out.size(), slots.size() + 1);
auto& label_tensor = out.back();
ASSERT_EQ(label_tensor.dims(), label_dims[i]);
for (size_t j = 0; j < batch_size && i * batch_num + j < ctr_data.size();
++j) {
auto& label = label_tensor.data<int64_t>()[j];
ASSERT_TRUE(label == 0 || label == 1);
ASSERT_EQ(label, label_value[i * batch_size + j]);
}
auto& tensor_6002 = out[0];
ASSERT_EQ(std::get<0>(data_slot_6002[i]), tensor_6002.lod());
ASSERT_EQ(std::memcmp(std::get<1>(data_slot_6002[i]).data(),
tensor_6002.data<int64_t>(),
tensor_6002.dims()[1] * sizeof(int64_t)),
0);
}
reader.ReadNext(&out);
ASSERT_EQ(out.size(), 0);
ASSERT_EQ(queue->Size(), 0);
reader.Shutdown();
reader.Start();
check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002,
data_slot_6003, batch_num, batch_size, queue, &reader);
reader.Shutdown();
ASSERT_EQ(queue->Size(), 5);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册