未验证 提交 243ee52b 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #677 from kuke/simp_fetch

Do not fetch if no print
...@@ -169,7 +169,8 @@ def profile(args): ...@@ -169,7 +169,8 @@ def profile(args):
outs = exe.run(fluid.default_main_program(), outs = exe.run(fluid.default_main_program(),
feed={"feature": feature_t, feed={"feature": feature_t,
"label": label_t}, "label": label_t},
fetch_list=[avg_cost, accuracy], fetch_list=[avg_cost, accuracy]
if args.print_train_acc else [],
return_numpy=False) return_numpy=False)
if args.print_train_acc: if args.print_train_acc:
......
...@@ -216,16 +216,17 @@ def train(args): ...@@ -216,16 +216,17 @@ def train(args):
label_t.set(labels, place) label_t.set(labels, place)
label_t.set_lod([lod]) label_t.set_lod([lod])
cost, acc = exe.run(fluid.default_main_program(), to_print = batch_id > 0 and (batch_id % args.print_per_batches == 0)
feed={"feature": feature_t, outs = exe.run(fluid.default_main_program(),
"label": label_t}, feed={"feature": feature_t,
fetch_list=[avg_cost, accuracy], "label": label_t},
return_numpy=False) fetch_list=[avg_cost, accuracy] if to_print else [],
return_numpy=False)
if batch_id > 0 and (batch_id % args.print_per_batches == 0): if to_print:
print("\nBatch %d, train cost: %f, train acc: %f" % print("\nBatch %d, train cost: %f, train acc: %f" %
(batch_id, lodtensor_to_ndarray(cost)[0], (batch_id, lodtensor_to_ndarray(outs[0])[0],
lodtensor_to_ndarray(acc)[0])) lodtensor_to_ndarray(outs[1])[0]))
# save the latest checkpoint # save the latest checkpoint
if args.checkpoints != '': if args.checkpoints != '':
model_path = os.path.join(args.checkpoints, model_path = os.path.join(args.checkpoints,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册