diff --git a/demo/word2vec/train_v2.py b/demo/word2vec/train_v2.py index 60235c93eab83d04599f2d451c73b64e6315ab87..7d952b446f9db432062fc3305a6b65b0ad66dd47 100644 --- a/demo/word2vec/train_v2.py +++ b/demo/word2vec/train_v2.py @@ -58,7 +58,8 @@ def main(): if isinstance(event, paddle.event.EndIteration): if event.batch_id % 100 == 0: result = trainer.test( - paddle.dataset.imikolov.test(word_dict, N), 128) + paddle.batch( + paddle.dataset.imikolov.test(word_dict, N), 32)) print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % ( event.pass_id, event.batch_id, event.cost, event.metrics, result.metrics) @@ -70,10 +71,9 @@ def main(): regularization=paddle.optimizer.L2Regularization(8e-4)) trainer = paddle.trainer.SGD(cost, parameters, adam_optimizer) trainer.train( - paddle.dataset.imikolov.train(word_dict, N), + paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32), num_passes=30, - batch_size=128, - event_handler=event_handler, ) + event_handler=event_handler) if __name__ == '__main__':