Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
283bdc50
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
283bdc50
编写于
6月 12, 2017
作者:
G
gongweibao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix by helin's comments
上级
96a56b96
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
56 addition
and
30 deletion
+56
-30
paddle/parameter/tests/test_argument.cpp
paddle/parameter/tests/test_argument.cpp
+1
-1
python/paddle/v2/dataset/common.py
python/paddle/v2/dataset/common.py
+34
-24
python/paddle/v2/dataset/tests/common_test.py
python/paddle/v2/dataset/tests/common_test.py
+21
-5
未找到文件。
paddle/parameter/tests/test_argument.cpp
浏览文件 @
283bdc50
...
...
@@ -42,7 +42,7 @@ TEST(Argument, poolSequenceWithStride) {
CHECK_EQ
(
outStart
[
3
],
4
);
CHECK_EQ
(
outStart
[
4
],
7
);
CHECK_EQ
(
stridePositions
->
getSize
(),
8
);
CHECK_EQ
(
stridePositions
->
getSize
(),
8
UL
);
auto
result
=
reversed
?
strideResultReversed
:
strideResult
;
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
CHECK_EQ
(
stridePositions
->
getData
()[
i
],
result
[
i
]);
...
...
python/paddle/v2/dataset/common.py
浏览文件 @
283bdc50
...
...
@@ -151,9 +151,14 @@ def cluster_files_reader(files_pattern,
return
reader
def
convert
(
output_path
,
eader
,
num_shards
,
name_prefix
):
def
convert
(
output_path
,
reader
,
num_shards
,
name_prefix
,
max_lines_to_shuffle
=
10000
):
import
recordio
import
cPickle
as
pickle
import
random
"""
Convert data from reader to recordio format files.
...
...
@@ -161,35 +166,40 @@ def convert(output_path, eader, num_shards, name_prefix):
:param reader: a data reader, from which the convert program will read data instances.
:param num_shards: the number of shards that the dataset will be partitioned into.
:param name_prefix: the name prefix of generated files.
:param max_lines_to_shuffle: the max lines numbers to shuffle before writing.
"""
def
open_needs
(
idx
):
n
=
"%s/%s-%05d"
%
(
output_path
,
name_prefix
,
idx
)
w
=
recordio
.
writer
(
n
)
f
=
open
(
n
,
"w"
)
idx
+=
1
assert
num_shards
>=
1
assert
max_lines_to_shuffle
>=
1
return
w
,
f
,
idx
def
open_writers
():
w
=
[]
for
i
in
range
(
0
,
num_shards
):
n
=
"%s/%s-%05d-of-%05d"
%
(
output_path
,
name_prefix
,
i
,
num_shards
-
1
)
w
.
append
(
recordio
.
writer
(
n
))
def
close_needs
(
w
,
f
):
if
w
is
not
None
:
w
.
close
()
return
w
if
f
is
not
None
:
f
.
close
()
def
close_writers
(
w
):
for
i
in
range
(
0
,
num_shards
):
w
[
i
].
close
()
idx
=
0
w
=
None
f
=
None
def
write_data
(
w
,
lines
):
random
.
shuffle
(
lines
)
for
i
,
d
in
enumerate
(
lines
):
d
=
pickle
.
dumps
(
d
,
pickle
.
HIGHEST_PROTOCOL
)
w
[
i
%
num_shards
].
write
(
d
)
for
i
,
d
in
enumerate
(
reader
()):
if
w
is
None
:
w
,
f
,
idx
=
open_needs
(
idx
)
w
.
write
(
pickle
.
dumps
(
d
,
pickle
.
HIGHEST_PROTOCOL
))
w
=
open_writers
()
lines
=
[]
if
i
%
num_shards
==
0
and
i
>=
num_shards
:
close_needs
(
w
,
f
)
w
,
f
,
idx
=
open_needs
(
idx
)
for
i
,
d
in
enumerate
(
reader
()):
lines
.
append
(
d
)
if
i
%
max_lines_to_shuffle
==
0
and
i
>=
max_lines_to_shuffle
:
write_data
(
w
,
lines
)
lines
=
[]
continue
close_needs
(
w
,
f
)
write_data
(
w
,
lines
)
close_writers
(
w
)
python/paddle/v2/dataset/tests/common_test.py
浏览文件 @
283bdc50
...
...
@@ -58,20 +58,36 @@ class TestCommon(unittest.TestCase):
self
.
assertEqual
(
e
,
str
(
"0"
))
def
test_convert
(
self
):
record_num
=
10
num_shards
=
4
def
test_reader
():
def
reader
():
for
x
in
xrange
(
10
):
for
x
in
xrange
(
record_num
):
yield
x
return
reader
path
=
tempfile
.
mkdtemp
()
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
test_reader
(),
4
,
'random_images'
)
test_reader
(),
num_shards
,
'random_images'
)
files
=
glob
.
glob
(
temp_path
+
'/random_images-*'
)
self
.
assertEqual
(
len
(
files
),
3
)
files
=
glob
.
glob
(
path
+
'/random_images-*'
)
self
.
assertEqual
(
len
(
files
),
num_shards
)
recs
=
[]
for
i
in
range
(
0
,
num_shards
):
n
=
"%s/random_images-%05d-of-%05d"
%
(
path
,
i
,
num_shards
-
1
)
r
=
recordio
.
reader
(
n
)
while
True
:
d
=
r
.
read
()
if
d
is
None
:
break
recs
.
append
(
d
)
recs
.
sort
()
self
.
assertEqual
(
total
,
record_num
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录