未验证 提交 243ee52b 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #677 from kuke/simp_fetch

Do not fetch if no print
......@@ -169,7 +169,8 @@ def profile(args):
outs = exe.run(fluid.default_main_program(),
feed={"feature": feature_t,
"label": label_t},
fetch_list=[avg_cost, accuracy],
fetch_list=[avg_cost, accuracy]
if args.print_train_acc else [],
return_numpy=False)
if args.print_train_acc:
......
......@@ -216,16 +216,17 @@ def train(args):
label_t.set(labels, place)
label_t.set_lod([lod])
cost, acc = exe.run(fluid.default_main_program(),
feed={"feature": feature_t,
"label": label_t},
fetch_list=[avg_cost, accuracy],
return_numpy=False)
to_print = batch_id > 0 and (batch_id % args.print_per_batches == 0)
outs = exe.run(fluid.default_main_program(),
feed={"feature": feature_t,
"label": label_t},
fetch_list=[avg_cost, accuracy] if to_print else [],
return_numpy=False)
if batch_id > 0 and (batch_id % args.print_per_batches == 0):
if to_print:
print("\nBatch %d, train cost: %f, train acc: %f" %
(batch_id, lodtensor_to_ndarray(cost)[0],
lodtensor_to_ndarray(acc)[0]))
(batch_id, lodtensor_to_ndarray(outs[0])[0],
lodtensor_to_ndarray(outs[1])[0]))
# save the latest checkpoint
if args.checkpoints != '':
model_path = os.path.join(args.checkpoints,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册