diff --git a/paddle/fluid/operators/reader/ctr_reader_test.cc b/paddle/fluid/operators/reader/ctr_reader_test.cc index 190182f45c5935c1e074950cef5ddb5e295d1a1f..731122e3c16ce2d7664365a6acc443fbf9e3aff7 100644 --- a/paddle/fluid/operators/reader/ctr_reader_test.cc +++ b/paddle/fluid/operators/reader/ctr_reader_test.cc @@ -55,6 +55,38 @@ static void generatedata(const std::vector& data, PADDLE_ENFORCE(out.good(), "save file %s failed!", file_name); } +static inline void check_all_data( + const std::vector& ctr_data, + const std::vector& slots, const std::vector& label_dims, + const std::vector& label_value, + const std::vector>>& data_slot_6002, + const std::vector>>& data_slot_6003, + size_t batch_num, size_t batch_size, + std::shared_ptr queue, CTRReader* reader) { + std::vector out; + for (size_t i = 0; i < batch_num; ++i) { + reader->ReadNext(&out); + ASSERT_EQ(out.size(), slots.size() + 1); + auto& label_tensor = out.back(); + ASSERT_EQ(label_tensor.dims(), label_dims[i]); + for (size_t j = 0; j < batch_size && i * batch_num + j < ctr_data.size(); + ++j) { + auto& label = label_tensor.data()[j]; + ASSERT_TRUE(label == 0 || label == 1); + ASSERT_EQ(label, label_value[i * batch_size + j]); + } + auto& tensor_6002 = out[0]; + ASSERT_EQ(std::get<0>(data_slot_6002[i]), tensor_6002.lod()); + ASSERT_EQ(std::memcmp(std::get<1>(data_slot_6002[i]).data(), + tensor_6002.data(), + tensor_6002.dims()[1] * sizeof(int64_t)), + 0); + } + reader->ReadNext(&out); + ASSERT_EQ(out.size(), 0); + ASSERT_EQ(queue->Size(), 0); +} + TEST(CTR_READER, read_data) { const std::vector ctr_data = { "aaaa 1 0 0:6002 1:6003 2:6004 3:6005 4:6006 -1\n", @@ -103,35 +135,15 @@ TEST(CTR_READER, read_data) { CTRReader reader(queue, batch_size, thread_num, slots, file_list); reader.Start(); - size_t batch_num = std::ceil(static_cast(ctr_data.size()) / batch_size) * thread_num; + check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002, + data_slot_6003, batch_num, batch_size, queue, &reader); - std::vector out; - for (size_t i = 0; i < batch_num; ++i) { - reader.ReadNext(&out); - ASSERT_EQ(out.size(), slots.size() + 1); - auto& label_tensor = out.back(); - ASSERT_EQ(label_tensor.dims(), label_dims[i]); - for (size_t j = 0; j < batch_size && i * batch_num + j < ctr_data.size(); - ++j) { - auto& label = label_tensor.data()[j]; - ASSERT_TRUE(label == 0 || label == 1); - ASSERT_EQ(label, label_value[i * batch_size + j]); - } - auto& tensor_6002 = out[0]; - ASSERT_EQ(std::get<0>(data_slot_6002[i]), tensor_6002.lod()); - ASSERT_EQ(std::memcmp(std::get<1>(data_slot_6002[i]).data(), - tensor_6002.data(), - tensor_6002.dims()[1] * sizeof(int64_t)), - 0); - } - reader.ReadNext(&out); - ASSERT_EQ(out.size(), 0); - ASSERT_EQ(queue->Size(), 0); reader.Shutdown(); reader.Start(); + check_all_data(ctr_data, slots, label_dims, label_value, data_slot_6002, + data_slot_6003, batch_num, batch_size, queue, &reader); reader.Shutdown(); - ASSERT_EQ(queue->Size(), 5); }