提交 4a265b52 编写于 作者: H hedaoyuan

Use reader in dataset imdb.py

上级 d3c755df
...@@ -2,10 +2,11 @@ import sys ...@@ -2,10 +2,11 @@ import sys
from os.path import join as join_path from os.path import join as join_path
import paddle.trainer_config_helpers.attrs as attrs import paddle.trainer_config_helpers.attrs as attrs
from paddle.trainer_config_helpers.poolings import MaxPooling from paddle.trainer_config_helpers.poolings import MaxPooling
import paddle.v2 as paddle
import paddle.v2.layer as layer import paddle.v2.layer as layer
import paddle.v2.activation as activation import paddle.v2.activation as activation
import paddle.v2.data_type as data_type import paddle.v2.data_type as data_type
import paddle.v2.dataset.imdb as imdb
import paddle.v2 as paddle
def sequence_conv_pool(input, def sequence_conv_pool(input,
...@@ -189,36 +190,15 @@ def stacked_lstm_net(input_dim, ...@@ -189,36 +190,15 @@ def stacked_lstm_net(input_dim,
return cost return cost
def data_reader(data_file, dict_file):
def reader():
with open(dict_file, 'r') as fdict, open(data_file, 'r') as fdata:
dictionary = dict()
for i, line in enumerate(fdict):
dictionary[line.split('\t')[0]] = i
for line_count, line in enumerate(fdata):
label, comment = line.strip().split('\t\t')
label = int(label)
words = comment.split()
word_slot = [dictionary[w] for w in words if w in dictionary]
yield (word_slot, label)
return reader
if __name__ == '__main__': if __name__ == '__main__':
# data file
train_file = "./data/pre-imdb/train_part_000"
test_file = "./data/pre-imdb/test_part_000"
dict_file = "./data/pre-imdb/dict.txt"
labels = "./data/pre-imdb/labels.list"
# init # init
paddle.init(use_gpu=True, trainer_count=4) paddle.init(use_gpu=True, trainer_count=4)
# network config # network config
dict_dim = len(open(dict_file).readlines()) print 'load dictionary...'
class_dim = len(open(labels).readlines()) word_dict = imdb.word_dict()
dict_dim = len(word_dict)
class_dim = 2
# Please choose the way to build the network # Please choose the way to build the network
# by uncommenting the corresponding line. # by uncommenting the corresponding line.
...@@ -246,7 +226,7 @@ if __name__ == '__main__': ...@@ -246,7 +226,7 @@ if __name__ == '__main__':
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
result = trainer.test( result = trainer.test(
reader=paddle.reader.batched( reader=paddle.reader.batched(
data_reader(test_file, dict_file), batch_size=128), lambda: imdb.test(word_dict), batch_size=128),
reader_dict={'word': 0, reader_dict={'word': 0,
'label': 1}) 'label': 1})
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
...@@ -259,8 +239,8 @@ if __name__ == '__main__': ...@@ -259,8 +239,8 @@ if __name__ == '__main__':
trainer.train( trainer.train(
reader=paddle.reader.batched( reader=paddle.reader.batched(
paddle.reader.shuffle( paddle.reader.shuffle(
data_reader(train_file, dict_file), buf_size=4096), lambda: imdb.train(word_dict), buf_size=1000),
batch_size=128), batch_size=100),
event_handler=event_handler, event_handler=event_handler,
reader_dict={'word': 0, reader_dict={'word': 0,
'label': 1}, 'label': 1},
......
...@@ -118,3 +118,8 @@ def test(word_idx): ...@@ -118,3 +118,8 @@ def test(word_idx):
return reader_creator( return reader_creator(
re.compile("aclImdb/test/pos/.*\.txt$"), re.compile("aclImdb/test/pos/.*\.txt$"),
re.compile("aclImdb/test/neg/.*\.txt$"), word_idx, 1000) re.compile("aclImdb/test/neg/.*\.txt$"), word_idx, 1000)
def word_dict():
return build_dict(
re.compile("aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$"), 150)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册