Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
6115fcc5
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6115fcc5
编写于
3月 02, 2017
作者:
W
wen-bo-yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format by yapf
上级
812e21f3
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
69 addition
and
34 deletion
+69
-34
python/paddle/v2/dataset/sentiment.py
python/paddle/v2/dataset/sentiment.py
+17
-34
python/paddle/v2/dataset/tests/test_sentiment.py
python/paddle/v2/dataset/tests/test_sentiment.py
+52
-0
未找到文件。
python/paddle/v2/dataset/sentiment.py
浏览文件 @
6115fcc5
...
...
@@ -20,38 +20,30 @@ The script fetch and preprocess movie_reviews data set
that provided by NLTK
"""
import
paddle.v2.dataset.common
as
common
import
collections
import
nltk
import
numpy
as
np
from
itertools
import
chain
from
nltk.corpus
import
movie_reviews
from
config
import
DATA_HOME
__all__
=
[
'train'
,
'test'
,
'get_
label_dict'
,
'get_
word_dict'
]
__all__
=
[
'train'
,
'test'
,
'get_word_dict'
]
NUM_TRAINING_INSTANCES
=
1600
NUM_TOTAL_INSTANCES
=
2000
def
get_label_dict
():
"""
Define the labels dict for dataset
"""
label_dict
=
{
'neg'
:
0
,
'pos'
:
1
}
return
label_dict
def
download_data_if_not_yet
():
"""
Download the data set, if the data set is not download.
"""
try
:
# make sure that nltk can find the data
nltk
.
data
.
path
.
append
(
DATA_HOME
)
if
common
.
DATA_HOME
not
in
nltk
.
data
.
path
:
nltk
.
data
.
path
.
append
(
common
.
DATA_HOME
)
movie_reviews
.
categories
()
except
LookupError
:
print
"Downloading movie_reviews data set, please wait....."
nltk
.
download
(
'movie_reviews'
,
download_dir
=
DATA_HOME
)
# make sure that nltk can find the data
nltk
.
data
.
path
.
append
(
DATA_HOME
)
nltk
.
download
(
'movie_reviews'
,
download_dir
=
common
.
DATA_HOME
)
print
"Download data set success....."
print
"Path is "
+
nltk
.
data
.
find
(
'corpora/movie_reviews'
).
path
...
...
@@ -63,12 +55,17 @@ def get_word_dict():
words_freq_sorted
"""
words_freq_sorted
=
list
()
word_freq_dict
=
collections
.
defaultdict
(
int
)
download_data_if_not_yet
()
words_freq
=
nltk
.
FreqDist
(
w
.
lower
()
for
w
in
movie_reviews
.
words
())
words_sort_list
=
words_freq
.
items
()
for
category
in
movie_reviews
.
categories
():
for
field
in
movie_reviews
.
fileids
(
category
):
for
words
in
movie_reviews
.
words
(
field
):
word_freq_dict
[
words
]
+=
1
words_sort_list
=
word_freq_dict
.
items
()
words_sort_list
.
sort
(
cmp
=
lambda
a
,
b
:
b
[
1
]
-
a
[
1
])
for
index
,
word
in
enumerate
(
words_sort_list
):
words_freq_sorted
.
append
((
word
[
0
],
index
+
1
))
words_freq_sorted
.
append
((
word
[
0
],
index
))
return
words_freq_sorted
...
...
@@ -79,7 +76,6 @@ def sort_files():
files_list
"""
files_list
=
list
()
download_data_if_not_yet
()
neg_file_list
=
movie_reviews
.
fileids
(
'neg'
)
pos_file_list
=
movie_reviews
.
fileids
(
'pos'
)
files_list
=
list
(
chain
.
from_iterable
(
zip
(
neg_file_list
,
pos_file_list
)))
...
...
@@ -104,9 +100,6 @@ def load_sentiment_data():
return
data_set
data_set
=
load_sentiment_data
()
def
reader_creator
(
data
):
"""
Reader creator, it format data set to numpy
...
...
@@ -114,15 +107,14 @@ def reader_creator(data):
train data set or test data set
"""
for
each
in
data
:
list_of_int
=
np
.
array
(
each
[
0
],
dtype
=
np
.
int32
)
label
=
each
[
1
]
yield
list_of_int
,
label
yield
each
[
0
],
each
[
1
]
def
train
():
"""
Default train set reader creator
"""
data_set
=
load_sentiment_data
()
return
reader_creator
(
data_set
[
0
:
NUM_TRAINING_INSTANCES
])
...
...
@@ -130,14 +122,5 @@ def test():
"""
Default test set reader creator
"""
data_set
=
load_sentiment_data
()
return
reader_creator
(
data_set
[
NUM_TRAINING_INSTANCES
:])
def
unittest
():
assert
len
(
data_set
)
==
NUM_TOTAL_INSTANCES
assert
len
(
list
(
train
()))
==
NUM_TRAINING_INSTANCES
assert
len
(
list
(
test
()))
==
NUM_TOTAL_INSTANCES
-
NUM_TRAINING_INSTANCES
if
__name__
==
'__main__'
:
unittest
()
python/paddle/v2/dataset/tests/test_sentiment.py
0 → 100644
浏览文件 @
6115fcc5
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
nltk
import
paddle.v2.dataset.sentiment
as
st
from
nltk.corpus
import
movie_reviews
class
TestSentimentMethods
(
unittest
.
TestCase
):
def
test_get_word_dict
(
self
):
word_dict
=
st
.
get_word_dict
()[
0
:
10
]
test_word_list
=
[(
u
','
,
0
),
(
u
'the'
,
1
),
(
u
'.'
,
2
),
(
u
'a'
,
3
),
(
u
'and'
,
4
),
(
u
'of'
,
5
),
(
u
'to'
,
6
),
(
u
"'"
,
7
),
(
u
'is'
,
8
),
(
u
'in'
,
9
)]
for
idx
,
each
in
enumerate
(
word_dict
):
self
.
assertEqual
(
each
,
test_word_list
[
idx
])
self
.
assertTrue
(
"/root/.cache/paddle/dataset"
in
nltk
.
data
.
path
)
def
test_sort_files
(
self
):
last_label
=
''
for
sample_file
in
st
.
sort_files
():
current_label
=
sample_file
.
split
(
"/"
)[
0
]
self
.
assertNotEqual
(
current_label
,
last_label
)
last_label
=
current_label
def
test_data_set
(
self
):
data_set
=
st
.
load_sentiment_data
()
last_label
=
-
1
for
each
in
st
.
test
():
self
.
assertNotEqual
(
each
[
1
],
last_label
)
last_label
=
each
[
1
]
self
.
assertEqual
(
len
(
data_set
),
st
.
NUM_TOTAL_INSTANCES
)
self
.
assertEqual
(
len
(
list
(
st
.
train
())),
st
.
NUM_TRAINING_INSTANCES
)
self
.
assertEqual
(
len
(
list
(
st
.
test
())),
(
st
.
NUM_TOTAL_INSTANCES
-
st
.
NUM_TRAINING_INSTANCES
))
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录