From 55fa686b5d1a62141435bf4504cdf1be3243ea9a Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Thu, 1 Mar 2018 08:33:40 +0000 Subject: [PATCH] Do not fetch if no print --- fluid/DeepASR/tools/profile.py | 3 ++- fluid/DeepASR/train.py | 17 +++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/fluid/DeepASR/tools/profile.py b/fluid/DeepASR/tools/profile.py index cb0227c3..9d0b4769 100644 --- a/fluid/DeepASR/tools/profile.py +++ b/fluid/DeepASR/tools/profile.py @@ -169,7 +169,8 @@ def profile(args): outs = exe.run(fluid.default_main_program(), feed={"feature": feature_t, "label": label_t}, - fetch_list=[avg_cost, accuracy], + fetch_list=[avg_cost, accuracy] + if args.print_train_acc else [], return_numpy=False) if args.print_train_acc: diff --git a/fluid/DeepASR/train.py b/fluid/DeepASR/train.py index 9856dad7..b5d2239e 100644 --- a/fluid/DeepASR/train.py +++ b/fluid/DeepASR/train.py @@ -216,16 +216,17 @@ def train(args): label_t.set(labels, place) label_t.set_lod([lod]) - cost, acc = exe.run(fluid.default_main_program(), - feed={"feature": feature_t, - "label": label_t}, - fetch_list=[avg_cost, accuracy], - return_numpy=False) + to_print = batch_id > 0 and (batch_id % args.print_per_batches == 0) + outs = exe.run(fluid.default_main_program(), + feed={"feature": feature_t, + "label": label_t}, + fetch_list=[avg_cost, accuracy] if to_print else [], + return_numpy=False) - if batch_id > 0 and (batch_id % args.print_per_batches == 0): + if to_print: print("\nBatch %d, train cost: %f, train acc: %f" % - (batch_id, lodtensor_to_ndarray(cost)[0], - lodtensor_to_ndarray(acc)[0])) + (batch_id, lodtensor_to_ndarray(outs[0])[0], + lodtensor_to_ndarray(outs[1])[0])) # save the latest checkpoint if args.checkpoints != '': model_path = os.path.join(args.checkpoints, -- GitLab