提交 97a1c0c7 编写于 作者: H Helin Wang

update word2vec to work with current train api

上级 797da069
......@@ -58,7 +58,8 @@ def main():
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
result = trainer.test(
paddle.dataset.imikolov.test(word_dict, N), 128)
paddle.batch(
paddle.dataset.imikolov.test(word_dict, N), 32))
print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics,
result.metrics)
......@@ -70,10 +71,9 @@ def main():
regularization=paddle.optimizer.L2Regularization(8e-4))
trainer = paddle.trainer.SGD(cost, parameters, adam_optimizer)
trainer.train(
paddle.dataset.imikolov.train(word_dict, N),
paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32),
num_passes=30,
batch_size=128,
event_handler=event_handler, )
event_handler=event_handler)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册