Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
97a594e7
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
97a594e7
编写于
6月 02, 2017
作者:
Y
Yancey
提交者:
GitHub
6月 02, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Split dataset into multiple files (#2320)
cluster dataset split and reader
上级
4ac9a6fa
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
101 addition
and
1 deletion
+101
-1
python/paddle/v2/dataset/common.py
python/paddle/v2/dataset/common.py
+76
-1
python/paddle/v2/dataset/tests/common_test.py
python/paddle/v2/dataset/tests/common_test.py
+25
-0
未找到文件。
python/paddle/v2/dataset/common.py
浏览文件 @
97a594e7
...
...
@@ -19,8 +19,10 @@ import shutil
import
sys
import
importlib
import
paddle.v2.dataset
import
cPickle
import
glob
__all__
=
[
'DATA_HOME'
,
'download'
,
'md5file'
]
__all__
=
[
'DATA_HOME'
,
'download'
,
'md5file'
,
'split'
,
'cluster_files_reader'
]
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset'
)
...
...
@@ -74,3 +76,76 @@ def fetch_all():
getattr
(
importlib
.
import_module
(
"paddle.v2.dataset.%s"
%
module_name
),
"fetch"
)()
def
split
(
reader
,
line_count
,
suffix
=
"%05d.pickle"
,
dumper
=
cPickle
.
dump
):
"""
you can call the function as:
split(paddle.v2.dataset.cifar.train10(), line_count=1000,
suffix="imikolov-train-%05d.pickle")
the output files as:
|-imikolov-train-00000.pickle
|-imikolov-train-00001.pickle
|- ...
|-imikolov-train-00480.pickle
:param reader: is a reader creator
:param line_count: line count for each file
:param suffix: the suffix for the output files, should contain "%d"
means the id for each file. Default is "%05d.pickle"
:param dumper: is a callable function that dump object to file, this
function will be called as dumper(obj, f) and obj is the object
will be dumped, f is a file object. Default is cPickle.dump.
"""
if
not
callable
(
dumper
):
raise
TypeError
(
"dumper should be callable."
)
lines
=
[]
indx_f
=
0
for
i
,
d
in
enumerate
(
reader
()):
lines
.
append
(
d
)
if
i
>=
line_count
and
i
%
line_count
==
0
:
with
open
(
suffix
%
indx_f
,
"w"
)
as
f
:
dumper
(
lines
,
f
)
lines
=
[]
indx_f
+=
1
if
lines
:
with
open
(
suffix
%
indx_f
,
"w"
)
as
f
:
dumper
(
lines
,
f
)
def
cluster_files_reader
(
files_pattern
,
trainer_count
,
trainer_id
,
loader
=
cPickle
.
load
):
"""
Create a reader that yield element from the given files, select
a file set according trainer count and trainer_id
:param files_pattern: the files which generating by split(...)
:param trainer_count: total trainer count
:param trainer_id: the trainer rank id
:param loader: is a callable function that load object from file, this
function will be called as loader(f) and f is a file object.
Default is cPickle.load
"""
def
reader
():
if
not
callable
(
loader
):
raise
TypeError
(
"loader should be callable."
)
file_list
=
glob
.
glob
(
files_pattern
)
file_list
.
sort
()
my_file_list
=
[]
for
idx
,
fn
in
enumerate
(
file_list
):
if
idx
%
trainer_count
==
trainer_id
:
print
"append file: %s"
%
fn
my_file_list
.
append
(
fn
)
for
fn
in
my_file_list
:
with
open
(
fn
,
"r"
)
as
f
:
lines
=
loader
(
f
)
for
line
in
lines
:
yield
line
return
reader
python/paddle/v2/dataset/tests/common_test.py
浏览文件 @
97a594e7
...
...
@@ -15,6 +15,7 @@
import
paddle.v2.dataset.common
import
unittest
import
tempfile
import
glob
class
TestCommon
(
unittest
.
TestCase
):
...
...
@@ -32,6 +33,30 @@ class TestCommon(unittest.TestCase):
paddle
.
v2
.
dataset
.
common
.
download
(
yi_avatar
,
'test'
,
'f75287202d6622414c706c36c16f8e0d'
))
def
test_split
(
self
):
def
test_reader
():
def
reader
():
for
x
in
xrange
(
10
):
yield
x
return
reader
_
,
temp_path
=
tempfile
.
mkstemp
()
paddle
.
v2
.
dataset
.
common
.
split
(
test_reader
(),
4
,
suffix
=
temp_path
+
'/test-%05d.pickle'
)
files
=
glob
.
glob
(
temp_path
+
'/test-%05d.pickle'
)
self
.
assertEqual
(
len
(
files
),
3
)
def
test_cluster_file_reader
(
self
):
_
,
temp_path
=
tempfile
.
mkstemp
()
for
x
in
xrange
(
5
):
with
open
(
temp_path
+
'/%05d.test'
%
x
)
as
f
:
f
.
write
(
'%d
\n
'
%
x
)
reader
=
paddle
.
v2
.
dataset
.
common
.
cluster_files_reader
(
temp_path
+
'/*.test'
,
5
,
0
)
for
idx
,
e
in
enumerate
(
reader
()):
self
.
assertEqual
(
e
,
str
(
"0"
))
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录