提交 d3c755df 编写于 作者: H hedaoyuan

Refine code

上级 0a33f170
import sys
from os.path import join as join_path
import paddle.trainer_config_helpers.attrs as attrs
from paddle.trainer_config_helpers.poolings import MaxPooling
......@@ -188,88 +189,69 @@ def stacked_lstm_net(input_dim,
return cost
def data_reader():
data_dir = "./data/pre-imdb"
train_file = "train_part_000"
test_file = "test_part_000"
dict_file = "dict.txt"
train_file = join_path(data_dir, train_file)
test_file = join_path(data_dir, test_file)
dict_file = join_path(data_dir, dict_file)
with open(dict_file, 'r') as fdict, open(train_file, 'r') as fdata:
dictionary = dict()
for i, line in enumerate(fdict):
dictionary[line.split('\t')[0]] = i
for line_count, line in enumerate(fdata):
label, comment = line.strip().split('\t\t')
label = int(label)
words = comment.split()
word_slot = [dictionary[w] for w in words if w in dictionary]
yield (word_slot, label)
def test_reader():
data_dir = "./data/pre-imdb"
train_file = "train_part_000"
test_file = "test_part_000"
dict_file = "dict.txt"
train_file = join_path(data_dir, train_file)
test_file = join_path(data_dir, test_file)
dict_file = join_path(data_dir, dict_file)
with open(dict_file, 'r') as fdict, open(test_file, 'r') as ftest:
dictionary = dict()
for i, line in enumerate(fdict):
dictionary[line.split('\t')[0]] = i
for line_count, line in enumerate(ftest):
label, comment = line.strip().split('\t\t')
label = int(label)
words = comment.split()
word_slot = [dictionary[w] for w in words if w in dictionary]
yield (word_slot, label)
def data_reader(data_file, dict_file):
def reader():
with open(dict_file, 'r') as fdict, open(data_file, 'r') as fdata:
dictionary = dict()
for i, line in enumerate(fdict):
dictionary[line.split('\t')[0]] = i
for line_count, line in enumerate(fdata):
label, comment = line.strip().split('\t\t')
label = int(label)
words = comment.split()
word_slot = [dictionary[w] for w in words if w in dictionary]
yield (word_slot, label)
return reader
if __name__ == '__main__':
data_dir = "./data/pre-imdb"
train_list = "train.list"
test_list = "test.list"
dict_file = "dict.txt"
dict_dim = len(open(join_path(data_dir, "dict.txt")).readlines())
class_dim = len(open(join_path(data_dir, 'labels.list')).readlines())
is_predict = False
# data file
train_file = "./data/pre-imdb/train_part_000"
test_file = "./data/pre-imdb/test_part_000"
dict_file = "./data/pre-imdb/dict.txt"
labels = "./data/pre-imdb/labels.list"
# init
paddle.init(use_gpu=True, trainer_count=4)
# network config
# cost = convolution_net(dict_dim, class_dim=class_dim, is_predict=is_predict)
cost = stacked_lstm_net(
dict_dim, class_dim=class_dim, stacked_num=3, is_predict=is_predict)
dict_dim = len(open(dict_file).readlines())
class_dim = len(open(labels).readlines())
# Please choose the way to build the network
# by uncommenting the corresponding line.
cost = convolution_net(dict_dim, class_dim=class_dim)
# cost = stacked_lstm_net(dict_dim, class_dim=class_dim, stacked_num=3)
# create parameters
parameters = paddle.parameters.create(cost)
# create optimizer
adam_optimizer = paddle.optimizer.Adam(
learning_rate=2e-3,
regularization=paddle.optimizer.L2Regularization(rate=8e-4),
model_average=paddle.optimizer.ModelAverage(average_window=0.5))
# End batch and end pass event handler
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % (
print "\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
else:
sys.stdout.write('.')
sys.stdout.flush()
if isinstance(event, paddle.event.EndPass):
result = trainer.test(
reader=paddle.reader.batched(
test_reader, batch_size=128),
data_reader(test_file, dict_file), batch_size=128),
reader_dict={'word': 0,
'label': 1})
print "Test with Pass %d, %s" % (event.pass_id, result.metrics)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
# create trainer
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=adam_optimizer)
......@@ -277,7 +259,8 @@ if __name__ == '__main__':
trainer.train(
reader=paddle.reader.batched(
paddle.reader.shuffle(
data_reader, buf_size=4096), batch_size=128),
data_reader(train_file, dict_file), buf_size=4096),
batch_size=128),
event_handler=event_handler,
reader_dict={'word': 0,
'label': 1},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册