From b9b8c88830b69119e793f72c236d22541e8d564f Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Wed, 23 Sep 2020 18:48:09 +0800 Subject: [PATCH] use pre-commit formate code ptb_dy.py (#4871) * fix ptb_dy time print for benchmark, test=develop * use pre-commit formate code --- dygraph/ptb_lm/ptb_dy.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dygraph/ptb_lm/ptb_dy.py b/dygraph/ptb_lm/ptb_dy.py index 086901d7..48ff9bb7 100644 --- a/dygraph/ptb_lm/ptb_dy.py +++ b/dygraph/ptb_lm/ptb_dy.py @@ -37,20 +37,25 @@ if sys.version[0] == '2': reload(sys) sys.setdefaultencoding("utf-8") + class TimeCostAverage(object): def __init__(self): self.reset() + def reset(self): self.cnt = 0 self.total_time = 0 + def record(self, usetime): self.cnt += 1 self.total_time += usetime + def get_average(self): if self.cnt == 0: return 0 return self.total_time / self.cnt + class SimpleLSTMRNN(fluid.Layer): def __init__(self, hidden_size, @@ -454,7 +459,8 @@ def train_ptb_lm(): "-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, lr: %.5f, loss: %.5f, batch_cost: %.5f s, reader_cost: %.5f s" % (epoch_id, batch_id, ppl[0], sgd._global_learning_rate().numpy(), out_loss, - batch_cost_avg.get_average(), reader_cost_avg.get_average())) + batch_cost_avg.get_average(), + reader_cost_avg.get_average())) batch_cost_avg.reset() reader_cost_avg.reset() batch_start = time.time() -- GitLab