提交 ed2f0de9 编写于 作者: T tink2123

mv model_average to incubate

上级 93670ab5
...@@ -42,6 +42,6 @@ class SRNLoss(nn.Layer): ...@@ -42,6 +42,6 @@ class SRNLoss(nn.Layer):
cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1]) cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1])
cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1]) cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1])
sum_cost = cost_word + cost_vsfd * 2.0 + cost_gsrm * 0.15 sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15
return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd} return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}
...@@ -182,12 +182,12 @@ class SRNLabelDecode(BaseRecLabelDecode): ...@@ -182,12 +182,12 @@ class SRNLabelDecode(BaseRecLabelDecode):
preds_prob = np.reshape(preds_prob, [-1, 25]) preds_prob = np.reshape(preds_prob, [-1, 25])
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) text = self.decode(preds_idx, preds_prob)
if label is None: if label is None:
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
return text return text
label = self.decode(label, is_remove_duplicate=True) label = self.decode(label)
return text, label return text, label
def decode(self, text_index, text_prob=None, is_remove_duplicate=False): def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
......
...@@ -174,6 +174,7 @@ def train(config, ...@@ -174,6 +174,7 @@ def train(config,
best_model_dict = {main_indicator: 0} best_model_dict = {main_indicator: 0}
best_model_dict.update(pre_best_model_dict) best_model_dict.update(pre_best_model_dict)
train_stats = TrainingStats(log_smooth_window, ['lr']) train_stats = TrainingStats(log_smooth_window, ['lr'])
model_average = False
model.train() model.train()
if 'start_epoch' in best_model_dict: if 'start_epoch' in best_model_dict:
...@@ -197,6 +198,7 @@ def train(config, ...@@ -197,6 +198,7 @@ def train(config,
if config['Architecture']['algorithm'] == "SRN": if config['Architecture']['algorithm'] == "SRN":
others = batch[-4:] others = batch[-4:]
preds = model(images, others) preds = model(images, others)
model_average = True
else: else:
preds = model(images) preds = model(images)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
...@@ -242,12 +244,13 @@ def train(config, ...@@ -242,12 +244,13 @@ def train(config,
# eval # eval
if global_step > start_eval_step and \ if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
model_average = paddle.optimizer.ModelAverage( if model_average:
0.15, Model_Average = paddle.incubate.optimizer.ModelAverage(
parameters=model.parameters(), 0.15,
min_average_window=10000, parameters=model.parameters(),
max_average_window=15625) min_average_window=10000,
model_average.apply() max_average_window=15625)
Model_Average.apply()
cur_metirc = eval(model, valid_dataloader, post_process_class, cur_metirc = eval(model, valid_dataloader, post_process_class,
eval_class) eval_class)
cur_metirc_str = 'cur metirc, {}'.format(', '.join( cur_metirc_str = 'cur metirc, {}'.format(', '.join(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册