提交 b129f4d8 编写于 作者: F fengjiayi

Formatting code

上级 5236f6d2
......@@ -69,17 +69,18 @@ def main():
for pass_id in range(PASS_NUM):
pass_acc.reset()
for data in train_reader():
loss, acc, b_size = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost, batch_acc, batch_size])
loss, acc, b_size = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost, batch_acc, batch_size])
pass_acc.add(value=acc, weight=b_size)
print("pass_id=" + str(pass_id) + " acc=" + str(acc[0]) + " pass_acc="
+ str(pass_acc.eval()[0]))
print("pass_id=" + str(pass_id) + " acc=" + str(acc[0]) +
" pass_acc=" + str(pass_acc.eval()[0]))
if loss < LOSS_THRESHOLD and pass_acc > ACC_THRESHOLD:
break
print("pass_id=" + str(pass_id) +
" pass_acc=" + str(pass_acc.eval()[0]))
print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc.eval()[
0]))
fluid.io.save_params(
exe, dirname='./mnist', main_program=fluid.default_main_program())
print('train mnist done')
......
......@@ -196,21 +196,24 @@ def train(learning_rate, batch_size, num_passes, model_save_dir='model'):
for pass_id in range(num_passes):
train_pass_acc_evaluator.reset()
for batch_id, data in enumerate(train_reader()):
loss, acc, size = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost, b_acc_var, b_size_var])
loss, acc, size = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost, b_acc_var, b_size_var])
train_pass_acc_evaluator.add(value=acc, weight=size)
print("Pass {0}, batch {1}, loss {2}, acc {3}".format(
pass_id, batch_id, loss[0], acc[0]))
test_pass_acc_evaluator.reset()
for data in test_reader():
loss, acc, size = exe.run(inference_program,
feed=feeder.feed(data),
fetch_list=[avg_cost, b_acc_var, b_size_var])
loss, acc, size = exe.run(
inference_program,
feed=feeder.feed(data),
fetch_list=[avg_cost, b_acc_var, b_size_var])
test_pass_acc_evaluator.add(value=acc, weight=size)
print("End pass {0}, train_acc {1}, test_acc {2}".format(
pass_id, train_pass_acc_evaluator.eval(), test_pass_acc_evaluator.eval()))
pass_id,
train_pass_acc_evaluator.eval(), test_pass_acc_evaluator.eval()))
if pass_id % 10 == 0:
model_path = os.path.join(model_save_dir, str(pass_id))
print 'save models to %s' % (model_path)
......
......@@ -150,7 +150,8 @@ def main(dict_path):
train_pass_acc_evaluator.add(value=acc_val, weight=size_val)
if batch_id and batch_id % conf.log_period == 0:
print("Pass id: %d, batch id: %d, cost: %f, pass_acc: %f" %
(pass_id, batch_id, cost_val, train_pass_acc_evaluator.eval()))
(pass_id, batch_id, cost_val,
train_pass_acc_evaluator.eval()))
end_time = time.time()
total_time += (end_time - start_time)
pass_test_acc = test(exe)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册