diff --git a/configs/rec/rec_resnet_stn_bilstm_att.yml b/configs/rec/rec_resnet_stn_bilstm_att.yml index 7b5a9c7117b1220c94b4ee3cf6036f67ffbc13a0..b18bb685739597ee2667008f3549c915b6ad3060 100644 --- a/configs/rec/rec_resnet_stn_bilstm_att.yml +++ b/configs/rec/rec_resnet_stn_bilstm_att.yml @@ -19,7 +19,6 @@ Global: max_text_length: 100 infer_mode: False use_space_char: False - eval_filter: True save_res_path: ./output/rec/predicts_seed.txt @@ -37,8 +36,8 @@ Optimizer: Architecture: - model_type: seed - algorithm: ASTER + model_type: rec + algorithm: seed Transform: name: STN_ON tps_inputsize: [32, 64] @@ -76,8 +75,10 @@ Train: img_mode: BGR channel_first: False - SEEDLabelEncode: # Class handling label - - SEEDResize: + - RecResizeImg: + character_type: en image_shape: [3, 64, 256] + padding: False - KeepKeys: keep_keys: ['image', 'label', 'length', 'fast_label'] # dataloader will return list in this order loader: @@ -95,8 +96,10 @@ Eval: img_mode: BGR channel_first: False - SEEDLabelEncode: # Class handling label - - SEEDResize: + - RecResizeImg: + character_type: en image_shape: [3, 64, 256] + padding: False - KeepKeys: keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order loader: diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 45bb2a1fcbd6ce25d5853c2ccdc47dab7b17e23f..f761eaf6691aadb5b8e0452341cfe06927677dc9 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -106,7 +106,6 @@ class BaseRecLabelEncode(object): self.max_text_len = max_text_length self.beg_str = "sos" self.end_str = "eos" - self.unknown = "UNKNOWN" if character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) @@ -357,7 +356,6 @@ class SEEDLabelEncode(BaseRecLabelEncode): character_type, use_space_char) def add_special_char(self, dict_character): - self.beg_str = "sos" self.end_str = "eos" dict_character = dict_character + [self.end_str] return dict_character diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 6904d38e6d97b6211623c276123a3149a605910b..71ed8976db7de24a489d1f75612a9a9a67995ba2 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -88,29 +88,19 @@ class RecResizeImg(object): image_shape, infer_mode=False, character_type='ch', + padding=True, **kwargs): self.image_shape = image_shape self.infer_mode = infer_mode self.character_type = character_type + self.padding = padding def __call__(self, data): img = data['image'] if self.infer_mode and self.character_type == "ch": norm_img = resize_norm_img_chinese(img, self.image_shape) else: - norm_img = resize_norm_img(img, self.image_shape) - data['image'] = norm_img - return data - - -class SEEDResize(object): - def __init__(self, image_shape, infer_mode=False, **kwargs): - self.image_shape = image_shape - self.infer_mode = infer_mode - - def __call__(self, data): - img = data['image'] - norm_img = resize_no_padding_img(img, self.image_shape) + norm_img = resize_norm_img(img, self.image_shape, self.padding) data['image'] = norm_img return data @@ -186,16 +176,21 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): return padding_im, resize_shape, pad_shape, valid_ratio -def resize_norm_img(img, image_shape): +def resize_norm_img(img, image_shape, padding=True): imgC, imgH, imgW = image_shape h = img.shape[0] w = img.shape[1] - ratio = w / float(h) - if math.ceil(imgH * ratio) > imgW: + if not padding: + resized_image = cv2.resize( + img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) resized_w = imgW else: - resized_w = int(math.ceil(imgH * ratio)) - resized_image = cv2.resize(img, (resized_w, imgH)) + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) resized_image = resized_image.astype('float32') if image_shape[0] == 1: resized_image = resized_image / 255 @@ -209,17 +204,6 @@ def resize_norm_img(img, image_shape): return padding_im -def resize_no_padding_img(img, image_shape): - imgC, imgH, imgW = image_shape - resized_image = cv2.resize( - img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) - resized_image = resized_image.astype('float32') - resized_image = resized_image.transpose((2, 0, 1)) / 255 - resized_image -= 0.5 - resized_image /= 0.5 - return resized_image - - def resize_norm_img_chinese(img, image_shape): imgC, imgH, imgW = image_shape # todo: change to 0 and modified image shape diff --git a/ppocr/modeling/transforms/__init__.py b/ppocr/modeling/transforms/__init__.py index 0e02a1c0cf88eefa3f7dde4f606b7db3016e4e1c..405ab3cc6c0380654f61e42e523ddc85839139b3 100755 --- a/ppocr/modeling/transforms/__init__.py +++ b/ppocr/modeling/transforms/__init__.py @@ -17,7 +17,7 @@ __all__ = ['build_transform'] def build_transform(config): from .tps import TPS - from .tps import STN_ON + from .stn import STN_ON support_dict = ['TPS', 'STN_ON'] diff --git a/ppocr/modeling/transforms/stn.py b/ppocr/modeling/transforms/stn.py index 23bd21891f3eb5ea794e4d46f8236010acc65b45..215895f4c4c719f407f4998f7429d965e0529ddc 100644 --- a/ppocr/modeling/transforms/stn.py +++ b/ppocr/modeling/transforms/stn.py @@ -22,6 +22,8 @@ from paddle import nn, ParamAttr from paddle.nn import functional as F import numpy as np +from .tps_spatial_transformer import TPSSpatialTransformer + def conv3x3_block(in_channels, out_channels, stride=1): n = 3 * 3 * out_channels @@ -106,3 +108,25 @@ class STN(nn.Layer): x = F.sigmoid(x) x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2]) return img_feat, x + + +class STN_ON(nn.Layer): + def __init__(self, in_channels, tps_inputsize, tps_outputsize, + num_control_points, tps_margins, stn_activation): + super(STN_ON, self).__init__() + self.tps = TPSSpatialTransformer( + output_image_size=tuple(tps_outputsize), + num_control_points=num_control_points, + margins=tuple(tps_margins)) + self.stn_head = STN(in_channels=in_channels, + num_ctrlpoints=num_control_points, + activation=stn_activation) + self.tps_inputsize = tps_inputsize + self.out_channels = in_channels + + def forward(self, image): + stn_input = paddle.nn.functional.interpolate( + image, self.tps_inputsize, mode="bilinear", align_corners=True) + stn_img_feat, ctrl_points = self.stn_head(stn_input) + x, _ = self.tps(image, ctrl_points) + return x diff --git a/ppocr/modeling/transforms/tps.py b/ppocr/modeling/transforms/tps.py index 81221b0351944ef8a054d44d3857a8c0ad61e052..6cd68555369dd1ddbd6ccf5236688a4b957b8525 100644 --- a/ppocr/modeling/transforms/tps.py +++ b/ppocr/modeling/transforms/tps.py @@ -22,9 +22,6 @@ from paddle import nn, ParamAttr from paddle.nn import functional as F import numpy as np -from .tps_spatial_transformer import TPSSpatialTransformer -from .stn import STN - class ConvBNLayer(nn.Layer): def __init__(self, @@ -305,25 +302,3 @@ class TPS(nn.Layer): [-1, image.shape[2], image.shape[3], 2]) batch_I_r = F.grid_sample(x=image, grid=batch_P_prime) return batch_I_r - - -class STN_ON(nn.Layer): - def __init__(self, in_channels, tps_inputsize, tps_outputsize, - num_control_points, tps_margins, stn_activation): - super(STN_ON, self).__init__() - self.tps = TPSSpatialTransformer( - output_image_size=tuple(tps_outputsize), - num_control_points=num_control_points, - margins=tuple(tps_margins)) - self.stn_head = STN(in_channels=in_channels, - num_ctrlpoints=num_control_points, - activation=stn_activation) - self.tps_inputsize = tps_inputsize - self.out_channels = in_channels - - def forward(self, image): - stn_input = paddle.nn.functional.interpolate( - image, self.tps_inputsize, mode="bilinear", align_corners=True) - stn_img_feat, ctrl_points = self.stn_head(stn_input) - x, _ = self.tps(image, ctrl_points) - return x diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 6c28e2b7a9b5459413db988bde33d57cacdd33c5..16f7f76596649e59e3818f1c785feff6535ef499 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -322,7 +322,6 @@ class SEEDLabelDecode(BaseRecLabelDecode): def add_special_char(self, dict_character): self.beg_str = "sos" self.end_str = "eos" - dict_character = dict_character dict_character = dict_character + [self.end_str] return dict_character diff --git a/requirements.txt b/requirements.txt index 0b2366c5cd344260d7afab811b27e19499a89b26..311030f65f2dc2dad4a51821e64f2777e7621a0b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ opencv-contrib-python==4.4.0.46 cython lxml premailer -openpyxl \ No newline at end of file +openpyxl +fasttext==0.9.1 \ No newline at end of file diff --git a/tools/program.py b/tools/program.py index 8a405d7d4f702efc0fa70f08302b331de5540b1b..8750dd9adcd51889bc1737985cad9f6fc2f8f4b3 100755 --- a/tools/program.py +++ b/tools/program.py @@ -186,9 +186,8 @@ def train(config, model.train() use_srn = config['Architecture']['algorithm'] == "SRN" - use_nrtr = config['Architecture']['algorithm'] == "NRTR" - use_sar = config['Architecture']['algorithm'] == 'SAR' - use_seed = config['Architecture']['algorithm'] == 'SEED' + extra_input = config['Architecture'][ + 'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"] try: model_type = config['Architecture']['model_type'] except: @@ -217,7 +216,7 @@ def train(config, images = batch[0] if use_srn: model_average = True - if use_srn or model_type == 'table' or use_nrtr or use_sar or use_seed: + if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) else: preds = model(images) @@ -281,8 +280,7 @@ def train(config, post_process_class, eval_class, model_type, - use_srn=use_srn, - use_sar=use_sar) + extra_input=extra_input) cur_metric_str = 'cur metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in cur_metric.items()])) logger.info(cur_metric_str) @@ -354,8 +352,7 @@ def eval(model, post_process_class, eval_class, model_type=None, - use_srn=False, - use_sar=False): + extra_input=False): model.eval() with paddle.no_grad(): total_frame = 0.0 @@ -368,7 +365,7 @@ def eval(model, break images = batch[0] start = time.time() - if use_srn or model_type == 'table' or use_sar: + if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) else: preds = model(images)