未验证 提交 612e8014 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #537 from tink2123/add_srn

Add SRN
......@@ -122,7 +122,10 @@ PaddleOCR开源的文本识别算法列表:
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))(百度自研, coming soon)
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))(百度自研)
*备注:* SRN模型使用了数据扰动方法对上述提到对两个训练集进行增广,增广后的数据可以在[百度网盘](todo)上下载。
原始论文使用两阶段训练平均精度为89.74%,PaddleOCR中使用one-stage训练,平均精度为88.33%。两种预训练权重均在[下载链接](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)中。
参考[DTRB](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
......@@ -136,6 +139,7 @@ PaddleOCR开源的文本识别算法列表:
|STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_ctc.tar)|
|RARE|Resnet34_vd|84.90%|rec_r34_vd_tps_bilstm_attn|[下载链接](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_attn.tar)|
|RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|[下载链接](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar)|
|SRN|Resnet50_vd_fpn|88.33%|rec_r50fpn_vd_none_srn|[下载链接](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)|
使用[LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/datasets.md#1icdar2019-lsvt)街景数据集根据真值将图crop出来30w数据,进行位置校准。此外基于LSVT语料生成500w合成数据训练中文模型,相关配置和预训练文件如下:
......
Global:
algorithm: SRN
use_gpu: true
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
save_model_dir: output/rec_pvam_withrotate
save_epoch_step: 1
eval_batch_step: 8000
train_batch_size_per_card: 64
test_batch_size_per_card: 1
image_shape: [1, 64, 256]
max_text_length: 25
character_type: en
loss_type: srn
num_heads: 8
average_window: 0.15
max_average_window: 15625
min_average_window: 10000
reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights:
checkpoints:
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
Backbone:
function: ppocr.modeling.backbones.rec_resnet50_fpn,ResNet
layers: 50
Head:
function: ppocr.modeling.heads.rec_srn_all_head,SRNPredict
encoder_type: rnn
num_encoder_TUs: 2
num_decoder_TUs: 4
hidden_dims: 512
SeqRNN:
hidden_size: 256
Loss:
function: ppocr.modeling.losses.rec_srn_loss,SRNLoss
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.0001
beta1: 0.9
beta2: 0.999
......@@ -32,6 +32,9 @@
| loss_type | 设置 loss 类型 | ctc | 支持两种loss: ctc / attention |
| distort | 设置是否使用数据增强 | false | 设置为true时,将在训练时随机进行扰动,支持的扰动操作可阅读[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py) |
| use_space_char | 设置是否识别空格 | false | 仅在 character_type=ch 时支持空格 |
| average_window | ModelAverage优化器中的窗口长度计算比例 | 0.15 | 目前仅应用与SRN |
| max_average_window | 平均值计算窗口长度的最大值 | 15625 | 推荐设置为一轮训练中mini-batchs的数目|
| min_average_window | 平均值计算窗口长度的最小值 | 10000 | \ |
| reader_yml | 设置reader配置文件 | ./configs/rec/rec_icdar15_reader.yml | \ |
| pretrain_weights | 加载预训练模型路径 | ./pretrain_models/CRNN/best_accuracy | \ |
| checkpoints | 加载模型参数路径 | None | 用于中断后加载参数继续训练 |
......
......@@ -26,7 +26,7 @@ from ppocr.utils.utility import initial_logger
from ppocr.utils.utility import get_image_file_list
logger = initial_logger()
from .img_tools import process_image, get_img_data
from .img_tools import process_image, process_image_srn, get_img_data
class LMDBReader(object):
......@@ -43,6 +43,9 @@ class LMDBReader(object):
self.mode = params['mode']
self.drop_last = False
self.use_tps = False
self.num_heads = None
if "num_heads" in params:
self.num_heads = params['num_heads']
if "tps" in params:
self.ues_tps = True
self.use_distort = False
......@@ -119,12 +122,19 @@ class LMDBReader(object):
img = cv2.imread(single_img)
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops,
tps=self.use_tps,
infer_mode=True)
if self.loss_type == 'srn':
norm_img = process_image_srn(
img=img,
image_shape=self.image_shape,
num_heads=self.num_heads,
max_text_length=self.max_text_length)
else:
norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops,
tps=self.use_tps,
infer_mode=True)
yield norm_img
else:
lmdb_sets = self.load_hierarchical_lmdb_dataset()
......@@ -144,14 +154,25 @@ class LMDBReader(object):
if sample_info is None:
continue
img, label = sample_info
outs = process_image(
img=img,
image_shape=self.image_shape,
label=label,
char_ops=self.char_ops,
loss_type=self.loss_type,
max_text_length=self.max_text_length,
distort=self.use_distort)
outs = []
if self.loss_type == "srn":
outs = process_image_srn(
img=img,
image_shape=self.image_shape,
num_heads=self.num_heads,
max_text_length=self.max_text_length,
label=label,
char_ops=self.char_ops,
loss_type=self.loss_type)
else:
outs = process_image(
img=img,
image_shape=self.image_shape,
label=label,
char_ops=self.char_ops,
loss_type=self.loss_type,
max_text_length=self.max_text_length)
if outs is None:
continue
yield outs
......
......@@ -381,3 +381,84 @@ def process_image(img,
assert False, "Unsupport loss_type %s in process_image"\
% loss_type
return (norm_img)
def resize_norm_img_srn(img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(image_shape,
num_heads,
max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype('int64')
lbl_weight = np.array([37] * max_text_length).reshape((-1,1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [1, num_heads, 1, 1]) * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape([-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [1, num_heads, 1, 1]) * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2]
def process_image_srn(img,
image_shape,
num_heads,
max_text_length,
label=None,
char_ops=None,
loss_type=None):
norm_img = resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
[lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
srn_other_inputs(image_shape, num_heads, max_text_length)
if label is not None:
char_num = char_ops.get_char_num()
text = char_ops.encode(label)
if len(text) == 0 or len(text) > max_text_length:
return None
else:
if loss_type == "srn":
text_padded = [37] * max_text_length
for i in range(len(text)):
text_padded[i] = text[i]
lbl_weight[i] = [1.0]
text_padded = np.array(text_padded)
text = text_padded.reshape(-1, 1)
return (norm_img, text,encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2,lbl_weight)
else:
assert False, "Unsupport loss_type %s in process_image"\
% loss_type
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2)
......@@ -58,6 +58,10 @@ class RecModel(object):
self.loss_type = global_params['loss_type']
self.image_shape = global_params['image_shape']
self.max_text_length = global_params['max_text_length']
if "num_heads" in params:
self.num_heads = global_params["num_heads"]
else:
self.num_heads = None
def create_feed(self, mode):
image_shape = deepcopy(self.image_shape)
......@@ -77,6 +81,48 @@ class RecModel(object):
lod_level=1)
feed_list = [image, label_in, label_out]
labels = {'label_in': label_in, 'label_out': label_out}
elif self.loss_type == "srn":
encoder_word_pos = fluid.data(
name="encoder_word_pos",
shape=[
-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)),
1
],
dtype="int64")
gsrm_word_pos = fluid.data(
name="gsrm_word_pos",
shape=[-1, self.max_text_length, 1],
dtype="int64")
gsrm_slf_attn_bias1 = fluid.data(
name="gsrm_slf_attn_bias1",
shape=[
-1, self.num_heads, self.max_text_length,
self.max_text_length
],
dtype="float32")
gsrm_slf_attn_bias2 = fluid.data(
name="gsrm_slf_attn_bias2",
shape=[
-1, self.num_heads, self.max_text_length,
self.max_text_length
],
dtype="float32")
lbl_weight = fluid.layers.data(
name="lbl_weight", shape=[-1, 1], dtype='int64')
label = fluid.data(
name='label', shape=[-1, 1], dtype='int32', lod_level=1)
feed_list = [
image, label, encoder_word_pos, gsrm_word_pos,
gsrm_slf_attn_bias1, gsrm_slf_attn_bias2, lbl_weight
]
labels = {
'label': label,
'encoder_word_pos': encoder_word_pos,
'gsrm_word_pos': gsrm_word_pos,
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2,
'lbl_weight': lbl_weight
}
else:
label = fluid.data(
name='label', shape=[None, 1], dtype='int32', lod_level=1)
......@@ -88,6 +134,8 @@ class RecModel(object):
use_double_buffer=True,
iterable=False)
else:
labels = None
loader = None
if self.char_type == "ch" and self.infer_img:
image_shape[-1] = -1
if self.tps != None:
......@@ -98,8 +146,42 @@ class RecModel(object):
)
image_shape = deepcopy(self.image_shape)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
labels = None
loader = None
if self.loss_type == "srn":
encoder_word_pos = fluid.data(
name="encoder_word_pos",
shape=[
-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)),
1
],
dtype="int64")
gsrm_word_pos = fluid.data(
name="gsrm_word_pos",
shape=[-1, self.max_text_length, 1],
dtype="int64")
gsrm_slf_attn_bias1 = fluid.data(
name="gsrm_slf_attn_bias1",
shape=[
-1, self.num_heads, self.max_text_length,
self.max_text_length
],
dtype="float32")
gsrm_slf_attn_bias2 = fluid.data(
name="gsrm_slf_attn_bias2",
shape=[
-1, self.num_heads, self.max_text_length,
self.max_text_length
],
dtype="float32")
feed_list = [
image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
labels = {
'encoder_word_pos': encoder_word_pos,
'gsrm_word_pos': gsrm_word_pos,
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2
}
return image, labels, loader
def __call__(self, mode):
......@@ -117,13 +199,27 @@ class RecModel(object):
label = labels['label_out']
else:
label = labels['label']
outputs = {'total_loss':loss, 'decoded_out':\
decoded_out, 'label':label}
if self.loss_type == 'srn':
total_loss, img_loss, word_loss = self.loss(predicts, labels)
outputs = {
'total_loss': total_loss,
'img_loss': img_loss,
'word_loss': word_loss,
'decoded_out': decoded_out,
'label': label
}
else:
outputs = {'total_loss':loss, 'decoded_out':\
decoded_out, 'label':label}
return loader, outputs
elif mode == "export":
predict = predicts['predict']
if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict)
if self.loss_type == "srn":
raise Exception(
"Warning! SRN does not support export model currently")
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
else:
predict = predicts['predict']
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = ["ResNet", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]
Trainable = True
w_nolr = fluid.ParamAttr(
trainable = Trainable)
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class ResNet():
def __init__(self, params):
self.layers = params['layers']
self.params = train_parameters
def __call__(self, input):
layers = self.layers
supported_layers = [18, 34, 50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
stride_list = [(2,2),(2,2),(1,1),(1,1)]
num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu', name="conv1")
F = []
if layers >= 50:
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=stride_list[block] if i == 0 else 1, name=conv_name)
F.append(conv)
base = F[-1]
for i in [-2, -3]:
b, c, w, h = F[i].shape
if (w,h) == base.shape[2:]:
base = base
else:
base = fluid.layers.conv2d_transpose( input=base, num_filters=c,filter_size=4, stride=2,
padding=1,act=None,
param_attr=w_nolr,
bias_attr=w_nolr)
base = fluid.layers.batch_norm(base, act = "relu", param_attr=w_nolr, bias_attr=w_nolr)
base = fluid.layers.concat([base, F[i]], axis=1)
base = fluid.layers.conv2d(base, num_filters=c, filter_size=1, param_attr=w_nolr, bias_attr=w_nolr)
base = fluid.layers.conv2d(base, num_filters=c, filter_size=3,padding = 1, param_attr=w_nolr, bias_attr=w_nolr)
base = fluid.layers.batch_norm(base, act = "relu", param_attr=w_nolr, bias_attr=w_nolr)
base = fluid.layers.conv2d(base, num_filters=512, filter_size=1,bias_attr=w_nolr,param_attr=w_nolr)
return base
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size= 2 if stride==(1,1) else filter_size,
dilation = 2 if stride==(1,1) else 1,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights",trainable = Trainable),
bias_attr=False,
name=name + '.conv2d.output.1')
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(input=conv,
act=act,
name=bn_name + '.output.1',
param_attr=ParamAttr(name=bn_name + '_scale',trainable = Trainable),
bias_attr=ParamAttr(bn_name + '_offset',trainable = Trainable),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', )
def shortcut(self, input, ch_out, stride, is_first, name):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1 or is_first == True:
if stride == (1,1):
return self.conv_bn_layer(input, ch_out, 1, 1, name=name)
else: #stride == (2,2)
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, name):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu', name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, name=name + "_branch2c")
short = self.shortcut(input, num_filters * 4, stride, is_first=False, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu', name=name + ".add.output.5")
def basic_block(self, input, num_filters, stride, is_first, name):
conv0 = self.conv_bn_layer(input=input, num_filters=num_filters, filter_size=3, act='relu', stride=stride,
name=name + "_branch2a")
conv1 = self.conv_bn_layer(input=conv0, num_filters=num_filters, filter_size=3, act=None,
name=name + "_branch2b")
short = self.shortcut(input, num_filters, stride, is_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
import numpy as np
from .self_attention.model import wrap_encoder
from .self_attention.model import wrap_encoder_forFeature
gradient_clip = 10
class SRNPredict(object):
def __init__(self, params):
super(SRNPredict, self).__init__()
self.char_num = params['char_num']
self.max_length = params['max_text_length']
self.num_heads = params['num_heads']
self.num_encoder_TUs = params['num_encoder_TUs']
self.num_decoder_TUs = params['num_decoder_TUs']
self.hidden_dims = params['hidden_dims']
def pvam(self, inputs, others):
b, c, h, w = inputs.shape
conv_features = fluid.layers.reshape(x=inputs, shape=[-1, c, h * w])
conv_features = fluid.layers.transpose(x=conv_features, perm=[0, 2, 1])
#===== Transformer encoder =====
b, t, c = conv_features.shape
encoder_word_pos = others["encoder_word_pos"]
gsrm_word_pos = others["gsrm_word_pos"]
enc_inputs = [conv_features, encoder_word_pos, None]
word_features = wrap_encoder_forFeature(
src_vocab_size=-1,
max_length=t,
n_layer=self.num_encoder_TUs,
n_head=self.num_heads,
d_key=int(self.hidden_dims / self.num_heads),
d_value=int(self.hidden_dims / self.num_heads),
d_model=self.hidden_dims,
d_inner_hid=self.hidden_dims,
prepostprocess_dropout=0.1,
attention_dropout=0.1,
relu_dropout=0.1,
preprocess_cmd="n",
postprocess_cmd="da",
weight_sharing=True,
enc_inputs=enc_inputs, )
fluid.clip.set_gradient_clip(
fluid.clip.GradientClipByValue(gradient_clip))
#===== Parallel Visual Attention Module =====
b, t, c = word_features.shape
word_features = fluid.layers.fc(word_features, c, num_flatten_dims=2)
word_features_ = fluid.layers.reshape(word_features, [-1, 1, t, c])
word_features_ = fluid.layers.expand(word_features_,
[1, self.max_length, 1, 1])
word_pos_feature = fluid.layers.embedding(gsrm_word_pos,
[self.max_length, c])
word_pos_ = fluid.layers.reshape(word_pos_feature,
[-1, self.max_length, 1, c])
word_pos_ = fluid.layers.expand(word_pos_, [1, 1, t, 1])
temp = fluid.layers.elementwise_add(
word_features_, word_pos_, act='tanh')
attention_weight = fluid.layers.fc(input=temp,
size=1,
num_flatten_dims=3,
bias_attr=False)
attention_weight = fluid.layers.reshape(
x=attention_weight, shape=[-1, self.max_length, t])
attention_weight = fluid.layers.softmax(input=attention_weight, axis=-1)
pvam_features = fluid.layers.matmul(attention_weight,
word_features) #[b, max_length, c]
return pvam_features
def gsrm(self, pvam_features, others):
#===== GSRM Visual-to-semantic embedding block =====
b, t, c = pvam_features.shape
word_out = fluid.layers.fc(
input=fluid.layers.reshape(pvam_features, [-1, c]),
size=self.char_num,
act="softmax")
#word_out.stop_gradient = True
word_ids = fluid.layers.argmax(word_out, axis=1)
word_ids.stop_gradient = True
word_ids = fluid.layers.reshape(x=word_ids, shape=[-1, t, 1])
#===== GSRM Semantic reasoning block =====
"""
This module is achieved through bi-transformers,
ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
"""
pad_idx = self.char_num
gsrm_word_pos = others["gsrm_word_pos"]
gsrm_slf_attn_bias1 = others["gsrm_slf_attn_bias1"]
gsrm_slf_attn_bias2 = others["gsrm_slf_attn_bias2"]
def prepare_bi(word_ids):
"""
prepare bi for gsrm
word1 for forward; word2 for backward
"""
word1 = fluid.layers.cast(word_ids, "float32")
word1 = fluid.layers.pad(word1, [0, 0, 1, 0, 0, 0],
pad_value=1.0 * pad_idx)
word1 = fluid.layers.cast(word1, "int64")
word1 = word1[:, :-1, :]
word2 = word_ids
return word1, word2
word1, word2 = prepare_bi(word_ids)
word1.stop_gradient = True
word2.stop_gradient = True
enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1]
enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2]
gsrm_feature1 = wrap_encoder(
src_vocab_size=self.char_num + 1,
max_length=self.max_length,
n_layer=self.num_decoder_TUs,
n_head=self.num_heads,
d_key=int(self.hidden_dims / self.num_heads),
d_value=int(self.hidden_dims / self.num_heads),
d_model=self.hidden_dims,
d_inner_hid=self.hidden_dims,
prepostprocess_dropout=0.1,
attention_dropout=0.1,
relu_dropout=0.1,
preprocess_cmd="n",
postprocess_cmd="da",
weight_sharing=True,
enc_inputs=enc_inputs_1, )
gsrm_feature2 = wrap_encoder(
src_vocab_size=self.char_num + 1,
max_length=self.max_length,
n_layer=self.num_decoder_TUs,
n_head=self.num_heads,
d_key=int(self.hidden_dims / self.num_heads),
d_value=int(self.hidden_dims / self.num_heads),
d_model=self.hidden_dims,
d_inner_hid=self.hidden_dims,
prepostprocess_dropout=0.1,
attention_dropout=0.1,
relu_dropout=0.1,
preprocess_cmd="n",
postprocess_cmd="da",
weight_sharing=True,
enc_inputs=enc_inputs_2, )
gsrm_feature2 = fluid.layers.pad(gsrm_feature2, [0, 0, 0, 1, 0, 0],
pad_value=0.)
gsrm_feature2 = gsrm_feature2[:, 1:, ]
gsrm_features = gsrm_feature1 + gsrm_feature2
b, t, c = gsrm_features.shape
gsrm_out = fluid.layers.matmul(
x=gsrm_features,
y=fluid.default_main_program().global_block().var(
"src_word_emb_table"),
transpose_y=True)
b, t, c = gsrm_out.shape
gsrm_out = fluid.layers.softmax(input=fluid.layers.reshape(gsrm_out,
[-1, c]))
return gsrm_features, word_out, gsrm_out
def vsfd(self, pvam_features, gsrm_features):
#===== Visual-Semantic Fusion Decoder Module =====
b, t, c1 = pvam_features.shape
b, t, c2 = gsrm_features.shape
combine_features_ = fluid.layers.concat(
[pvam_features, gsrm_features], axis=2)
img_comb_features_ = fluid.layers.reshape(
x=combine_features_, shape=[-1, c1 + c2])
img_comb_features_map = fluid.layers.fc(input=img_comb_features_,
size=c1,
act="sigmoid")
img_comb_features_map = fluid.layers.reshape(
x=img_comb_features_map, shape=[-1, t, c1])
combine_features = img_comb_features_map * pvam_features + (
1.0 - img_comb_features_map) * gsrm_features
img_comb_features = fluid.layers.reshape(
x=combine_features, shape=[-1, c1])
fc_out = fluid.layers.fc(input=img_comb_features,
size=self.char_num,
act="softmax")
return fc_out
def __call__(self, inputs, others, mode=None):
pvam_features = self.pvam(inputs, others)
gsrm_features, word_out, gsrm_out = self.gsrm(pvam_features, others)
final_out = self.vsfd(pvam_features, gsrm_features)
_, decoded_out = fluid.layers.topk(input=final_out, k=1)
predicts = {
'predict': final_out,
'decoded_out': decoded_out,
'word_out': word_out,
'gsrm_out': gsrm_out
}
return predicts
此差异已折叠。
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
class SRNLoss(object):
def __init__(self, params):
super(SRNLoss, self).__init__()
self.char_num = params['char_num']
def __call__(self, predicts, others):
predict = predicts['predict']
word_predict = predicts['word_out']
gsrm_predict = predicts['gsrm_out']
label = others['label']
lbl_weight = others['lbl_weight']
casted_label = fluid.layers.cast(x=label, dtype='int64')
cost_word = fluid.layers.cross_entropy(
input=word_predict, label=casted_label)
cost_gsrm = fluid.layers.cross_entropy(
input=gsrm_predict, label=casted_label)
cost_vsfd = fluid.layers.cross_entropy(
input=predict, label=casted_label)
cost_word = fluid.layers.reshape(
x=fluid.layers.reduce_sum(cost_word), shape=[1])
cost_gsrm = fluid.layers.reshape(
x=fluid.layers.reduce_sum(cost_gsrm), shape=[1])
cost_vsfd = fluid.layers.reshape(
x=fluid.layers.reduce_sum(cost_vsfd), shape=[1])
sum_cost = fluid.layers.sum(
[cost_word, cost_vsfd * 2.0, cost_gsrm * 0.15])
return [sum_cost, cost_vsfd, cost_word]
......@@ -25,6 +25,9 @@ class CharacterOps(object):
def __init__(self, config):
self.character_type = config['character_type']
self.loss_type = config['loss_type']
self.max_text_len = config['max_text_length']
if self.loss_type == "srn" and self.character_type != "en":
raise Exception("SRN can only support in character_type == en")
if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
......@@ -54,6 +57,8 @@ class CharacterOps(object):
self.end_str = "eos"
if self.loss_type == "attention":
dict_character = [self.beg_str, self.end_str] + dict_character
elif self.loss_type == "srn":
dict_character = dict_character + [self.beg_str, self.end_str]
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
......@@ -147,6 +152,39 @@ def cal_predicts_accuracy(char_ops,
return acc, acc_num, img_num
def cal_predicts_accuracy_srn(char_ops,
preds,
labels,
max_text_len,
is_debug=False):
acc_num = 0
img_num = 0
total_len = preds.shape[0]
img_num = int(total_len / max_text_len)
for i in range(img_num):
cur_label = []
cur_pred = []
for j in range(max_text_len):
if labels[j + i * max_text_len] != 37: #0
cur_label.append(labels[j + i * max_text_len][0])
else:
break
for j in range(max_text_len + 1):
if j < len(cur_label) and preds[j + i * max_text_len][
0] != cur_label[j]:
break
elif j == len(cur_label) and j == max_text_len:
acc_num += 1
break
elif j == len(cur_label) and preds[j + i * max_text_len][0] == 37:
acc_num += 1
break
acc = acc_num * 1.0 / img_num
return acc, acc_num, img_num
def convert_rec_attention_infer_res(preds):
img_num = preds.shape[0]
target_lod = [0]
......
......@@ -29,7 +29,7 @@ FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
from ppocr.utils.character import cal_predicts_accuracy
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn
from ppocr.utils.character import convert_rec_label_to_lod
from ppocr.utils.character import convert_rec_attention_infer_res
from ppocr.utils.utility import create_module
......@@ -60,22 +60,60 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
for ino in range(img_num):
img_list.append(data[ino][0])
label_list.append(data[ino][1])
img_list = np.concatenate(img_list, axis=0)
outs = exe.run(eval_info_dict['program'], \
if config['Global']['loss_type'] != "srn":
img_list = np.concatenate(img_list, axis=0)
outs = exe.run(eval_info_dict['program'], \
feed={'image': img_list}, \
fetch_list=eval_info_dict['fetch_varname_list'], \
return_numpy=False)
preds = np.array(outs[0])
if preds.shape[1] != 1:
preds, preds_lod = convert_rec_attention_infer_res(preds)
preds = np.array(outs[0])
if config['Global']['loss_type'] == "attention":
preds, preds_lod = convert_rec_attention_infer_res(preds)
else:
preds_lod = outs[0].lod()[0]
labels, labels_lod = convert_rec_label_to_lod(label_list)
acc, acc_num, sample_num = cal_predicts_accuracy(
char_ops, preds, preds_lod, labels, labels_lod,
is_remove_duplicate)
else:
preds_lod = outs[0].lod()[0]
labels, labels_lod = convert_rec_label_to_lod(label_list)
acc, acc_num, sample_num = cal_predicts_accuracy(
char_ops, preds, preds_lod, labels, labels_lod, is_remove_duplicate)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
for ino in range(img_num):
encoder_word_pos_list.append(data[ino][2])
gsrm_word_pos_list.append(data[ino][3])
gsrm_slf_attn_bias1_list.append(data[ino][4])
gsrm_slf_attn_bias2_list.append(data[ino][5])
img_list = np.concatenate(img_list, axis=0)
label_list = np.concatenate(label_list, axis=0)
encoder_word_pos_list = np.concatenate(
encoder_word_pos_list, axis=0).astype(np.int64)
gsrm_word_pos_list = np.concatenate(
gsrm_word_pos_list, axis=0).astype(np.int64)
gsrm_slf_attn_bias1_list = np.concatenate(
gsrm_slf_attn_bias1_list, axis=0).astype(np.float32)
gsrm_slf_attn_bias2_list = np.concatenate(
gsrm_slf_attn_bias2_list, axis=0).astype(np.float32)
labels = label_list
outs = exe.run(eval_info_dict['program'], \
feed={'image': img_list, 'encoder_word_pos': encoder_word_pos_list,
'gsrm_word_pos': gsrm_word_pos_list, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1_list,
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2_list}, \
fetch_list=eval_info_dict['fetch_varname_list'], \
return_numpy=False)
preds = np.array(outs[0])
acc, acc_num, sample_num = cal_predicts_accuracy_srn(
char_ops, preds, labels, config['Global']['max_text_length'])
total_acc_num += acc_num
total_sample_num += sample_num
logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
#logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
total_batch_num += 1
avg_acc = total_acc_num * 1.0 / total_sample_num
metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \
......
......@@ -40,7 +40,8 @@ class TextRecognizer(object):
char_ops_params = {
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
"use_space_char": args.use_space_char,
"max_text_length": args.max_text_length
}
if self.rec_algorithm != "RARE":
char_ops_params['loss_type'] = 'ctc'
......
......@@ -59,6 +59,7 @@ def parse_args():
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=30)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument(
"--rec_char_dict_path",
type=str,
......
......@@ -64,7 +64,6 @@ def main():
exe = fluid.Executor(place)
rec_model = create_module(config['Architecture']['function'])(params=config)
startup_prog = fluid.Program()
eval_prog = fluid.Program()
with fluid.program_guard(eval_prog, startup_prog):
......@@ -86,10 +85,36 @@ def main():
for i in range(max_img_num):
logger.info("infer_img:%s" % infer_list[i])
img = next(blobs)
predict = exe.run(program=eval_prog,
feed={"image": img},
fetch_list=fetch_varname_list,
return_numpy=False)
if loss_type != "srn":
predict = exe.run(program=eval_prog,
feed={"image": img},
fetch_list=fetch_varname_list,
return_numpy=False)
else:
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
encoder_word_pos_list.append(img[1])
gsrm_word_pos_list.append(img[2])
gsrm_slf_attn_bias1_list.append(img[3])
gsrm_slf_attn_bias2_list.append(img[4])
encoder_word_pos_list = np.concatenate(
encoder_word_pos_list, axis=0).astype(np.int64)
gsrm_word_pos_list = np.concatenate(
gsrm_word_pos_list, axis=0).astype(np.int64)
gsrm_slf_attn_bias1_list = np.concatenate(
gsrm_slf_attn_bias1_list, axis=0).astype(np.float32)
gsrm_slf_attn_bias2_list = np.concatenate(
gsrm_slf_attn_bias2_list, axis=0).astype(np.float32)
predict = exe.run(program=eval_prog, \
feed={'image': img[0], 'encoder_word_pos': encoder_word_pos_list,
'gsrm_word_pos': gsrm_word_pos_list, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1_list,
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2_list}, \
fetch_list=fetch_varname_list, \
return_numpy=False)
if loss_type == "ctc":
preds = np.array(predict[0])
preds = preds.reshape(-1)
......@@ -114,7 +139,18 @@ def main():
score = np.mean(probs[0, 1:end_pos[1]])
preds = preds.reshape(-1)
preds_text = char_ops.decode(preds)
elif loss_type == "srn":
cur_pred = []
preds = np.array(predict[0])
preds = preds.reshape(-1)
probs = np.array(predict[1])
ind = np.argmax(probs, axis=1)
valid_ind = np.where(preds != 37)[0]
if len(valid_ind) == 0:
continue
score = np.mean(probs[valid_ind, ind[valid_ind]])
preds = preds[:valid_ind[-1] + 1]
preds_text = char_ops.decode(preds)
logger.info("\t index: {}".format(preds))
logger.info("\t word : {}".format(preds_text))
logger.info("\t score: {}".format(score))
......
......@@ -32,7 +32,8 @@ from eval_utils.eval_det_utils import eval_det_run
from eval_utils.eval_rec_utils import eval_rec_run
from ppocr.utils.save_load import save_model
import numpy as np
from ppocr.utils.character import cal_predicts_accuracy, CharacterOps
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps
class ArgsParser(ArgumentParser):
def __init__(self):
......@@ -176,8 +177,16 @@ def build(config, main_prog, startup_prog, mode):
fetch_name_list = list(outputs.keys())
fetch_varname_list = [outputs[v].name for v in fetch_name_list]
opt_loss_name = None
model_average = None
img_loss_name = None
word_loss_name = None
if mode == "train":
opt_loss = outputs['total_loss']
# srn loss
#img_loss = outputs['img_loss']
#word_loss = outputs['word_loss']
#img_loss_name = img_loss.name
#word_loss_name = word_loss.name
opt_params = config['Optimizer']
optimizer = create_module(opt_params['function'])(opt_params)
optimizer.minimize(opt_loss)
......@@ -185,7 +194,17 @@ def build(config, main_prog, startup_prog, mode):
global_lr = optimizer._global_learning_rate()
fetch_name_list.insert(0, "lr")
fetch_varname_list.insert(0, global_lr.name)
return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name)
if "loss_type" in config["Global"]:
if config['Global']["loss_type"] == 'srn':
model_average = fluid.optimizer.ModelAverage(
config['Global']['average_window'],
min_average_window=config['Global'][
'min_average_window'],
max_average_window=config['Global'][
'max_average_window'])
return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name,
model_average)
def build_export(config, main_prog, startup_prog):
......@@ -329,14 +348,20 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
lr = np.mean(np.array(train_outs[fetch_map['lr']]))
preds_idx = fetch_map['decoded_out']
preds = np.array(train_outs[preds_idx])
preds_lod = train_outs[preds_idx].lod()[0]
labels_idx = fetch_map['label']
labels = np.array(train_outs[labels_idx])
labels_lod = train_outs[labels_idx].lod()[0]
acc, acc_num, img_num = cal_predicts_accuracy(
config['Global']['char_ops'], preds, preds_lod, labels,
labels_lod)
if config['Global']['loss_type'] != 'srn':
preds_lod = train_outs[preds_idx].lod()[0]
labels_lod = train_outs[labels_idx].lod()[0]
acc, acc_num, img_num = cal_predicts_accuracy(
config['Global']['char_ops'], preds, preds_lod, labels,
labels_lod)
else:
acc, acc_num, img_num = cal_predicts_accuracy_srn(
config['Global']['char_ops'], preds, labels,
config['Global']['max_text_length'])
t2 = time.time()
train_batch_elapse = t2 - t1
stats = {'loss': loss, 'acc': acc}
......@@ -350,6 +375,9 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
if train_batch_id > 0 and\
train_batch_id % eval_batch_step == 0:
model_average = train_info_dict['model_average']
if model_average != None:
model_average.apply(exe)
metrics = eval_rec_run(exe, config, eval_info_dict, "eval")
eval_acc = metrics['avg_acc']
eval_sample_num = metrics['total_sample_num']
......@@ -375,6 +403,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
save_model(train_info_dict['train_program'], save_path)
return
def preprocess():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
......@@ -386,8 +415,8 @@ def preprocess():
check_gpu(use_gpu)
alg = config['Global']['algorithm']
assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE']
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']:
assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
config['Global']['char_ops'] = CharacterOps(config['Global'])
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
......
......@@ -52,6 +52,7 @@ def main():
train_fetch_name_list = train_build_outputs[1]
train_fetch_varname_list = train_build_outputs[2]
train_opt_loss_name = train_build_outputs[3]
model_average = train_build_outputs[-1]
eval_program = fluid.Program()
eval_build_outputs = program.build(
......@@ -85,7 +86,8 @@ def main():
'train_program':train_program,\
'reader':train_loader,\
'fetch_name_list':train_fetch_name_list,\
'fetch_varname_list':train_fetch_varname_list}
'fetch_varname_list':train_fetch_varname_list,\
'model_average': model_average}
eval_info_dict = {'program':eval_program,\
'reader':eval_reader,\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册