diff --git a/dygraph/ptb_lm/ptb_dy.py b/dygraph/ptb_lm/ptb_dy.py index 086901d7d5efb8a773551cdeaacb29829d966e80..48ff9bb7bdc3c210c8586bc744e10307c52a3b27 100644 --- a/dygraph/ptb_lm/ptb_dy.py +++ b/dygraph/ptb_lm/ptb_dy.py @@ -37,20 +37,25 @@ if sys.version[0] == '2': reload(sys) sys.setdefaultencoding("utf-8") + class TimeCostAverage(object): def __init__(self): self.reset() + def reset(self): self.cnt = 0 self.total_time = 0 + def record(self, usetime): self.cnt += 1 self.total_time += usetime + def get_average(self): if self.cnt == 0: return 0 return self.total_time / self.cnt + class SimpleLSTMRNN(fluid.Layer): def __init__(self, hidden_size, @@ -454,7 +459,8 @@ def train_ptb_lm(): "-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, lr: %.5f, loss: %.5f, batch_cost: %.5f s, reader_cost: %.5f s" % (epoch_id, batch_id, ppl[0], sgd._global_learning_rate().numpy(), out_loss, - batch_cost_avg.get_average(), reader_cost_avg.get_average())) + batch_cost_avg.get_average(), + reader_cost_avg.get_average())) batch_cost_avg.reset() reader_cost_avg.reset() batch_start = time.time()