Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
c9d78710
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c9d78710
编写于
6月 13, 2017
作者:
G
gongweibao
提交者:
GitHub
6月 13, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2407 from gongweibao/convert
Add convert function
上级
1b8d2e65
46ccfc01
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
87 addition
and
1 deletion
+87
-1
paddle/parameter/tests/test_argument.cpp
paddle/parameter/tests/test_argument.cpp
+1
-1
python/paddle/v2/dataset/common.py
python/paddle/v2/dataset/common.py
+54
-0
python/paddle/v2/dataset/tests/common_test.py
python/paddle/v2/dataset/tests/common_test.py
+32
-0
未找到文件。
paddle/parameter/tests/test_argument.cpp
浏览文件 @
c9d78710
...
@@ -42,7 +42,7 @@ TEST(Argument, poolSequenceWithStride) {
...
@@ -42,7 +42,7 @@ TEST(Argument, poolSequenceWithStride) {
CHECK_EQ
(
outStart
[
3
],
4
);
CHECK_EQ
(
outStart
[
3
],
4
);
CHECK_EQ
(
outStart
[
4
],
7
);
CHECK_EQ
(
outStart
[
4
],
7
);
CHECK_EQ
(
stridePositions
->
getSize
(),
8
);
CHECK_EQ
(
stridePositions
->
getSize
(),
8
UL
);
auto
result
=
reversed
?
strideResultReversed
:
strideResult
;
auto
result
=
reversed
?
strideResultReversed
:
strideResult
;
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
CHECK_EQ
(
stridePositions
->
getData
()[
i
],
result
[
i
]);
CHECK_EQ
(
stridePositions
->
getData
()[
i
],
result
[
i
]);
...
...
python/paddle/v2/dataset/common.py
浏览文件 @
c9d78710
...
@@ -149,3 +149,57 @@ def cluster_files_reader(files_pattern,
...
@@ -149,3 +149,57 @@ def cluster_files_reader(files_pattern,
yield
line
yield
line
return
reader
return
reader
def
convert
(
output_path
,
reader
,
num_shards
,
name_prefix
,
max_lines_to_shuffle
=
1000
):
import
recordio
import
cPickle
as
pickle
import
random
"""
Convert data from reader to recordio format files.
:param output_path: directory in which output files will be saved.
:param reader: a data reader, from which the convert program will read data instances.
:param num_shards: the number of shards that the dataset will be partitioned into.
:param name_prefix: the name prefix of generated files.
:param max_lines_to_shuffle: the max lines numbers to shuffle before writing.
"""
assert
num_shards
>=
1
assert
max_lines_to_shuffle
>=
1
def
open_writers
():
w
=
[]
for
i
in
range
(
0
,
num_shards
):
n
=
"%s/%s-%05d-of-%05d"
%
(
output_path
,
name_prefix
,
i
,
num_shards
-
1
)
w
.
append
(
recordio
.
writer
(
n
))
return
w
def
close_writers
(
w
):
for
i
in
range
(
0
,
num_shards
):
w
[
i
].
close
()
def
write_data
(
w
,
lines
):
random
.
shuffle
(
lines
)
for
i
,
d
in
enumerate
(
lines
):
d
=
pickle
.
dumps
(
d
,
pickle
.
HIGHEST_PROTOCOL
)
w
[
i
%
num_shards
].
write
(
d
)
w
=
open_writers
()
lines
=
[]
for
i
,
d
in
enumerate
(
reader
()):
lines
.
append
(
d
)
if
i
%
max_lines_to_shuffle
==
0
and
i
>=
max_lines_to_shuffle
:
write_data
(
w
,
lines
)
lines
=
[]
continue
write_data
(
w
,
lines
)
close_writers
(
w
)
python/paddle/v2/dataset/tests/common_test.py
浏览文件 @
c9d78710
...
@@ -57,6 +57,38 @@ class TestCommon(unittest.TestCase):
...
@@ -57,6 +57,38 @@ class TestCommon(unittest.TestCase):
for
idx
,
e
in
enumerate
(
reader
()):
for
idx
,
e
in
enumerate
(
reader
()):
self
.
assertEqual
(
e
,
str
(
"0"
))
self
.
assertEqual
(
e
,
str
(
"0"
))
def
test_convert
(
self
):
record_num
=
10
num_shards
=
4
def
test_reader
():
def
reader
():
for
x
in
xrange
(
record_num
):
yield
x
return
reader
path
=
tempfile
.
mkdtemp
()
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
test_reader
(),
num_shards
,
'random_images'
)
files
=
glob
.
glob
(
path
+
'/random_images-*'
)
self
.
assertEqual
(
len
(
files
),
num_shards
)
recs
=
[]
for
i
in
range
(
0
,
num_shards
):
n
=
"%s/random_images-%05d-of-%05d"
%
(
path
,
i
,
num_shards
-
1
)
r
=
recordio
.
reader
(
n
)
while
True
:
d
=
r
.
read
()
if
d
is
None
:
break
recs
.
append
(
d
)
recs
.
sort
()
self
.
assertEqual
(
total
,
record_num
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录