diff --git a/PaddleNLP/language_model/train.py b/PaddleNLP/language_model/train.py index 33f9651f6500601e461f14e665b81862f97b93c7..5ff5e59da4849ab6ca984a5cec0453cac9720824 100644 --- a/PaddleNLP/language_model/train.py +++ b/PaddleNLP/language_model/train.py @@ -49,6 +49,19 @@ import pickle SEED = 123 +class TimeCostAverage(object): + def __init__(self): + self.reset() + def reset(self): + self.cnt = 0 + self.total_time = 0 + def record(self, usetime): + self.cnt += 1 + self.total_time += usetime + def get_average(self): + if self.cnt == 0: + return 0 + return self.total_time / self.cnt @contextlib.contextmanager def profile_context(profile=True, profiler_path='/tmp/paddingrnn.profile'): @@ -293,8 +306,10 @@ def main(): total_loss = 0 iters = 0 + batch_cost_avg = TimeCostAverage() init_hidden, init_cell = generate_init_data() + batch_start_time = time.time() for batch_id, batch in enumerate(train_data_iter): input_data_feed = prepare_input( batch, @@ -303,7 +318,6 @@ def main(): epoch_id=epoch_id, with_lr=True, device_count=device_count) - batch_start_time = time.time() fetch_outs = exe.run(train_program, feed=input_data_feed, fetch_list=[ @@ -313,6 +327,7 @@ def main(): use_program_cache=True) batch_time = time.time() - batch_start_time batch_times.append(batch_time) + batch_cost_avg.record(batch_time) cost_train = np.array(fetch_outs[0]) lr = np.array(fetch_outs[1]) @@ -324,13 +339,17 @@ def main(): ppl = np.exp(total_loss / iters) print( "-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f" - % (epoch_id, batch_id, batch_time, ppl[0], lr[0])) + % (epoch_id, batch_id, batch_cost_avg.get_average(), ppl[0], lr[0])) + batch_cost_avg.reset() # profiler tools for benchmark if args.profile and batch_id == log_interval: profiler.reset_profiler() elif args.profile and batch_id == (log_interval + 5): break + + batch_start_time = time.time() + ppl = np.exp(total_loss / iters) return ppl @@ -342,6 +361,7 @@ def main(): total_loss = 0 iters = 0 + batch_cost_avg = TimeCostAverage() dataloader.start() batch_id = 0 @@ -355,6 +375,7 @@ def main(): batch_time = time.time() - batch_start_time batch_times.append(batch_time) batch_start_time = time.time() + batch_cost_avg.record(batch_time) new_lr = generate_new_lr(epoch_id, device_count) data_feeds['learning_rate'] = new_lr @@ -381,7 +402,8 @@ def main(): ppl = np.exp(total_loss / iters) print( "-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f" - % (epoch_id, batch_id, batch_time, ppl[0], lr[0])) + % (epoch_id, batch_id, batch_cost_avg.get_average(), ppl[0], lr[0])) + batch_cost_avg.reset() batch_id += 1 # profiler tools for benchmark