From 6832ca029fe6d7bccd68fddcfe1aedc8e4d6618f Mon Sep 17 00:00:00 2001 From: tink2123 Date: Sat, 15 Aug 2020 12:39:07 +0800 Subject: [PATCH] update config --- .../rec_r50fpn_vd_none_srn_pvam_test_all.yml | 5 +- ppocr/data/rec/dataset_traversal.py | 49 +++--- ppocr/modeling/architectures/rec_model.py | 99 ++++++++++--- ppocr/modeling/heads/self_attention/model.py | 139 +++++++++--------- tools/eval_utils/eval_rec_utils.py | 21 ++- tools/program.py | 15 +- train_data | 1 + 7 files changed, 197 insertions(+), 132 deletions(-) create mode 120000 train_data diff --git a/configs/rec/rec_r50fpn_vd_none_srn_pvam_test_all.yml b/configs/rec/rec_r50fpn_vd_none_srn_pvam_test_all.yml index 933a7513..7a0f136c 100755 --- a/configs/rec/rec_r50fpn_vd_none_srn_pvam_test_all.yml +++ b/configs/rec/rec_r50fpn_vd_none_srn_pvam_test_all.yml @@ -17,11 +17,12 @@ Global: average_window: 0.15 max_average_window: 15625 min_average_window: 10000 - reader_yml: ./configs/rec/rec_srn_reader.yml + reader_yml: ./configs/rec/rec_benchmark_reader.yml pretrain_weights: checkpoints: save_inference_dir: - + infer_img: + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/ppocr/data/rec/dataset_traversal.py b/ppocr/data/rec/dataset_traversal.py index 7135fca5..b46e37da 100755 --- a/ppocr/data/rec/dataset_traversal.py +++ b/ppocr/data/rec/dataset_traversal.py @@ -118,15 +118,14 @@ class LMDBReader(object): image_file_list = get_image_file_list(self.infer_img) for single_img in image_file_list: img = cv2.imread(single_img) - if img.shape[-1]==1 or len(list(img.shape))==2: + if img.shape[-1] == 1 or len(list(img.shape)) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if self.loss_type == 'srn': norm_img = process_image_srn( img=img, image_shape=self.image_shape, num_heads=self.num_heads, - max_text_length=self.max_text_length - ) + max_text_length=self.max_text_length) else: norm_img = process_image( img=img, @@ -135,20 +134,20 @@ class LMDBReader(object): tps=self.use_tps, infer_mode=True) yield norm_img - elif self.mode == 'test': - image_file_list = get_image_file_list(self.infer_img) - for single_img in image_file_list: - img = cv2.imread(single_img) - if img.shape[-1]==1 or len(list(img.shape))==2: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - norm_img = process_image( - img=img, - image_shape=self.image_shape, - char_ops=self.char_ops, - tps=self.use_tps, - infer_mode=True - ) - yield norm_img + #elif self.mode == 'eval': + # image_file_list = get_image_file_list(self.infer_img) + # for single_img in image_file_list: + # img = cv2.imread(single_img) + # if img.shape[-1]==1 or len(list(img.shape))==2: + # img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + # norm_img = process_image( + # img=img, + # image_shape=self.image_shape, + # char_ops=self.char_ops, + # tps=self.use_tps, + # infer_mode=True + # ) + # yield norm_img else: lmdb_sets = self.load_hierarchical_lmdb_dataset() if process_id == 0: @@ -169,14 +168,15 @@ class LMDBReader(object): img, label = sample_info outs = [] if self.loss_type == "srn": - outs = process_image_srn(img, self.image_shape, self.num_heads, - self.max_text_length, label, - self.char_ops, self.loss_type) + outs = process_image_srn( + img, self.image_shape, self.num_heads, + self.max_text_length, label, self.char_ops, + self.loss_type) else: - outs = process_image(img, self.image_shape, label, - self.char_ops, self.loss_type, - self.max_text_length) + outs = process_image( + img, self.image_shape, label, self.char_ops, + self.loss_type, self.max_text_length) if outs is None: continue yield outs @@ -184,6 +184,7 @@ class LMDBReader(object): if finish_read_num == len(lmdb_sets): break self.close_lmdb_dataset(lmdb_sets) + def batch_iter_reader(): batch_outs = [] for outs in sample_iter_reader(): @@ -311,4 +312,4 @@ class SimpleReader(object): if self.infer_img is None: return batch_iter_reader - return sample_iter_reader \ No newline at end of file + return sample_iter_reader diff --git a/ppocr/modeling/architectures/rec_model.py b/ppocr/modeling/architectures/rec_model.py index a030f362..d2e01a43 100755 --- a/ppocr/modeling/architectures/rec_model.py +++ b/ppocr/modeling/architectures/rec_model.py @@ -79,17 +79,45 @@ class RecModel(object): feed_list = [image, label_in, label_out] labels = {'label_in': label_in, 'label_out': label_out} elif self.loss_type == "srn": - encoder_word_pos = fluid.data(name="encoder_word_pos", shape=[-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)), 1], dtype="int64") - gsrm_word_pos = fluid.data(name="gsrm_word_pos", shape=[-1, self.max_text_length, 1], dtype="int64") - gsrm_slf_attn_bias1 = fluid.data(name="gsrm_slf_attn_bias1", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length]) - gsrm_slf_attn_bias2 = fluid.data(name="gsrm_slf_attn_bias2", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length]) - lbl_weight = fluid.layers.data(name="lbl_weight", shape=[-1, 1], dtype='int64') + encoder_word_pos = fluid.data( + name="encoder_word_pos", + shape=[ + -1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)), + 1 + ], + dtype="int64") + gsrm_word_pos = fluid.data( + name="gsrm_word_pos", + shape=[-1, self.max_text_length, 1], + dtype="int64") + gsrm_slf_attn_bias1 = fluid.data( + name="gsrm_slf_attn_bias1", + shape=[ + -1, self.num_heads, self.max_text_length, + self.max_text_length + ]) + gsrm_slf_attn_bias2 = fluid.data( + name="gsrm_slf_attn_bias2", + shape=[ + -1, self.num_heads, self.max_text_length, + self.max_text_length + ]) + lbl_weight = fluid.layers.data( + name="lbl_weight", shape=[-1, 1], dtype='int64') label = fluid.data( name='label', shape=[-1, 1], dtype='int32', lod_level=1) - feed_list = [image, label, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2, lbl_weight] - labels = {'label': label, 'encoder_word_pos': encoder_word_pos, - 'gsrm_word_pos': gsrm_word_pos, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1, - 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2,'lbl_weight':lbl_weight} + feed_list = [ + image, label, encoder_word_pos, gsrm_word_pos, + gsrm_slf_attn_bias1, gsrm_slf_attn_bias2, lbl_weight + ] + labels = { + 'label': label, + 'encoder_word_pos': encoder_word_pos, + 'gsrm_word_pos': gsrm_word_pos, + 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1, + 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2, + 'lbl_weight': lbl_weight + } else: label = fluid.data( name='label', shape=[None, 1], dtype='int32', lod_level=1) @@ -112,15 +140,41 @@ class RecModel(object): "We set img_shape to be the same , it may affect the inference effect" ) image_shape = deepcopy(self.image_shape) - image = fluid.data(name='image', shape=image_shape, dtype='float32') + image = fluid.data(name='image', shape=image_shape, dtype='float32') if self.loss_type == "srn": - encoder_word_pos = fluid.data(name="encoder_word_pos", shape=[-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)), 1], dtype="int64") - gsrm_word_pos = fluid.data(name="gsrm_word_pos", shape=[-1, self.max_text_length, 1], dtype="int64") - gsrm_slf_attn_bias1 = fluid.data(name="gsrm_slf_attn_bias1", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length]) - gsrm_slf_attn_bias2 = fluid.data(name="gsrm_slf_attn_bias2", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length]) - feed_list = [image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] - labels = {'encoder_word_pos': encoder_word_pos, 'gsrm_word_pos': gsrm_word_pos, - 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1, 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2} + encoder_word_pos = fluid.data( + name="encoder_word_pos", + shape=[ + -1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)), + 1 + ], + dtype="int64") + gsrm_word_pos = fluid.data( + name="gsrm_word_pos", + shape=[-1, self.max_text_length, 1], + dtype="int64") + gsrm_slf_attn_bias1 = fluid.data( + name="gsrm_slf_attn_bias1", + shape=[ + -1, self.num_heads, self.max_text_length, + self.max_text_length + ]) + gsrm_slf_attn_bias2 = fluid.data( + name="gsrm_slf_attn_bias2", + shape=[ + -1, self.num_heads, self.max_text_length, + self.max_text_length + ]) + feed_list = [ + image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2 + ] + labels = { + 'encoder_word_pos': encoder_word_pos, + 'gsrm_word_pos': gsrm_word_pos, + 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1, + 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2 + } return image, labels, loader def __call__(self, mode): @@ -140,8 +194,13 @@ class RecModel(object): label = labels['label'] if self.loss_type == 'srn': total_loss, img_loss, word_loss = self.loss(predicts, labels) - outputs = {'total_loss':total_loss, 'img_loss':img_loss, 'word_loss':word_loss, - 'decoded_out':decoded_out, 'label':label} + outputs = { + 'total_loss': total_loss, + 'img_loss': img_loss, + 'word_loss': word_loss, + 'decoded_out': decoded_out, + 'label': label + } else: outputs = {'total_loss':loss, 'decoded_out':\ decoded_out, 'label':label} @@ -156,4 +215,4 @@ class RecModel(object): predict = predicts['predict'] if self.loss_type == "ctc": predict = fluid.layers.softmax(predict) - return loader, {'decoded_out': decoded_out, 'predicts': predict} \ No newline at end of file + return loader, {'decoded_out': decoded_out, 'predicts': predict} diff --git a/ppocr/modeling/heads/self_attention/model.py b/ppocr/modeling/heads/self_attention/model.py index d4aecd5f..8ac1458b 100644 --- a/ppocr/modeling/heads/self_attention/model.py +++ b/ppocr/modeling/heads/self_attention/model.py @@ -4,8 +4,9 @@ import numpy as np import paddle.fluid as fluid import paddle.fluid.layers as layers -from .desc import * -from .config import ModelHyperParams,TrainTaskConfig +# Set seed for CE +dropout_seed = None + def wrap_layer_with_block(layer, block_idx): """ @@ -114,7 +115,7 @@ def multi_head_attention(queries, def __split_heads_qkv(queries, keys, values, n_head, d_key, d_value): """ - Reshape input tensors at the last dimension to split multi-heads + Reshape input tensors at the last dimension to split multi-heads and then transpose. Specifically, transform the input tensor with shape [bs, max_sequence_length, n_head * hidden_dim] to the output tensor with shape [bs, n_head, max_sequence_length, hidden_dim]. @@ -269,23 +270,24 @@ pre_process_layer = partial(pre_post_process_layer, None) post_process_layer = pre_post_process_layer -def prepare_encoder(src_word,#[b,t,c] - src_pos, - src_vocab_size, - src_emb_dim, - src_max_len, - dropout_rate=0., - bos_idx=0, - word_emb_param_name=None, - pos_enc_param_name=None): +def prepare_encoder( + src_word, #[b,t,c] + src_pos, + src_vocab_size, + src_emb_dim, + src_max_len, + dropout_rate=0., + bos_idx=0, + word_emb_param_name=None, + pos_enc_param_name=None): """Add word embeddings and position encodings. The output tensor has a shape of: [batch_size, max_src_length_in_batch, d_model]. This module is used at the bottom of the encoder stacks. """ - - src_word_emb =src_word#layers.concat(res,axis=1) - src_word_emb=layers.cast(src_word_emb,'float32') + + src_word_emb = src_word #layers.concat(res,axis=1) + src_word_emb = layers.cast(src_word_emb, 'float32') # print("src_word_emb",src_word_emb) src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5) @@ -302,14 +304,14 @@ def prepare_encoder(src_word,#[b,t,c] def prepare_decoder(src_word, - src_pos, - src_vocab_size, - src_emb_dim, - src_max_len, - dropout_rate=0., - bos_idx=0, - word_emb_param_name=None, - pos_enc_param_name=None): + src_pos, + src_vocab_size, + src_emb_dim, + src_max_len, + dropout_rate=0., + bos_idx=0, + word_emb_param_name=None, + pos_enc_param_name=None): """Add word embeddings and position encodings. The output tensor has a shape of: [batch_size, max_src_length_in_batch, d_model]. @@ -323,7 +325,7 @@ def prepare_decoder(src_word, name=word_emb_param_name, initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5))) # print("target_word_emb",src_word_emb) - src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim ** 0.5) + src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5) src_pos_enc = layers.embedding( src_pos, size=[src_max_len, src_emb_dim], @@ -335,6 +337,7 @@ def prepare_decoder(src_word, enc_input, dropout_prob=dropout_rate, seed=dropout_seed, is_test=False) if dropout_rate else enc_input + # prepare_encoder = partial( # prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[0]) # prepare_decoder = partial( @@ -595,21 +598,9 @@ def transformer(src_vocab_size, weights = all_inputs[-1] enc_output = wrap_encoder( - src_vocab_size, - ModelHyperParams.src_seq_len, - n_layer, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - prepostprocess_dropout, - attention_dropout, - relu_dropout, - preprocess_cmd, - postprocess_cmd, - weight_sharing, - enc_inputs) + src_vocab_size, 64, n_layer, n_head, d_key, d_value, d_model, + d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, + preprocess_cmd, postprocess_cmd, weight_sharing, enc_inputs) predict = wrap_decoder( trg_vocab_size, @@ -650,34 +641,34 @@ def transformer(src_vocab_size, def wrap_encoder_forFeature(src_vocab_size, - max_length, - n_layer, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - prepostprocess_dropout, - attention_dropout, - relu_dropout, - preprocess_cmd, - postprocess_cmd, - weight_sharing, - enc_inputs=None, - bos_idx=0): + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + enc_inputs=None, + bos_idx=0): """ The wrapper assembles together all needed layers for the encoder. img, src_pos, src_slf_attn_bias = enc_inputs img """ - + if enc_inputs is None: # This is used to implement independent encoder program in inference. conv_features, src_pos, src_slf_attn_bias = make_all_inputs( encoder_data_input_fields) else: - conv_features, src_pos, src_slf_attn_bias = enc_inputs# - b,t,c = conv_features.shape + conv_features, src_pos, src_slf_attn_bias = enc_inputs # + b, t, c = conv_features.shape #""" # insert cnn #""" @@ -694,11 +685,11 @@ def wrap_encoder_forFeature(src_vocab_size, #b , c, h, w = feat.shape#h=6 #print(feat) #layers.Print(feat,message="conv_feat",summarize=10) - + #feat =layers.conv2d(feat,c,filter_size =[4 , 1],act="relu") #feat = layers.pool2d(feat,pool_stride=(3,1),pool_size=(3,1)) #src_word = layers.squeeze(feat,axes=[2]) #src_word [-1,c,ww] - + #feat = layers.transpose(feat, [0,3,1,2]) #src_word = layers.reshape(feat,[-1,w, c*h]) #src_word = layers.im2sequence( @@ -706,10 +697,10 @@ def wrap_encoder_forFeature(src_vocab_size, # stride=[1, 1], # filter_size=[feat.shape[2], 1]) #layers.Print(src_word,message="src_word",summarize=10) - + # print('feat',feat) #print("src_word",src_word) - + enc_input = prepare_encoder( conv_features, src_pos, @@ -718,7 +709,7 @@ def wrap_encoder_forFeature(src_vocab_size, max_length, prepostprocess_dropout, bos_idx=bos_idx, - word_emb_param_name=word_emb_param_names[0]) + word_emb_param_name="src_word_emb_table") enc_output = encoder( enc_input, @@ -736,6 +727,7 @@ def wrap_encoder_forFeature(src_vocab_size, postprocess_cmd, ) return enc_output + def wrap_encoder(src_vocab_size, max_length, n_layer, @@ -762,7 +754,7 @@ def wrap_encoder(src_vocab_size, src_word, src_pos, src_slf_attn_bias = make_all_inputs( encoder_data_input_fields) else: - src_word, src_pos, src_slf_attn_bias = enc_inputs# + src_word, src_pos, src_slf_attn_bias = enc_inputs # #""" # insert cnn #""" @@ -779,11 +771,11 @@ def wrap_encoder(src_vocab_size, #b , c, h, w = feat.shape#h=6 #print(feat) #layers.Print(feat,message="conv_feat",summarize=10) - + #feat =layers.conv2d(feat,c,filter_size =[4 , 1],act="relu") #feat = layers.pool2d(feat,pool_stride=(3,1),pool_size=(3,1)) #src_word = layers.squeeze(feat,axes=[2]) #src_word [-1,c,ww] - + #feat = layers.transpose(feat, [0,3,1,2]) #src_word = layers.reshape(feat,[-1,w, c*h]) #src_word = layers.im2sequence( @@ -791,7 +783,7 @@ def wrap_encoder(src_vocab_size, # stride=[1, 1], # filter_size=[feat.shape[2], 1]) #layers.Print(src_word,message="src_word",summarize=10) - + # print('feat',feat) #print("src_word",src_word) enc_input = prepare_decoder( @@ -802,7 +794,7 @@ def wrap_encoder(src_vocab_size, max_length, prepostprocess_dropout, bos_idx=bos_idx, - word_emb_param_name=word_emb_param_names[0]) + word_emb_param_name="src_word_emb_table") enc_output = encoder( enc_input, @@ -858,8 +850,8 @@ def wrap_decoder(trg_vocab_size, max_length, prepostprocess_dropout, bos_idx=bos_idx, - word_emb_param_name=word_emb_param_names[0] - if weight_sharing else word_emb_param_names[1]) + word_emb_param_name="src_word_emb_table" + if weight_sharing else "trg_word_emb_table") dec_output = decoder( dec_input, enc_output, @@ -886,7 +878,7 @@ def wrap_decoder(trg_vocab_size, predict = layers.matmul( x=dec_output, y=fluid.default_main_program().global_block().var( - word_emb_param_names[0]), + "trg_word_emb_table"), transpose_y=True) else: predict = layers.fc(input=dec_output, @@ -931,12 +923,13 @@ def fast_decode(src_vocab_size, enc_inputs_len = len(encoder_data_input_fields) dec_inputs_len = len(fast_decoder_data_input_fields) - enc_inputs = all_inputs[0:enc_inputs_len]#enc_inputs tensor - dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]#dec_inputs tensor + enc_inputs = all_inputs[0:enc_inputs_len] #enc_inputs tensor + dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + + dec_inputs_len] #dec_inputs tensor enc_output = wrap_encoder( src_vocab_size, - ModelHyperParams.src_seq_len,##to do !!!!!???? + 64, ##to do !!!!!???? n_layer, n_head, d_key, diff --git a/tools/eval_utils/eval_rec_utils.py b/tools/eval_utils/eval_rec_utils.py index 3d496bd3..ecdf0aaf 100644 --- a/tools/eval_utils/eval_rec_utils.py +++ b/tools/eval_utils/eval_rec_utils.py @@ -61,7 +61,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode): img_list.append(data[ino][0]) label_list.append(data[ino][1]) - if config['Global']['loss_type'] != "srn": + if config['Global']['loss_type'] != "srn": img_list = np.concatenate(img_list, axis=0) outs = exe.run(eval_info_dict['program'], \ feed={'image': img_list}, \ @@ -75,7 +75,8 @@ def eval_rec_run(exe, config, eval_info_dict, mode): preds_lod = outs[0].lod()[0] labels, labels_lod = convert_rec_label_to_lod(label_list) acc, acc_num, sample_num = cal_predicts_accuracy( - char_ops, preds, preds_lod, labels, labels_lod, is_remove_duplicate) + char_ops, preds, preds_lod, labels, labels_lod, + is_remove_duplicate) else: encoder_word_pos_list = [] gsrm_word_pos_list = [] @@ -89,15 +90,19 @@ def eval_rec_run(exe, config, eval_info_dict, mode): img_list = np.concatenate(img_list, axis=0) label_list = np.concatenate(label_list, axis=0) - encoder_word_pos_list = np.concatenate(encoder_word_pos_list, axis=0).astype(np.int64) - gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list, axis=0).astype(np.int64) - gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list, axis=0).astype(np.float32) - gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list, axis=0).astype(np.float32) + encoder_word_pos_list = np.concatenate( + encoder_word_pos_list, axis=0).astype(np.int64) + gsrm_word_pos_list = np.concatenate( + gsrm_word_pos_list, axis=0).astype(np.int64) + gsrm_slf_attn_bias1_list = np.concatenate( + gsrm_slf_attn_bias1_list, axis=0).astype(np.float32) + gsrm_slf_attn_bias2_list = np.concatenate( + gsrm_slf_attn_bias2_list, axis=0).astype(np.float32) labels = label_list outs = exe.run(eval_info_dict['program'], \ - feed={'image': img_list, 'encoder_word_pos': encoder_word_pos_list, + feed={'image': img_list, 'encoder_word_pos': encoder_word_pos_list, 'gsrm_word_pos': gsrm_word_pos_list, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1_list, 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2_list}, \ fetch_list=eval_info_dict['fetch_varname_list'], \ @@ -108,7 +113,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode): total_acc_num += acc_num total_sample_num += sample_num - logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc)) + #logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc)) total_batch_num += 1 avg_acc = total_acc_num * 1.0 / total_sample_num metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \ diff --git a/tools/program.py b/tools/program.py index 64c827e7..6ebc27cb 100755 --- a/tools/program.py +++ b/tools/program.py @@ -34,6 +34,7 @@ from ppocr.utils.save_load import save_model import numpy as np from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps + class ArgsParser(ArgumentParser): def __init__(self): super(ArgsParser, self).__init__( @@ -196,10 +197,13 @@ def build(config, main_prog, startup_prog, mode): if config['Global']["loss_type"] == 'srn': model_average = fluid.optimizer.ModelAverage( config['Global']['average_window'], - min_average_window=config['Global']['min_average_window'], - max_average_window=config['Global']['max_average_window']) + min_average_window=config['Global'][ + 'min_average_window'], + max_average_window=config['Global'][ + 'max_average_window']) - return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name,model_average) + return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name, + model_average) def build_export(config, main_prog, startup_prog): @@ -398,6 +402,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): save_model(train_info_dict['train_program'], save_path) return + def preprocess(): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) @@ -409,8 +414,8 @@ def preprocess(): check_gpu(use_gpu) alg = config['Global']['algorithm'] - assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE'] - if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']: + assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN'] + if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']: config['Global']['char_ops'] = CharacterOps(config['Global']) place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() diff --git a/train_data b/train_data new file mode 120000 index 00000000..7c2082ab --- /dev/null +++ b/train_data @@ -0,0 +1 @@ +/workspace/PaddleOCR/train_data/ \ No newline at end of file -- GitLab