提交 15393353 编写于 作者: Y Yu Yang

Add unittest related #653

* But not reproduce the problem.
上级 8cd59b8e
...@@ -15,16 +15,16 @@ limitations under the License. */ ...@@ -15,16 +15,16 @@ limitations under the License. */
#ifndef PADDLE_NO_PYTHON #ifndef PADDLE_NO_PYTHON
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <fstream> #include <fstream>
#include "paddle/utils/Util.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/gserver/dataproviders/DataProvider.h" #include "paddle/gserver/dataproviders/DataProvider.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/utils/Util.h"
P_DEFINE_string(train_list, "unittest.list", "file list for unittest"); P_DEFINE_string(train_list, "unittest.list", "file list for unittest");
namespace paddle { namespace paddle {
namespace unittest { namespace unittest {
namespace pydp2 { namespace pydp2 {
extern void setOnPoolFilledHook(const std::function<void(size_t)>& func); extern void setOnPoolFilledHook(const std::function<void(size_t)> &func);
extern void clearOnPoolFilledHook(); extern void clearOnPoolFilledHook();
} // namespace pydp2 } // namespace pydp2
...@@ -33,8 +33,8 @@ extern void clearOnPoolFilledHook(); ...@@ -33,8 +33,8 @@ extern void clearOnPoolFilledHook();
const paddle::real epsilon = 1e-5; const paddle::real epsilon = 1e-5;
static inline int64_t readDataBatch(paddle::DataBatch* batch, static inline int64_t readDataBatch(paddle::DataBatch *batch,
const std::string& funcName, const std::string &funcName,
int64_t batchSize = 65535) { int64_t batchSize = 65535) {
paddle::DataConfig config; paddle::DataConfig config;
config.set_type("py2"); config.set_type("py2");
...@@ -143,7 +143,7 @@ TEST(PyDataProvider2, init_hook) { ...@@ -143,7 +143,7 @@ TEST(PyDataProvider2, init_hook) {
paddle::DataBatch batch; paddle::DataBatch batch;
int64_t num = provider->getNextBatchInternal(100000, &batch); int64_t num = provider->getNextBatchInternal(100000, &batch);
ASSERT_EQ(num, 200); ASSERT_EQ(num, 200);
auto& mat = batch.getStreams()[0].value; auto &mat = batch.getStreams()[0].value;
ASSERT_EQ((size_t)mat->getWidth(), (size_t)20); ASSERT_EQ((size_t)mat->getWidth(), (size_t)20);
for (size_t i = 0; i < 200; ++i) { for (size_t i = 0; i < 200; ++i) {
for (size_t j = 0; j < 20; ++j) { for (size_t j = 0; j < 20; ++j) {
...@@ -170,7 +170,7 @@ TEST(PyDataProvider2, sparse_no_value_no_seq) { ...@@ -170,7 +170,7 @@ TEST(PyDataProvider2, sparse_no_value_no_seq) {
CHECK(csm != nullptr); CHECK(csm != nullptr);
for (int i = 0; i < 200; ++i) { for (int i = 0; i < 200; ++i) {
CHECK_EQ(csm->getColNum(i), (size_t)10); CHECK_EQ(csm->getColNum(i), (size_t)10);
int* cols = csm->getRowCols(i); int *cols = csm->getRowCols(i);
for (int j = 0; j < 10; ++j) { for (int j = 0; j < 10; ++j) {
CHECK_EQ(cols[j], (i + 1) * (j + 1)); CHECK_EQ(cols[j], (i + 1) * (j + 1));
} }
...@@ -185,8 +185,8 @@ TEST(PyDataProvider2, sparse_value_no_seq) { ...@@ -185,8 +185,8 @@ TEST(PyDataProvider2, sparse_value_no_seq) {
CHECK(csm != nullptr); CHECK(csm != nullptr);
for (int i = 0; i < 200; ++i) { for (int i = 0; i < 200; ++i) {
CHECK_EQ(csm->getColNum(i), (size_t)10); CHECK_EQ(csm->getColNum(i), (size_t)10);
int* cols = csm->getRowCols(i); int *cols = csm->getRowCols(i);
real* dat = csm->getRowValues(i); real *dat = csm->getRowValues(i);
for (int j = 0; j < 10; ++j) { for (int j = 0; j < 10; ++j) {
EXPECT_EQ(cols[j], (i + 1) * (j + 1)); EXPECT_EQ(cols[j], (i + 1) * (j + 1));
EXPECT_EQ(dat[j], real(j) / real(i + 1)); EXPECT_EQ(dat[j], real(j) / real(i + 1));
...@@ -197,7 +197,7 @@ TEST(PyDataProvider2, sparse_value_no_seq) { ...@@ -197,7 +197,7 @@ TEST(PyDataProvider2, sparse_value_no_seq) {
TEST(PyDataProvider2, index_seq) { TEST(PyDataProvider2, index_seq) {
paddle::DataBatch batch; paddle::DataBatch batch;
CHECK_EQ(readDataBatch(&batch, "test_index_seq"), 200); CHECK_EQ(readDataBatch(&batch, "test_index_seq"), 200);
auto& arg = batch.getStreams()[0]; auto &arg = batch.getStreams()[0];
CHECK_EQ((int)arg.ids->getSize(), (200 + 1) * 200 / 2); CHECK_EQ((int)arg.ids->getSize(), (200 + 1) * 200 / 2);
size_t tmp = 0; size_t tmp = 0;
for (size_t i = 0; i < 200; ++i) { // CHECK DATA CORRECT for (size_t i = 0; i < 200; ++i) { // CHECK DATA CORRECT
...@@ -219,7 +219,7 @@ TEST(PyDataProvider2, index_seq) { ...@@ -219,7 +219,7 @@ TEST(PyDataProvider2, index_seq) {
TEST(PyDataProvider2, index_sub_seq) { TEST(PyDataProvider2, index_sub_seq) {
paddle::DataBatch batch; paddle::DataBatch batch;
ASSERT_EQ(readDataBatch(&batch, "test_index_sub_seq"), 200); ASSERT_EQ(readDataBatch(&batch, "test_index_sub_seq"), 200);
auto& arg = batch.getStreams()[0]; auto &arg = batch.getStreams()[0];
size_t tmp = 0; size_t tmp = 0;
for (size_t i = 0; i < 200; ++i) { for (size_t i = 0; i < 200; ++i) {
for (size_t j = 0; j < i + 1; ++j) { for (size_t j = 0; j < i + 1; ++j) {
...@@ -268,7 +268,7 @@ TEST(PyDataProvider2, min_pool_size) { ...@@ -268,7 +268,7 @@ TEST(PyDataProvider2, min_pool_size) {
} }
}); });
while (true) { while (true) {
size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch); int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
if (realBatchSize) { if (realBatchSize) {
totalData -= realBatchSize; totalData -= realBatchSize;
} else { } else {
...@@ -291,7 +291,7 @@ TEST(PyDataProvider2, can_over_batch_size) { ...@@ -291,7 +291,7 @@ TEST(PyDataProvider2, can_over_batch_size) {
provider->reset(); provider->reset();
constexpr size_t batchSize = 100; constexpr size_t batchSize = 100;
while (true) { while (true) {
size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch); int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
if (realBatchSize) { if (realBatchSize) {
CHECK_LE(realBatchSize, batchSize); CHECK_LE(realBatchSize, batchSize);
} else { } else {
...@@ -317,12 +317,12 @@ TEST(PyDataProvider2, input_order) { ...@@ -317,12 +317,12 @@ TEST(PyDataProvider2, input_order) {
provider->reset(); provider->reset();
constexpr size_t batchSize = 100; constexpr size_t batchSize = 100;
while (true) { while (true) {
size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch); int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
if (!realBatchSize) { if (!realBatchSize) {
break; break;
} }
ASSERT_EQ(batch.getStreams().size(), (size_t)2); ASSERT_EQ(batch.getStreams().size(), static_cast<size_t>(2));
for (size_t i = 0; i < realBatchSize; ++i) { for (int64_t i = 0; i < realBatchSize; ++i) {
ASSERT_EQ(batch.getStream(0).ids->getData()[i], 0); ASSERT_EQ(batch.getStream(0).ids->getData()[i], 0);
ASSERT_EQ(batch.getStream(1).ids->getData()[i], 1); ASSERT_EQ(batch.getStream(1).ids->getData()[i], 1);
} }
...@@ -341,11 +341,11 @@ TEST(PyDataProvider2, test_check) { ...@@ -341,11 +341,11 @@ TEST(PyDataProvider2, test_check) {
paddle::DataProvider::create(config, false)); paddle::DataProvider::create(config, false));
provider->reset(); provider->reset();
while (true) { while (true) {
size_t realBatchSize = provider->getNextBatchInternal(100, &batch); int64_t realBatchSize = provider->getNextBatchInternal(100, &batch);
if (!realBatchSize) { if (!realBatchSize) {
break; break;
} else { } else {
auto& ivec = batch.getStream(0).ids; auto &ivec = batch.getStream(0).ids;
for (size_t i = 0; i < ivec->getSize(); ++i) { for (size_t i = 0; i < ivec->getSize(); ++i) {
CHECK_LT(ivec->getData()[i], 10); CHECK_LT(ivec->getData()[i], 10);
} }
...@@ -370,7 +370,30 @@ TEST(PyDataProvider2, multiThread) { ...@@ -370,7 +370,30 @@ TEST(PyDataProvider2, multiThread) {
provider.reset(); provider.reset();
} }
int main(int argc, char** argv) { TEST(PyDataProvider2, minPoolSizeWithCache) {
paddle::DataConfig config;
config.set_type("py2");
config.set_files(FLAGS_train_list.c_str());
config.set_load_data_module("test_PyDataProvider2");
config.set_load_data_object("test_min_pool_size_with_cache");
config.set_async_load_data(true);
std::unique_ptr<paddle::DataProvider> provider(
paddle::DataProvider::create(config, false));
paddle::DataBatch batch;
for (int i = 0; i < 10; ++i) {
provider->reset();
int64_t sum = 0;
while (int64_t actualNum = provider->getNextBatch(100, &batch)) {
sum += actualNum;
}
ASSERT_EQ(1 << 20, sum);
}
}
int main(int argc, char **argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
paddle::initMain(argc, argv); paddle::initMain(argc, argv);
paddle::initPython(argc, argv); paddle::initPython(argc, argv);
......
...@@ -111,3 +111,13 @@ def test_check(settings, filename): ...@@ -111,3 +111,13 @@ def test_check(settings, filename):
if i < 10: if i < 10:
yield_good_value = True yield_good_value = True
yield i yield i
@provider(
input_types=[index_slot(10)],
min_pool_size=1000,
cache=CacheType.CACHE_PASS_IN_MEM, )
def test_min_pool_size_with_cache(settings, filename):
import random
for _ in xrange(2**20):
yield random.randint(0, 9)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册