提交 812e21f3 编写于 作者: W wen-bo-yang

add cross reading sample files and fix bugs

上级 55d19fc4
......@@ -72,7 +72,7 @@ setup(name="py_paddle",
packages=['py_paddle'],
include_dirs = include_dirs,
install_requires = [
'nltk',
'nltk>=3.2.2',
'numpy>=1.8.0', # The numpy is required.
'protobuf>=3.0.0' # The paddle protobuf version
],
......
import os
__all__ = ['DATA_HOME']
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
if not os.path.exists(DATA_HOME):
os.makedirs(DATA_HOME)
......@@ -20,9 +20,9 @@ The script fetch and preprocess movie_reviews data set
that provided by NLTK
"""
import nltk
import numpy as np
from itertools import chain
from nltk.corpus import movie_reviews
from config import DATA_HOME
......@@ -50,9 +50,10 @@ def download_data_if_not_yet():
except LookupError:
print "Downloading movie_reviews data set, please wait....."
nltk.download('movie_reviews', download_dir=DATA_HOME)
print "Download data set success......"
# make sure that nltk can find the data
nltk.data.path.append(DATA_HOME)
print "Download data set success....."
print "Path is " + nltk.data.find('corpora/movie_reviews').path
def get_word_dict():
......@@ -67,24 +68,39 @@ def get_word_dict():
words_sort_list = words_freq.items()
words_sort_list.sort(cmp=lambda a, b: b[1] - a[1])
for index, word in enumerate(words_sort_list):
words_freq_sorted.append(word[0])
words_freq_sorted.append((word[0], index + 1))
return words_freq_sorted
def sort_files():
"""
Sorted the sample for cross reading the sample
:return:
files_list
"""
files_list = list()
download_data_if_not_yet()
neg_file_list = movie_reviews.fileids('neg')
pos_file_list = movie_reviews.fileids('pos')
files_list = list(chain.from_iterable(zip(neg_file_list, pos_file_list)))
return files_list
def load_sentiment_data():
"""
Load the data set
:return:
data_set
"""
label_dict = get_label_dict()
data_set = list()
download_data_if_not_yet()
words_freq = nltk.FreqDist(w.lower() for w in movie_reviews.words())
data_set = [([words_freq[word.lower()]
for word in movie_reviews.words(fileid)],
label_dict[category])
for category in movie_reviews.categories()
for fileid in movie_reviews.fileids(category)]
words_ids = dict(get_word_dict())
for sample_file in sort_files():
words_list = list()
category = 0 if 'neg' in sample_file else 1
for word in movie_reviews.words(sample_file):
words_list.append(words_ids[word.lower()])
data_set.append((words_list, category))
return data_set
......@@ -98,9 +114,9 @@ def reader_creator(data):
train data set or test data set
"""
for each in data:
sentences = np.array(each[0], dtype=np.int32)
labels = np.array(each[1], dtype=np.int8)
yield sentences, labels
list_of_int = np.array(each[0], dtype=np.int32)
label = each[1]
yield list_of_int, label
def train():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册