提交 a14f8da9 编写于 作者: T tink2123

polish seed code

上级 1effa5f3
...@@ -19,7 +19,6 @@ Global: ...@@ -19,7 +19,6 @@ Global:
max_text_length: 100 max_text_length: 100
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
eval_filter: True
save_res_path: ./output/rec/predicts_seed.txt save_res_path: ./output/rec/predicts_seed.txt
...@@ -37,8 +36,8 @@ Optimizer: ...@@ -37,8 +36,8 @@ Optimizer:
Architecture: Architecture:
model_type: seed model_type: rec
algorithm: ASTER algorithm: seed
Transform: Transform:
name: STN_ON name: STN_ON
tps_inputsize: [32, 64] tps_inputsize: [32, 64]
...@@ -76,8 +75,10 @@ Train: ...@@ -76,8 +75,10 @@ Train:
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- SEEDLabelEncode: # Class handling label - SEEDLabelEncode: # Class handling label
- SEEDResize: - RecResizeImg:
character_type: en
image_shape: [3, 64, 256] image_shape: [3, 64, 256]
padding: False
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'label', 'length', 'fast_label'] # dataloader will return list in this order keep_keys: ['image', 'label', 'length', 'fast_label'] # dataloader will return list in this order
loader: loader:
...@@ -95,8 +96,10 @@ Eval: ...@@ -95,8 +96,10 @@ Eval:
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- SEEDLabelEncode: # Class handling label - SEEDLabelEncode: # Class handling label
- SEEDResize: - RecResizeImg:
character_type: en
image_shape: [3, 64, 256] image_shape: [3, 64, 256]
padding: False
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader: loader:
......
...@@ -106,7 +106,6 @@ class BaseRecLabelEncode(object): ...@@ -106,7 +106,6 @@ class BaseRecLabelEncode(object):
self.max_text_len = max_text_length self.max_text_len = max_text_length
self.beg_str = "sos" self.beg_str = "sos"
self.end_str = "eos" self.end_str = "eos"
self.unknown = "UNKNOWN"
if character_type == "en": if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
...@@ -357,7 +356,6 @@ class SEEDLabelEncode(BaseRecLabelEncode): ...@@ -357,7 +356,6 @@ class SEEDLabelEncode(BaseRecLabelEncode):
character_type, use_space_char) character_type, use_space_char)
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos" self.end_str = "eos"
dict_character = dict_character + [self.end_str] dict_character = dict_character + [self.end_str]
return dict_character return dict_character
......
...@@ -88,29 +88,19 @@ class RecResizeImg(object): ...@@ -88,29 +88,19 @@ class RecResizeImg(object):
image_shape, image_shape,
infer_mode=False, infer_mode=False,
character_type='ch', character_type='ch',
padding=True,
**kwargs): **kwargs):
self.image_shape = image_shape self.image_shape = image_shape
self.infer_mode = infer_mode self.infer_mode = infer_mode
self.character_type = character_type self.character_type = character_type
self.padding = padding
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
if self.infer_mode and self.character_type == "ch": if self.infer_mode and self.character_type == "ch":
norm_img = resize_norm_img_chinese(img, self.image_shape) norm_img = resize_norm_img_chinese(img, self.image_shape)
else: else:
norm_img = resize_norm_img(img, self.image_shape) norm_img = resize_norm_img(img, self.image_shape, self.padding)
data['image'] = norm_img
return data
class SEEDResize(object):
def __init__(self, image_shape, infer_mode=False, **kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
def __call__(self, data):
img = data['image']
norm_img = resize_no_padding_img(img, self.image_shape)
data['image'] = norm_img data['image'] = norm_img
return data return data
...@@ -186,16 +176,21 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): ...@@ -186,16 +176,21 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
return padding_im, resize_shape, pad_shape, valid_ratio return padding_im, resize_shape, pad_shape, valid_ratio
def resize_norm_img(img, image_shape): def resize_norm_img(img, image_shape, padding=True):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
h = img.shape[0] h = img.shape[0]
w = img.shape[1] w = img.shape[1]
ratio = w / float(h) if not padding:
if math.ceil(imgH * ratio) > imgW: resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW resized_w = imgW
else: else:
resized_w = int(math.ceil(imgH * ratio)) ratio = w / float(h)
resized_image = cv2.resize(img, (resized_w, imgH)) if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32') resized_image = resized_image.astype('float32')
if image_shape[0] == 1: if image_shape[0] == 1:
resized_image = resized_image / 255 resized_image = resized_image / 255
...@@ -209,17 +204,6 @@ def resize_norm_img(img, image_shape): ...@@ -209,17 +204,6 @@ def resize_norm_img(img, image_shape):
return padding_im return padding_im
def resize_no_padding_img(img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
return resized_image
def resize_norm_img_chinese(img, image_shape): def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape # todo: change to 0 and modified image shape
......
...@@ -17,7 +17,7 @@ __all__ = ['build_transform'] ...@@ -17,7 +17,7 @@ __all__ = ['build_transform']
def build_transform(config): def build_transform(config):
from .tps import TPS from .tps import TPS
from .tps import STN_ON from .stn import STN_ON
support_dict = ['TPS', 'STN_ON'] support_dict = ['TPS', 'STN_ON']
......
...@@ -22,6 +22,8 @@ from paddle import nn, ParamAttr ...@@ -22,6 +22,8 @@ from paddle import nn, ParamAttr
from paddle.nn import functional as F from paddle.nn import functional as F
import numpy as np import numpy as np
from .tps_spatial_transformer import TPSSpatialTransformer
def conv3x3_block(in_channels, out_channels, stride=1): def conv3x3_block(in_channels, out_channels, stride=1):
n = 3 * 3 * out_channels n = 3 * 3 * out_channels
...@@ -106,3 +108,25 @@ class STN(nn.Layer): ...@@ -106,3 +108,25 @@ class STN(nn.Layer):
x = F.sigmoid(x) x = F.sigmoid(x)
x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2]) x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2])
return img_feat, x return img_feat, x
class STN_ON(nn.Layer):
def __init__(self, in_channels, tps_inputsize, tps_outputsize,
num_control_points, tps_margins, stn_activation):
super(STN_ON, self).__init__()
self.tps = TPSSpatialTransformer(
output_image_size=tuple(tps_outputsize),
num_control_points=num_control_points,
margins=tuple(tps_margins))
self.stn_head = STN(in_channels=in_channels,
num_ctrlpoints=num_control_points,
activation=stn_activation)
self.tps_inputsize = tps_inputsize
self.out_channels = in_channels
def forward(self, image):
stn_input = paddle.nn.functional.interpolate(
image, self.tps_inputsize, mode="bilinear", align_corners=True)
stn_img_feat, ctrl_points = self.stn_head(stn_input)
x, _ = self.tps(image, ctrl_points)
return x
...@@ -22,9 +22,6 @@ from paddle import nn, ParamAttr ...@@ -22,9 +22,6 @@ from paddle import nn, ParamAttr
from paddle.nn import functional as F from paddle.nn import functional as F
import numpy as np import numpy as np
from .tps_spatial_transformer import TPSSpatialTransformer
from .stn import STN
class ConvBNLayer(nn.Layer): class ConvBNLayer(nn.Layer):
def __init__(self, def __init__(self,
...@@ -305,25 +302,3 @@ class TPS(nn.Layer): ...@@ -305,25 +302,3 @@ class TPS(nn.Layer):
[-1, image.shape[2], image.shape[3], 2]) [-1, image.shape[2], image.shape[3], 2])
batch_I_r = F.grid_sample(x=image, grid=batch_P_prime) batch_I_r = F.grid_sample(x=image, grid=batch_P_prime)
return batch_I_r return batch_I_r
class STN_ON(nn.Layer):
def __init__(self, in_channels, tps_inputsize, tps_outputsize,
num_control_points, tps_margins, stn_activation):
super(STN_ON, self).__init__()
self.tps = TPSSpatialTransformer(
output_image_size=tuple(tps_outputsize),
num_control_points=num_control_points,
margins=tuple(tps_margins))
self.stn_head = STN(in_channels=in_channels,
num_ctrlpoints=num_control_points,
activation=stn_activation)
self.tps_inputsize = tps_inputsize
self.out_channels = in_channels
def forward(self, image):
stn_input = paddle.nn.functional.interpolate(
image, self.tps_inputsize, mode="bilinear", align_corners=True)
stn_img_feat, ctrl_points = self.stn_head(stn_input)
x, _ = self.tps(image, ctrl_points)
return x
...@@ -322,7 +322,6 @@ class SEEDLabelDecode(BaseRecLabelDecode): ...@@ -322,7 +322,6 @@ class SEEDLabelDecode(BaseRecLabelDecode):
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
self.beg_str = "sos" self.beg_str = "sos"
self.end_str = "eos" self.end_str = "eos"
dict_character = dict_character
dict_character = dict_character + [self.end_str] dict_character = dict_character + [self.end_str]
return dict_character return dict_character
......
...@@ -11,4 +11,5 @@ opencv-contrib-python==4.4.0.46 ...@@ -11,4 +11,5 @@ opencv-contrib-python==4.4.0.46
cython cython
lxml lxml
premailer premailer
openpyxl openpyxl
\ No newline at end of file fasttext==0.9.1
\ No newline at end of file
...@@ -186,9 +186,8 @@ def train(config, ...@@ -186,9 +186,8 @@ def train(config,
model.train() model.train()
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
use_nrtr = config['Architecture']['algorithm'] == "NRTR" extra_input = config['Architecture'][
use_sar = config['Architecture']['algorithm'] == 'SAR' 'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
use_seed = config['Architecture']['algorithm'] == 'SEED'
try: try:
model_type = config['Architecture']['model_type'] model_type = config['Architecture']['model_type']
except: except:
...@@ -217,7 +216,7 @@ def train(config, ...@@ -217,7 +216,7 @@ def train(config,
images = batch[0] images = batch[0]
if use_srn: if use_srn:
model_average = True model_average = True
if use_srn or model_type == 'table' or use_nrtr or use_sar or use_seed: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
else: else:
preds = model(images) preds = model(images)
...@@ -281,8 +280,7 @@ def train(config, ...@@ -281,8 +280,7 @@ def train(config,
post_process_class, post_process_class,
eval_class, eval_class,
model_type, model_type,
use_srn=use_srn, extra_input=extra_input)
use_sar=use_sar)
cur_metric_str = 'cur metric, {}'.format(', '.join( cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()])) ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str) logger.info(cur_metric_str)
...@@ -354,8 +352,7 @@ def eval(model, ...@@ -354,8 +352,7 @@ def eval(model,
post_process_class, post_process_class,
eval_class, eval_class,
model_type=None, model_type=None,
use_srn=False, extra_input=False):
use_sar=False):
model.eval() model.eval()
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
...@@ -368,7 +365,7 @@ def eval(model, ...@@ -368,7 +365,7 @@ def eval(model,
break break
images = batch[0] images = batch[0]
start = time.time() start = time.time()
if use_srn or model_type == 'table' or use_sar: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
else: else:
preds = model(images) preds = model(images)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册