From 8fa09b82f9c182b8bec4b538b7c7da9ace03b4eb Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sat, 4 Mar 2017 21:50:56 +0800 Subject: [PATCH] add some comment for api_train_v2 of seqtoseq --- demo/seqToseq/api_train_v2.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/demo/seqToseq/api_train_v2.py b/demo/seqToseq/api_train_v2.py index 8b9db4352..a5f59ec37 100644 --- a/demo/seqToseq/api_train_v2.py +++ b/demo/seqToseq/api_train_v2.py @@ -73,31 +73,34 @@ def main(): cost = seqToseq_net_v2(source_dict_dim, target_dict_dim) parameters = paddle.parameters.create(cost) + # define optimize method and trainer optimizer = paddle.optimizer.Adam(learning_rate=1e-4) - - def event_handler(event): - if isinstance(event, paddle.event.EndIteration): - if event.batch_id % 10 == 0: - print "Pass %d, Batch %d, Cost %f, %s" % ( - event.pass_id, event.batch_id, event.cost, event.metrics) - trainer = paddle.trainer.SGD(cost=cost, parameters=parameters, update_equation=optimizer) + # define data reader reader_dict = { 'source_language_word': 0, 'target_language_word': 1, 'target_language_next_word': 2 } - trn_reader = paddle.reader.batched( + wmt14_reader = paddle.reader.batched( paddle.reader.shuffle( train_reader("data/pre-wmt14/train/train"), buf_size=8192), batch_size=5) + # define event_handler callback + def event_handler(event): + if isinstance(event, paddle.event.EndIteration): + if event.batch_id % 10 == 0: + print "Pass %d, Batch %d, Cost %f, %s" % ( + event.pass_id, event.batch_id, event.cost, event.metrics) + + # start to train trainer.train( - reader=trn_reader, + reader=wmt14_reader, event_handler=event_handler, num_passes=10000, reader_dict=reader_dict) -- GitLab