提交 a14f8da9 编写于 作者: T tink2123

polish seed code

上级 1effa5f3
......@@ -19,7 +19,6 @@ Global:
max_text_length: 100
infer_mode: False
use_space_char: False
eval_filter: True
save_res_path: ./output/rec/predicts_seed.txt
......@@ -37,8 +36,8 @@ Optimizer:
Architecture:
model_type: seed
algorithm: ASTER
model_type: rec
algorithm: seed
Transform:
name: STN_ON
tps_inputsize: [32, 64]
......@@ -76,8 +75,10 @@ Train:
img_mode: BGR
channel_first: False
- SEEDLabelEncode: # Class handling label
- SEEDResize:
- RecResizeImg:
character_type: en
image_shape: [3, 64, 256]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'length', 'fast_label'] # dataloader will return list in this order
loader:
......@@ -95,8 +96,10 @@ Eval:
img_mode: BGR
channel_first: False
- SEEDLabelEncode: # Class handling label
- SEEDResize:
- RecResizeImg:
character_type: en
image_shape: [3, 64, 256]
padding: False
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
......
......@@ -106,7 +106,6 @@ class BaseRecLabelEncode(object):
self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
self.unknown = "UNKNOWN"
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
......@@ -357,7 +356,6 @@ class SEEDLabelEncode(BaseRecLabelEncode):
character_type, use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
dict_character = dict_character + [self.end_str]
return dict_character
......
......@@ -88,29 +88,19 @@ class RecResizeImg(object):
image_shape,
infer_mode=False,
character_type='ch',
padding=True,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.character_type = character_type
self.padding = padding
def __call__(self, data):
img = data['image']
if self.infer_mode and self.character_type == "ch":
norm_img = resize_norm_img_chinese(img, self.image_shape)
else:
norm_img = resize_norm_img(img, self.image_shape)
data['image'] = norm_img
return data
class SEEDResize(object):
def __init__(self, image_shape, infer_mode=False, **kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
def __call__(self, data):
img = data['image']
norm_img = resize_no_padding_img(img, self.image_shape)
norm_img = resize_norm_img(img, self.image_shape, self.padding)
data['image'] = norm_img
return data
......@@ -186,10 +176,15 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
return padding_im, resize_shape, pad_shape, valid_ratio
def resize_norm_img(img, image_shape):
def resize_norm_img(img, image_shape, padding=True):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
if not padding:
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW
else:
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
......@@ -209,17 +204,6 @@ def resize_norm_img(img, image_shape):
return padding_im
def resize_no_padding_img(img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
return resized_image
def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
......
......@@ -17,7 +17,7 @@ __all__ = ['build_transform']
def build_transform(config):
from .tps import TPS
from .tps import STN_ON
from .stn import STN_ON
support_dict = ['TPS', 'STN_ON']
......
......@@ -22,6 +22,8 @@ from paddle import nn, ParamAttr
from paddle.nn import functional as F
import numpy as np
from .tps_spatial_transformer import TPSSpatialTransformer
def conv3x3_block(in_channels, out_channels, stride=1):
n = 3 * 3 * out_channels
......@@ -106,3 +108,25 @@ class STN(nn.Layer):
x = F.sigmoid(x)
x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2])
return img_feat, x
class STN_ON(nn.Layer):
def __init__(self, in_channels, tps_inputsize, tps_outputsize,
num_control_points, tps_margins, stn_activation):
super(STN_ON, self).__init__()
self.tps = TPSSpatialTransformer(
output_image_size=tuple(tps_outputsize),
num_control_points=num_control_points,
margins=tuple(tps_margins))
self.stn_head = STN(in_channels=in_channels,
num_ctrlpoints=num_control_points,
activation=stn_activation)
self.tps_inputsize = tps_inputsize
self.out_channels = in_channels
def forward(self, image):
stn_input = paddle.nn.functional.interpolate(
image, self.tps_inputsize, mode="bilinear", align_corners=True)
stn_img_feat, ctrl_points = self.stn_head(stn_input)
x, _ = self.tps(image, ctrl_points)
return x
......@@ -22,9 +22,6 @@ from paddle import nn, ParamAttr
from paddle.nn import functional as F
import numpy as np
from .tps_spatial_transformer import TPSSpatialTransformer
from .stn import STN
class ConvBNLayer(nn.Layer):
def __init__(self,
......@@ -305,25 +302,3 @@ class TPS(nn.Layer):
[-1, image.shape[2], image.shape[3], 2])
batch_I_r = F.grid_sample(x=image, grid=batch_P_prime)
return batch_I_r
class STN_ON(nn.Layer):
def __init__(self, in_channels, tps_inputsize, tps_outputsize,
num_control_points, tps_margins, stn_activation):
super(STN_ON, self).__init__()
self.tps = TPSSpatialTransformer(
output_image_size=tuple(tps_outputsize),
num_control_points=num_control_points,
margins=tuple(tps_margins))
self.stn_head = STN(in_channels=in_channels,
num_ctrlpoints=num_control_points,
activation=stn_activation)
self.tps_inputsize = tps_inputsize
self.out_channels = in_channels
def forward(self, image):
stn_input = paddle.nn.functional.interpolate(
image, self.tps_inputsize, mode="bilinear", align_corners=True)
stn_img_feat, ctrl_points = self.stn_head(stn_input)
x, _ = self.tps(image, ctrl_points)
return x
......@@ -322,7 +322,6 @@ class SEEDLabelDecode(BaseRecLabelDecode):
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
dict_character = dict_character
dict_character = dict_character + [self.end_str]
return dict_character
......
......@@ -12,3 +12,4 @@ cython
lxml
premailer
openpyxl
fasttext==0.9.1
\ No newline at end of file
......@@ -186,9 +186,8 @@ def train(config,
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
use_sar = config['Architecture']['algorithm'] == 'SAR'
use_seed = config['Architecture']['algorithm'] == 'SEED'
extra_input = config['Architecture'][
'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
try:
model_type = config['Architecture']['model_type']
except:
......@@ -217,7 +216,7 @@ def train(config,
images = batch[0]
if use_srn:
model_average = True
if use_srn or model_type == 'table' or use_nrtr or use_sar or use_seed:
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
else:
preds = model(images)
......@@ -281,8 +280,7 @@ def train(config,
post_process_class,
eval_class,
model_type,
use_srn=use_srn,
use_sar=use_sar)
extra_input=extra_input)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str)
......@@ -354,8 +352,7 @@ def eval(model,
post_process_class,
eval_class,
model_type=None,
use_srn=False,
use_sar=False):
extra_input=False):
model.eval()
with paddle.no_grad():
total_frame = 0.0
......@@ -368,7 +365,7 @@ def eval(model,
break
images = batch[0]
start = time.time()
if use_srn or model_type == 'table' or use_sar:
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
else:
preds = model(images)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册