提交 6115fcc5 编写于 作者: W wen-bo-yang

format by yapf

上级 812e21f3
......@@ -20,38 +20,30 @@ The script fetch and preprocess movie_reviews data set
that provided by NLTK
"""
import paddle.v2.dataset.common as common
import collections
import nltk
import numpy as np
from itertools import chain
from nltk.corpus import movie_reviews
from config import DATA_HOME
__all__ = ['train', 'test', 'get_label_dict', 'get_word_dict']
__all__ = ['train', 'test', 'get_word_dict']
NUM_TRAINING_INSTANCES = 1600
NUM_TOTAL_INSTANCES = 2000
def get_label_dict():
"""
Define the labels dict for dataset
"""
label_dict = {'neg': 0, 'pos': 1}
return label_dict
def download_data_if_not_yet():
"""
Download the data set, if the data set is not download.
"""
try:
# make sure that nltk can find the data
nltk.data.path.append(DATA_HOME)
if common.DATA_HOME not in nltk.data.path:
nltk.data.path.append(common.DATA_HOME)
movie_reviews.categories()
except LookupError:
print "Downloading movie_reviews data set, please wait....."
nltk.download('movie_reviews', download_dir=DATA_HOME)
# make sure that nltk can find the data
nltk.data.path.append(DATA_HOME)
nltk.download('movie_reviews', download_dir=common.DATA_HOME)
print "Download data set success....."
print "Path is " + nltk.data.find('corpora/movie_reviews').path
......@@ -63,12 +55,17 @@ def get_word_dict():
words_freq_sorted
"""
words_freq_sorted = list()
word_freq_dict = collections.defaultdict(int)
download_data_if_not_yet()
words_freq = nltk.FreqDist(w.lower() for w in movie_reviews.words())
words_sort_list = words_freq.items()
for category in movie_reviews.categories():
for field in movie_reviews.fileids(category):
for words in movie_reviews.words(field):
word_freq_dict[words] += 1
words_sort_list = word_freq_dict.items()
words_sort_list.sort(cmp=lambda a, b: b[1] - a[1])
for index, word in enumerate(words_sort_list):
words_freq_sorted.append((word[0], index + 1))
words_freq_sorted.append((word[0], index))
return words_freq_sorted
......@@ -79,7 +76,6 @@ def sort_files():
files_list
"""
files_list = list()
download_data_if_not_yet()
neg_file_list = movie_reviews.fileids('neg')
pos_file_list = movie_reviews.fileids('pos')
files_list = list(chain.from_iterable(zip(neg_file_list, pos_file_list)))
......@@ -104,9 +100,6 @@ def load_sentiment_data():
return data_set
data_set = load_sentiment_data()
def reader_creator(data):
"""
Reader creator, it format data set to numpy
......@@ -114,15 +107,14 @@ def reader_creator(data):
train data set or test data set
"""
for each in data:
list_of_int = np.array(each[0], dtype=np.int32)
label = each[1]
yield list_of_int, label
yield each[0], each[1]
def train():
"""
Default train set reader creator
"""
data_set = load_sentiment_data()
return reader_creator(data_set[0:NUM_TRAINING_INSTANCES])
......@@ -130,14 +122,5 @@ def test():
"""
Default test set reader creator
"""
data_set = load_sentiment_data()
return reader_creator(data_set[NUM_TRAINING_INSTANCES:])
def unittest():
assert len(data_set) == NUM_TOTAL_INSTANCES
assert len(list(train())) == NUM_TRAINING_INSTANCES
assert len(list(test())) == NUM_TOTAL_INSTANCES - NUM_TRAINING_INSTANCES
if __name__ == '__main__':
unittest()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import nltk
import paddle.v2.dataset.sentiment as st
from nltk.corpus import movie_reviews
class TestSentimentMethods(unittest.TestCase):
def test_get_word_dict(self):
word_dict = st.get_word_dict()[0:10]
test_word_list = [(u',', 0), (u'the', 1), (u'.', 2), (u'a', 3),
(u'and', 4), (u'of', 5), (u'to', 6), (u"'", 7),
(u'is', 8), (u'in', 9)]
for idx, each in enumerate(word_dict):
self.assertEqual(each, test_word_list[idx])
self.assertTrue("/root/.cache/paddle/dataset" in nltk.data.path)
def test_sort_files(self):
last_label = ''
for sample_file in st.sort_files():
current_label = sample_file.split("/")[0]
self.assertNotEqual(current_label, last_label)
last_label = current_label
def test_data_set(self):
data_set = st.load_sentiment_data()
last_label = -1
for each in st.test():
self.assertNotEqual(each[1], last_label)
last_label = each[1]
self.assertEqual(len(data_set), st.NUM_TOTAL_INSTANCES)
self.assertEqual(len(list(st.train())), st.NUM_TRAINING_INSTANCES)
self.assertEqual(
len(list(st.test())),
(st.NUM_TOTAL_INSTANCES - st.NUM_TRAINING_INSTANCES))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册