提交 6115fcc5 编写于 作者: W wen-bo-yang

format by yapf

上级 812e21f3
...@@ -20,38 +20,30 @@ The script fetch and preprocess movie_reviews data set ...@@ -20,38 +20,30 @@ The script fetch and preprocess movie_reviews data set
that provided by NLTK that provided by NLTK
""" """
import paddle.v2.dataset.common as common
import collections
import nltk import nltk
import numpy as np import numpy as np
from itertools import chain from itertools import chain
from nltk.corpus import movie_reviews from nltk.corpus import movie_reviews
from config import DATA_HOME
__all__ = ['train', 'test', 'get_label_dict', 'get_word_dict'] __all__ = ['train', 'test', 'get_word_dict']
NUM_TRAINING_INSTANCES = 1600 NUM_TRAINING_INSTANCES = 1600
NUM_TOTAL_INSTANCES = 2000 NUM_TOTAL_INSTANCES = 2000
def get_label_dict():
"""
Define the labels dict for dataset
"""
label_dict = {'neg': 0, 'pos': 1}
return label_dict
def download_data_if_not_yet(): def download_data_if_not_yet():
""" """
Download the data set, if the data set is not download. Download the data set, if the data set is not download.
""" """
try: try:
# make sure that nltk can find the data # make sure that nltk can find the data
nltk.data.path.append(DATA_HOME) if common.DATA_HOME not in nltk.data.path:
nltk.data.path.append(common.DATA_HOME)
movie_reviews.categories() movie_reviews.categories()
except LookupError: except LookupError:
print "Downloading movie_reviews data set, please wait....." print "Downloading movie_reviews data set, please wait....."
nltk.download('movie_reviews', download_dir=DATA_HOME) nltk.download('movie_reviews', download_dir=common.DATA_HOME)
# make sure that nltk can find the data
nltk.data.path.append(DATA_HOME)
print "Download data set success....." print "Download data set success....."
print "Path is " + nltk.data.find('corpora/movie_reviews').path print "Path is " + nltk.data.find('corpora/movie_reviews').path
...@@ -63,12 +55,17 @@ def get_word_dict(): ...@@ -63,12 +55,17 @@ def get_word_dict():
words_freq_sorted words_freq_sorted
""" """
words_freq_sorted = list() words_freq_sorted = list()
word_freq_dict = collections.defaultdict(int)
download_data_if_not_yet() download_data_if_not_yet()
words_freq = nltk.FreqDist(w.lower() for w in movie_reviews.words())
words_sort_list = words_freq.items() for category in movie_reviews.categories():
for field in movie_reviews.fileids(category):
for words in movie_reviews.words(field):
word_freq_dict[words] += 1
words_sort_list = word_freq_dict.items()
words_sort_list.sort(cmp=lambda a, b: b[1] - a[1]) words_sort_list.sort(cmp=lambda a, b: b[1] - a[1])
for index, word in enumerate(words_sort_list): for index, word in enumerate(words_sort_list):
words_freq_sorted.append((word[0], index + 1)) words_freq_sorted.append((word[0], index))
return words_freq_sorted return words_freq_sorted
...@@ -79,7 +76,6 @@ def sort_files(): ...@@ -79,7 +76,6 @@ def sort_files():
files_list files_list
""" """
files_list = list() files_list = list()
download_data_if_not_yet()
neg_file_list = movie_reviews.fileids('neg') neg_file_list = movie_reviews.fileids('neg')
pos_file_list = movie_reviews.fileids('pos') pos_file_list = movie_reviews.fileids('pos')
files_list = list(chain.from_iterable(zip(neg_file_list, pos_file_list))) files_list = list(chain.from_iterable(zip(neg_file_list, pos_file_list)))
...@@ -104,9 +100,6 @@ def load_sentiment_data(): ...@@ -104,9 +100,6 @@ def load_sentiment_data():
return data_set return data_set
data_set = load_sentiment_data()
def reader_creator(data): def reader_creator(data):
""" """
Reader creator, it format data set to numpy Reader creator, it format data set to numpy
...@@ -114,15 +107,14 @@ def reader_creator(data): ...@@ -114,15 +107,14 @@ def reader_creator(data):
train data set or test data set train data set or test data set
""" """
for each in data: for each in data:
list_of_int = np.array(each[0], dtype=np.int32) yield each[0], each[1]
label = each[1]
yield list_of_int, label
def train(): def train():
""" """
Default train set reader creator Default train set reader creator
""" """
data_set = load_sentiment_data()
return reader_creator(data_set[0:NUM_TRAINING_INSTANCES]) return reader_creator(data_set[0:NUM_TRAINING_INSTANCES])
...@@ -130,14 +122,5 @@ def test(): ...@@ -130,14 +122,5 @@ def test():
""" """
Default test set reader creator Default test set reader creator
""" """
data_set = load_sentiment_data()
return reader_creator(data_set[NUM_TRAINING_INSTANCES:]) return reader_creator(data_set[NUM_TRAINING_INSTANCES:])
def unittest():
assert len(data_set) == NUM_TOTAL_INSTANCES
assert len(list(train())) == NUM_TRAINING_INSTANCES
assert len(list(test())) == NUM_TOTAL_INSTANCES - NUM_TRAINING_INSTANCES
if __name__ == '__main__':
unittest()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import nltk
import paddle.v2.dataset.sentiment as st
from nltk.corpus import movie_reviews
class TestSentimentMethods(unittest.TestCase):
def test_get_word_dict(self):
word_dict = st.get_word_dict()[0:10]
test_word_list = [(u',', 0), (u'the', 1), (u'.', 2), (u'a', 3),
(u'and', 4), (u'of', 5), (u'to', 6), (u"'", 7),
(u'is', 8), (u'in', 9)]
for idx, each in enumerate(word_dict):
self.assertEqual(each, test_word_list[idx])
self.assertTrue("/root/.cache/paddle/dataset" in nltk.data.path)
def test_sort_files(self):
last_label = ''
for sample_file in st.sort_files():
current_label = sample_file.split("/")[0]
self.assertNotEqual(current_label, last_label)
last_label = current_label
def test_data_set(self):
data_set = st.load_sentiment_data()
last_label = -1
for each in st.test():
self.assertNotEqual(each[1], last_label)
last_label = each[1]
self.assertEqual(len(data_set), st.NUM_TOTAL_INSTANCES)
self.assertEqual(len(list(st.train())), st.NUM_TRAINING_INSTANCES)
self.assertEqual(
len(list(st.test())),
(st.NUM_TOTAL_INSTANCES - st.NUM_TRAINING_INSTANCES))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册