提交 55d19fc4 编写于 作者: W wen-bo-yang

fix bugs

上级 a6f25f3d
......@@ -72,7 +72,6 @@ setup(name="py_paddle",
packages=['py_paddle'],
include_dirs = include_dirs,
install_requires = [
'h5py',
'nltk',
'numpy>=1.8.0', # The numpy is required.
'protobuf>=3.0.0' # The paddle protobuf version
......
......@@ -2,7 +2,7 @@ import os
__all__ = ['DATA_HOME']
DATA_HOME = os.path.expanduser('~/.cache/paddle_data_set')
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
if not os.path.exists(DATA_HOME):
os.makedirs(DATA_HOME)
import random
# /usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The script fetch and preprocess movie_reviews data set
that provided by NLTK
"""
import nltk
import numpy as np
from nltk.corpus import movie_reviews
from config import DATA_HOME
__all__ = ['train', 'test', 'get_label_dict', 'get_word_dict']
SPLIT_NUM = 800
TOTAL_DATASET_NUM = 1000
NUM_TRAINING_INSTANCES = 1600
NUM_TOTAL_INSTANCES = 2000
def get_label_dict():
"""
Define the labels dict for dataset
"""
label_dict = {'neg': 0, 'pos': 1}
return label_dict
def is_download_data():
def download_data_if_not_yet():
"""
Download the data set, if the data set is not download.
"""
try:
# make sure that nltk can find the data
nltk.data.path.append(DATA_HOME)
movie_reviews.categories()
except LookupError:
print "dd"
print "Downloading movie_reviews data set, please wait....."
nltk.download('movie_reviews', download_dir=DATA_HOME)
print "Download data set success......"
# make sure that nltk can find the data
nltk.data.path.append(DATA_HOME)
def get_word_dict():
"""
Sorted the words by the frequency of words which occur in sample
:return:
words_freq_sorted
"""
words_freq_sorted = list()
is_download_data()
download_data_if_not_yet()
words_freq = nltk.FreqDist(w.lower() for w in movie_reviews.words())
words_sort_list = words_freq.items()
words_sort_list.sort(cmp=lambda a, b: b[1] - a[1])
print words_sort_list
for index, word in enumerate(words_sort_list):
words_freq_sorted.append(word[0])
return words_freq_sorted
def load_sentiment_data():
"""
Load the data set
:return:
data_set
"""
label_dict = get_label_dict()
is_download_data()
download_data_if_not_yet()
words_freq = nltk.FreqDist(w.lower() for w in movie_reviews.words())
data_set = [([words_freq[word]
for word in movie_reviews.words(fileid)], label_dict[category])
data_set = [([words_freq[word.lower()]
for word in movie_reviews.words(fileid)],
label_dict[category])
for category in movie_reviews.categories()
for fileid in movie_reviews.fileids(category)]
random.shuffle(data_set)
return data_set
data_set = load_sentiment_data()
def reader_creator(data_type):
if data_type == 'train':
for each in data_set[0:SPLIT_NUM]:
train_sentences = np.array(each[0], dtype=np.int32)
train_label = np.array(each[1], dtype=np.int8)
yield train_sentences, train_label
else:
for each in data_set[SPLIT_NUM:]:
test_sentences = np.array(each[0], dtype=np.int32)
test_label = np.array(each[1], dtype=np.int8)
yield test_sentences, test_label
def reader_creator(data):
"""
Reader creator, it format data set to numpy
:param data:
train data set or test data set
"""
for each in data:
sentences = np.array(each[0], dtype=np.int32)
labels = np.array(each[1], dtype=np.int8)
yield sentences, labels
def train():
return reader_creator('train')
"""
Default train set reader creator
"""
return reader_creator(data_set[0:NUM_TRAINING_INSTANCES])
def test():
return reader_creator('test')
"""
Default test set reader creator
"""
return reader_creator(data_set[NUM_TRAINING_INSTANCES:])
def unittest():
assert len(data_set) == NUM_TOTAL_INSTANCES
assert len(list(train())) == NUM_TRAINING_INSTANCES
assert len(list(test())) == NUM_TOTAL_INSTANCES - NUM_TRAINING_INSTANCES
if __name__ == '__main__':
for train in train():
print "train"
print train
for test in test():
print "test"
print test
unittest()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册