From 5c65eff6ef3faed880d356a94c4c914a21dd9a35 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sun, 21 Oct 2018 20:46:03 +0800 Subject: [PATCH] update test for ctr data --- paddle/fluid/operators/reader/ctr_reader.cc | 9 +- paddle/fluid/operators/reader/ctr_reader.h | 6 +- .../fluid/operators/reader/ctr_reader_test.cc | 174 +++++++++--------- 3 files changed, 96 insertions(+), 93 deletions(-) diff --git a/paddle/fluid/operators/reader/ctr_reader.cc b/paddle/fluid/operators/reader/ctr_reader.cc index cb86f4c613c..47f2c56c64a 100644 --- a/paddle/fluid/operators/reader/ctr_reader.cc +++ b/paddle/fluid/operators/reader/ctr_reader.cc @@ -168,7 +168,10 @@ void ReadThread(const std::vector& file_list, while (reader.HasNext()) { batch_data.clear(); + batch_data.reserve(batch_size); + batch_label.clear(); + batch_label.reserve(batch_size); // read batch_size data for (int i = 0; i < batch_size; ++i) { @@ -205,7 +208,8 @@ void ReadThread(const std::vector& file_list, int64_t* tensor_data = lod_tensor.mutable_data( framework::make_ddim({1, static_cast(batch_feasign.size())}), platform::CPUPlace()); - memcpy(tensor_data, batch_feasign.data(), batch_feasign.size()); + memcpy(tensor_data, batch_feasign.data(), + batch_feasign.size() * sizeof(int64_t)); lod_datas.push_back(lod_tensor); } @@ -214,7 +218,8 @@ void ReadThread(const std::vector& file_list, int64_t* label_tensor_data = label_tensor.mutable_data( framework::make_ddim({1, static_cast(batch_label.size())}), platform::CPUPlace()); - memcpy(label_tensor_data, batch_label.data(), batch_label.size()); + memcpy(label_tensor_data, batch_label.data(), + batch_label.size() * sizeof(int64_t)); lod_datas.push_back(label_tensor); queue->Push(lod_datas); diff --git a/paddle/fluid/operators/reader/ctr_reader.h b/paddle/fluid/operators/reader/ctr_reader.h index 89f63364c8d..d87f81402fc 100644 --- a/paddle/fluid/operators/reader/ctr_reader.h +++ b/paddle/fluid/operators/reader/ctr_reader.h @@ -55,13 +55,14 @@ class CTRReader : public framework::FileReader { const std::vector& slots, const std::vector& file_list) : batch_size_(batch_size), slots_(slots), file_list_(file_list) { + PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!"); PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty"); thread_num_ = file_list_.size() > thread_num ? thread_num : file_list_.size(); queue_ = queue; SplitFiles(); - for (int i = 0; i < thread_num_; ++i) { + for (size_t i = 0; i < thread_num_; ++i) { read_thread_status_.push_back(Stopped); } } @@ -76,6 +77,7 @@ class CTRReader : public framework::FileReader { void Shutdown() override { VLOG(3) << "Shutdown reader"; + // shutdown should stop all the reader thread for (auto& read_thread : read_threads_) { read_thread->join(); } @@ -108,7 +110,7 @@ class CTRReader : public framework::FileReader { } private: - int thread_num_; + size_t thread_num_; const int batch_size_; const std::vector slots_; const std::vector file_list_; diff --git a/paddle/fluid/operators/reader/ctr_reader_test.cc b/paddle/fluid/operators/reader/ctr_reader_test.cc index 51fbdf2d079..a73d54385e6 100644 --- a/paddle/fluid/operators/reader/ctr_reader_test.cc +++ b/paddle/fluid/operators/reader/ctr_reader_test.cc @@ -14,8 +14,15 @@ #include "paddle/fluid/operators/reader/ctr_reader.h" +#include #include +#include +#include +#include +#include +#include + #include "gtest/gtest.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -25,109 +32,98 @@ using paddle::operators::reader::LoDTensorBlockingQueue; using paddle::operators::reader::LoDTensorBlockingQueueHolder; using paddle::operators::reader::CTRReader; using paddle::framework::LoDTensor; -using paddle::operators::reader::GetTimeInSec; +using paddle::framework::LoD; +using paddle::platform::CPUPlace; + +static void generatedata(const std::vector& data, + const std::string& file_name) { + std::ifstream in(file_name.c_str()); + if (in.good()) { + VLOG(3) << "file " << file_name << " exist, delete it first!"; + remove(file_name.c_str()); + } else { + in.close(); + } + + ogzstream out(file_name.c_str()); + PADDLE_ENFORCE(out.good(), "open file %s failed!", file_name); + for (auto& c : data) { + out << c; + } + out.close(); + PADDLE_ENFORCE(out.good(), "save file %s failed!", file_name); +} TEST(CTR_READER, read_data) { + const std::vector ctr_data = { + "aaaa 1 0 0:6002 1:6003 2:6004 3:6005 4:6006 -1\n", + "bbbb 1 0 5:6003 6:6003 7:6003 8:6004 9:6004 -1\n", + "cccc 1 1 10:6002 11:6002 12:6002 13:6002 14:6002 -2\n", + "dddd 1 0 15:6003 16:6003 17:6003 18:6003 19:6004 -3\n", + "1111 1 1 20:6001 21:6001 22:6001 23:6001 24:6001 12\n", + "2222 1 1 25:6004 26:6004 27:6004 28:6005 29:6005 aa\n", + "3333 1 0 30:6002 31:6003 32:6004 33:6004 34:6005 er\n", + "eeee 1 1 35:6003 36:6003 37:6005 38:6005 39:6005 dd\n", + "ffff 1 1 40:6002 41:6003 42:6004 43:6004 44:6005 66\n", + "gggg 1 1 46:6006 45:6006 47:6003 48:6003 49:6003 ba\n", + }; + std::string gz_file_name = "test_ctr_reader_data.gz"; + generatedata(ctr_data, gz_file_name); + + std::vector label_value = {0, 0, 1, 0, 1, 1, 0, 1, 1, 1}; + + std::vector>> data_slot_6002{ + {{{0, 1, 2}}, {0, 0}}, + {{{0, 5, 6}}, {10, 11, 12, 13, 14, 0}}, + {{{0, 1, 2}}, {0, 0}}, + {{{0, 1, 2}}, {30, 0}}, + {{{0, 1, 2}}, {40, 0}}}; + std::vector>> data_slot_6003{ + {{{0, 1, 4}}, {1, 5, 6, 7}}, + {{{0, 1, 5}}, {0, 15, 16, 17, 18}}, + {{{0, 1, 2}}, {0, 0}}, + {{{0, 1, 3}}, {31, 35, 36}}, + {{{0, 1, 4}}, {41, 47, 48, 49}}}; + LoDTensorBlockingQueueHolder queue_holder; int capacity = 64; queue_holder.InitOnce(capacity, {}, false); std::shared_ptr queue = queue_holder.GetQueue(); - int batch_size = 10; - int thread_num = 3; - std::vector slots = { - "6002", "6003", "6004", "6005", "6006", "6007", "6008", "6009", "6010", - "6011", "6012", "6013", "6014", "6015", "6016", "6017", "6018", "6019", - "6020", "6021", "6023", "6024", "6025", "6026", "6027", "6028", "6029", - "6030", "6031", "6032", "6033", "6034", "6035", "6036", "6037", "6038", - "6039", "6040", "6041", "6042", "6043", "6044", "6045", "6046", "6047", - "6048", "6050", "6051", "6052", "6054", "6055", "6056", "6057", "6058", - "6059", "6060", "6061", "6062", "6063", "6064", "6065", "6066", "6067", - "6068", "6069", "6070", "6071", "6072", "6073", "6074", "6075", "6076", - "6077", "6078", "6079", "6080", "6081", "6082", "6083", "6084", "6085", - "6086", "6087", "6088", "6089", "6090", "6091", "6092", "6093", "6094", - "6095", "6096", "6097", "6098", "6099", "6100", "6101", "6102", "6103", - "6104", "6105", "6106", "6107", "6108", "6109", "6110", "6111", "6112", - "6113", "6114", "6115", "6116", "6117", "6118", "6119", "6120", "6121", - "6122", "6123", "6124", "6125", "6126", "6127", "6128", "6129", "6130", - "6131", "6132", "6133", "6134", "6135", "6136", "6137", "6138", "6139", - "6140", "6141", "6142", "6143", "6144", "6145", "6146", "6147", "6148", - "6149", "6150", "6151", "6152", "6153", "6155", "6156", "6157", "6158", - "6160", "6161", "6162", "6163", "6164", "6165", "6166", "6167", "6168", - "6169", "6170", "6171", "6172", "6173", "6174", "6175", "6176", "6177", - "6178", "6181", "6182", "6183", "6184", "6185", "6186", "6188", "6189", - "6190", "6191", "6192", "6194", "6195", "6196", "6197", "6198", "6199", - "6200", "6201", "6202", "6203", "6204", "6205", "6206", "6207", "6208", - "6209", "6210", "6211", "6212", "6213", "6214", "6215", "6216", "6217", - "6218", "6220", "6222", "6223", "6224", "6225", "6226", "6227", "6228", - "6229", "6230", "6231", "6232", "6233", "6234", "6235", "6236", "6237", - "6238", "6239", "6240", "6241", "6242", "6243", "6244", "6245", "6247", - "6248", "6250", "6251", "6253", "6254", "6255", "6256", "6257", "6258", - "6259", "6260", "6261", "6262", "6263", "6264", "6265", "6350", "6351", - "6352", "6353", "6354", "6355", "6356", "6738", "6739", "6740", "6741", - "6751", "6753", "6754", "6755", "6756", "6757", "6759", "6760", "6763", - "6764", "6765", "6766", "6767", "6768", "6769", "6770", "6806", "6807", - "6808", "6809", "6810", "6811", "6812", "6813", "6814", "6815", "6816", - "6817", "6818", "6819", "6820", "6821", "6822", "6823", "6824", "6825", - "6826", "6827", "6828", "6829", "6830", "6831", "6832", "6833", "6834", - "6835", "6836", "6837", "6838", "6839", "6840", "6841", "6842", "6843", - "6844", "6845", "6846", "6847", "6848", "6849", "6850", "6851", "6852", - "6853", "6854", "6855", "6856", "6857", "6858", "6859", "6860", "6861", - "6862", "6863", "6864", "6865", "6866", "6867", "6868", "6869", "6870", - "6871", "6872", "6873", "6874", "6875", "6876", "6877", "6878", "6879", - "6880", "6881", "6882", "6883", "6884", "6885", "6886", "6887", "6888", - "6889", "6890", "6891", "6892", "6893", "6894", "6895", "6896", "6897", - "6898", "6899", "6900", "6901", "6902", "6903", "6904", "6905", "6906", - "6907", "6908", "6909", "6910", "6911", "6912", "6913", "6914", "6915", - "6916", "6917", "6918", "6919", "6920", "6921", "6922", "6923", "6924", - "6925", "6926", "6927", "6928", "6929", "6930", "6931", "6932", "6933", - "6934", "6935", "6936", "6937", "6938", "6939", "6940", "6941", "6942", - "6943", "6944", "6945", "6946", "6947", "6948", "6949", "6950", "6951", - "6952", "6953", "6954", "6955", "6956", "6957", "6958", "6959", "6960", - "6961", "6962", "6963", "7001", "7002", "7003", "7004", "7005", "7006", - "7007", "7008", "7009", "7010", "7011", "7012", "7013", "7014", "7015", - "7016", "7017", "7018", "7019", "7020", "7021", "7022", "7023", "7024", - "7025", "7026", "7027", "7028", "7029", "7030", "7031", "7032", "7033", - "7034", "7035", "7036", "7037", "7038", "7039", "7040", "7041", "7042", - "7043", "7044", "7045", "7046", "7047", "7048", "7049", "7050", "7051", - "7052", "7053", "7054", "7055", "7056", "7057", "7058", "7060", "7062", - "7063", "7064", "7065", "7066", "7067", "7068", "7069", "7070", "7071", - "7072", "7073", "7074", "7075", "7076", "7077", "7078", "7079", "7080", - "7081", "7082", "7083", "7084", "7085", "7086", "7087", "7088", "7089", - "7090", "7091", "7092", "7093", "7094", "7095", "7096", "7097", "7098", - "7099", "7100", "7101", "7102", "7103", "7104", "7105", "7106", "7107", - "7108", "7109", "7110", "7120", "7122", "7123", "7124", "7125", "7126", - "7127", "7128", "7129", "7131", "7133", "7134", "7135", "7136", "7137", - "7138", "7139", "7140", "7141", "7142", "7143", "7144", "7145", "7146", - "7147", "7148", "7149", "7150", "7151", "7152", "7153", "7154", "7155", - "7156", "7157", "7158", "7159", "7160", "7161", "7162", "7163", "7164", - "7165", "7166", "7167", "7168", "7169", "7170", "7171", "7172", "7173", - "7174", "7175", "7176", "7177", "7178", "7179", "7180", "7181", "7182", - "7183", "7184", "7185", "7186", "7187", "7500", "7501", "7502", "7503", - "7504", "7505", "7506", "7507", "7508", "7509", "7510", "7511", "7512", - "7513", "7514", "7515", "7516", "7517", "7750"}; - std::vector file_list = { - "/Users/qiaolongfei/project/gzip_test/part-00000-A.gz", - "/Users/qiaolongfei/project/gzip_test/part-00001-A.gz", - "/Users/qiaolongfei/project/gzip_test/part-00002-A.gz", - "/Users/qiaolongfei/project/gzip_test/part-00003-A.gz"}; + int batch_size = 2; + int thread_num = 1; + std::vector slots = {"6002", "6003"}; + std::vector file_list; + for (int i = 0; i < thread_num; ++i) { + file_list.push_back(gz_file_name); + } CTRReader reader(queue, batch_size, thread_num, slots, file_list); reader.Start(); - std::cout << "start to reader data" << std::endl; - std::vector out; - int read_batch = 10000; - uint64_t t0 = GetTimeInSec(); - for (int i = 0; i < read_batch; ++i) { + size_t batch_num = std::ceil(ctr_data.size() / batch_size) * thread_num; + + for (size_t i = 0; i < batch_num; ++i) { + std::vector out; reader.ReadNext(&out); - if (i != 0 && i % 100 == 0) { - uint64_t t1 = GetTimeInSec(); - float line_per_s = 100 * batch_size * 1000000 / (t1 - t0); - VLOG(3) << "line_per_second = " << line_per_s; - t0 = GetTimeInSec(); + ASSERT_EQ(out.size(), slots.size() + 1); + auto& label_tensor = out.back(); + ASSERT_EQ(label_tensor.dims(), + paddle::framework::make_ddim({1, batch_size})); + for (size_t j = 0; j < batch_size && i * batch_num + j < ctr_data.size(); + ++j) { + auto& label = label_tensor.data()[j]; + ASSERT_TRUE(label == 0 || label == 1); + ASSERT_EQ(label, label_value[i * batch_size + j]); } + auto& tensor_6002 = out[0]; + ASSERT_EQ(std::get<0>(data_slot_6002[i]), tensor_6002.lod()); + ASSERT_EQ(std::memcmp(std::get<1>(data_slot_6002[i]).data(), + tensor_6002.data(), + tensor_6002.dims()[1] * sizeof(int64_t)), + 0); } + ASSERT_EQ(queue->Size(), 0); } -- GitLab