From e8229015225e4660339b916b53d6247bfcac351a Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Mon, 16 Nov 2020 19:00:27 +0800 Subject: [PATCH] =?UTF-8?q?=E6=97=A5=E5=BF=97=E7=AC=A6=E5=90=88benchmark?= =?UTF-8?q?=E8=A7=84=E8=8C=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/program.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/tools/program.py b/tools/program.py index c2b9306c..a279cf65 100755 --- a/tools/program.py +++ b/tools/program.py @@ -185,12 +185,15 @@ def train(config, for epoch in range(start_epoch, epoch_num): if epoch > 0: train_dataloader = build_dataloader(config, 'Train', device, logger) - + train_batch_cost = 0.0 + train_reader_cost = 0.0 + batch_sum = 0 + batch_start = time.time() for idx, batch in enumerate(train_dataloader): + train_reader_cost += time.time() - batch_start if idx >= len(train_dataloader): break lr = optimizer.get_lr() - t1 = time.time() images = batch[0] preds = model(images) loss = loss_class(preds, batch) @@ -198,6 +201,10 @@ def train(config, avg_loss.backward() optimizer.step() optimizer.clear_grad() + + train_batch_cost += time.time() - batch_start + batch_sum += len(images) + if not isinstance(lr_scheduler, float): lr_scheduler.step() @@ -213,9 +220,6 @@ def train(config, metirc = eval_class.get_metric() train_stats.update(metirc) - t2 = time.time() - train_batch_elapse = t2 - t1 - if vdl_writer is not None and dist.get_rank() == 0: for k, v in train_stats.get().items(): vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step) @@ -224,9 +228,15 @@ def train(config, if dist.get_rank( ) == 0 and global_step > 0 and global_step % print_batch_step == 0: logs = train_stats.log() - strs = 'epoch: [{}/{}], iter: {}, {}, time: {:.3f}'.format( - epoch, epoch_num, global_step, logs, train_batch_elapse) + strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f}s, batch_cost: {:.5f}s, samples: {}, ips: {:.5f}'.format( + epoch, epoch_num, global_step, logs, train_reader_cost / + print_batch_step, train_batch_cost / print_batch_step, + batch_sum, batch_sum / train_batch_cost) logger.info(strs) + train_batch_cost = 0.0 + train_reader_cost = 0.0 + batch_sum = 0 + batch_start = time.time() # eval if global_step > start_eval_step and \ (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: -- GitLab