diff --git a/paddle/gserver/tests/test_PyDataProvider2.cpp b/paddle/gserver/tests/test_PyDataProvider2.cpp index 24aa73910f254e636dfb88182552fe47c12c8543..6674e6b87c6acdda78416eb1f8bf015e642b967f 100644 --- a/paddle/gserver/tests/test_PyDataProvider2.cpp +++ b/paddle/gserver/tests/test_PyDataProvider2.cpp @@ -15,16 +15,16 @@ limitations under the License. */ #ifndef PADDLE_NO_PYTHON #include #include -#include "paddle/utils/Util.h" -#include "paddle/utils/PythonUtil.h" #include "paddle/gserver/dataproviders/DataProvider.h" +#include "paddle/utils/PythonUtil.h" +#include "paddle/utils/Util.h" P_DEFINE_string(train_list, "unittest.list", "file list for unittest"); namespace paddle { namespace unittest { namespace pydp2 { -extern void setOnPoolFilledHook(const std::function& func); +extern void setOnPoolFilledHook(const std::function &func); extern void clearOnPoolFilledHook(); } // namespace pydp2 @@ -33,8 +33,8 @@ extern void clearOnPoolFilledHook(); const paddle::real epsilon = 1e-5; -static inline int64_t readDataBatch(paddle::DataBatch* batch, - const std::string& funcName, +static inline int64_t readDataBatch(paddle::DataBatch *batch, + const std::string &funcName, int64_t batchSize = 65535) { paddle::DataConfig config; config.set_type("py2"); @@ -143,7 +143,7 @@ TEST(PyDataProvider2, init_hook) { paddle::DataBatch batch; int64_t num = provider->getNextBatchInternal(100000, &batch); ASSERT_EQ(num, 200); - auto& mat = batch.getStreams()[0].value; + auto &mat = batch.getStreams()[0].value; ASSERT_EQ((size_t)mat->getWidth(), (size_t)20); for (size_t i = 0; i < 200; ++i) { for (size_t j = 0; j < 20; ++j) { @@ -170,7 +170,7 @@ TEST(PyDataProvider2, sparse_no_value_no_seq) { CHECK(csm != nullptr); for (int i = 0; i < 200; ++i) { CHECK_EQ(csm->getColNum(i), (size_t)10); - int* cols = csm->getRowCols(i); + int *cols = csm->getRowCols(i); for (int j = 0; j < 10; ++j) { CHECK_EQ(cols[j], (i + 1) * (j + 1)); } @@ -185,8 +185,8 @@ TEST(PyDataProvider2, sparse_value_no_seq) { CHECK(csm != nullptr); for (int i = 0; i < 200; ++i) { CHECK_EQ(csm->getColNum(i), (size_t)10); - int* cols = csm->getRowCols(i); - real* dat = csm->getRowValues(i); + int *cols = csm->getRowCols(i); + real *dat = csm->getRowValues(i); for (int j = 0; j < 10; ++j) { EXPECT_EQ(cols[j], (i + 1) * (j + 1)); EXPECT_EQ(dat[j], real(j) / real(i + 1)); @@ -197,7 +197,7 @@ TEST(PyDataProvider2, sparse_value_no_seq) { TEST(PyDataProvider2, index_seq) { paddle::DataBatch batch; CHECK_EQ(readDataBatch(&batch, "test_index_seq"), 200); - auto& arg = batch.getStreams()[0]; + auto &arg = batch.getStreams()[0]; CHECK_EQ((int)arg.ids->getSize(), (200 + 1) * 200 / 2); size_t tmp = 0; for (size_t i = 0; i < 200; ++i) { // CHECK DATA CORRECT @@ -219,7 +219,7 @@ TEST(PyDataProvider2, index_seq) { TEST(PyDataProvider2, index_sub_seq) { paddle::DataBatch batch; ASSERT_EQ(readDataBatch(&batch, "test_index_sub_seq"), 200); - auto& arg = batch.getStreams()[0]; + auto &arg = batch.getStreams()[0]; size_t tmp = 0; for (size_t i = 0; i < 200; ++i) { for (size_t j = 0; j < i + 1; ++j) { @@ -268,7 +268,7 @@ TEST(PyDataProvider2, min_pool_size) { } }); while (true) { - size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch); + int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch); if (realBatchSize) { totalData -= realBatchSize; } else { @@ -291,7 +291,7 @@ TEST(PyDataProvider2, can_over_batch_size) { provider->reset(); constexpr size_t batchSize = 100; while (true) { - size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch); + int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch); if (realBatchSize) { CHECK_LE(realBatchSize, batchSize); } else { @@ -317,12 +317,12 @@ TEST(PyDataProvider2, input_order) { provider->reset(); constexpr size_t batchSize = 100; while (true) { - size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch); + int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch); if (!realBatchSize) { break; } - ASSERT_EQ(batch.getStreams().size(), (size_t)2); - for (size_t i = 0; i < realBatchSize; ++i) { + ASSERT_EQ(batch.getStreams().size(), static_cast(2)); + for (int64_t i = 0; i < realBatchSize; ++i) { ASSERT_EQ(batch.getStream(0).ids->getData()[i], 0); ASSERT_EQ(batch.getStream(1).ids->getData()[i], 1); } @@ -341,11 +341,11 @@ TEST(PyDataProvider2, test_check) { paddle::DataProvider::create(config, false)); provider->reset(); while (true) { - size_t realBatchSize = provider->getNextBatchInternal(100, &batch); + int64_t realBatchSize = provider->getNextBatchInternal(100, &batch); if (!realBatchSize) { break; } else { - auto& ivec = batch.getStream(0).ids; + auto &ivec = batch.getStream(0).ids; for (size_t i = 0; i < ivec->getSize(); ++i) { CHECK_LT(ivec->getData()[i], 10); } @@ -370,7 +370,30 @@ TEST(PyDataProvider2, multiThread) { provider.reset(); } -int main(int argc, char** argv) { +TEST(PyDataProvider2, minPoolSizeWithCache) { + paddle::DataConfig config; + config.set_type("py2"); + config.set_files(FLAGS_train_list.c_str()); + config.set_load_data_module("test_PyDataProvider2"); + config.set_load_data_object("test_min_pool_size_with_cache"); + config.set_async_load_data(true); + + std::unique_ptr provider( + paddle::DataProvider::create(config, false)); + + paddle::DataBatch batch; + + for (int i = 0; i < 10; ++i) { + provider->reset(); + int64_t sum = 0; + while (int64_t actualNum = provider->getNextBatch(100, &batch)) { + sum += actualNum; + } + ASSERT_EQ(1 << 20, sum); + } +} + +int main(int argc, char **argv) { testing::InitGoogleTest(&argc, argv); paddle::initMain(argc, argv); paddle::initPython(argc, argv); diff --git a/paddle/gserver/tests/test_PyDataProvider2.py b/paddle/gserver/tests/test_PyDataProvider2.py index 7ca30198fb1d0e7384db2c28524c7898dcd27e50..bf23c52fd78455c8ca7e480aa87438ee04ab2a74 100644 --- a/paddle/gserver/tests/test_PyDataProvider2.py +++ b/paddle/gserver/tests/test_PyDataProvider2.py @@ -111,3 +111,13 @@ def test_check(settings, filename): if i < 10: yield_good_value = True yield i + + +@provider( + input_types=[index_slot(10)], + min_pool_size=1000, + cache=CacheType.CACHE_PASS_IN_MEM, ) +def test_min_pool_size_with_cache(settings, filename): + import random + for _ in xrange(2**20): + yield random.randint(0, 9)