From ed2f0de95e58298ee733ee83976ef43079a613a0 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Fri, 22 Jan 2021 03:15:56 +0000 Subject: [PATCH] mv model_average to incubate --- ppocr/losses/rec_srn_loss.py | 2 +- ppocr/postprocess/rec_postprocess.py | 4 ++-- tools/program.py | 15 +++++++++------ 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/ppocr/losses/rec_srn_loss.py b/ppocr/losses/rec_srn_loss.py index d722ee0f..7d5b65eb 100644 --- a/ppocr/losses/rec_srn_loss.py +++ b/ppocr/losses/rec_srn_loss.py @@ -42,6 +42,6 @@ class SRNLoss(nn.Layer): cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1]) cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1]) - sum_cost = cost_word + cost_vsfd * 2.0 + cost_gsrm * 0.15 + sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15 return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd} diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 867f920a..8c972a14 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -182,12 +182,12 @@ class SRNLabelDecode(BaseRecLabelDecode): preds_prob = np.reshape(preds_prob, [-1, 25]) - text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) + text = self.decode(preds_idx, preds_prob) if label is None: text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) return text - label = self.decode(label, is_remove_duplicate=True) + label = self.decode(label) return text, label def decode(self, text_index, text_prob=None, is_remove_duplicate=False): diff --git a/tools/program.py b/tools/program.py index 885d45f5..f329dcd5 100755 --- a/tools/program.py +++ b/tools/program.py @@ -174,6 +174,7 @@ def train(config, best_model_dict = {main_indicator: 0} best_model_dict.update(pre_best_model_dict) train_stats = TrainingStats(log_smooth_window, ['lr']) + model_average = False model.train() if 'start_epoch' in best_model_dict: @@ -197,6 +198,7 @@ def train(config, if config['Architecture']['algorithm'] == "SRN": others = batch[-4:] preds = model(images, others) + model_average = True else: preds = model(images) loss = loss_class(preds, batch) @@ -242,12 +244,13 @@ def train(config, # eval if global_step > start_eval_step and \ (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: - model_average = paddle.optimizer.ModelAverage( - 0.15, - parameters=model.parameters(), - min_average_window=10000, - max_average_window=15625) - model_average.apply() + if model_average: + Model_Average = paddle.incubate.optimizer.ModelAverage( + 0.15, + parameters=model.parameters(), + min_average_window=10000, + max_average_window=15625) + Model_Average.apply() cur_metirc = eval(model, valid_dataloader, post_process_class, eval_class) cur_metirc_str = 'cur metirc, {}'.format(', '.join( -- GitLab