From d815ec2354ed722d6a0e5d6102318ca0b5d9e687 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Wed, 24 Jan 2018 16:51:50 +0800 Subject: [PATCH] Use test accuracy to exit(0) --- .../fluid/tests/book/test_recognize_digits.py | 31 +++++++++++++------ 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits.py index 4ecdcdc6327..19b93a66f47 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits.py @@ -16,6 +16,7 @@ import argparse import paddle.v2.fluid as fluid import paddle.v2 as paddle import sys +import numpy def parse_arg(): @@ -100,6 +101,8 @@ def main(): else: avg_loss, acc = net_conf(img, label) + test_program = fluid.default_main_program().clone() + optimizer = fluid.optimizer.Adam(learning_rate=0.001) optimizer.minimize(avg_loss) @@ -112,6 +115,8 @@ def main(): paddle.reader.shuffle( paddle.dataset.mnist.train(), buf_size=500), batch_size=BATCH_SIZE) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=BATCH_SIZE) feeder = fluid.DataFeeder(feed_list=[img, label], place=place) PASS_NUM = 100 @@ -119,21 +124,27 @@ def main(): for batch_id, data in enumerate(train_reader()): need_check = (batch_id + 1) % 10 == 0 + # train a mini-batch, fetch nothing + exe.run(feed=feeder.feed(data)) if need_check: - fetch_list = [avg_loss, acc] - else: - fetch_list = [] - - outs = exe.run(feed=feeder.feed(data), fetch_list=fetch_list) - if need_check: - avg_loss_np, acc_np = outs - if float(acc_np) > 0.9: + acc_set = [] + avg_loss_set = [] + for test_data in test_reader(): + acc_np, avg_loss_np = exe.run(program=test_program, + feed=feeder.feed(test_data), + fetch_list=[acc, avg_loss]) + acc_set.append(float(acc_np)) + avg_loss_set.append(float(avg_loss_np)) + # get test acc and loss + acc_val = numpy.array(acc_set).mean() + avg_loss_val = numpy.array(avg_loss_set).mean() + if float(acc_val) > 0.85: # test acc > 85% exit(0) else: print( - 'PassID {0:1}, BatchID {1:04}, Loss {2:2.2}, Acc {3:2.2}'. + 'PassID {0:1}, BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}'. format(pass_id, batch_id + 1, - float(avg_loss_np), float(acc_np))) + float(avg_loss_val), float(acc_val))) if __name__ == '__main__': -- GitLab