Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
dfe6a2c5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dfe6a2c5
编写于
2月 27, 2017
作者:
W
wangkuiyi
提交者:
GitHub
2月 27, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1466 from reyoung/feature/movielens_data
Add MovieLens DataSet
上级
283e82f4
de9012a5
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
153 addition
and
32 deletion
+153
-32
python/paddle/v2/dataset/cifar.py
python/paddle/v2/dataset/cifar.py
+4
-31
python/paddle/v2/dataset/config.py
python/paddle/v2/dataset/config.py
+29
-1
python/paddle/v2/dataset/movielens.py
python/paddle/v2/dataset/movielens.py
+120
-0
未找到文件。
python/paddle/v2/dataset/cifar.py
浏览文件 @
dfe6a2c5
...
...
@@ -5,16 +5,14 @@ URL: https://www.cs.toronto.edu/~kriz/cifar.html
the default train_creator, test_creator used for CIFAR-10 dataset.
"""
from
config
import
DATA_HOME
import
os
import
hashlib
import
urllib2
import
shutil
import
tarfile
import
cPickle
import
itertools
import
tarfile
import
numpy
from
config
import
download
__all__
=
[
'cifar_100_train_creator'
,
'cifar_100_test_creator'
,
'train_creator'
,
'test_creator'
...
...
@@ -47,31 +45,6 @@ def __read_batch__(filename, sub_name):
return
reader
def
download
(
url
,
md5
):
filename
=
os
.
path
.
split
(
url
)[
-
1
]
assert
DATA_HOME
is
not
None
filepath
=
os
.
path
.
join
(
DATA_HOME
,
md5
)
if
not
os
.
path
.
exists
(
filepath
):
os
.
makedirs
(
filepath
)
__full_file__
=
os
.
path
.
join
(
filepath
,
filename
)
def
__file_ok__
():
if
not
os
.
path
.
exists
(
__full_file__
):
return
False
md5_hash
=
hashlib
.
md5
()
with
open
(
__full_file__
,
'rb'
)
as
f
:
for
chunk
in
iter
(
lambda
:
f
.
read
(
4096
),
b
""
):
md5_hash
.
update
(
chunk
)
return
md5_hash
.
hexdigest
()
==
md5
while
not
__file_ok__
():
response
=
urllib2
.
urlopen
(
url
)
with
open
(
__full_file__
,
mode
=
'wb'
)
as
of
:
shutil
.
copyfileobj
(
fsrc
=
response
,
fdst
=
of
)
return
__full_file__
def
cifar_100_train_creator
():
fn
=
download
(
url
=
CIFAR100_URL
,
md5
=
CIFAR100_MD5
)
return
__read_batch__
(
fn
,
'train'
)
...
...
python/paddle/v2/dataset/config.py
浏览文件 @
dfe6a2c5
import
hashlib
import
os
import
shutil
import
urllib2
__all__
=
[
'DATA_HOME'
]
__all__
=
[
'DATA_HOME'
,
'download'
]
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle_data_set'
)
if
not
os
.
path
.
exists
(
DATA_HOME
):
os
.
makedirs
(
DATA_HOME
)
def
download
(
url
,
md5
):
filename
=
os
.
path
.
split
(
url
)[
-
1
]
assert
DATA_HOME
is
not
None
filepath
=
os
.
path
.
join
(
DATA_HOME
,
md5
)
if
not
os
.
path
.
exists
(
filepath
):
os
.
makedirs
(
filepath
)
__full_file__
=
os
.
path
.
join
(
filepath
,
filename
)
def
__file_ok__
():
if
not
os
.
path
.
exists
(
__full_file__
):
return
False
md5_hash
=
hashlib
.
md5
()
with
open
(
__full_file__
,
'rb'
)
as
f
:
for
chunk
in
iter
(
lambda
:
f
.
read
(
4096
),
b
""
):
md5_hash
.
update
(
chunk
)
return
md5_hash
.
hexdigest
()
==
md5
while
not
__file_ok__
():
response
=
urllib2
.
urlopen
(
url
)
with
open
(
__full_file__
,
mode
=
'wb'
)
as
of
:
shutil
.
copyfileobj
(
fsrc
=
response
,
fdst
=
of
)
return
__full_file__
python/paddle/v2/dataset/movielens.py
0 → 100644
浏览文件 @
dfe6a2c5
import
zipfile
from
config
import
download
import
re
import
random
import
functools
__all__
=
[
'train_creator'
,
'test_creator'
]
class
MovieInfo
(
object
):
def
__init__
(
self
,
index
,
categories
,
title
):
self
.
index
=
int
(
index
)
self
.
categories
=
categories
self
.
title
=
title
def
value
(
self
):
return
[
self
.
index
,
[
CATEGORIES_DICT
[
c
]
for
c
in
self
.
categories
],
[
MOVIE_TITLE_DICT
[
w
.
lower
()]
for
w
in
self
.
title
.
split
()]
]
class
UserInfo
(
object
):
def
__init__
(
self
,
index
,
gender
,
age
,
job_id
):
self
.
index
=
int
(
index
)
self
.
is_male
=
gender
==
'M'
self
.
age
=
[
1
,
18
,
25
,
35
,
45
,
50
,
56
].
index
(
int
(
age
))
self
.
job_id
=
int
(
job_id
)
def
value
(
self
):
return
[
self
.
index
,
0
if
self
.
is_male
else
1
,
self
.
age
,
self
.
job_id
]
MOVIE_INFO
=
None
MOVIE_TITLE_DICT
=
None
CATEGORIES_DICT
=
None
USER_INFO
=
None
def
__initialize_meta_info__
():
fn
=
download
(
url
=
'http://files.grouplens.org/datasets/movielens/ml-1m.zip'
,
md5
=
'c4d9eecfca2ab87c1945afe126590906'
)
global
MOVIE_INFO
if
MOVIE_INFO
is
None
:
pattern
=
re
.
compile
(
r
'^(.*)\((\d+)\)$'
)
with
zipfile
.
ZipFile
(
file
=
fn
)
as
package
:
for
info
in
package
.
infolist
():
assert
isinstance
(
info
,
zipfile
.
ZipInfo
)
MOVIE_INFO
=
dict
()
title_word_set
=
set
()
categories_set
=
set
()
with
package
.
open
(
'ml-1m/movies.dat'
)
as
movie_file
:
for
i
,
line
in
enumerate
(
movie_file
):
movie_id
,
title
,
categories
=
line
.
strip
().
split
(
'::'
)
categories
=
categories
.
split
(
'|'
)
for
c
in
categories
:
categories_set
.
add
(
c
)
title
=
pattern
.
match
(
title
).
group
(
1
)
MOVIE_INFO
[
int
(
movie_id
)]
=
MovieInfo
(
index
=
movie_id
,
categories
=
categories
,
title
=
title
)
for
w
in
title
.
split
():
title_word_set
.
add
(
w
.
lower
())
global
MOVIE_TITLE_DICT
MOVIE_TITLE_DICT
=
dict
()
for
i
,
w
in
enumerate
(
title_word_set
):
MOVIE_TITLE_DICT
[
w
]
=
i
global
CATEGORIES_DICT
CATEGORIES_DICT
=
dict
()
for
i
,
c
in
enumerate
(
categories_set
):
CATEGORIES_DICT
[
c
]
=
i
global
USER_INFO
USER_INFO
=
dict
()
with
package
.
open
(
'ml-1m/users.dat'
)
as
user_file
:
for
line
in
user_file
:
uid
,
gender
,
age
,
job
,
_
=
line
.
strip
().
split
(
"::"
)
USER_INFO
[
int
(
uid
)]
=
UserInfo
(
index
=
uid
,
gender
=
gender
,
age
=
age
,
job_id
=
job
)
return
fn
def
__reader__
(
rand_seed
=
0
,
test_ratio
=
0.1
,
is_test
=
False
):
fn
=
__initialize_meta_info__
()
rand
=
random
.
Random
(
x
=
rand_seed
)
with
zipfile
.
ZipFile
(
file
=
fn
)
as
package
:
with
package
.
open
(
'ml-1m/ratings.dat'
)
as
rating
:
for
line
in
rating
:
if
(
rand
.
random
()
<
test_ratio
)
==
is_test
:
uid
,
mov_id
,
rating
,
_
=
line
.
strip
().
split
(
"::"
)
uid
=
int
(
uid
)
mov_id
=
int
(
mov_id
)
rating
=
float
(
rating
)
*
2
-
5.0
mov
=
MOVIE_INFO
[
mov_id
]
usr
=
USER_INFO
[
uid
]
yield
usr
.
value
()
+
mov
.
value
()
+
[[
rating
]]
def
__reader_creator__
(
**
kwargs
):
return
lambda
:
__reader__
(
**
kwargs
)
train_creator
=
functools
.
partial
(
__reader_creator__
,
is_test
=
False
)
test_creator
=
functools
.
partial
(
__reader_creator__
,
is_test
=
True
)
def
unittest
():
for
train_count
,
_
in
enumerate
(
train_creator
()()):
pass
for
test_count
,
_
in
enumerate
(
test_creator
()()):
pass
print
train_count
,
test_count
if
__name__
==
'__main__'
:
unittest
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录