From 6115fcc5a73497157718eadb3bd596311ea83a55 Mon Sep 17 00:00:00 2001 From: wen-bo-yang Date: Thu, 2 Mar 2017 04:11:11 +0000 Subject: [PATCH] format by yapf --- python/paddle/v2/dataset/sentiment.py | 51 ++++++------------ .../paddle/v2/dataset/tests/test_sentiment.py | 52 +++++++++++++++++++ 2 files changed, 69 insertions(+), 34 deletions(-) create mode 100644 python/paddle/v2/dataset/tests/test_sentiment.py diff --git a/python/paddle/v2/dataset/sentiment.py b/python/paddle/v2/dataset/sentiment.py index 9825d2ef96a..1e7f222f4d2 100644 --- a/python/paddle/v2/dataset/sentiment.py +++ b/python/paddle/v2/dataset/sentiment.py @@ -20,38 +20,30 @@ The script fetch and preprocess movie_reviews data set that provided by NLTK """ +import paddle.v2.dataset.common as common +import collections import nltk import numpy as np from itertools import chain from nltk.corpus import movie_reviews -from config import DATA_HOME -__all__ = ['train', 'test', 'get_label_dict', 'get_word_dict'] +__all__ = ['train', 'test', 'get_word_dict'] NUM_TRAINING_INSTANCES = 1600 NUM_TOTAL_INSTANCES = 2000 -def get_label_dict(): - """ - Define the labels dict for dataset - """ - label_dict = {'neg': 0, 'pos': 1} - return label_dict - - def download_data_if_not_yet(): """ Download the data set, if the data set is not download. """ try: # make sure that nltk can find the data - nltk.data.path.append(DATA_HOME) + if common.DATA_HOME not in nltk.data.path: + nltk.data.path.append(common.DATA_HOME) movie_reviews.categories() except LookupError: print "Downloading movie_reviews data set, please wait....." - nltk.download('movie_reviews', download_dir=DATA_HOME) - # make sure that nltk can find the data - nltk.data.path.append(DATA_HOME) + nltk.download('movie_reviews', download_dir=common.DATA_HOME) print "Download data set success....." print "Path is " + nltk.data.find('corpora/movie_reviews').path @@ -63,12 +55,17 @@ def get_word_dict(): words_freq_sorted """ words_freq_sorted = list() + word_freq_dict = collections.defaultdict(int) download_data_if_not_yet() - words_freq = nltk.FreqDist(w.lower() for w in movie_reviews.words()) - words_sort_list = words_freq.items() + + for category in movie_reviews.categories(): + for field in movie_reviews.fileids(category): + for words in movie_reviews.words(field): + word_freq_dict[words] += 1 + words_sort_list = word_freq_dict.items() words_sort_list.sort(cmp=lambda a, b: b[1] - a[1]) for index, word in enumerate(words_sort_list): - words_freq_sorted.append((word[0], index + 1)) + words_freq_sorted.append((word[0], index)) return words_freq_sorted @@ -79,7 +76,6 @@ def sort_files(): files_list """ files_list = list() - download_data_if_not_yet() neg_file_list = movie_reviews.fileids('neg') pos_file_list = movie_reviews.fileids('pos') files_list = list(chain.from_iterable(zip(neg_file_list, pos_file_list))) @@ -104,9 +100,6 @@ def load_sentiment_data(): return data_set -data_set = load_sentiment_data() - - def reader_creator(data): """ Reader creator, it format data set to numpy @@ -114,15 +107,14 @@ def reader_creator(data): train data set or test data set """ for each in data: - list_of_int = np.array(each[0], dtype=np.int32) - label = each[1] - yield list_of_int, label + yield each[0], each[1] def train(): """ Default train set reader creator """ + data_set = load_sentiment_data() return reader_creator(data_set[0:NUM_TRAINING_INSTANCES]) @@ -130,14 +122,5 @@ def test(): """ Default test set reader creator """ + data_set = load_sentiment_data() return reader_creator(data_set[NUM_TRAINING_INSTANCES:]) - - -def unittest(): - assert len(data_set) == NUM_TOTAL_INSTANCES - assert len(list(train())) == NUM_TRAINING_INSTANCES - assert len(list(test())) == NUM_TOTAL_INSTANCES - NUM_TRAINING_INSTANCES - - -if __name__ == '__main__': - unittest() diff --git a/python/paddle/v2/dataset/tests/test_sentiment.py b/python/paddle/v2/dataset/tests/test_sentiment.py new file mode 100644 index 00000000000..48a14aad2a9 --- /dev/null +++ b/python/paddle/v2/dataset/tests/test_sentiment.py @@ -0,0 +1,52 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import nltk +import paddle.v2.dataset.sentiment as st +from nltk.corpus import movie_reviews + + +class TestSentimentMethods(unittest.TestCase): + def test_get_word_dict(self): + word_dict = st.get_word_dict()[0:10] + test_word_list = [(u',', 0), (u'the', 1), (u'.', 2), (u'a', 3), + (u'and', 4), (u'of', 5), (u'to', 6), (u"'", 7), + (u'is', 8), (u'in', 9)] + for idx, each in enumerate(word_dict): + self.assertEqual(each, test_word_list[idx]) + self.assertTrue("/root/.cache/paddle/dataset" in nltk.data.path) + + def test_sort_files(self): + last_label = '' + for sample_file in st.sort_files(): + current_label = sample_file.split("/")[0] + self.assertNotEqual(current_label, last_label) + last_label = current_label + + def test_data_set(self): + data_set = st.load_sentiment_data() + last_label = -1 + for each in st.test(): + self.assertNotEqual(each[1], last_label) + last_label = each[1] + self.assertEqual(len(data_set), st.NUM_TOTAL_INSTANCES) + self.assertEqual(len(list(st.train())), st.NUM_TRAINING_INSTANCES) + self.assertEqual( + len(list(st.test())), + (st.NUM_TOTAL_INSTANCES - st.NUM_TRAINING_INSTANCES)) + + +if __name__ == '__main__': + unittest.main() -- GitLab