From 4b945982d3d962f3dd9936c9305a0f262b702b80 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 7 Mar 2017 01:26:00 +0800 Subject: [PATCH] Fix README --- understand_sentiment/README.md | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/understand_sentiment/README.md b/understand_sentiment/README.md index a4a359a..ccce07f 100644 --- a/understand_sentiment/README.md +++ b/understand_sentiment/README.md @@ -223,11 +223,11 @@ if __name__ == '__main__': ``` 加载数据字典,这里通过`word_dict()`API可以直接构造字典。`class_dim`是指样本类别数,该示例中样本只有正负两类。 ``` - train_reader = paddle.reader.batched( + train_reader = paddle.batch( paddle.reader.shuffle( lambda: paddle.dataset.imdb.train(word_dict), buf_size=1000), batch_size=100) - test_reader = paddle.reader.batched( + test_reader = paddle.batch( lambda: paddle.dataset.imdb.test(word_dict), batch_size=100) ``` @@ -272,12 +272,7 @@ Paddle中提供了一系列优化算法的API,这里使用Adam优化算法。 sys.stdout.write('.') sys.stdout.flush() if isinstance(event, paddle.event.EndPass): - result = trainer.test( - reader=paddle.reader.batched( - lambda: paddle.dataset.imdb.test(word_dict), - batch_size=128), - reader_dict={'word': 0, - 'label': 1}) + result = trainer.test(reader=test_reader, reader_dict=reader_dict) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) ``` 可以通过给train函数传递一个`event_handler`来获取每个batch和每个pass结束的状态。比如构造如下一个`event_handler`可以在每100个batch结束后输出cost和error;在每个pass结束后调用`trainer.test`计算一遍测试集并获得当前模型在测试集上的error。 @@ -288,15 +283,10 @@ Paddle中提供了一系列优化算法的API,这里使用Adam优化算法。 update_equation=adam_optimizer) trainer.train( - reader=paddle.reader.batched( - paddle.reader.shuffle( - lambda: paddle.dataset.imdb.train(word_dict), buf_size=1000), - batch_size=100), + reader=train_reader, event_handler=event_handler, - reader_dict={'word': 0, - 'label': 1}, - num_passes=10) - + reader_dict=reader_dict, + num_passes=2) ``` 程序运行之后的输出如下。 ``` -- GitLab