From 229acda8565b3ef0a1e980389626998cfd822513 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Mon, 7 Feb 2022 12:19:25 +0000 Subject: [PATCH] fix ips info and reduce interval of metric calc --- ppocr/losses/rec_ctc_loss.py | 3 ++- tools/program.py | 21 +++++++++++---------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/ppocr/losses/rec_ctc_loss.py b/ppocr/losses/rec_ctc_loss.py index 063d68e3..502fc8c5 100755 --- a/ppocr/losses/rec_ctc_loss.py +++ b/ppocr/losses/rec_ctc_loss.py @@ -31,7 +31,8 @@ class CTCLoss(nn.Layer): predicts = predicts[-1] predicts = predicts.transpose((1, 0, 2)) N, B, _ = predicts.shape - preds_lengths = paddle.to_tensor([N] * B, dtype='int64') + preds_lengths = paddle.to_tensor( + [N] * B, dtype='int64', place=paddle.CPUPlace()) labels = batch[1].astype("int32") label_lengths = batch[2].astype('int64') loss = self.loss_func(predicts, labels, preds_lengths, label_lengths) diff --git a/tools/program.py b/tools/program.py index 5ffb93d1..a0336916 100755 --- a/tools/program.py +++ b/tools/program.py @@ -146,6 +146,7 @@ def train(config, scaler=None): cal_metric_during_train = config['Global'].get('cal_metric_during_train', False) + calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1) log_smooth_window = config['Global']['log_smooth_window'] epoch_num = config['Global']['epoch_num'] print_batch_step = config['Global']['print_batch_step'] @@ -244,6 +245,16 @@ def train(config, optimizer.step() optimizer.clear_grad() + if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need + batch = [item.numpy() for item in batch] + if model_type in ['table', 'kie']: + eval_class(preds, batch) + else: + post_result = post_process_class(preds, batch[1]) + eval_class(post_result, batch) + metric = eval_class.get_metric() + train_stats.update(metric) + train_batch_time = time.time() - reader_start train_batch_cost += train_batch_time eta_meter.update(train_batch_time) @@ -258,16 +269,6 @@ def train(config, stats['lr'] = lr train_stats.update(stats) - if cal_metric_during_train: # only rec and cls need - batch = [item.numpy() for item in batch] - if model_type in ['table', 'kie']: - eval_class(preds, batch) - else: - post_result = post_process_class(preds, batch[1]) - eval_class(post_result, batch) - metric = eval_class.get_metric() - train_stats.update(metric) - if vdl_writer is not None and dist.get_rank() == 0: for k, v in train_stats.get().items(): vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step) -- GitLab