diff --git a/ppocr/losses/rec_srn_loss.py b/ppocr/losses/rec_srn_loss.py index d722ee0f22b7692e4c4c9bbc8d2f1f583d754dac..7d5b65ebaf1ee135d1fefe8d93ddc3f77985b132 100644 --- a/ppocr/losses/rec_srn_loss.py +++ b/ppocr/losses/rec_srn_loss.py @@ -42,6 +42,6 @@ class SRNLoss(nn.Layer): cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1]) cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1]) - sum_cost = cost_word + cost_vsfd * 2.0 + cost_gsrm * 0.15 + sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15 return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd} diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 867f920a3ca5aedf4c32ab8ed80e7cf35b4781c2..8c972a143ba506d9cd13f960802efdde5df3e54f 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -182,12 +182,12 @@ class SRNLabelDecode(BaseRecLabelDecode): preds_prob = np.reshape(preds_prob, [-1, 25]) - text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) + text = self.decode(preds_idx, preds_prob) if label is None: text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) return text - label = self.decode(label, is_remove_duplicate=True) + label = self.decode(label) return text, label def decode(self, text_index, text_prob=None, is_remove_duplicate=False): diff --git a/tools/program.py b/tools/program.py index 885d45f5e91f364910057dbcd97753c1f2125b6c..f329dcd5782b21cc76d0537eb8a65d41adead19a 100755 --- a/tools/program.py +++ b/tools/program.py @@ -174,6 +174,7 @@ def train(config, best_model_dict = {main_indicator: 0} best_model_dict.update(pre_best_model_dict) train_stats = TrainingStats(log_smooth_window, ['lr']) + model_average = False model.train() if 'start_epoch' in best_model_dict: @@ -197,6 +198,7 @@ def train(config, if config['Architecture']['algorithm'] == "SRN": others = batch[-4:] preds = model(images, others) + model_average = True else: preds = model(images) loss = loss_class(preds, batch) @@ -242,12 +244,13 @@ def train(config, # eval if global_step > start_eval_step and \ (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: - model_average = paddle.optimizer.ModelAverage( - 0.15, - parameters=model.parameters(), - min_average_window=10000, - max_average_window=15625) - model_average.apply() + if model_average: + Model_Average = paddle.incubate.optimizer.ModelAverage( + 0.15, + parameters=model.parameters(), + min_average_window=10000, + max_average_window=15625) + Model_Average.apply() cur_metirc = eval(model, valid_dataloader, post_process_class, eval_class) cur_metirc_str = 'cur metirc, {}'.format(', '.join(