提交 d3c755df 编写于 作者: H hedaoyuan

Refine code

上级 0a33f170
import sys
from os.path import join as join_path from os.path import join as join_path
import paddle.trainer_config_helpers.attrs as attrs import paddle.trainer_config_helpers.attrs as attrs
from paddle.trainer_config_helpers.poolings import MaxPooling from paddle.trainer_config_helpers.poolings import MaxPooling
...@@ -188,16 +189,9 @@ def stacked_lstm_net(input_dim, ...@@ -188,16 +189,9 @@ def stacked_lstm_net(input_dim,
return cost return cost
def data_reader(): def data_reader(data_file, dict_file):
data_dir = "./data/pre-imdb" def reader():
train_file = "train_part_000" with open(dict_file, 'r') as fdict, open(data_file, 'r') as fdata:
test_file = "test_part_000"
dict_file = "dict.txt"
train_file = join_path(data_dir, train_file)
test_file = join_path(data_dir, test_file)
dict_file = join_path(data_dir, dict_file)
with open(dict_file, 'r') as fdict, open(train_file, 'r') as fdata:
dictionary = dict() dictionary = dict()
for i, line in enumerate(fdict): for i, line in enumerate(fdict):
dictionary[line.split('\t')[0]] = i dictionary[line.split('\t')[0]] = i
...@@ -209,67 +203,55 @@ def data_reader(): ...@@ -209,67 +203,55 @@ def data_reader():
word_slot = [dictionary[w] for w in words if w in dictionary] word_slot = [dictionary[w] for w in words if w in dictionary]
yield (word_slot, label) yield (word_slot, label)
return reader
def test_reader():
data_dir = "./data/pre-imdb"
train_file = "train_part_000"
test_file = "test_part_000"
dict_file = "dict.txt"
train_file = join_path(data_dir, train_file)
test_file = join_path(data_dir, test_file)
dict_file = join_path(data_dir, dict_file)
with open(dict_file, 'r') as fdict, open(test_file, 'r') as ftest:
dictionary = dict()
for i, line in enumerate(fdict):
dictionary[line.split('\t')[0]] = i
for line_count, line in enumerate(ftest):
label, comment = line.strip().split('\t\t')
label = int(label)
words = comment.split()
word_slot = [dictionary[w] for w in words if w in dictionary]
yield (word_slot, label)
if __name__ == '__main__': if __name__ == '__main__':
data_dir = "./data/pre-imdb" # data file
train_list = "train.list" train_file = "./data/pre-imdb/train_part_000"
test_list = "test.list" test_file = "./data/pre-imdb/test_part_000"
dict_file = "dict.txt" dict_file = "./data/pre-imdb/dict.txt"
dict_dim = len(open(join_path(data_dir, "dict.txt")).readlines()) labels = "./data/pre-imdb/labels.list"
class_dim = len(open(join_path(data_dir, 'labels.list')).readlines())
is_predict = False
# init # init
paddle.init(use_gpu=True, trainer_count=4) paddle.init(use_gpu=True, trainer_count=4)
# network config # network config
# cost = convolution_net(dict_dim, class_dim=class_dim, is_predict=is_predict) dict_dim = len(open(dict_file).readlines())
cost = stacked_lstm_net( class_dim = len(open(labels).readlines())
dict_dim, class_dim=class_dim, stacked_num=3, is_predict=is_predict)
# Please choose the way to build the network
# by uncommenting the corresponding line.
cost = convolution_net(dict_dim, class_dim=class_dim)
# cost = stacked_lstm_net(dict_dim, class_dim=class_dim, stacked_num=3)
# create parameters # create parameters
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
# create optimizer
adam_optimizer = paddle.optimizer.Adam( adam_optimizer = paddle.optimizer.Adam(
learning_rate=2e-3, learning_rate=2e-3,
regularization=paddle.optimizer.L2Regularization(rate=8e-4), regularization=paddle.optimizer.L2Regularization(rate=8e-4),
model_average=paddle.optimizer.ModelAverage(average_window=0.5)) model_average=paddle.optimizer.ModelAverage(average_window=0.5))
# End batch and end pass event handler
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0: if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % ( print "\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics) event.pass_id, event.batch_id, event.cost, event.metrics)
else:
sys.stdout.write('.')
sys.stdout.flush()
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
result = trainer.test( result = trainer.test(
reader=paddle.reader.batched( reader=paddle.reader.batched(
test_reader, batch_size=128), data_reader(test_file, dict_file), batch_size=128),
reader_dict={'word': 0, reader_dict={'word': 0,
'label': 1}) 'label': 1})
print "Test with Pass %d, %s" % (event.pass_id, result.metrics) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
# create trainer
trainer = paddle.trainer.SGD(cost=cost, trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters, parameters=parameters,
update_equation=adam_optimizer) update_equation=adam_optimizer)
...@@ -277,7 +259,8 @@ if __name__ == '__main__': ...@@ -277,7 +259,8 @@ if __name__ == '__main__':
trainer.train( trainer.train(
reader=paddle.reader.batched( reader=paddle.reader.batched(
paddle.reader.shuffle( paddle.reader.shuffle(
data_reader, buf_size=4096), batch_size=128), data_reader(train_file, dict_file), buf_size=4096),
batch_size=128),
event_handler=event_handler, event_handler=event_handler,
reader_dict={'word': 0, reader_dict={'word': 0,
'label': 1}, 'label': 1},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册