From cf186f39f1f3dc24abfceeb415d3c0e9d9ce1212 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Tue, 22 Sep 2020 20:32:43 +0800 Subject: [PATCH] fix ptb_dy time print for benchmark, test=develop (#4866) --- dygraph/ptb_lm/ptb_dy.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/dygraph/ptb_lm/ptb_dy.py b/dygraph/ptb_lm/ptb_dy.py index f38a8c93..086901d7 100644 --- a/dygraph/ptb_lm/ptb_dy.py +++ b/dygraph/ptb_lm/ptb_dy.py @@ -37,6 +37,19 @@ if sys.version[0] == '2': reload(sys) sys.setdefaultencoding("utf-8") +class TimeCostAverage(object): + def __init__(self): + self.reset() + def reset(self): + self.cnt = 0 + self.total_time = 0 + def record(self, usetime): + self.cnt += 1 + self.total_time += usetime + def get_average(self): + if self.cnt == 0: + return 0 + return self.total_time / self.cnt class SimpleLSTMRNN(fluid.Layer): def __init__(self, @@ -405,10 +418,17 @@ def train_ptb_lm(): init_hidden = to_variable(init_hidden_data) init_cell = to_variable(init_cell_data) + batch_cost_avg = TimeCostAverage() + reader_cost_avg = TimeCostAverage() + batch_start = time.time() for batch_id, batch in enumerate(train_data_loader): if args.max_iter and total_batch_num == args.max_iter: return + + train_reader_cost = time.time() - batch_start + reader_cost_avg.record(train_reader_cost) + x, y = batch dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden, @@ -426,13 +446,17 @@ def train_ptb_lm(): total_batch_num = total_batch_num + 1 #this is for benchmark train_batch_cost = time.time() - batch_start + batch_cost_avg.record(train_batch_cost) + if batch_id > 0 and batch_id % log_interval == 0: ppl = np.exp(total_loss / iters) print( - "-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, lr: %.5f, loss: %.5f, batch_cost: %.5f s" + "-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, lr: %.5f, loss: %.5f, batch_cost: %.5f s, reader_cost: %.5f s" % (epoch_id, batch_id, ppl[0], sgd._global_learning_rate().numpy(), out_loss, - train_batch_cost)) + batch_cost_avg.get_average(), reader_cost_avg.get_average())) + batch_cost_avg.reset() + reader_cost_avg.reset() batch_start = time.time() ppl = np.exp(total_loss / iters) -- GitLab