Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
cb602fce
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cb602fce
编写于
9月 21, 2020
作者:
Y
yaoxuefeng6
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add data_generator ut
上级
dfbe4488
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
125 addition
and
12 deletion
+125
-12
python/paddle/distributed/fleet/data_generator/data_generator.py
...paddle/distributed/fleet/data_generator/data_generator.py
+0
-8
python/paddle/fluid/tests/unittests/test_data_generator.py
python/paddle/fluid/tests/unittests/test_data_generator.py
+125
-4
未找到文件。
python/paddle/distributed/fleet/data_generator/data_generator.py
浏览文件 @
cb602fce
...
@@ -27,14 +27,6 @@ class DataGenerator(object):
...
@@ -27,14 +27,6 @@ class DataGenerator(object):
self
.
_proto_info
=
None
self
.
_proto_info
=
None
self
.
batch_size_
=
32
self
.
batch_size_
=
32
def
_set_line_limit
(
self
,
line_limit
):
if
not
isinstance
(
line_limit
,
int
):
raise
ValueError
(
"line_limit%s must be in int type"
%
type
(
line_limit
))
if
line_limit
<
1
:
raise
ValueError
(
"line_limit can not less than 1"
)
self
.
_line_limit
=
line_limit
def
set_batch
(
self
,
batch_size
):
def
set_batch
(
self
,
batch_size
):
'''
'''
Set batch size of current DataGenerator
Set batch size of current DataGenerator
...
...
python/paddle/fluid/tests/unittests/test_data_generator.py
浏览文件 @
cb602fce
...
@@ -13,12 +13,15 @@
...
@@ -13,12 +13,15 @@
import
paddle
import
paddle
import
unittest
import
unittest
import
paddle.distributed.fleet
as
fleet
import
paddle.distributed.fleet
as
fleet
import
os
class
MyMultiSlotDataGenerator
(
fleet
.
MultiSlotDataGenerator
):
class
MyMultiSlotDataGenerator
(
fleet
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
def
generate_sample
(
self
,
line
):
def
data_iter
():
def
data_iter
():
for
i
in
range
(
100
):
for
i
in
range
(
40
):
if
i
==
1
:
yield
None
yield
(
"words"
,
[
1
,
2
,
3
,
4
]),
(
"label"
,
[
0
])
yield
(
"words"
,
[
1
,
2
,
3
,
4
]),
(
"label"
,
[
0
])
return
data_iter
return
data_iter
...
@@ -27,22 +30,140 @@ class MyMultiSlotDataGenerator(fleet.MultiSlotDataGenerator):
...
@@ -27,22 +30,140 @@ class MyMultiSlotDataGenerator(fleet.MultiSlotDataGenerator):
class
MyMultiSlotStringDataGenerator
(
fleet
.
MultiSlotStringDataGenerator
):
class
MyMultiSlotStringDataGenerator
(
fleet
.
MultiSlotStringDataGenerator
):
def
generate_sample
(
self
,
line
):
def
generate_sample
(
self
,
line
):
def
data_iter
():
def
data_iter
():
for
i
in
range
(
100
):
for
i
in
range
(
40
):
if
i
==
1
:
yield
None
yield
(
"words"
,
[
"1"
,
"2"
,
"3"
,
"4"
]),
(
"label"
,
[
"0"
])
yield
(
"words"
,
[
"1"
,
"2"
,
"3"
,
"4"
]),
(
"label"
,
[
"0"
])
return
data_iter
return
data_iter
class
MyMultiSlotDataGenerator_error
(
fleet
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
def
data_iter
():
for
i
in
range
(
40
):
if
i
==
1
:
yield
None
yield
"words"
return
data_iter
class
MyMultiSlotDataGenerator_error_2
(
fleet
.
MultiSlotStringDataGenerator
):
def
generate_sample
(
self
,
line
):
def
data_iter
():
for
i
in
range
(
40
):
if
i
==
1
:
yield
None
yield
"words"
return
data_iter
class
MyMultiSlotDataGenerator_error_3
(
fleet
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
def
data_iter
():
for
i
in
range
(
40
):
if
i
==
1
:
yield
None
yield
(
1
,
[
"1"
,
"2"
,
"3"
,
"4"
]),
(
2
,
[
"0"
])
return
data_iter
class
MyMultiSlotDataGenerator_error_4
(
fleet
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
def
data_iter
():
for
i
in
range
(
40
):
if
i
==
1
:
yield
None
yield
(
"words"
,
"1"
),
(
"label"
,
"0"
)
return
data_iter
class
MyMultiSlotDataGenerator_error_5
(
fleet
.
MultiSlotDataGenerator
):
def
generate_sample
(
self
,
line
):
def
data_iter
():
for
i
in
range
(
40
):
if
i
==
1
:
yield
None
yield
(
"words"
,
[]),
(
"label"
,
[])
return
data_iter
class
TestMultiSlotDataGenerator
(
unittest
.
TestCase
):
class
TestMultiSlotDataGenerator
(
unittest
.
TestCase
):
def
test_MultiSlotDataGenerator_basic
(
self
):
def
test_MultiSlotDataGenerator_basic
(
self
):
my_ms_dg
=
MyMultiSlotDataGenerator
()
my_ms_dg
=
MyMultiSlotDataGenerator
()
my_ms_dg
.
set_batch
(
1
)
my_ms_dg
.
run_from_memory
()
my_ms_dg
.
run_from_memory
()
class
TestMultiSlotStringDataGenerator
(
unittest
.
TestCase
):
class
TestMultiSlotStringDataGenerator
(
unittest
.
TestCase
):
def
test_MyMultiSlotStringDataGenerator_basic
(
self
):
def
test_MyMultiSlotStringDataGenerator_basic
(
self
):
my_mss_dg
=
MyMultiSlotStringDataGenerator
()
my_ms_dg
=
MyMultiSlotStringDataGenerator
()
my_mss_dg
.
run_from_memory
()
my_ms_dg
.
set_batch
(
1
)
my_ms_dg
.
run_from_memory
()
class
TestMultiSlotStringDataGenerator_2
(
unittest
.
TestCase
):
def
test_MyMultiSlotStringDataGenerator_stdin
(
self
):
with
open
(
"test_queue_dataset_run_a.txt"
,
"w"
)
as
f
:
data
=
"2 1 2
\n
"
data
+=
"2 6 2
\n
"
data
+=
"2 5 2
\n
"
data
+=
"2 7 2
\n
"
f
.
write
(
data
)
tmp
=
os
.
popen
(
"cat test_queue_dataset_run_a.txt | python my_data_generator.py"
).
readlines
()
expected_res
=
[
'1 2 1 1 1 2
\n
'
,
'1 2 1 6 1 2
\n
'
,
'1 2 1 5 1 2
\n
'
,
'1 2 1 7 1 2
\n
'
]
self
.
assertEqual
(
tmp
,
expected_res
)
os
.
remove
(
"./test_queue_dataset_run_a.txt"
)
class
TestMultiSlotDataGenerator_error
(
unittest
.
TestCase
):
def
test_MultiSlotDataGenerator_error
(
self
):
with
self
.
assertRaises
(
ValueError
):
my_ms_dg
=
MyMultiSlotDataGenerator_error
()
my_ms_dg
.
set_batch
(
1
)
my_ms_dg
.
run_from_memory
()
class
TestMultiSlotDataGenerator_error_2
(
unittest
.
TestCase
):
def
test_MultiSlotDataGenerator_error
(
self
):
with
self
.
assertRaises
(
ValueError
):
my_ms_dg
=
MyMultiSlotDataGenerator_error_2
()
my_ms_dg
.
set_batch
(
1
)
my_ms_dg
.
run_from_memory
()
class
TestMultiSlotDataGenerator_error_3
(
unittest
.
TestCase
):
def
test_MultiSlotDataGenerator_error
(
self
):
with
self
.
assertRaises
(
ValueError
):
my_ms_dg
=
MyMultiSlotDataGenerator_error_3
()
my_ms_dg
.
set_batch
(
1
)
my_ms_dg
.
run_from_memory
()
class
TestMultiSlotDataGenerator_error_4
(
unittest
.
TestCase
):
def
test_MultiSlotDataGenerator_error
(
self
):
with
self
.
assertRaises
(
ValueError
):
my_ms_dg
=
MyMultiSlotDataGenerator_error_4
()
my_ms_dg
.
set_batch
(
1
)
my_ms_dg
.
run_from_memory
()
class
TestMultiSlotDataGenerator_error_5
(
unittest
.
TestCase
):
def
test_MultiSlotDataGenerator_error
(
self
):
with
self
.
assertRaises
(
ValueError
):
my_ms_dg
=
MyMultiSlotDataGenerator_error_5
()
my_ms_dg
.
set_batch
(
1
)
my_ms_dg
.
run_from_memory
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录