提交 812e21f3 编写于 作者: W wen-bo-yang

add cross reading sample files and fix bugs

上级 55d19fc4
...@@ -72,7 +72,7 @@ setup(name="py_paddle", ...@@ -72,7 +72,7 @@ setup(name="py_paddle",
packages=['py_paddle'], packages=['py_paddle'],
include_dirs = include_dirs, include_dirs = include_dirs,
install_requires = [ install_requires = [
'nltk', 'nltk>=3.2.2',
'numpy>=1.8.0', # The numpy is required. 'numpy>=1.8.0', # The numpy is required.
'protobuf>=3.0.0' # The paddle protobuf version 'protobuf>=3.0.0' # The paddle protobuf version
], ],
......
import os
__all__ = ['DATA_HOME']
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
if not os.path.exists(DATA_HOME):
os.makedirs(DATA_HOME)
...@@ -20,9 +20,9 @@ The script fetch and preprocess movie_reviews data set ...@@ -20,9 +20,9 @@ The script fetch and preprocess movie_reviews data set
that provided by NLTK that provided by NLTK
""" """
import nltk import nltk
import numpy as np import numpy as np
from itertools import chain
from nltk.corpus import movie_reviews from nltk.corpus import movie_reviews
from config import DATA_HOME from config import DATA_HOME
...@@ -50,9 +50,10 @@ def download_data_if_not_yet(): ...@@ -50,9 +50,10 @@ def download_data_if_not_yet():
except LookupError: except LookupError:
print "Downloading movie_reviews data set, please wait....." print "Downloading movie_reviews data set, please wait....."
nltk.download('movie_reviews', download_dir=DATA_HOME) nltk.download('movie_reviews', download_dir=DATA_HOME)
print "Download data set success......"
# make sure that nltk can find the data # make sure that nltk can find the data
nltk.data.path.append(DATA_HOME) nltk.data.path.append(DATA_HOME)
print "Download data set success....."
print "Path is " + nltk.data.find('corpora/movie_reviews').path
def get_word_dict(): def get_word_dict():
...@@ -67,24 +68,39 @@ def get_word_dict(): ...@@ -67,24 +68,39 @@ def get_word_dict():
words_sort_list = words_freq.items() words_sort_list = words_freq.items()
words_sort_list.sort(cmp=lambda a, b: b[1] - a[1]) words_sort_list.sort(cmp=lambda a, b: b[1] - a[1])
for index, word in enumerate(words_sort_list): for index, word in enumerate(words_sort_list):
words_freq_sorted.append(word[0]) words_freq_sorted.append((word[0], index + 1))
return words_freq_sorted return words_freq_sorted
def sort_files():
"""
Sorted the sample for cross reading the sample
:return:
files_list
"""
files_list = list()
download_data_if_not_yet()
neg_file_list = movie_reviews.fileids('neg')
pos_file_list = movie_reviews.fileids('pos')
files_list = list(chain.from_iterable(zip(neg_file_list, pos_file_list)))
return files_list
def load_sentiment_data(): def load_sentiment_data():
""" """
Load the data set Load the data set
:return: :return:
data_set data_set
""" """
label_dict = get_label_dict() data_set = list()
download_data_if_not_yet() download_data_if_not_yet()
words_freq = nltk.FreqDist(w.lower() for w in movie_reviews.words()) words_ids = dict(get_word_dict())
data_set = [([words_freq[word.lower()] for sample_file in sort_files():
for word in movie_reviews.words(fileid)], words_list = list()
label_dict[category]) category = 0 if 'neg' in sample_file else 1
for category in movie_reviews.categories() for word in movie_reviews.words(sample_file):
for fileid in movie_reviews.fileids(category)] words_list.append(words_ids[word.lower()])
data_set.append((words_list, category))
return data_set return data_set
...@@ -98,9 +114,9 @@ def reader_creator(data): ...@@ -98,9 +114,9 @@ def reader_creator(data):
train data set or test data set train data set or test data set
""" """
for each in data: for each in data:
sentences = np.array(each[0], dtype=np.int32) list_of_int = np.array(each[0], dtype=np.int32)
labels = np.array(each[1], dtype=np.int8) label = each[1]
yield sentences, labels yield list_of_int, label
def train(): def train():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册