diff --git a/configs/rec/rec_r50fpn_vd_none_srn.yml b/configs/rec/rec_r50fpn_vd_none_srn.yml index 7a0f136c28dd967aeb422d843a49cf65b934d7ca..30709e479f8da56b6bd7fe9ebf817a27bff9cc38 100755 --- a/configs/rec/rec_r50fpn_vd_none_srn.yml +++ b/configs/rec/rec_r50fpn_vd_none_srn.yml @@ -27,7 +27,7 @@ Architecture: function: ppocr.modeling.architectures.rec_model,RecModel Backbone: - function: ppocr.modeling.backbones.rec_resnet50_fpn,ResNet + function: ppocr.modeling.backbones.rec_resnet_fpn,ResNet layers: 50 Head: diff --git a/doc/doc_en/FAQ_en.md b/doc/doc_en/FAQ_en.md index a89567f7d912c815d802f021bf8b751f7d94e25c..25777be77b6393c09c38e3c319ca1bd50cc3b1e8 100644 --- a/doc/doc_en/FAQ_en.md +++ b/doc/doc_en/FAQ_en.md @@ -45,7 +45,7 @@ At present, the open source model, dataset and magnitude are as follows: Among them, the public datasets are opensourced, users can search and download by themselves, or refer to [Chinese data set](./datasets_en.md), synthetic data is not opensourced, users can use open-source synthesis tools to synthesize data themselves. Current available synthesis tools include [text_renderer](https://github.com/Sanster/text_renderer), [SynthText](https://github.com/ankush-me/SynthText), [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator), etc. 10. **Error in using the model with TPS module for prediction** -Error message: Input(X) dims[3] and Input(Grid) dims[2] should be equal, but received X dimension[3](108) != Grid dimension[2](100) +Error message: Input(X) dims[3] and Input(Grid) dims[2] should be equal, but received X dimension[3]\(108) != Grid dimension[2]\(100) Solution:TPS does not support variable shape. Please set --rec_image_shape='3,32,100' and --rec_char_type='en' 11. **Custom dictionary used during training, the recognition results show that words do not appear in the dictionary** diff --git a/ppocr/data/rec/dataset_traversal.py b/ppocr/data/rec/dataset_traversal.py index ebee624ab74b2390323ab538627f459cb2353e8b..67cbf9b53ad7b877be8985d76627cdf97d49f423 100755 --- a/ppocr/data/rec/dataset_traversal.py +++ b/ppocr/data/rec/dataset_traversal.py @@ -214,6 +214,8 @@ class SimpleReader(object): self.mode = params['mode'] self.infer_img = params['infer_img'] self.use_tps = False + if "num_heads" in params: + self.num_heads = params['num_heads'] if "tps" in params: self.use_tps = True self.use_distort = False @@ -251,12 +253,19 @@ class SimpleReader(object): img = cv2.imread(single_img) if img.shape[-1] == 1 or len(list(img.shape)) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - norm_img = process_image( - img=img, - image_shape=self.image_shape, - char_ops=self.char_ops, - tps=self.use_tps, - infer_mode=True) + if self.loss_type == 'srn': + norm_img = process_image_srn( + img=img, + image_shape=self.image_shape, + num_heads=self.num_heads, + max_text_length=self.max_text_length) + else: + norm_img = process_image( + img=img, + image_shape=self.image_shape, + char_ops=self.char_ops, + tps=self.use_tps, + infer_mode=True) yield norm_img else: with open(self.label_file_path, "rb") as fin: @@ -286,14 +295,25 @@ class SimpleReader(object): img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) label = substr[1] - outs = process_image( - img=img, - image_shape=self.image_shape, - label=label, - char_ops=self.char_ops, - loss_type=self.loss_type, - max_text_length=self.max_text_length, - distort=self.use_distort) + if self.loss_type == "srn": + outs = process_image_srn( + img=img, + image_shape=self.image_shape, + num_heads=self.num_heads, + max_text_length=self.max_text_length, + label=label, + char_ops=self.char_ops, + loss_type=self.loss_type) + + else: + outs = process_image( + img=img, + image_shape=self.image_shape, + label=label, + char_ops=self.char_ops, + loss_type=self.loss_type, + max_text_length=self.max_text_length, + distort=self.use_distort) if outs is None: continue yield outs diff --git a/ppocr/data/rec/img_tools.py b/ppocr/data/rec/img_tools.py index 527e0266ee33ac81e29b5610ed05f401860078a4..8b497e6b803ba0fffaefc3e12c366130504b9ce0 100755 --- a/ppocr/data/rec/img_tools.py +++ b/ppocr/data/rec/img_tools.py @@ -410,7 +410,8 @@ def resize_norm_img_srn(img, image_shape): def srn_other_inputs(image_shape, num_heads, - max_text_length): + max_text_length, + char_num): imgC, imgH, imgW = image_shape feature_dim = int((imgH / 8) * (imgW / 8)) @@ -418,7 +419,7 @@ def srn_other_inputs(image_shape, encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype('int64') gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype('int64') - lbl_weight = np.array([37] * max_text_length).reshape((-1,1)).astype('int64') + lbl_weight = np.array([int(char_num-1)] * max_text_length).reshape((-1,1)).astype('int64') gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length]) @@ -441,17 +442,18 @@ def process_image_srn(img, loss_type=None): norm_img = resize_norm_img_srn(img, image_shape) norm_img = norm_img[np.newaxis, :] + char_num = char_ops.get_char_num() + [lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ - srn_other_inputs(image_shape, num_heads, max_text_length) + srn_other_inputs(image_shape, num_heads, max_text_length,char_num) if label is not None: - char_num = char_ops.get_char_num() text = char_ops.encode(label) if len(text) == 0 or len(text) > max_text_length: return None else: if loss_type == "srn": - text_padded = [37] * max_text_length + text_padded = [int(char_num-1)] * max_text_length for i in range(len(text)): text_padded[i] = text[i] lbl_weight[i] = [1.0] diff --git a/ppocr/modeling/backbones/rec_resnet50_fpn.py b/ppocr/modeling/backbones/rec_resnet_fpn.py similarity index 51% rename from ppocr/modeling/backbones/rec_resnet50_fpn.py rename to ppocr/modeling/backbones/rec_resnet_fpn.py index f6d72377fe4e2d3355a4510f070178ad48dd2a27..0a05b5def8b79943f045d9cc941835cddc82bfdb 100755 --- a/ppocr/modeling/backbones/rec_resnet50_fpn.py +++ b/ppocr/modeling/backbones/rec_resnet_fpn.py @@ -22,12 +22,12 @@ import paddle import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr - -__all__ = ["ResNet", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"] +__all__ = [ + "ResNet", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152" +] Trainable = True -w_nolr = fluid.ParamAttr( - trainable = Trainable) +w_nolr = fluid.ParamAttr(trainable=Trainable) train_parameters = { "input_size": [3, 224, 224], "input_mean": [0.485, 0.456, 0.406], @@ -40,12 +40,12 @@ train_parameters = { } } + class ResNet(): def __init__(self, params): self.layers = params['layers'] self.params = train_parameters - def __call__(self, input): layers = self.layers supported_layers = [18, 34, 50, 101, 152] @@ -60,12 +60,17 @@ class ResNet(): depth = [3, 4, 23, 3] elif layers == 152: depth = [3, 8, 36, 3] - stride_list = [(2,2),(2,2),(1,1),(1,1)] + stride_list = [(2, 2), (2, 2), (1, 1), (1, 1)] num_filters = [64, 128, 256, 512] conv = self.conv_bn_layer( - input=input, num_filters=64, filter_size=7, stride=2, act='relu', name="conv1") - F = [] + input=input, + num_filters=64, + filter_size=7, + stride=2, + act='relu', + name="conv1") + F = [] if layers >= 50: for block in range(len(depth)): for i in range(depth[block]): @@ -79,26 +84,67 @@ class ResNet(): conv = self.bottleneck_block( input=conv, num_filters=num_filters[block], - stride=stride_list[block] if i == 0 else 1, name=conv_name) + stride=stride_list[block] if i == 0 else 1, + name=conv_name) + F.append(conv) + else: + for block in range(len(depth)): + for i in range(depth[block]): + conv_name = "res" + str(block + 2) + chr(97 + i) + + if i == 0 and block != 0: + stride = (2, 1) + else: + stride = (1, 1) + + conv = self.basic_block( + input=conv, + num_filters=num_filters[block], + stride=stride, + if_first=block == i == 0, + name=conv_name) F.append(conv) base = F[-1] - for i in [-2, -3]: + for i in [-2, -3]: b, c, w, h = F[i].shape - if (w,h) == base.shape[2:]: + if (w, h) == base.shape[2:]: base = base else: - base = fluid.layers.conv2d_transpose( input=base, num_filters=c,filter_size=4, stride=2, - padding=1,act=None, + base = fluid.layers.conv2d_transpose( + input=base, + num_filters=c, + filter_size=4, + stride=2, + padding=1, + act=None, param_attr=w_nolr, bias_attr=w_nolr) - base = fluid.layers.batch_norm(base, act = "relu", param_attr=w_nolr, bias_attr=w_nolr) + base = fluid.layers.batch_norm( + base, act="relu", param_attr=w_nolr, bias_attr=w_nolr) base = fluid.layers.concat([base, F[i]], axis=1) - base = fluid.layers.conv2d(base, num_filters=c, filter_size=1, param_attr=w_nolr, bias_attr=w_nolr) - base = fluid.layers.conv2d(base, num_filters=c, filter_size=3,padding = 1, param_attr=w_nolr, bias_attr=w_nolr) - base = fluid.layers.batch_norm(base, act = "relu", param_attr=w_nolr, bias_attr=w_nolr) - - base = fluid.layers.conv2d(base, num_filters=512, filter_size=1,bias_attr=w_nolr,param_attr=w_nolr) + base = fluid.layers.conv2d( + base, + num_filters=c, + filter_size=1, + param_attr=w_nolr, + bias_attr=w_nolr) + base = fluid.layers.conv2d( + base, + num_filters=c, + filter_size=3, + padding=1, + param_attr=w_nolr, + bias_attr=w_nolr) + base = fluid.layers.batch_norm( + base, act="relu", param_attr=w_nolr, bias_attr=w_nolr) + + base = fluid.layers.conv2d( + base, + num_filters=512, + filter_size=1, + bias_attr=w_nolr, + param_attr=w_nolr) return base @@ -113,13 +159,14 @@ class ResNet(): conv = fluid.layers.conv2d( input=input, num_filters=num_filters, - filter_size= 2 if stride==(1,1) else filter_size, - dilation = 2 if stride==(1,1) else 1, + filter_size=2 if stride == (1, 1) else filter_size, + dilation=2 if stride == (1, 1) else 1, stride=stride, padding=(filter_size - 1) // 2, groups=groups, act=None, - param_attr=ParamAttr(name=name + "_weights",trainable = Trainable), + param_attr=ParamAttr( + name=name + "_weights", trainable=Trainable), bias_attr=False, name=name + '.conv2d.output.1') @@ -127,28 +174,35 @@ class ResNet(): bn_name = "bn_" + name else: bn_name = "bn" + name[3:] - return fluid.layers.batch_norm(input=conv, - act=act, - name=bn_name + '.output.1', - param_attr=ParamAttr(name=bn_name + '_scale',trainable = Trainable), - bias_attr=ParamAttr(bn_name + '_offset',trainable = Trainable), - moving_mean_name=bn_name + '_mean', - moving_variance_name=bn_name + '_variance', ) + return fluid.layers.batch_norm( + input=conv, + act=act, + name=bn_name + '.output.1', + param_attr=ParamAttr( + name=bn_name + '_scale', trainable=Trainable), + bias_attr=ParamAttr( + bn_name + '_offset', trainable=Trainable), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', ) def shortcut(self, input, ch_out, stride, is_first, name): ch_in = input.shape[1] if ch_in != ch_out or stride != 1 or is_first == True: - if stride == (1,1): + if stride == (1, 1): return self.conv_bn_layer(input, ch_out, 1, 1, name=name) - else: #stride == (2,2) + else: #stride == (2,2) return self.conv_bn_layer(input, ch_out, 1, stride, name=name) - + else: return input def bottleneck_block(self, input, num_filters, stride, name): conv0 = self.conv_bn_layer( - input=input, num_filters=num_filters, filter_size=1, act='relu', name=name + "_branch2a") + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name + "_branch2a") conv1 = self.conv_bn_layer( input=conv0, num_filters=num_filters, @@ -157,16 +211,36 @@ class ResNet(): act='relu', name=name + "_branch2b") conv2 = self.conv_bn_layer( - input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, name=name + "_branch2c") + input=conv1, + num_filters=num_filters * 4, + filter_size=1, + act=None, + name=name + "_branch2c") - short = self.shortcut(input, num_filters * 4, stride, is_first=False, name=name + "_branch1") + short = self.shortcut( + input, + num_filters * 4, + stride, + is_first=False, + name=name + "_branch1") - return fluid.layers.elementwise_add(x=short, y=conv2, act='relu', name=name + ".add.output.5") + return fluid.layers.elementwise_add( + x=short, y=conv2, act='relu', name=name + ".add.output.5") def basic_block(self, input, num_filters, stride, is_first, name): - conv0 = self.conv_bn_layer(input=input, num_filters=num_filters, filter_size=3, act='relu', stride=stride, - name=name + "_branch2a") - conv1 = self.conv_bn_layer(input=conv0, num_filters=num_filters, filter_size=3, act=None, - name=name + "_branch2b") - short = self.shortcut(input, num_filters, stride, is_first, name=name + "_branch1") + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=3, + act='relu', + stride=stride, + name=name + "_branch2a") + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + act=None, + name=name + "_branch2b") + short = self.shortcut( + input, num_filters, stride, is_first, name=name + "_branch1") return fluid.layers.elementwise_add(x=short, y=conv1, act='relu') diff --git a/ppocr/utils/character.py b/ppocr/utils/character.py index c7c93fc557604a32d12343d929c119fd787ee126..b4b2021e02c9905623fd9fad5c9673543569c1c2 100755 --- a/ppocr/utils/character.py +++ b/ppocr/utils/character.py @@ -26,8 +26,6 @@ class CharacterOps(object): self.character_type = config['character_type'] self.loss_type = config['loss_type'] self.max_text_len = config['max_text_length'] - if self.loss_type == "srn" and self.character_type != "en": - raise Exception("SRN can only support in character_type == en") if self.character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) @@ -160,13 +158,15 @@ def cal_predicts_accuracy_srn(char_ops, acc_num = 0 img_num = 0 + char_num = char_ops.get_char_num() + total_len = preds.shape[0] img_num = int(total_len / max_text_len) for i in range(img_num): cur_label = [] cur_pred = [] for j in range(max_text_len): - if labels[j + i * max_text_len] != 37: #0 + if labels[j + i * max_text_len] != int(char_num-1): #0 cur_label.append(labels[j + i * max_text_len][0]) else: break @@ -178,7 +178,7 @@ def cal_predicts_accuracy_srn(char_ops, elif j == len(cur_label) and j == max_text_len: acc_num += 1 break - elif j == len(cur_label) and preds[j + i * max_text_len][0] == 37: + elif j == len(cur_label) and preds[j + i * max_text_len][0] == int(char_num-1): acc_num += 1 break acc = acc_num * 1.0 / img_num diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 7a81b3d4cedc26616fa1194baa9e4431c2176150..fd70cd66dccc2cb755efbd10c4d16c9f7a97146d 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -140,12 +140,12 @@ def main(): preds = preds.reshape(-1) preds_text = char_ops.decode(preds) elif loss_type == "srn": - cur_pred = [] + char_num = char_ops.get_char_num() preds = np.array(predict[0]) preds = preds.reshape(-1) probs = np.array(predict[1]) ind = np.argmax(probs, axis=1) - valid_ind = np.where(preds != 37)[0] + valid_ind = np.where(preds != int(char_num-1))[0] if len(valid_ind) == 0: continue score = np.mean(probs[valid_ind, ind[valid_ind]])