提交 93670ab5 编写于 作者: T tink2123

all ready

上级 297871d4
...@@ -3,7 +3,7 @@ Global: ...@@ -3,7 +3,7 @@ Global:
epoch_num: 72 epoch_num: 72
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 5 print_batch_step: 5
save_model_dir: ./output/rec/srn save_model_dir: ./output/rec/srn_new
save_epoch_step: 3 save_epoch_step: 3
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 5000] eval_batch_step: [0, 5000]
...@@ -25,8 +25,10 @@ Global: ...@@ -25,8 +25,10 @@ Global:
Optimizer: Optimizer:
name: Adam name: Adam
beta1: 0.9
beta2: 0.999
clip_norm: 10.0
lr: lr:
name: Cosine
learning_rate: 0.0001 learning_rate: 0.0001
Architecture: Architecture:
...@@ -58,7 +60,6 @@ Train: ...@@ -58,7 +60,6 @@ Train:
dataset: dataset:
name: LMDBDataSet name: LMDBDataSet
data_dir: ./train_data/srn_train_data_duiqi data_dir: ./train_data/srn_train_data_duiqi
#label_file_list: ["./train_data/ic15_data/1.txt"]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
...@@ -77,7 +78,7 @@ Train: ...@@ -77,7 +78,7 @@ Train:
loader: loader:
shuffle: False shuffle: False
batch_size_per_card: 64 batch_size_per_card: 64
drop_last: True drop_last: False
num_workers: 4 num_workers: 4
Eval: Eval:
......
...@@ -359,6 +359,7 @@ class PrepareDecoder(nn.Layer): ...@@ -359,6 +359,7 @@ class PrepareDecoder(nn.Layer):
self.emb0 = paddle.nn.Embedding( self.emb0 = paddle.nn.Embedding(
num_embeddings=src_vocab_size, num_embeddings=src_vocab_size,
embedding_dim=self.src_emb_dim, embedding_dim=self.src_emb_dim,
padding_idx=bos_idx,
weight_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(
name=word_emb_param_name, name=word_emb_param_name,
initializer=nn.initializer.Normal(0., src_emb_dim**-0.5))) initializer=nn.initializer.Normal(0., src_emb_dim**-0.5)))
......
...@@ -182,14 +182,15 @@ class SRNLabelDecode(BaseRecLabelDecode): ...@@ -182,14 +182,15 @@ class SRNLabelDecode(BaseRecLabelDecode):
preds_prob = np.reshape(preds_prob, [-1, 25]) preds_prob = np.reshape(preds_prob, [-1, 25])
text = self.decode(preds_idx, preds_prob) text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
if label is None: if label is None:
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
return text return text
label = self.decode(label, is_remove_duplicate=False) label = self.decode(label, is_remove_duplicate=True)
return text, label return text, label
def decode(self, text_index, text_prob=None, is_remove_duplicate=True): def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """ """ convert text-index into text-label. """
result_list = [] result_list = []
ignored_tokens = self.get_ignored_tokens() ignored_tokens = self.get_ignored_tokens()
......
...@@ -242,6 +242,12 @@ def train(config, ...@@ -242,6 +242,12 @@ def train(config,
# eval # eval
if global_step > start_eval_step and \ if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
model_average = paddle.optimizer.ModelAverage(
0.15,
parameters=model.parameters(),
min_average_window=10000,
max_average_window=15625)
model_average.apply()
cur_metirc = eval(model, valid_dataloader, post_process_class, cur_metirc = eval(model, valid_dataloader, post_process_class,
eval_class) eval_class)
cur_metirc_str = 'cur metirc, {}'.format(', '.join( cur_metirc_str = 'cur metirc, {}'.format(', '.join(
...@@ -277,6 +283,7 @@ def train(config, ...@@ -277,6 +283,7 @@ def train(config,
best_model_dict[main_indicator], best_model_dict[main_indicator],
global_step) global_step)
global_step += 1 global_step += 1
optimizer.clear_grad()
batch_start = time.time() batch_start = time.time()
if dist.get_rank() == 0: if dist.get_rank() == 0:
save_model( save_model(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册