提交 8fa09b82 编写于 作者: Q qiaolongfei

add some comment for api_train_v2 of seqtoseq

上级 c5fb4fd4
...@@ -73,31 +73,34 @@ def main(): ...@@ -73,31 +73,34 @@ def main():
cost = seqToseq_net_v2(source_dict_dim, target_dict_dim) cost = seqToseq_net_v2(source_dict_dim, target_dict_dim)
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
# define optimize method and trainer
optimizer = paddle.optimizer.Adam(learning_rate=1e-4) optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 10 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
trainer = paddle.trainer.SGD(cost=cost, trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters, parameters=parameters,
update_equation=optimizer) update_equation=optimizer)
# define data reader
reader_dict = { reader_dict = {
'source_language_word': 0, 'source_language_word': 0,
'target_language_word': 1, 'target_language_word': 1,
'target_language_next_word': 2 'target_language_next_word': 2
} }
trn_reader = paddle.reader.batched( wmt14_reader = paddle.reader.batched(
paddle.reader.shuffle( paddle.reader.shuffle(
train_reader("data/pre-wmt14/train/train"), buf_size=8192), train_reader("data/pre-wmt14/train/train"), buf_size=8192),
batch_size=5) batch_size=5)
# define event_handler callback
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 10 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
# start to train
trainer.train( trainer.train(
reader=trn_reader, reader=wmt14_reader,
event_handler=event_handler, event_handler=event_handler,
num_passes=10000, num_passes=10000,
reader_dict=reader_dict) reader_dict=reader_dict)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册