提交 93670ab5 编写于 作者: T tink2123

all ready

上级 297871d4
......@@ -3,7 +3,7 @@ Global:
epoch_num: 72
log_smooth_window: 20
print_batch_step: 5
save_model_dir: ./output/rec/srn
save_model_dir: ./output/rec/srn_new
save_epoch_step: 3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 5000]
......@@ -25,8 +25,10 @@ Global:
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
clip_norm: 10.0
lr:
name: Cosine
learning_rate: 0.0001
Architecture:
......@@ -58,7 +60,6 @@ Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/srn_train_data_duiqi
#label_file_list: ["./train_data/ic15_data/1.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
......@@ -77,7 +78,7 @@ Train:
loader:
shuffle: False
batch_size_per_card: 64
drop_last: True
drop_last: False
num_workers: 4
Eval:
......
......@@ -359,6 +359,7 @@ class PrepareDecoder(nn.Layer):
self.emb0 = paddle.nn.Embedding(
num_embeddings=src_vocab_size,
embedding_dim=self.src_emb_dim,
padding_idx=bos_idx,
weight_attr=paddle.ParamAttr(
name=word_emb_param_name,
initializer=nn.initializer.Normal(0., src_emb_dim**-0.5)))
......
......@@ -182,14 +182,15 @@ class SRNLabelDecode(BaseRecLabelDecode):
preds_prob = np.reshape(preds_prob, [-1, 25])
text = self.decode(preds_idx, preds_prob)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
if label is None:
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
return text
label = self.decode(label, is_remove_duplicate=False)
label = self.decode(label, is_remove_duplicate=True)
return text, label
def decode(self, text_index, text_prob=None, is_remove_duplicate=True):
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
......
......@@ -242,6 +242,12 @@ def train(config,
# eval
if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
model_average = paddle.optimizer.ModelAverage(
0.15,
parameters=model.parameters(),
min_average_window=10000,
max_average_window=15625)
model_average.apply()
cur_metirc = eval(model, valid_dataloader, post_process_class,
eval_class)
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
......@@ -277,6 +283,7 @@ def train(config,
best_model_dict[main_indicator],
global_step)
global_step += 1
optimizer.clear_grad()
batch_start = time.time()
if dist.get_rank() == 0:
save_model(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册