提交 d815ec23 编写于 作者: Y Yang Yu

Use test accuracy to exit(0)

上级 ef55a8f6
...@@ -16,6 +16,7 @@ import argparse ...@@ -16,6 +16,7 @@ import argparse
import paddle.v2.fluid as fluid import paddle.v2.fluid as fluid
import paddle.v2 as paddle import paddle.v2 as paddle
import sys import sys
import numpy
def parse_arg(): def parse_arg():
...@@ -100,6 +101,8 @@ def main(): ...@@ -100,6 +101,8 @@ def main():
else: else:
avg_loss, acc = net_conf(img, label) avg_loss, acc = net_conf(img, label)
test_program = fluid.default_main_program().clone()
optimizer = fluid.optimizer.Adam(learning_rate=0.001) optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(avg_loss) optimizer.minimize(avg_loss)
...@@ -112,6 +115,8 @@ def main(): ...@@ -112,6 +115,8 @@ def main():
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500), paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
feeder = fluid.DataFeeder(feed_list=[img, label], place=place) feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
PASS_NUM = 100 PASS_NUM = 100
...@@ -119,21 +124,27 @@ def main(): ...@@ -119,21 +124,27 @@ def main():
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
need_check = (batch_id + 1) % 10 == 0 need_check = (batch_id + 1) % 10 == 0
# train a mini-batch, fetch nothing
exe.run(feed=feeder.feed(data))
if need_check: if need_check:
fetch_list = [avg_loss, acc] acc_set = []
else: avg_loss_set = []
fetch_list = [] for test_data in test_reader():
acc_np, avg_loss_np = exe.run(program=test_program,
outs = exe.run(feed=feeder.feed(data), fetch_list=fetch_list) feed=feeder.feed(test_data),
if need_check: fetch_list=[acc, avg_loss])
avg_loss_np, acc_np = outs acc_set.append(float(acc_np))
if float(acc_np) > 0.9: avg_loss_set.append(float(avg_loss_np))
# get test acc and loss
acc_val = numpy.array(acc_set).mean()
avg_loss_val = numpy.array(avg_loss_set).mean()
if float(acc_val) > 0.85: # test acc > 85%
exit(0) exit(0)
else: else:
print( print(
'PassID {0:1}, BatchID {1:04}, Loss {2:2.2}, Acc {3:2.2}'. 'PassID {0:1}, BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}'.
format(pass_id, batch_id + 1, format(pass_id, batch_id + 1,
float(avg_loss_np), float(acc_np))) float(avg_loss_val), float(acc_val)))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册