diff --git a/04.word2vec/train.py b/04.word2vec/train.py index 3600025863cd91e9b2e2c1c0ffb19af9fc28070d..eb596673ce8ad55dfe8c5c258beb344e825b7c25 100644 --- a/04.word2vec/train.py +++ b/04.word2vec/train.py @@ -12,17 +12,15 @@ def wordemb(inlayer): input=inlayer, size=embsize, param_attr=paddle.attr.Param( - name="_proj", - initial_std=0.001, - learning_rate=1, - l2_rate=0, )) + name="_proj", initial_std=0.001, learning_rate=1, l2_rate=0)) return wordemb def main(): - paddle.init(use_gpu=False, trainer_count=1) + paddle.init(use_gpu=False, trainer_count=3) word_dict = paddle.dataset.imikolov.build_dict() dict_size = len(word_dict) + # Every layer takes integer value of range [0, dict_size) firstword = paddle.layer.data( name="firstw", type=paddle.data_type.integer_value(dict_size)) secondword = paddle.layer.data( @@ -57,22 +55,26 @@ def main(): def event_handler(event): if isinstance(event, paddle.event.EndIteration): if event.batch_id % 100 == 0: - result = trainer.test( - paddle.batch( - paddle.dataset.imikolov.test(word_dict, N), 32)) - print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % ( - event.pass_id, event.batch_id, event.cost, event.metrics, - result.metrics) + print "Pass %d, Batch %d, Cost %f, %s" % ( + event.pass_id, event.batch_id, event.cost, event.metrics) + + if isinstance(event, paddle.event.EndPass): + result = trainer.test( + paddle.batch(paddle.dataset.imikolov.test(word_dict, N), 32)) + print "Pass %d, Testing metrics %s" % (event.pass_id, + result.metrics) + with open("model_%d.tar" % event.pass_id, 'w') as f: + parameters.to_tar(f) cost = paddle.layer.classification_cost(input=predictword, label=nextword) parameters = paddle.parameters.create(cost) - adam_optimizer = paddle.optimizer.Adam( + adagrad = paddle.optimizer.AdaGrad( learning_rate=3e-3, regularization=paddle.optimizer.L2Regularization(8e-4)) - trainer = paddle.trainer.SGD(cost, parameters, adam_optimizer) + trainer = paddle.trainer.SGD(cost, parameters, adagrad) trainer.train( paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32), - num_passes=30, + num_passes=100, event_handler=event_handler)