Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
adc23f61
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
adc23f61
编写于
12月 07, 2016
作者:
G
gangliao
提交者:
GitHub
12月 07, 2016
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request
#752
from reyoung/feature/fix_data_loss_in_pydp2
Add unittest related #653
上级
b6d036ab
15393353
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
52 addition
and
19 deletion
+52
-19
paddle/gserver/tests/test_PyDataProvider2.cpp
paddle/gserver/tests/test_PyDataProvider2.cpp
+42
-19
paddle/gserver/tests/test_PyDataProvider2.py
paddle/gserver/tests/test_PyDataProvider2.py
+10
-0
未找到文件。
paddle/gserver/tests/test_PyDataProvider2.cpp
浏览文件 @
adc23f61
...
...
@@ -15,16 +15,16 @@ limitations under the License. */
#ifndef PADDLE_NO_PYTHON
#include <gtest/gtest.h>
#include <fstream>
#include "paddle/utils/Util.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/gserver/dataproviders/DataProvider.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/utils/Util.h"
P_DEFINE_string
(
train_list
,
"unittest.list"
,
"file list for unittest"
);
namespace
paddle
{
namespace
unittest
{
namespace
pydp2
{
extern
void
setOnPoolFilledHook
(
const
std
::
function
<
void
(
size_t
)
>
&
func
);
extern
void
setOnPoolFilledHook
(
const
std
::
function
<
void
(
size_t
)
>
&
func
);
extern
void
clearOnPoolFilledHook
();
}
// namespace pydp2
...
...
@@ -33,8 +33,8 @@ extern void clearOnPoolFilledHook();
const
paddle
::
real
epsilon
=
1e-5
;
static
inline
int64_t
readDataBatch
(
paddle
::
DataBatch
*
batch
,
const
std
::
string
&
funcName
,
static
inline
int64_t
readDataBatch
(
paddle
::
DataBatch
*
batch
,
const
std
::
string
&
funcName
,
int64_t
batchSize
=
65535
)
{
paddle
::
DataConfig
config
;
config
.
set_type
(
"py2"
);
...
...
@@ -143,7 +143,7 @@ TEST(PyDataProvider2, init_hook) {
paddle
::
DataBatch
batch
;
int64_t
num
=
provider
->
getNextBatchInternal
(
100000
,
&
batch
);
ASSERT_EQ
(
num
,
200
);
auto
&
mat
=
batch
.
getStreams
()[
0
].
value
;
auto
&
mat
=
batch
.
getStreams
()[
0
].
value
;
ASSERT_EQ
((
size_t
)
mat
->
getWidth
(),
(
size_t
)
20
);
for
(
size_t
i
=
0
;
i
<
200
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
20
;
++
j
)
{
...
...
@@ -170,7 +170,7 @@ TEST(PyDataProvider2, sparse_no_value_no_seq) {
CHECK
(
csm
!=
nullptr
);
for
(
int
i
=
0
;
i
<
200
;
++
i
)
{
CHECK_EQ
(
csm
->
getColNum
(
i
),
(
size_t
)
10
);
int
*
cols
=
csm
->
getRowCols
(
i
);
int
*
cols
=
csm
->
getRowCols
(
i
);
for
(
int
j
=
0
;
j
<
10
;
++
j
)
{
CHECK_EQ
(
cols
[
j
],
(
i
+
1
)
*
(
j
+
1
));
}
...
...
@@ -185,8 +185,8 @@ TEST(PyDataProvider2, sparse_value_no_seq) {
CHECK
(
csm
!=
nullptr
);
for
(
int
i
=
0
;
i
<
200
;
++
i
)
{
CHECK_EQ
(
csm
->
getColNum
(
i
),
(
size_t
)
10
);
int
*
cols
=
csm
->
getRowCols
(
i
);
real
*
dat
=
csm
->
getRowValues
(
i
);
int
*
cols
=
csm
->
getRowCols
(
i
);
real
*
dat
=
csm
->
getRowValues
(
i
);
for
(
int
j
=
0
;
j
<
10
;
++
j
)
{
EXPECT_EQ
(
cols
[
j
],
(
i
+
1
)
*
(
j
+
1
));
EXPECT_EQ
(
dat
[
j
],
real
(
j
)
/
real
(
i
+
1
));
...
...
@@ -197,7 +197,7 @@ TEST(PyDataProvider2, sparse_value_no_seq) {
TEST
(
PyDataProvider2
,
index_seq
)
{
paddle
::
DataBatch
batch
;
CHECK_EQ
(
readDataBatch
(
&
batch
,
"test_index_seq"
),
200
);
auto
&
arg
=
batch
.
getStreams
()[
0
];
auto
&
arg
=
batch
.
getStreams
()[
0
];
CHECK_EQ
((
int
)
arg
.
ids
->
getSize
(),
(
200
+
1
)
*
200
/
2
);
size_t
tmp
=
0
;
for
(
size_t
i
=
0
;
i
<
200
;
++
i
)
{
// CHECK DATA CORRECT
...
...
@@ -219,7 +219,7 @@ TEST(PyDataProvider2, index_seq) {
TEST
(
PyDataProvider2
,
index_sub_seq
)
{
paddle
::
DataBatch
batch
;
ASSERT_EQ
(
readDataBatch
(
&
batch
,
"test_index_sub_seq"
),
200
);
auto
&
arg
=
batch
.
getStreams
()[
0
];
auto
&
arg
=
batch
.
getStreams
()[
0
];
size_t
tmp
=
0
;
for
(
size_t
i
=
0
;
i
<
200
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
i
+
1
;
++
j
)
{
...
...
@@ -268,7 +268,7 @@ TEST(PyDataProvider2, min_pool_size) {
}
});
while
(
true
)
{
size
_t
realBatchSize
=
provider
->
getNextBatchInternal
(
batchSize
,
&
batch
);
int64
_t
realBatchSize
=
provider
->
getNextBatchInternal
(
batchSize
,
&
batch
);
if
(
realBatchSize
)
{
totalData
-=
realBatchSize
;
}
else
{
...
...
@@ -291,7 +291,7 @@ TEST(PyDataProvider2, can_over_batch_size) {
provider
->
reset
();
constexpr
size_t
batchSize
=
100
;
while
(
true
)
{
size
_t
realBatchSize
=
provider
->
getNextBatchInternal
(
batchSize
,
&
batch
);
int64
_t
realBatchSize
=
provider
->
getNextBatchInternal
(
batchSize
,
&
batch
);
if
(
realBatchSize
)
{
CHECK_LE
(
realBatchSize
,
batchSize
);
}
else
{
...
...
@@ -317,12 +317,12 @@ TEST(PyDataProvider2, input_order) {
provider
->
reset
();
constexpr
size_t
batchSize
=
100
;
while
(
true
)
{
size
_t
realBatchSize
=
provider
->
getNextBatchInternal
(
batchSize
,
&
batch
);
int64
_t
realBatchSize
=
provider
->
getNextBatchInternal
(
batchSize
,
&
batch
);
if
(
!
realBatchSize
)
{
break
;
}
ASSERT_EQ
(
batch
.
getStreams
().
size
(),
(
size_t
)
2
);
for
(
size
_t
i
=
0
;
i
<
realBatchSize
;
++
i
)
{
ASSERT_EQ
(
batch
.
getStreams
().
size
(),
static_cast
<
size_t
>
(
2
)
);
for
(
int64
_t
i
=
0
;
i
<
realBatchSize
;
++
i
)
{
ASSERT_EQ
(
batch
.
getStream
(
0
).
ids
->
getData
()[
i
],
0
);
ASSERT_EQ
(
batch
.
getStream
(
1
).
ids
->
getData
()[
i
],
1
);
}
...
...
@@ -341,11 +341,11 @@ TEST(PyDataProvider2, test_check) {
paddle
::
DataProvider
::
create
(
config
,
false
));
provider
->
reset
();
while
(
true
)
{
size
_t
realBatchSize
=
provider
->
getNextBatchInternal
(
100
,
&
batch
);
int64
_t
realBatchSize
=
provider
->
getNextBatchInternal
(
100
,
&
batch
);
if
(
!
realBatchSize
)
{
break
;
}
else
{
auto
&
ivec
=
batch
.
getStream
(
0
).
ids
;
auto
&
ivec
=
batch
.
getStream
(
0
).
ids
;
for
(
size_t
i
=
0
;
i
<
ivec
->
getSize
();
++
i
)
{
CHECK_LT
(
ivec
->
getData
()[
i
],
10
);
}
...
...
@@ -370,7 +370,30 @@ TEST(PyDataProvider2, multiThread) {
provider
.
reset
();
}
int
main
(
int
argc
,
char
**
argv
)
{
TEST
(
PyDataProvider2
,
minPoolSizeWithCache
)
{
paddle
::
DataConfig
config
;
config
.
set_type
(
"py2"
);
config
.
set_files
(
FLAGS_train_list
.
c_str
());
config
.
set_load_data_module
(
"test_PyDataProvider2"
);
config
.
set_load_data_object
(
"test_min_pool_size_with_cache"
);
config
.
set_async_load_data
(
true
);
std
::
unique_ptr
<
paddle
::
DataProvider
>
provider
(
paddle
::
DataProvider
::
create
(
config
,
false
));
paddle
::
DataBatch
batch
;
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
provider
->
reset
();
int64_t
sum
=
0
;
while
(
int64_t
actualNum
=
provider
->
getNextBatch
(
100
,
&
batch
))
{
sum
+=
actualNum
;
}
ASSERT_EQ
(
1
<<
20
,
sum
);
}
}
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
paddle
::
initMain
(
argc
,
argv
);
paddle
::
initPython
(
argc
,
argv
);
...
...
paddle/gserver/tests/test_PyDataProvider2.py
浏览文件 @
adc23f61
...
...
@@ -111,3 +111,13 @@ def test_check(settings, filename):
if
i
<
10
:
yield_good_value
=
True
yield
i
@
provider
(
input_types
=
[
index_slot
(
10
)],
min_pool_size
=
1000
,
cache
=
CacheType
.
CACHE_PASS_IN_MEM
,
)
def
test_min_pool_size_with_cache
(
settings
,
filename
):
import
random
for
_
in
xrange
(
2
**
20
):
yield
random
.
randint
(
0
,
9
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录