From f04c97a0359d3eafa7a13807be5065199a5716d5 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 22 Nov 2017 10:05:09 +0800 Subject: [PATCH] refine test_understand_sentiment_lstm (#5781) * fix * Fix a bug --- .../book/test_understand_sentiment_lstm.py | 54 +++++++++++-------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py index 280f6e902c3..9a51a2f207e 100644 --- a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py +++ b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py @@ -54,17 +54,17 @@ def to_lodtensor(data, place): return res -def chop_data(data, chop_len=80, batch_len=50): +def chop_data(data, chop_len=80, batch_size=50): data = [(x[0][:chop_len], x[1]) for x in data if len(x[0]) >= chop_len] - return data[:batch_len] + return data[:batch_size] def prepare_feed_data(data, place): tensor_words = to_lodtensor(map(lambda x: x[0], data), place) label = np.array(map(lambda x: x[1], data)).astype("int64") - label = label.reshape([50, 1]) + label = label.reshape([len(label), 1]) tensor_label = core.LoDTensor() tensor_label.set(label, place) @@ -72,33 +72,41 @@ def prepare_feed_data(data, place): def main(): - word_dict = paddle.dataset.imdb.word_dict() - cost, acc = lstm_net(dict_dim=len(word_dict), class_dim=2) + BATCH_SIZE = 100 + PASS_NUM = 5 - batch_size = 100 - train_data = paddle.batch( - paddle.reader.buffered( - paddle.dataset.imdb.train(word_dict), size=batch_size * 10), - batch_size=batch_size) + word_dict = paddle.dataset.imdb.word_dict() + print "load word dict successfully" + dict_dim = len(word_dict) + class_dim = 2 - data = chop_data(next(train_data())) + cost, acc = lstm_net(dict_dim=dict_dim, class_dim=class_dim) + train_data = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.imdb.train(word_dict), buf_size=BATCH_SIZE * 10), + batch_size=BATCH_SIZE) place = core.CPUPlace() - tensor_words, tensor_label = prepare_feed_data(data, place) exe = Executor(place) + exe.run(framework.default_startup_program()) - while True: - outs = exe.run(framework.default_main_program(), - feed={"words": tensor_words, - "label": tensor_label}, - fetch_list=[cost, acc]) - cost_val = np.array(outs[0]) - acc_val = np.array(outs[1]) - - print("cost=" + str(cost_val) + " acc=" + str(acc_val)) - if acc_val > 0.9: - break + for pass_id in xrange(PASS_NUM): + for data in train_data(): + chopped_data = chop_data(data) + tensor_words, tensor_label = prepare_feed_data(chopped_data, place) + + outs = exe.run(framework.default_main_program(), + feed={"words": tensor_words, + "label": tensor_label}, + fetch_list=[cost, acc]) + cost_val = np.array(outs[0]) + acc_val = np.array(outs[1]) + + print("cost=" + str(cost_val) + " acc=" + str(acc_val)) + if acc_val > 0.7: + exit(0) + exit(1) if __name__ == '__main__': -- GitLab