From bf4863c95082651cc8daf5c39455632ca6c113db Mon Sep 17 00:00:00 2001 From: tink2123 Date: Sat, 15 Aug 2020 15:45:55 +0800 Subject: [PATCH] update infer_rec for srn --- ppocr/data/rec/dataset_traversal.py | 41 +++-- ppocr/modeling/architectures/rec_model.py | 5 +- ppocr/modeling/backbones/rec_resnet_vd.py | 2 +- ppocr/modeling/heads/rec_srn_all_head.py | 194 ++++++++++++---------- ppocr/modeling/losses/rec_srn_loss.py | 39 ++--- ppocr/utils/character.py | 25 +-- tools/eval_utils/eval_rec_utils.py | 4 +- tools/infer_rec.py | 48 +++++- 8 files changed, 197 insertions(+), 161 deletions(-) diff --git a/ppocr/data/rec/dataset_traversal.py b/ppocr/data/rec/dataset_traversal.py index b46e37da..53c7e87b 100755 --- a/ppocr/data/rec/dataset_traversal.py +++ b/ppocr/data/rec/dataset_traversal.py @@ -40,10 +40,12 @@ class LMDBReader(object): self.image_shape = params['image_shape'] self.loss_type = params['loss_type'] self.max_text_length = params['max_text_length'] - self.num_heads = params['num_heads'] self.mode = params['mode'] self.drop_last = False self.use_tps = False + self.num_heads = None + if "num_heads" in params: + self.num_heads = params['num_heads'] if "tps" in params: self.ues_tps = True self.use_distort = False @@ -134,20 +136,6 @@ class LMDBReader(object): tps=self.use_tps, infer_mode=True) yield norm_img - #elif self.mode == 'eval': - # image_file_list = get_image_file_list(self.infer_img) - # for single_img in image_file_list: - # img = cv2.imread(single_img) - # if img.shape[-1]==1 or len(list(img.shape))==2: - # img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - # norm_img = process_image( - # img=img, - # image_shape=self.image_shape, - # char_ops=self.char_ops, - # tps=self.use_tps, - # infer_mode=True - # ) - # yield norm_img else: lmdb_sets = self.load_hierarchical_lmdb_dataset() if process_id == 0: @@ -169,14 +157,22 @@ class LMDBReader(object): outs = [] if self.loss_type == "srn": outs = process_image_srn( - img, self.image_shape, self.num_heads, - self.max_text_length, label, self.char_ops, - self.loss_type) + img=img, + image_shape=self.image_shape, + num_heads=self.num_heads, + max_text_length=self.max_text_length, + label=label, + char_ops=self.char_ops, + loss_type=self.loss_type) else: outs = process_image( - img, self.image_shape, label, self.char_ops, - self.loss_type, self.max_text_length) + img=img, + image_shape=self.image_shape, + label=label, + char_ops=self.char_ops, + loss_type=self.loss_type, + max_text_length=self.max_text_length) if outs is None: continue yield outs @@ -192,8 +188,9 @@ class LMDBReader(object): if len(batch_outs) == self.batch_size: yield batch_outs batch_outs = [] - if len(batch_outs) != 0: - yield batch_outs + if not self.drop_last: + if len(batch_outs) != 0: + yield batch_outs if self.infer_img is None: return batch_iter_reader diff --git a/ppocr/modeling/architectures/rec_model.py b/ppocr/modeling/architectures/rec_model.py index d2e01a43..5eacd5de 100755 --- a/ppocr/modeling/architectures/rec_model.py +++ b/ppocr/modeling/architectures/rec_model.py @@ -58,7 +58,10 @@ class RecModel(object): self.loss_type = global_params['loss_type'] self.image_shape = global_params['image_shape'] self.max_text_length = global_params['max_text_length'] - self.num_heads = global_params["num_heads"] + if "num_heads" in params: + self.num_heads = global_params["num_heads"] + else: + self.num_heads = None def create_feed(self, mode): image_shape = deepcopy(self.image_shape) diff --git a/ppocr/modeling/backbones/rec_resnet_vd.py b/ppocr/modeling/backbones/rec_resnet_vd.py index 2c7cd4c7..bc58c8ac 100755 --- a/ppocr/modeling/backbones/rec_resnet_vd.py +++ b/ppocr/modeling/backbones/rec_resnet_vd.py @@ -32,7 +32,7 @@ class ResNet(): def __init__(self, params): self.layers = params['layers'] self.is_3x3 = True - supported_layers = [18, 34, 50, 101, 152] + supported_layers = [18, 34, 50, 101, 152, 200] assert self.layers in supported_layers, \ "supported layers are {} but input layer is {}".format(supported_layers, self.layers) diff --git a/ppocr/modeling/heads/rec_srn_all_head.py b/ppocr/modeling/heads/rec_srn_all_head.py index bf1f4a44..e1bb955d 100755 --- a/ppocr/modeling/heads/rec_srn_all_head.py +++ b/ppocr/modeling/heads/rec_srn_all_head.py @@ -21,15 +21,12 @@ import math import paddle import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr -#from .rec_seq_encoder import SequenceEncoder -#from ..common_functions import get_para_bias_attr import numpy as np from .self_attention.model import wrap_encoder from .self_attention.model import wrap_encoder_forFeature gradient_clip = 10 - class SRNPredict(object): def __init__(self, params): super(SRNPredict, self).__init__() @@ -41,7 +38,6 @@ class SRNPredict(object): self.num_decoder_TUs = params['num_decoder_TUs'] self.hidden_dims = params['hidden_dims'] - def pvam(self, inputs, others): b, c, h, w = inputs.shape @@ -53,52 +49,62 @@ class SRNPredict(object): encoder_word_pos = others["encoder_word_pos"] gsrm_word_pos = others["gsrm_word_pos"] - enc_inputs = [conv_features, encoder_word_pos, None] - word_features = wrap_encoder_forFeature(src_vocab_size=-1, - max_length=t, - n_layer=self.num_encoder_TUs, - n_head=self.num_heads, - d_key= int(self.hidden_dims / self.num_heads), - d_value= int(self.hidden_dims / self.num_heads), - d_model=self.hidden_dims, - d_inner_hid=self.hidden_dims, - prepostprocess_dropout=0.1, - attention_dropout=0.1, - relu_dropout=0.1, - preprocess_cmd="n", - postprocess_cmd="da", - weight_sharing=True, - enc_inputs=enc_inputs, - ) - fluid.clip.set_gradient_clip(fluid.clip.GradientClipByValue(gradient_clip)) + word_features = wrap_encoder_forFeature( + src_vocab_size=-1, + max_length=t, + n_layer=self.num_encoder_TUs, + n_head=self.num_heads, + d_key=int(self.hidden_dims / self.num_heads), + d_value=int(self.hidden_dims / self.num_heads), + d_model=self.hidden_dims, + d_inner_hid=self.hidden_dims, + prepostprocess_dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + preprocess_cmd="n", + postprocess_cmd="da", + weight_sharing=True, + enc_inputs=enc_inputs, ) + fluid.clip.set_gradient_clip( + fluid.clip.GradientClipByValue(gradient_clip)) #===== Parallel Visual Attention Module ===== b, t, c = word_features.shape - word_features = fluid.layers.fc(word_features, c, num_flatten_dims=2) + word_features = fluid.layers.fc(word_features, c, num_flatten_dims=2) word_features_ = fluid.layers.reshape(word_features, [-1, 1, t, c]) - word_features_ = fluid.layers.expand(word_features_, [1, self.max_length, 1, 1]) - word_pos_feature = fluid.layers.embedding(gsrm_word_pos, [self.max_length, c]) - word_pos_ = fluid.layers.reshape(word_pos_feature, [-1, self.max_length, 1, c]) + word_features_ = fluid.layers.expand(word_features_, + [1, self.max_length, 1, 1]) + word_pos_feature = fluid.layers.embedding(gsrm_word_pos, + [self.max_length, c]) + word_pos_ = fluid.layers.reshape(word_pos_feature, + [-1, self.max_length, 1, c]) word_pos_ = fluid.layers.expand(word_pos_, [1, 1, t, 1]) - temp = fluid.layers.elementwise_add(word_features_, word_pos_, act='tanh') + temp = fluid.layers.elementwise_add( + word_features_, word_pos_, act='tanh') + + attention_weight = fluid.layers.fc(input=temp, + size=1, + num_flatten_dims=3, + bias_attr=False) + attention_weight = fluid.layers.reshape( + x=attention_weight, shape=[-1, self.max_length, t]) + attention_weight = fluid.layers.softmax(input=attention_weight, axis=-1) - attention_weight = fluid.layers.fc(input=temp, size=1, num_flatten_dims=3, bias_attr=False) - attention_weight = fluid.layers.reshape(x=attention_weight, shape=[-1, self.max_length, t]) - attention_weight = fluid.layers.softmax(input=attention_weight, axis=-1) + pvam_features = fluid.layers.matmul(attention_weight, + word_features) #[b, max_length, c] - pvam_features = fluid.layers.matmul(attention_weight, word_features)#[b, max_length, c] - return pvam_features - + def gsrm(self, pvam_features, others): #===== GSRM Visual-to-semantic embedding block ===== b, t, c = pvam_features.shape - word_out = fluid.layers.fc(input=fluid.layers.reshape(pvam_features, [-1, c]), - size=self.char_num, - act="softmax") + word_out = fluid.layers.fc( + input=fluid.layers.reshape(pvam_features, [-1, c]), + size=self.char_num, + act="softmax") #word_out.stop_gradient = True word_ids = fluid.layers.argmax(word_out, axis=1) word_ids.stop_gradient = True @@ -106,7 +112,7 @@ class SRNPredict(object): #===== GSRM Semantic reasoning block ===== """ - This module is achieved through bi-transformers, + This module is achieved through bi-transformers, ngram_feature1 is the froward one, ngram_fetaure2 is the backward one """ pad_idx = self.char_num @@ -120,7 +126,8 @@ class SRNPredict(object): word1 for forward; word2 for backward """ word1 = fluid.layers.cast(word_ids, "float32") - word1 = fluid.layers.pad(word1, [0, 0, 1, 0, 0, 0], pad_value=1.0 * pad_idx) + word1 = fluid.layers.pad(word1, [0, 0, 1, 0, 0, 0], + pad_value=1.0 * pad_idx) word1 = fluid.layers.cast(word1, "int64") word1 = word1[:, :-1, :] word2 = word_ids @@ -132,39 +139,40 @@ class SRNPredict(object): enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1] enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2] - gsrm_feature1 = wrap_encoder(src_vocab_size=self.char_num + 1, - max_length=self.max_length, - n_layer=self.num_decoder_TUs, - n_head=self.num_heads, - d_key=int(self.hidden_dims / self.num_heads), - d_value=int(self.hidden_dims / self.num_heads), - d_model=self.hidden_dims, - d_inner_hid=self.hidden_dims, - prepostprocess_dropout=0.1, - attention_dropout=0.1, - relu_dropout=0.1, - preprocess_cmd="n", - postprocess_cmd="da", - weight_sharing=True, - enc_inputs=enc_inputs_1, - ) - gsrm_feature2 = wrap_encoder(src_vocab_size=self.char_num + 1, - max_length=self.max_length, - n_layer=self.num_decoder_TUs, - n_head=self.num_heads, - d_key=int(self.hidden_dims / self.num_heads), - d_value=int(self.hidden_dims / self.num_heads), - d_model=self.hidden_dims, - d_inner_hid=self.hidden_dims, - prepostprocess_dropout=0.1, - attention_dropout=0.1, - relu_dropout=0.1, - preprocess_cmd="n", - postprocess_cmd="da", - weight_sharing=True, - enc_inputs=enc_inputs_2, - ) - gsrm_feature2 = fluid.layers.pad(gsrm_feature2, [0, 0, 0, 1, 0, 0], pad_value=0.) + gsrm_feature1 = wrap_encoder( + src_vocab_size=self.char_num + 1, + max_length=self.max_length, + n_layer=self.num_decoder_TUs, + n_head=self.num_heads, + d_key=int(self.hidden_dims / self.num_heads), + d_value=int(self.hidden_dims / self.num_heads), + d_model=self.hidden_dims, + d_inner_hid=self.hidden_dims, + prepostprocess_dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + preprocess_cmd="n", + postprocess_cmd="da", + weight_sharing=True, + enc_inputs=enc_inputs_1, ) + gsrm_feature2 = wrap_encoder( + src_vocab_size=self.char_num + 1, + max_length=self.max_length, + n_layer=self.num_decoder_TUs, + n_head=self.num_heads, + d_key=int(self.hidden_dims / self.num_heads), + d_value=int(self.hidden_dims / self.num_heads), + d_model=self.hidden_dims, + d_inner_hid=self.hidden_dims, + prepostprocess_dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + preprocess_cmd="n", + postprocess_cmd="da", + weight_sharing=True, + enc_inputs=enc_inputs_2, ) + gsrm_feature2 = fluid.layers.pad(gsrm_feature2, [0, 0, 0, 1, 0, 0], + pad_value=0.) gsrm_feature2 = gsrm_feature2[:, 1:, ] gsrm_features = gsrm_feature1 + gsrm_feature2 @@ -172,10 +180,12 @@ class SRNPredict(object): gsrm_out = fluid.layers.matmul( x=gsrm_features, - y=fluid.default_main_program().global_block().var("src_word_emb_table"), + y=fluid.default_main_program().global_block().var( + "src_word_emb_table"), transpose_y=True) - b,t,c = gsrm_out.shape - gsrm_out = fluid.layers.softmax(input=fluid.layers.reshape(gsrm_out, [-1, c])) + b, t, c = gsrm_out.shape + gsrm_out = fluid.layers.softmax(input=fluid.layers.reshape(gsrm_out, + [-1, c])) return gsrm_features, word_out, gsrm_out @@ -184,19 +194,25 @@ class SRNPredict(object): #===== Visual-Semantic Fusion Decoder Module ===== b, t, c1 = pvam_features.shape b, t, c2 = gsrm_features.shape - combine_features_ = fluid.layers.concat([pvam_features, gsrm_features], axis=2) - img_comb_features_ = fluid.layers.reshape(x=combine_features_, shape=[-1, c1 + c2]) - img_comb_features_map = fluid.layers.fc(input=img_comb_features_, size=c1, act="sigmoid") - img_comb_features_map = fluid.layers.reshape(x=img_comb_features_map, shape=[-1, t, c1]) - combine_features = img_comb_features_map * pvam_features + (1.0 - img_comb_features_map) * gsrm_features - img_comb_features = fluid.layers.reshape(x=combine_features, shape=[-1, c1]) + combine_features_ = fluid.layers.concat( + [pvam_features, gsrm_features], axis=2) + img_comb_features_ = fluid.layers.reshape( + x=combine_features_, shape=[-1, c1 + c2]) + img_comb_features_map = fluid.layers.fc(input=img_comb_features_, + size=c1, + act="sigmoid") + img_comb_features_map = fluid.layers.reshape( + x=img_comb_features_map, shape=[-1, t, c1]) + combine_features = img_comb_features_map * pvam_features + ( + 1.0 - img_comb_features_map) * gsrm_features + img_comb_features = fluid.layers.reshape( + x=combine_features, shape=[-1, c1]) fc_out = fluid.layers.fc(input=img_comb_features, size=self.char_num, act="softmax") return fc_out - def __call__(self, inputs, others, mode=None): pvam_features = self.pvam(inputs, others) @@ -204,15 +220,11 @@ class SRNPredict(object): final_out = self.vsfd(pvam_features, gsrm_features) _, decoded_out = fluid.layers.topk(input=final_out, k=1) - predicts = {'predict': final_out, 'decoded_out': decoded_out, - 'word_out': word_out, 'gsrm_out': gsrm_out} + predicts = { + 'predict': final_out, + 'decoded_out': decoded_out, + 'word_out': word_out, + 'gsrm_out': gsrm_out + } return predicts - - - - - - - - diff --git a/ppocr/modeling/losses/rec_srn_loss.py b/ppocr/modeling/losses/rec_srn_loss.py index 68a480ac..b1ebd86f 100755 --- a/ppocr/modeling/losses/rec_srn_loss.py +++ b/ppocr/modeling/losses/rec_srn_loss.py @@ -35,24 +35,21 @@ class SRNLoss(object): lbl_weight = others['lbl_weight'] casted_label = fluid.layers.cast(x=label, dtype='int64') - cost_word = fluid.layers.cross_entropy(input=word_predict, label=casted_label) - cost_gsrm = fluid.layers.cross_entropy(input=gsrm_predict, label=casted_label) - cost_vsfd = fluid.layers.cross_entropy(input=predict, label=casted_label) - - #cost_word = cost_word * lbl_weight - #cost_gsrm = cost_gsrm * lbl_weight - #cost_vsfd = cost_vsfd * lbl_weight - - cost_word = fluid.layers.reshape(x=fluid.layers.reduce_sum(cost_word), shape=[1]) - cost_gsrm = fluid.layers.reshape(x=fluid.layers.reduce_sum(cost_gsrm), shape=[1]) - cost_vsfd = fluid.layers.reshape(x=fluid.layers.reduce_sum(cost_vsfd), shape=[1]) - - sum_cost = fluid.layers.sum([cost_word, cost_vsfd * 2.0, cost_gsrm * 0.15]) - - #sum_cost = fluid.layers.sum([cost_word * 3.0, cost_vsfd, cost_gsrm * 0.15]) - #sum_cost = cost_word - - #fluid.layers.Print(cost_word,message="word_cost") - #fluid.layers.Print(cost_vsfd,message="img_cost") - return [sum_cost,cost_vsfd,cost_word] - #return [sum_cost, cost_vsfd, cost_word] + cost_word = fluid.layers.cross_entropy( + input=word_predict, label=casted_label) + cost_gsrm = fluid.layers.cross_entropy( + input=gsrm_predict, label=casted_label) + cost_vsfd = fluid.layers.cross_entropy( + input=predict, label=casted_label) + + cost_word = fluid.layers.reshape( + x=fluid.layers.reduce_sum(cost_word), shape=[1]) + cost_gsrm = fluid.layers.reshape( + x=fluid.layers.reduce_sum(cost_gsrm), shape=[1]) + cost_vsfd = fluid.layers.reshape( + x=fluid.layers.reduce_sum(cost_vsfd), shape=[1]) + + sum_cost = fluid.layers.sum( + [cost_word, cost_vsfd * 2.0, cost_gsrm * 0.15]) + + return [sum_cost, cost_vsfd, cost_word] diff --git a/ppocr/utils/character.py b/ppocr/utils/character.py index 79d6f5ca..5f2963ac 100755 --- a/ppocr/utils/character.py +++ b/ppocr/utils/character.py @@ -149,38 +149,29 @@ def cal_predicts_accuracy(char_ops, acc = acc_num * 1.0 / img_num return acc, acc_num, img_num + def cal_predicts_accuracy_srn(char_ops, - preds, - labels, - max_text_len, - is_debug=False): + preds, + labels, + max_text_len, + is_debug=False): acc_num = 0 img_num = 0 total_len = preds.shape[0] img_num = int(total_len / max_text_len) - #print (img_num) for i in range(img_num): cur_label = [] cur_pred = [] for j in range(max_text_len): - if labels[j + i * max_text_len] != 37: #0 + if labels[j + i * max_text_len] != 37: #0 cur_label.append(labels[j + i * max_text_len][0]) else: break - if is_debug: - for j in range(max_text_len): - if preds[j + i * max_text_len] != 37: #0 - cur_pred.append(preds[j + i * max_text_len][0]) - else: - break - print ("cur_label: ", cur_label) - print ("cur_pred: ", cur_pred) - - for j in range(max_text_len + 1): - if j < len(cur_label) and preds[j + i * max_text_len][0] != cur_label[j]: + if j < len(cur_label) and preds[j + i * max_text_len][ + 0] != cur_label[j]: break elif j == len(cur_label) and j == max_text_len: acc_num += 1 diff --git a/tools/eval_utils/eval_rec_utils.py b/tools/eval_utils/eval_rec_utils.py index ecdf0aaf..5a653678 100644 --- a/tools/eval_utils/eval_rec_utils.py +++ b/tools/eval_utils/eval_rec_utils.py @@ -123,8 +123,8 @@ def eval_rec_run(exe, config, eval_info_dict, mode): def test_rec_benchmark(exe, config, eval_info_dict): " Evaluate lmdb dataset " - eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', \ - 'IC13_857', 'IC15_1811', 'IC15_2077','SVTP', 'CUTE80'] + eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860','IC03_867', \ + 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077','SVTP', 'CUTE80'] eval_data_dir = config['TestReader']['lmdb_sets_dir'] total_evaluation_data_number = 0 total_correct_number = 0 diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 8cde44d8..21b503cc 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -64,7 +64,6 @@ def main(): exe = fluid.Executor(place) rec_model = create_module(config['Architecture']['function'])(params=config) - startup_prog = fluid.Program() eval_prog = fluid.Program() with fluid.program_guard(eval_prog, startup_prog): @@ -86,10 +85,36 @@ def main(): for i in range(max_img_num): logger.info("infer_img:%s" % infer_list[i]) img = next(blobs) - predict = exe.run(program=eval_prog, - feed={"image": img}, - fetch_list=fetch_varname_list, - return_numpy=False) + if loss_type != "srn": + predict = exe.run(program=eval_prog, + feed={"image": img}, + fetch_list=fetch_varname_list, + return_numpy=False) + else: + encoder_word_pos_list = [] + gsrm_word_pos_list = [] + gsrm_slf_attn_bias1_list = [] + gsrm_slf_attn_bias2_list = [] + encoder_word_pos_list.append(img[1]) + gsrm_word_pos_list.append(img[2]) + gsrm_slf_attn_bias1_list.append(img[3]) + gsrm_slf_attn_bias2_list.append(img[4]) + + encoder_word_pos_list = np.concatenate( + encoder_word_pos_list, axis=0).astype(np.int64) + gsrm_word_pos_list = np.concatenate( + gsrm_word_pos_list, axis=0).astype(np.int64) + gsrm_slf_attn_bias1_list = np.concatenate( + gsrm_slf_attn_bias1_list, axis=0).astype(np.float32) + gsrm_slf_attn_bias2_list = np.concatenate( + gsrm_slf_attn_bias2_list, axis=0).astype(np.float32) + + predict = exe.run(program=eval_prog, \ + feed={'image': img[0], 'encoder_word_pos': encoder_word_pos_list, + 'gsrm_word_pos': gsrm_word_pos_list, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1_list, + 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2_list}, \ + fetch_list=fetch_varname_list, \ + return_numpy=False) if loss_type == "ctc": preds = np.array(predict[0]) preds = preds.reshape(-1) @@ -114,7 +139,18 @@ def main(): score = np.mean(probs[0, 1:end_pos[1]]) preds = preds.reshape(-1) preds_text = char_ops.decode(preds) - + elif loss_type == "srn": + cur_pred = [] + preds = np.array(predict[0]) + preds = preds.reshape(-1) + probs = np.array(predict[1]) + ind = np.argmax(probs, axis=1) + valid_ind = np.where(preds != 37)[0] + if len(valid_ind) == 0: + continue + score = np.mean(probs[valid_ind, ind[valid_ind]]) + preds = preds[:valid_ind[-1] + 1] + preds_text = char_ops.decode(preds) logger.info("\t index: {}".format(preds)) logger.info("\t word : {}".format(preds_text)) logger.info("\t score: {}".format(score)) -- GitLab