提交 803da664 编写于 作者: H hedaoyuan

Add test

上级 1d0a8c2f
......@@ -142,6 +142,28 @@ def data_reader():
yield (word_slot, label)
def test_reader():
data_dir = "./data/pre-imdb"
train_file = "train_part_000"
test_file = "test_part_000"
dict_file = "dict.txt"
train_file = join_path(data_dir, train_file)
test_file = join_path(data_dir, test_file)
dict_file = join_path(data_dir, dict_file)
with open(dict_file, 'r') as fdict, open(test_file, 'r') as ftest:
dictionary = dict()
for i, line in enumerate(fdict):
dictionary[line.split('\t')[0]] = i
for line_count, line in enumerate(ftest):
label, comment = line.strip().split('\t\t')
label = int(label)
words = comment.split()
word_slot = [dictionary[w] for w in words if w in dictionary]
yield (word_slot, label)
if __name__ == '__main__':
data_dir = "./data/pre-imdb"
train_list = "train.list"
......@@ -170,6 +192,13 @@ if __name__ == '__main__':
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
if isinstance(event, paddle.event.EndPass):
result = trainer.test(
reader=paddle.reader.batched(
test_reader, batch_size=128),
reader_dict={'word': 0,
'label': 1})
print "Test with Pass %d, %s" % (event.pass_id, result.metrics)
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册