提交 bbd7580e 编写于 作者: K Kexin Zhao 提交者: daminglu

simplify recognize digits example code (#10722)

上级 2a636529
...@@ -71,24 +71,18 @@ def train(use_cuda, train_program, save_dirname): ...@@ -71,24 +71,18 @@ def train(use_cuda, train_program, save_dirname):
if isinstance(event, fluid.EndEpochEvent): if isinstance(event, fluid.EndEpochEvent):
test_reader = paddle.batch( test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE) paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
test_metrics = trainer.test( avg_cost, acc = trainer.test(
reader=test_reader, feed_order=['img', 'label']) reader=test_reader, feed_order=['img', 'label'])
avg_cost_set = test_metrics[0]
acc_set = test_metrics[1]
# get test acc and loss
acc = numpy.array(acc_set).mean()
avg_cost = numpy.array(avg_cost_set).mean()
print("avg_cost: %s" % avg_cost) print("avg_cost: %s" % avg_cost)
print("acc : %s" % acc) print("acc : %s" % acc)
if float(acc) > 0.2: # Smaller value to increase CI speed if acc > 0.2: # Smaller value to increase CI speed
trainer.save_params(save_dirname) trainer.save_params(save_dirname)
else: else:
print('BatchID {0}, Test Loss {1:0.2}, Acc {2:0.2}'.format( print('BatchID {0}, Test Loss {1:0.2}, Acc {2:0.2}'.format(
event.epoch + 1, float(avg_cost), float(acc))) event.epoch + 1, avg_cost, acc))
if math.isnan(float(avg_cost)): if math.isnan(avg_cost):
sys.exit("got NaN loss, training failed.") sys.exit("got NaN loss, training failed.")
elif isinstance(event, fluid.EndStepEvent): elif isinstance(event, fluid.EndStepEvent):
print("Step {0}, Epoch {1} Metrics {2}".format( print("Step {0}, Epoch {1} Metrics {2}".format(
......
...@@ -55,24 +55,18 @@ def train(use_cuda, train_program, save_dirname): ...@@ -55,24 +55,18 @@ def train(use_cuda, train_program, save_dirname):
if isinstance(event, fluid.EndEpochEvent): if isinstance(event, fluid.EndEpochEvent):
test_reader = paddle.batch( test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE) paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
test_metrics = trainer.test( avg_cost, acc = trainer.test(
reader=test_reader, feed_order=['img', 'label']) reader=test_reader, feed_order=['img', 'label'])
avg_cost_set = test_metrics[0]
acc_set = test_metrics[1]
# get test acc and loss
acc = numpy.array(acc_set).mean()
avg_cost = numpy.array(avg_cost_set).mean()
print("avg_cost: %s" % avg_cost) print("avg_cost: %s" % avg_cost)
print("acc : %s" % acc) print("acc : %s" % acc)
if float(acc) > 0.2: # Smaller value to increase CI speed if acc > 0.2: # Smaller value to increase CI speed
trainer.save_params(save_dirname) trainer.save_params(save_dirname)
else: else:
print('BatchID {0}, Test Loss {1:0.2}, Acc {2:0.2}'.format( print('BatchID {0}, Test Loss {1:0.2}, Acc {2:0.2}'.format(
event.epoch + 1, float(avg_cost), float(acc))) event.epoch + 1, avg_cost, acc))
if math.isnan(float(avg_cost)): if math.isnan(avg_cost):
sys.exit("got NaN loss, training failed.") sys.exit("got NaN loss, training failed.")
train_reader = paddle.batch( train_reader = paddle.batch(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册