From e07327e85110bfbf012ed0805ab9ce82ed5e3bc0 Mon Sep 17 00:00:00 2001 From: Tao Luo Date: Tue, 22 Sep 2020 20:37:39 +0800 Subject: [PATCH] fix mobilenet model time print (#4867) test=develop --- dygraph/mobilenet/train.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/dygraph/mobilenet/train.py b/dygraph/mobilenet/train.py index debd22ab..f620842b 100644 --- a/dygraph/mobilenet/train.py +++ b/dygraph/mobilenet/train.py @@ -38,6 +38,19 @@ args = parse_args() if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0: print_arguments(args) +class TimeCostAverage(object): + def __init__(self): + self.reset() + def reset(self): + self.cnt = 0 + self.total_time = 0 + def record(self, usetime): + self.cnt += 1 + self.total_time += usetime + def get_average(self): + if self.cnt == 0: + return 0 + return self.total_time / self.cnt def eval(net, test_data_loader, eop): total_loss = 0.0 @@ -170,6 +183,10 @@ def train_mobilenet(): t_last = 0 # 4.1 for each batch, call net() , backward(), and minimize() + batch_cost_avg = TimeCostAverage() + batch_reader_avg = TimeCostAverage() + batch_net_avg = TimeCostAverage() + batch_backward_avg = TimeCostAverage() batch_start = time.time() for img, label in train_data_loader(): if args.max_iter and total_batch_num == args.max_iter: @@ -208,16 +225,25 @@ def train_mobilenet(): # NOTE: used for benchmark train_batch_cost = time.time() - batch_start + batch_cost_avg.record(train_batch_cost) + batch_reader_avg.record(batch_reader_end - batch_start) + batch_net_avg.record(batch_net_end - batch_reader_end) + batch_backward_avg.record(batch_backward_end - batch_net_end) + total_batch_num = total_batch_num + 1 if batch_id % args.print_step == 0: print( "[Epoch %d, batch %d], avg_loss %.5f, acc_top1 %.5f, acc_top5 %.5f, batch_cost: %.5f s, net_t: %.5f s, backward_t: %.5f s, reader_t: %.5f s" % (eop, batch_id, avg_loss.numpy(), acc_top1.numpy(), - acc_top5.numpy(), train_batch_cost, - batch_net_end - batch_reader_end, - batch_backward_end - batch_net_end, - batch_reader_end - batch_start)) + acc_top5.numpy(), batch_cost_avg.get_average(), + batch_net_avg.get_average(), + batch_backward_avg.get_average(), + batch_reader_avg.get_average())) sys.stdout.flush() + batch_cost_avg.reset() + batch_net_avg.reset() + batch_backward_avg.reset() + batch_reader_avg.reset() batch_start = time.time() if args.ce: -- GitLab