提交 d815ec23 编写于 作者: Y Yang Yu

Use test accuracy to exit(0)

上级 ef55a8f6
......@@ -16,6 +16,7 @@ import argparse
import paddle.v2.fluid as fluid
import paddle.v2 as paddle
import sys
import numpy
def parse_arg():
......@@ -100,6 +101,8 @@ def main():
else:
avg_loss, acc = net_conf(img, label)
test_program = fluid.default_main_program().clone()
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(avg_loss)
......@@ -112,6 +115,8 @@ def main():
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
PASS_NUM = 100
......@@ -119,21 +124,27 @@ def main():
for batch_id, data in enumerate(train_reader()):
need_check = (batch_id + 1) % 10 == 0
# train a mini-batch, fetch nothing
exe.run(feed=feeder.feed(data))
if need_check:
fetch_list = [avg_loss, acc]
else:
fetch_list = []
outs = exe.run(feed=feeder.feed(data), fetch_list=fetch_list)
if need_check:
avg_loss_np, acc_np = outs
if float(acc_np) > 0.9:
acc_set = []
avg_loss_set = []
for test_data in test_reader():
acc_np, avg_loss_np = exe.run(program=test_program,
feed=feeder.feed(test_data),
fetch_list=[acc, avg_loss])
acc_set.append(float(acc_np))
avg_loss_set.append(float(avg_loss_np))
# get test acc and loss
acc_val = numpy.array(acc_set).mean()
avg_loss_val = numpy.array(avg_loss_set).mean()
if float(acc_val) > 0.85: # test acc > 85%
exit(0)
else:
print(
'PassID {0:1}, BatchID {1:04}, Loss {2:2.2}, Acc {3:2.2}'.
'PassID {0:1}, BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}'.
format(pass_id, batch_id + 1,
float(avg_loss_np), float(acc_np)))
float(avg_loss_val), float(acc_val)))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册