diff --git a/demo/seqToseq/api_train_v2.py b/demo/seqToseq/api_train_v2.py index 8b9db4352ddb1293c86a66f29176d8ec6e9e5ec8..a5f59ec379738eb5bed3e7559739cae38582ed06 100644 --- a/demo/seqToseq/api_train_v2.py +++ b/demo/seqToseq/api_train_v2.py @@ -73,31 +73,34 @@ def main(): cost = seqToseq_net_v2(source_dict_dim, target_dict_dim) parameters = paddle.parameters.create(cost) + # define optimize method and trainer optimizer = paddle.optimizer.Adam(learning_rate=1e-4) - - def event_handler(event): - if isinstance(event, paddle.event.EndIteration): - if event.batch_id % 10 == 0: - print "Pass %d, Batch %d, Cost %f, %s" % ( - event.pass_id, event.batch_id, event.cost, event.metrics) - trainer = paddle.trainer.SGD(cost=cost, parameters=parameters, update_equation=optimizer) + # define data reader reader_dict = { 'source_language_word': 0, 'target_language_word': 1, 'target_language_next_word': 2 } - trn_reader = paddle.reader.batched( + wmt14_reader = paddle.reader.batched( paddle.reader.shuffle( train_reader("data/pre-wmt14/train/train"), buf_size=8192), batch_size=5) + # define event_handler callback + def event_handler(event): + if isinstance(event, paddle.event.EndIteration): + if event.batch_id % 10 == 0: + print "Pass %d, Batch %d, Cost %f, %s" % ( + event.pass_id, event.batch_id, event.cost, event.metrics) + + # start to train trainer.train( - reader=trn_reader, + reader=wmt14_reader, event_handler=event_handler, num_passes=10000, reader_dict=reader_dict)