Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
612e8014
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
612e8014
编写于
8月 16, 2020
作者:
M
MissPenguin
提交者:
GitHub
8月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #537 from tink2123/add_srn
Add SRN
上级
bad9f6cd
9c893102
变更
18
展开全部
显示空白变更内容
内联
并排
Showing
18 changed file
with
1962 addition
and
48 deletion
+1962
-48
README_cn.md
README_cn.md
+5
-1
configs/rec/rec_r50fpn_vd_none_srn.yml
configs/rec/rec_r50fpn_vd_none_srn.yml
+49
-0
doc/doc_ch/config.md
doc/doc_ch/config.md
+3
-0
ppocr/data/rec/dataset_traversal.py
ppocr/data/rec/dataset_traversal.py
+36
-15
ppocr/data/rec/img_tools.py
ppocr/data/rec/img_tools.py
+81
-0
ppocr/modeling/architectures/rec_model.py
ppocr/modeling/architectures/rec_model.py
+100
-4
ppocr/modeling/backbones/rec_resnet50_fpn.py
ppocr/modeling/backbones/rec_resnet50_fpn.py
+172
-0
ppocr/modeling/heads/rec_srn_all_head.py
ppocr/modeling/heads/rec_srn_all_head.py
+230
-0
ppocr/modeling/heads/self_attention/__init__.py
ppocr/modeling/heads/self_attention/__init__.py
+0
-0
ppocr/modeling/heads/self_attention/model.py
ppocr/modeling/heads/self_attention/model.py
+1058
-0
ppocr/modeling/losses/rec_srn_loss.py
ppocr/modeling/losses/rec_srn_loss.py
+55
-0
ppocr/utils/character.py
ppocr/utils/character.py
+38
-0
tools/eval_utils/eval_rec_utils.py
tools/eval_utils/eval_rec_utils.py
+49
-11
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+2
-1
tools/infer/utility.py
tools/infer/utility.py
+1
-0
tools/infer_rec.py
tools/infer_rec.py
+42
-6
tools/program.py
tools/program.py
+38
-9
tools/train.py
tools/train.py
+3
-1
未找到文件。
README_cn.md
浏览文件 @
612e8014
...
...
@@ -122,7 +122,10 @@ PaddleOCR开源的文本识别算法列表:
-
[
x] Rosetta([paper
](
https://arxiv.org/abs/1910.05085
)
)
-
[
x] STAR-Net([paper
](
http://www.bmva.org/bmvc/2016/papers/paper043/index.html
)
)
-
[
x] RARE([paper
](
https://arxiv.org/abs/1603.03915v1
)
)
-
[
] SRN([paper
](
https://arxiv.org/abs/2003.12294
)
)(百度自研, coming soon)
-
[
x] SRN([paper
](
https://arxiv.org/abs/2003.12294
)
)(百度自研)
*备注:*
SRN模型使用了数据扰动方法对上述提到对两个训练集进行增广,增广后的数据可以在
[
百度网盘
](
todo
)
上下载。
原始论文使用两阶段训练平均精度为89.74%,PaddleOCR中使用one-stage训练,平均精度为88.33%。两种预训练权重均在
[
下载链接
](
https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar
)
中。
参考
[
DTRB
](
https://arxiv.org/abs/1904.01906
)
文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
...
...
@@ -136,6 +139,7 @@ PaddleOCR开源的文本识别算法列表:
|STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|
[
下载链接
](
https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_ctc.tar
)
|
|RARE|Resnet34_vd|84.90%|rec_r34_vd_tps_bilstm_attn|
[
下载链接
](
https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_attn.tar
)
|
|RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|
[
下载链接
](
https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar
)
|
|SRN|Resnet50_vd_fpn|88.33%|rec_r50fpn_vd_none_srn|
[
下载链接
](
https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar
)
|
使用
[
LSVT
](
https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/datasets.md#1icdar2019-lsvt
)
街景数据集根据真值将图crop出来30w数据,进行位置校准。此外基于LSVT语料生成500w合成数据训练中文模型,相关配置和预训练文件如下:
...
...
configs/rec/rec_r50fpn_vd_none_srn.yml
0 → 100755
浏览文件 @
612e8014
Global
:
algorithm
:
SRN
use_gpu
:
true
epoch_num
:
72
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
output/rec_pvam_withrotate
save_epoch_step
:
1
eval_batch_step
:
8000
train_batch_size_per_card
:
64
test_batch_size_per_card
:
1
image_shape
:
[
1
,
64
,
256
]
max_text_length
:
25
character_type
:
en
loss_type
:
srn
num_heads
:
8
average_window
:
0.15
max_average_window
:
15625
min_average_window
:
10000
reader_yml
:
./configs/rec/rec_benchmark_reader.yml
pretrain_weights
:
checkpoints
:
save_inference_dir
:
infer_img
:
Architecture
:
function
:
ppocr.modeling.architectures.rec_model,RecModel
Backbone
:
function
:
ppocr.modeling.backbones.rec_resnet50_fpn,ResNet
layers
:
50
Head
:
function
:
ppocr.modeling.heads.rec_srn_all_head,SRNPredict
encoder_type
:
rnn
num_encoder_TUs
:
2
num_decoder_TUs
:
4
hidden_dims
:
512
SeqRNN
:
hidden_size
:
256
Loss
:
function
:
ppocr.modeling.losses.rec_srn_loss,SRNLoss
Optimizer
:
function
:
ppocr.optimizer,AdamDecay
base_lr
:
0.0001
beta1
:
0.9
beta2
:
0.999
doc/doc_ch/config.md
浏览文件 @
612e8014
...
...
@@ -32,6 +32,9 @@
| loss_type | 设置 loss 类型 | ctc | 支持两种loss: ctc / attention |
| distort | 设置是否使用数据增强 | false | 设置为true时,将在训练时随机进行扰动,支持的扰动操作可阅读
[
img_tools.py
](
https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py
)
|
| use_space_char | 设置是否识别空格 | false | 仅在 character_type=ch 时支持空格 |
| average_window | ModelAverage优化器中的窗口长度计算比例 | 0.15 | 目前仅应用与SRN |
| max_average_window | 平均值计算窗口长度的最大值 | 15625 | 推荐设置为一轮训练中mini-batchs的数目|
| min_average_window | 平均值计算窗口长度的最小值 | 10000 |
\
|
| reader_yml | 设置reader配置文件 | ./configs/rec/rec_icdar15_reader.yml |
\
|
| pretrain_weights | 加载预训练模型路径 | ./pretrain_models/CRNN/best_accuracy |
\
|
| checkpoints | 加载模型参数路径 | None | 用于中断后加载参数继续训练 |
...
...
ppocr/data/rec/dataset_traversal.py
浏览文件 @
612e8014
...
...
@@ -26,7 +26,7 @@ from ppocr.utils.utility import initial_logger
from
ppocr.utils.utility
import
get_image_file_list
logger
=
initial_logger
()
from
.img_tools
import
process_image
,
get_img_data
from
.img_tools
import
process_image
,
process_image_srn
,
get_img_data
class
LMDBReader
(
object
):
...
...
@@ -43,6 +43,9 @@ class LMDBReader(object):
self
.
mode
=
params
[
'mode'
]
self
.
drop_last
=
False
self
.
use_tps
=
False
self
.
num_heads
=
None
if
"num_heads"
in
params
:
self
.
num_heads
=
params
[
'num_heads'
]
if
"tps"
in
params
:
self
.
ues_tps
=
True
self
.
use_distort
=
False
...
...
@@ -119,6 +122,13 @@ class LMDBReader(object):
img
=
cv2
.
imread
(
single_img
)
if
img
.
shape
[
-
1
]
==
1
or
len
(
list
(
img
.
shape
))
==
2
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
if
self
.
loss_type
==
'srn'
:
norm_img
=
process_image_srn
(
img
=
img
,
image_shape
=
self
.
image_shape
,
num_heads
=
self
.
num_heads
,
max_text_length
=
self
.
max_text_length
)
else
:
norm_img
=
process_image
(
img
=
img
,
image_shape
=
self
.
image_shape
,
...
...
@@ -144,14 +154,25 @@ class LMDBReader(object):
if
sample_info
is
None
:
continue
img
,
label
=
sample_info
outs
=
[]
if
self
.
loss_type
==
"srn"
:
outs
=
process_image_srn
(
img
=
img
,
image_shape
=
self
.
image_shape
,
num_heads
=
self
.
num_heads
,
max_text_length
=
self
.
max_text_length
,
label
=
label
,
char_ops
=
self
.
char_ops
,
loss_type
=
self
.
loss_type
)
else
:
outs
=
process_image
(
img
=
img
,
image_shape
=
self
.
image_shape
,
label
=
label
,
char_ops
=
self
.
char_ops
,
loss_type
=
self
.
loss_type
,
max_text_length
=
self
.
max_text_length
,
distort
=
self
.
use_distort
)
max_text_length
=
self
.
max_text_length
)
if
outs
is
None
:
continue
yield
outs
...
...
ppocr/data/rec/img_tools.py
浏览文件 @
612e8014
...
...
@@ -381,3 +381,84 @@ def process_image(img,
assert
False
,
"Unsupport loss_type %s in process_image"
\
%
loss_type
return
(
norm_img
)
def
resize_norm_img_srn
(
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
img_black
=
np
.
zeros
((
imgH
,
imgW
))
im_hei
=
img
.
shape
[
0
]
im_wid
=
img
.
shape
[
1
]
if
im_wid
<=
im_hei
*
1
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
1
,
imgH
))
elif
im_wid
<=
im_hei
*
2
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
2
,
imgH
))
elif
im_wid
<=
im_hei
*
3
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
3
,
imgH
))
else
:
img_new
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
))
img_np
=
np
.
asarray
(
img_new
)
img_np
=
cv2
.
cvtColor
(
img_np
,
cv2
.
COLOR_BGR2GRAY
)
img_black
[:,
0
:
img_np
.
shape
[
1
]]
=
img_np
img_black
=
img_black
[:,
:,
np
.
newaxis
]
row
,
col
,
c
=
img_black
.
shape
c
=
1
return
np
.
reshape
(
img_black
,
(
c
,
row
,
col
)).
astype
(
np
.
float32
)
def
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
):
imgC
,
imgH
,
imgW
=
image_shape
feature_dim
=
int
((
imgH
/
8
)
*
(
imgW
/
8
))
encoder_word_pos
=
np
.
array
(
range
(
0
,
feature_dim
)).
reshape
((
feature_dim
,
1
)).
astype
(
'int64'
)
gsrm_word_pos
=
np
.
array
(
range
(
0
,
max_text_length
)).
reshape
((
max_text_length
,
1
)).
astype
(
'int64'
)
lbl_weight
=
np
.
array
([
37
]
*
max_text_length
).
reshape
((
-
1
,
1
)).
astype
(
'int64'
)
gsrm_attn_bias_data
=
np
.
ones
((
1
,
max_text_length
,
max_text_length
))
gsrm_slf_attn_bias1
=
np
.
triu
(
gsrm_attn_bias_data
,
1
).
reshape
([
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias1
=
np
.
tile
(
gsrm_slf_attn_bias1
,
[
1
,
num_heads
,
1
,
1
])
*
[
-
1e9
]
gsrm_slf_attn_bias2
=
np
.
tril
(
gsrm_attn_bias_data
,
-
1
).
reshape
([
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias2
=
np
.
tile
(
gsrm_slf_attn_bias2
,
[
1
,
num_heads
,
1
,
1
])
*
[
-
1e9
]
encoder_word_pos
=
encoder_word_pos
[
np
.
newaxis
,
:]
gsrm_word_pos
=
gsrm_word_pos
[
np
.
newaxis
,
:]
return
[
lbl_weight
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
def
process_image_srn
(
img
,
image_shape
,
num_heads
,
max_text_length
,
label
=
None
,
char_ops
=
None
,
loss_type
=
None
):
norm_img
=
resize_norm_img_srn
(
img
,
image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
[
lbl_weight
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
=
\
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
)
if
label
is
not
None
:
char_num
=
char_ops
.
get_char_num
()
text
=
char_ops
.
encode
(
label
)
if
len
(
text
)
==
0
or
len
(
text
)
>
max_text_length
:
return
None
else
:
if
loss_type
==
"srn"
:
text_padded
=
[
37
]
*
max_text_length
for
i
in
range
(
len
(
text
)):
text_padded
[
i
]
=
text
[
i
]
lbl_weight
[
i
]
=
[
1.0
]
text_padded
=
np
.
array
(
text_padded
)
text
=
text_padded
.
reshape
(
-
1
,
1
)
return
(
norm_img
,
text
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
,
lbl_weight
)
else
:
assert
False
,
"Unsupport loss_type %s in process_image"
\
%
loss_type
return
(
norm_img
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
)
ppocr/modeling/architectures/rec_model.py
浏览文件 @
612e8014
...
...
@@ -58,6 +58,10 @@ class RecModel(object):
self
.
loss_type
=
global_params
[
'loss_type'
]
self
.
image_shape
=
global_params
[
'image_shape'
]
self
.
max_text_length
=
global_params
[
'max_text_length'
]
if
"num_heads"
in
params
:
self
.
num_heads
=
global_params
[
"num_heads"
]
else
:
self
.
num_heads
=
None
def
create_feed
(
self
,
mode
):
image_shape
=
deepcopy
(
self
.
image_shape
)
...
...
@@ -77,6 +81,48 @@ class RecModel(object):
lod_level
=
1
)
feed_list
=
[
image
,
label_in
,
label_out
]
labels
=
{
'label_in'
:
label_in
,
'label_out'
:
label_out
}
elif
self
.
loss_type
==
"srn"
:
encoder_word_pos
=
fluid
.
data
(
name
=
"encoder_word_pos"
,
shape
=
[
-
1
,
int
((
image_shape
[
-
2
]
/
8
)
*
(
image_shape
[
-
1
]
/
8
)),
1
],
dtype
=
"int64"
)
gsrm_word_pos
=
fluid
.
data
(
name
=
"gsrm_word_pos"
,
shape
=
[
-
1
,
self
.
max_text_length
,
1
],
dtype
=
"int64"
)
gsrm_slf_attn_bias1
=
fluid
.
data
(
name
=
"gsrm_slf_attn_bias1"
,
shape
=
[
-
1
,
self
.
num_heads
,
self
.
max_text_length
,
self
.
max_text_length
],
dtype
=
"float32"
)
gsrm_slf_attn_bias2
=
fluid
.
data
(
name
=
"gsrm_slf_attn_bias2"
,
shape
=
[
-
1
,
self
.
num_heads
,
self
.
max_text_length
,
self
.
max_text_length
],
dtype
=
"float32"
)
lbl_weight
=
fluid
.
layers
.
data
(
name
=
"lbl_weight"
,
shape
=
[
-
1
,
1
],
dtype
=
'int64'
)
label
=
fluid
.
data
(
name
=
'label'
,
shape
=
[
-
1
,
1
],
dtype
=
'int32'
,
lod_level
=
1
)
feed_list
=
[
image
,
label
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
,
lbl_weight
]
labels
=
{
'label'
:
label
,
'encoder_word_pos'
:
encoder_word_pos
,
'gsrm_word_pos'
:
gsrm_word_pos
,
'gsrm_slf_attn_bias1'
:
gsrm_slf_attn_bias1
,
'gsrm_slf_attn_bias2'
:
gsrm_slf_attn_bias2
,
'lbl_weight'
:
lbl_weight
}
else
:
label
=
fluid
.
data
(
name
=
'label'
,
shape
=
[
None
,
1
],
dtype
=
'int32'
,
lod_level
=
1
)
...
...
@@ -88,6 +134,8 @@ class RecModel(object):
use_double_buffer
=
True
,
iterable
=
False
)
else
:
labels
=
None
loader
=
None
if
self
.
char_type
==
"ch"
and
self
.
infer_img
:
image_shape
[
-
1
]
=
-
1
if
self
.
tps
!=
None
:
...
...
@@ -98,8 +146,42 @@ class RecModel(object):
)
image_shape
=
deepcopy
(
self
.
image_shape
)
image
=
fluid
.
data
(
name
=
'image'
,
shape
=
image_shape
,
dtype
=
'float32'
)
labels
=
None
loader
=
None
if
self
.
loss_type
==
"srn"
:
encoder_word_pos
=
fluid
.
data
(
name
=
"encoder_word_pos"
,
shape
=
[
-
1
,
int
((
image_shape
[
-
2
]
/
8
)
*
(
image_shape
[
-
1
]
/
8
)),
1
],
dtype
=
"int64"
)
gsrm_word_pos
=
fluid
.
data
(
name
=
"gsrm_word_pos"
,
shape
=
[
-
1
,
self
.
max_text_length
,
1
],
dtype
=
"int64"
)
gsrm_slf_attn_bias1
=
fluid
.
data
(
name
=
"gsrm_slf_attn_bias1"
,
shape
=
[
-
1
,
self
.
num_heads
,
self
.
max_text_length
,
self
.
max_text_length
],
dtype
=
"float32"
)
gsrm_slf_attn_bias2
=
fluid
.
data
(
name
=
"gsrm_slf_attn_bias2"
,
shape
=
[
-
1
,
self
.
num_heads
,
self
.
max_text_length
,
self
.
max_text_length
],
dtype
=
"float32"
)
feed_list
=
[
image
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
labels
=
{
'encoder_word_pos'
:
encoder_word_pos
,
'gsrm_word_pos'
:
gsrm_word_pos
,
'gsrm_slf_attn_bias1'
:
gsrm_slf_attn_bias1
,
'gsrm_slf_attn_bias2'
:
gsrm_slf_attn_bias2
}
return
image
,
labels
,
loader
def
__call__
(
self
,
mode
):
...
...
@@ -117,13 +199,27 @@ class RecModel(object):
label
=
labels
[
'label_out'
]
else
:
label
=
labels
[
'label'
]
if
self
.
loss_type
==
'srn'
:
total_loss
,
img_loss
,
word_loss
=
self
.
loss
(
predicts
,
labels
)
outputs
=
{
'total_loss'
:
total_loss
,
'img_loss'
:
img_loss
,
'word_loss'
:
word_loss
,
'decoded_out'
:
decoded_out
,
'label'
:
label
}
else
:
outputs
=
{
'total_loss'
:
loss
,
'decoded_out'
:
\
decoded_out
,
'label'
:
label
}
return
loader
,
outputs
elif
mode
==
"export"
:
predict
=
predicts
[
'predict'
]
if
self
.
loss_type
==
"ctc"
:
predict
=
fluid
.
layers
.
softmax
(
predict
)
if
self
.
loss_type
==
"srn"
:
raise
Exception
(
"Warning! SRN does not support export model currently"
)
return
[
image
,
{
'decoded_out'
:
decoded_out
,
'predicts'
:
predict
}]
else
:
predict
=
predicts
[
'predict'
]
...
...
ppocr/modeling/backbones/rec_resnet50_fpn.py
0 → 100755
浏览文件 @
612e8014
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
__all__
=
[
"ResNet"
,
"ResNet18"
,
"ResNet34"
,
"ResNet50"
,
"ResNet101"
,
"ResNet152"
]
Trainable
=
True
w_nolr
=
fluid
.
ParamAttr
(
trainable
=
Trainable
)
train_parameters
=
{
"input_size"
:
[
3
,
224
,
224
],
"input_mean"
:
[
0.485
,
0.456
,
0.406
],
"input_std"
:
[
0.229
,
0.224
,
0.225
],
"learning_strategy"
:
{
"name"
:
"piecewise_decay"
,
"batch_size"
:
256
,
"epochs"
:
[
30
,
60
,
90
],
"steps"
:
[
0.1
,
0.01
,
0.001
,
0.0001
]
}
}
class
ResNet
():
def
__init__
(
self
,
params
):
self
.
layers
=
params
[
'layers'
]
self
.
params
=
train_parameters
def
__call__
(
self
,
input
):
layers
=
self
.
layers
supported_layers
=
[
18
,
34
,
50
,
101
,
152
]
assert
layers
in
supported_layers
,
\
"supported layers are {} but input layer is {}"
.
format
(
supported_layers
,
layers
)
if
layers
==
18
:
depth
=
[
2
,
2
,
2
,
2
]
elif
layers
==
34
or
layers
==
50
:
depth
=
[
3
,
4
,
6
,
3
]
elif
layers
==
101
:
depth
=
[
3
,
4
,
23
,
3
]
elif
layers
==
152
:
depth
=
[
3
,
8
,
36
,
3
]
stride_list
=
[(
2
,
2
),(
2
,
2
),(
1
,
1
),(
1
,
1
)]
num_filters
=
[
64
,
128
,
256
,
512
]
conv
=
self
.
conv_bn_layer
(
input
=
input
,
num_filters
=
64
,
filter_size
=
7
,
stride
=
2
,
act
=
'relu'
,
name
=
"conv1"
)
F
=
[]
if
layers
>=
50
:
for
block
in
range
(
len
(
depth
)):
for
i
in
range
(
depth
[
block
]):
if
layers
in
[
101
,
152
]
and
block
==
2
:
if
i
==
0
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"a"
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"b"
+
str
(
i
)
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
conv
=
self
.
bottleneck_block
(
input
=
conv
,
num_filters
=
num_filters
[
block
],
stride
=
stride_list
[
block
]
if
i
==
0
else
1
,
name
=
conv_name
)
F
.
append
(
conv
)
base
=
F
[
-
1
]
for
i
in
[
-
2
,
-
3
]:
b
,
c
,
w
,
h
=
F
[
i
].
shape
if
(
w
,
h
)
==
base
.
shape
[
2
:]:
base
=
base
else
:
base
=
fluid
.
layers
.
conv2d_transpose
(
input
=
base
,
num_filters
=
c
,
filter_size
=
4
,
stride
=
2
,
padding
=
1
,
act
=
None
,
param_attr
=
w_nolr
,
bias_attr
=
w_nolr
)
base
=
fluid
.
layers
.
batch_norm
(
base
,
act
=
"relu"
,
param_attr
=
w_nolr
,
bias_attr
=
w_nolr
)
base
=
fluid
.
layers
.
concat
([
base
,
F
[
i
]],
axis
=
1
)
base
=
fluid
.
layers
.
conv2d
(
base
,
num_filters
=
c
,
filter_size
=
1
,
param_attr
=
w_nolr
,
bias_attr
=
w_nolr
)
base
=
fluid
.
layers
.
conv2d
(
base
,
num_filters
=
c
,
filter_size
=
3
,
padding
=
1
,
param_attr
=
w_nolr
,
bias_attr
=
w_nolr
)
base
=
fluid
.
layers
.
batch_norm
(
base
,
act
=
"relu"
,
param_attr
=
w_nolr
,
bias_attr
=
w_nolr
)
base
=
fluid
.
layers
.
conv2d
(
base
,
num_filters
=
512
,
filter_size
=
1
,
bias_attr
=
w_nolr
,
param_attr
=
w_nolr
)
return
base
def
conv_bn_layer
(
self
,
input
,
num_filters
,
filter_size
,
stride
=
1
,
groups
=
1
,
act
=
None
,
name
=
None
):
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
2
if
stride
==
(
1
,
1
)
else
filter_size
,
dilation
=
2
if
stride
==
(
1
,
1
)
else
1
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
groups
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
name
+
"_weights"
,
trainable
=
Trainable
),
bias_attr
=
False
,
name
=
name
+
'.conv2d.output.1'
)
if
name
==
"conv1"
:
bn_name
=
"bn_"
+
name
else
:
bn_name
=
"bn"
+
name
[
3
:]
return
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
act
,
name
=
bn_name
+
'.output.1'
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
,
trainable
=
Trainable
),
bias_attr
=
ParamAttr
(
bn_name
+
'_offset'
,
trainable
=
Trainable
),
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
,
)
def
shortcut
(
self
,
input
,
ch_out
,
stride
,
is_first
,
name
):
ch_in
=
input
.
shape
[
1
]
if
ch_in
!=
ch_out
or
stride
!=
1
or
is_first
==
True
:
if
stride
==
(
1
,
1
):
return
self
.
conv_bn_layer
(
input
,
ch_out
,
1
,
1
,
name
=
name
)
else
:
#stride == (2,2)
return
self
.
conv_bn_layer
(
input
,
ch_out
,
1
,
stride
,
name
=
name
)
else
:
return
input
def
bottleneck_block
(
self
,
input
,
num_filters
,
stride
,
name
):
conv0
=
self
.
conv_bn_layer
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
1
,
act
=
'relu'
,
name
=
name
+
"_branch2a"
)
conv1
=
self
.
conv_bn_layer
(
input
=
conv0
,
num_filters
=
num_filters
,
filter_size
=
3
,
stride
=
stride
,
act
=
'relu'
,
name
=
name
+
"_branch2b"
)
conv2
=
self
.
conv_bn_layer
(
input
=
conv1
,
num_filters
=
num_filters
*
4
,
filter_size
=
1
,
act
=
None
,
name
=
name
+
"_branch2c"
)
short
=
self
.
shortcut
(
input
,
num_filters
*
4
,
stride
,
is_first
=
False
,
name
=
name
+
"_branch1"
)
return
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
conv2
,
act
=
'relu'
,
name
=
name
+
".add.output.5"
)
def
basic_block
(
self
,
input
,
num_filters
,
stride
,
is_first
,
name
):
conv0
=
self
.
conv_bn_layer
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
3
,
act
=
'relu'
,
stride
=
stride
,
name
=
name
+
"_branch2a"
)
conv1
=
self
.
conv_bn_layer
(
input
=
conv0
,
num_filters
=
num_filters
,
filter_size
=
3
,
act
=
None
,
name
=
name
+
"_branch2b"
)
short
=
self
.
shortcut
(
input
,
num_filters
,
stride
,
is_first
,
name
=
name
+
"_branch1"
)
return
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
conv1
,
act
=
'relu'
)
ppocr/modeling/heads/rec_srn_all_head.py
0 → 100755
浏览文件 @
612e8014
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
import
numpy
as
np
from
.self_attention.model
import
wrap_encoder
from
.self_attention.model
import
wrap_encoder_forFeature
gradient_clip
=
10
class
SRNPredict
(
object
):
def
__init__
(
self
,
params
):
super
(
SRNPredict
,
self
).
__init__
()
self
.
char_num
=
params
[
'char_num'
]
self
.
max_length
=
params
[
'max_text_length'
]
self
.
num_heads
=
params
[
'num_heads'
]
self
.
num_encoder_TUs
=
params
[
'num_encoder_TUs'
]
self
.
num_decoder_TUs
=
params
[
'num_decoder_TUs'
]
self
.
hidden_dims
=
params
[
'hidden_dims'
]
def
pvam
(
self
,
inputs
,
others
):
b
,
c
,
h
,
w
=
inputs
.
shape
conv_features
=
fluid
.
layers
.
reshape
(
x
=
inputs
,
shape
=
[
-
1
,
c
,
h
*
w
])
conv_features
=
fluid
.
layers
.
transpose
(
x
=
conv_features
,
perm
=
[
0
,
2
,
1
])
#===== Transformer encoder =====
b
,
t
,
c
=
conv_features
.
shape
encoder_word_pos
=
others
[
"encoder_word_pos"
]
gsrm_word_pos
=
others
[
"gsrm_word_pos"
]
enc_inputs
=
[
conv_features
,
encoder_word_pos
,
None
]
word_features
=
wrap_encoder_forFeature
(
src_vocab_size
=-
1
,
max_length
=
t
,
n_layer
=
self
.
num_encoder_TUs
,
n_head
=
self
.
num_heads
,
d_key
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_value
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_model
=
self
.
hidden_dims
,
d_inner_hid
=
self
.
hidden_dims
,
prepostprocess_dropout
=
0.1
,
attention_dropout
=
0.1
,
relu_dropout
=
0.1
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
,
weight_sharing
=
True
,
enc_inputs
=
enc_inputs
,
)
fluid
.
clip
.
set_gradient_clip
(
fluid
.
clip
.
GradientClipByValue
(
gradient_clip
))
#===== Parallel Visual Attention Module =====
b
,
t
,
c
=
word_features
.
shape
word_features
=
fluid
.
layers
.
fc
(
word_features
,
c
,
num_flatten_dims
=
2
)
word_features_
=
fluid
.
layers
.
reshape
(
word_features
,
[
-
1
,
1
,
t
,
c
])
word_features_
=
fluid
.
layers
.
expand
(
word_features_
,
[
1
,
self
.
max_length
,
1
,
1
])
word_pos_feature
=
fluid
.
layers
.
embedding
(
gsrm_word_pos
,
[
self
.
max_length
,
c
])
word_pos_
=
fluid
.
layers
.
reshape
(
word_pos_feature
,
[
-
1
,
self
.
max_length
,
1
,
c
])
word_pos_
=
fluid
.
layers
.
expand
(
word_pos_
,
[
1
,
1
,
t
,
1
])
temp
=
fluid
.
layers
.
elementwise_add
(
word_features_
,
word_pos_
,
act
=
'tanh'
)
attention_weight
=
fluid
.
layers
.
fc
(
input
=
temp
,
size
=
1
,
num_flatten_dims
=
3
,
bias_attr
=
False
)
attention_weight
=
fluid
.
layers
.
reshape
(
x
=
attention_weight
,
shape
=
[
-
1
,
self
.
max_length
,
t
])
attention_weight
=
fluid
.
layers
.
softmax
(
input
=
attention_weight
,
axis
=-
1
)
pvam_features
=
fluid
.
layers
.
matmul
(
attention_weight
,
word_features
)
#[b, max_length, c]
return
pvam_features
def
gsrm
(
self
,
pvam_features
,
others
):
#===== GSRM Visual-to-semantic embedding block =====
b
,
t
,
c
=
pvam_features
.
shape
word_out
=
fluid
.
layers
.
fc
(
input
=
fluid
.
layers
.
reshape
(
pvam_features
,
[
-
1
,
c
]),
size
=
self
.
char_num
,
act
=
"softmax"
)
#word_out.stop_gradient = True
word_ids
=
fluid
.
layers
.
argmax
(
word_out
,
axis
=
1
)
word_ids
.
stop_gradient
=
True
word_ids
=
fluid
.
layers
.
reshape
(
x
=
word_ids
,
shape
=
[
-
1
,
t
,
1
])
#===== GSRM Semantic reasoning block =====
"""
This module is achieved through bi-transformers,
ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
"""
pad_idx
=
self
.
char_num
gsrm_word_pos
=
others
[
"gsrm_word_pos"
]
gsrm_slf_attn_bias1
=
others
[
"gsrm_slf_attn_bias1"
]
gsrm_slf_attn_bias2
=
others
[
"gsrm_slf_attn_bias2"
]
def
prepare_bi
(
word_ids
):
"""
prepare bi for gsrm
word1 for forward; word2 for backward
"""
word1
=
fluid
.
layers
.
cast
(
word_ids
,
"float32"
)
word1
=
fluid
.
layers
.
pad
(
word1
,
[
0
,
0
,
1
,
0
,
0
,
0
],
pad_value
=
1.0
*
pad_idx
)
word1
=
fluid
.
layers
.
cast
(
word1
,
"int64"
)
word1
=
word1
[:,
:
-
1
,
:]
word2
=
word_ids
return
word1
,
word2
word1
,
word2
=
prepare_bi
(
word_ids
)
word1
.
stop_gradient
=
True
word2
.
stop_gradient
=
True
enc_inputs_1
=
[
word1
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
]
enc_inputs_2
=
[
word2
,
gsrm_word_pos
,
gsrm_slf_attn_bias2
]
gsrm_feature1
=
wrap_encoder
(
src_vocab_size
=
self
.
char_num
+
1
,
max_length
=
self
.
max_length
,
n_layer
=
self
.
num_decoder_TUs
,
n_head
=
self
.
num_heads
,
d_key
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_value
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_model
=
self
.
hidden_dims
,
d_inner_hid
=
self
.
hidden_dims
,
prepostprocess_dropout
=
0.1
,
attention_dropout
=
0.1
,
relu_dropout
=
0.1
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
,
weight_sharing
=
True
,
enc_inputs
=
enc_inputs_1
,
)
gsrm_feature2
=
wrap_encoder
(
src_vocab_size
=
self
.
char_num
+
1
,
max_length
=
self
.
max_length
,
n_layer
=
self
.
num_decoder_TUs
,
n_head
=
self
.
num_heads
,
d_key
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_value
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_model
=
self
.
hidden_dims
,
d_inner_hid
=
self
.
hidden_dims
,
prepostprocess_dropout
=
0.1
,
attention_dropout
=
0.1
,
relu_dropout
=
0.1
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
,
weight_sharing
=
True
,
enc_inputs
=
enc_inputs_2
,
)
gsrm_feature2
=
fluid
.
layers
.
pad
(
gsrm_feature2
,
[
0
,
0
,
0
,
1
,
0
,
0
],
pad_value
=
0.
)
gsrm_feature2
=
gsrm_feature2
[:,
1
:,
]
gsrm_features
=
gsrm_feature1
+
gsrm_feature2
b
,
t
,
c
=
gsrm_features
.
shape
gsrm_out
=
fluid
.
layers
.
matmul
(
x
=
gsrm_features
,
y
=
fluid
.
default_main_program
().
global_block
().
var
(
"src_word_emb_table"
),
transpose_y
=
True
)
b
,
t
,
c
=
gsrm_out
.
shape
gsrm_out
=
fluid
.
layers
.
softmax
(
input
=
fluid
.
layers
.
reshape
(
gsrm_out
,
[
-
1
,
c
]))
return
gsrm_features
,
word_out
,
gsrm_out
def
vsfd
(
self
,
pvam_features
,
gsrm_features
):
#===== Visual-Semantic Fusion Decoder Module =====
b
,
t
,
c1
=
pvam_features
.
shape
b
,
t
,
c2
=
gsrm_features
.
shape
combine_features_
=
fluid
.
layers
.
concat
(
[
pvam_features
,
gsrm_features
],
axis
=
2
)
img_comb_features_
=
fluid
.
layers
.
reshape
(
x
=
combine_features_
,
shape
=
[
-
1
,
c1
+
c2
])
img_comb_features_map
=
fluid
.
layers
.
fc
(
input
=
img_comb_features_
,
size
=
c1
,
act
=
"sigmoid"
)
img_comb_features_map
=
fluid
.
layers
.
reshape
(
x
=
img_comb_features_map
,
shape
=
[
-
1
,
t
,
c1
])
combine_features
=
img_comb_features_map
*
pvam_features
+
(
1.0
-
img_comb_features_map
)
*
gsrm_features
img_comb_features
=
fluid
.
layers
.
reshape
(
x
=
combine_features
,
shape
=
[
-
1
,
c1
])
fc_out
=
fluid
.
layers
.
fc
(
input
=
img_comb_features
,
size
=
self
.
char_num
,
act
=
"softmax"
)
return
fc_out
def
__call__
(
self
,
inputs
,
others
,
mode
=
None
):
pvam_features
=
self
.
pvam
(
inputs
,
others
)
gsrm_features
,
word_out
,
gsrm_out
=
self
.
gsrm
(
pvam_features
,
others
)
final_out
=
self
.
vsfd
(
pvam_features
,
gsrm_features
)
_
,
decoded_out
=
fluid
.
layers
.
topk
(
input
=
final_out
,
k
=
1
)
predicts
=
{
'predict'
:
final_out
,
'decoded_out'
:
decoded_out
,
'word_out'
:
word_out
,
'gsrm_out'
:
gsrm_out
}
return
predicts
ppocr/modeling/heads/self_attention/__init__.py
0 → 100644
浏览文件 @
612e8014
ppocr/modeling/heads/self_attention/model.py
0 → 100644
浏览文件 @
612e8014
此差异已折叠。
点击以展开。
ppocr/modeling/losses/rec_srn_loss.py
0 → 100755
浏览文件 @
612e8014
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
import
paddle.fluid
as
fluid
class
SRNLoss
(
object
):
def
__init__
(
self
,
params
):
super
(
SRNLoss
,
self
).
__init__
()
self
.
char_num
=
params
[
'char_num'
]
def
__call__
(
self
,
predicts
,
others
):
predict
=
predicts
[
'predict'
]
word_predict
=
predicts
[
'word_out'
]
gsrm_predict
=
predicts
[
'gsrm_out'
]
label
=
others
[
'label'
]
lbl_weight
=
others
[
'lbl_weight'
]
casted_label
=
fluid
.
layers
.
cast
(
x
=
label
,
dtype
=
'int64'
)
cost_word
=
fluid
.
layers
.
cross_entropy
(
input
=
word_predict
,
label
=
casted_label
)
cost_gsrm
=
fluid
.
layers
.
cross_entropy
(
input
=
gsrm_predict
,
label
=
casted_label
)
cost_vsfd
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
casted_label
)
cost_word
=
fluid
.
layers
.
reshape
(
x
=
fluid
.
layers
.
reduce_sum
(
cost_word
),
shape
=
[
1
])
cost_gsrm
=
fluid
.
layers
.
reshape
(
x
=
fluid
.
layers
.
reduce_sum
(
cost_gsrm
),
shape
=
[
1
])
cost_vsfd
=
fluid
.
layers
.
reshape
(
x
=
fluid
.
layers
.
reduce_sum
(
cost_vsfd
),
shape
=
[
1
])
sum_cost
=
fluid
.
layers
.
sum
(
[
cost_word
,
cost_vsfd
*
2.0
,
cost_gsrm
*
0.15
])
return
[
sum_cost
,
cost_vsfd
,
cost_word
]
ppocr/utils/character.py
浏览文件 @
612e8014
...
...
@@ -25,6 +25,9 @@ class CharacterOps(object):
def
__init__
(
self
,
config
):
self
.
character_type
=
config
[
'character_type'
]
self
.
loss_type
=
config
[
'loss_type'
]
self
.
max_text_len
=
config
[
'max_text_length'
]
if
self
.
loss_type
==
"srn"
and
self
.
character_type
!=
"en"
:
raise
Exception
(
"SRN can only support in character_type == en"
)
if
self
.
character_type
==
"en"
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
dict_character
=
list
(
self
.
character_str
)
...
...
@@ -54,6 +57,8 @@ class CharacterOps(object):
self
.
end_str
=
"eos"
if
self
.
loss_type
==
"attention"
:
dict_character
=
[
self
.
beg_str
,
self
.
end_str
]
+
dict_character
elif
self
.
loss_type
==
"srn"
:
dict_character
=
dict_character
+
[
self
.
beg_str
,
self
.
end_str
]
self
.
dict
=
{}
for
i
,
char
in
enumerate
(
dict_character
):
self
.
dict
[
char
]
=
i
...
...
@@ -147,6 +152,39 @@ def cal_predicts_accuracy(char_ops,
return
acc
,
acc_num
,
img_num
def
cal_predicts_accuracy_srn
(
char_ops
,
preds
,
labels
,
max_text_len
,
is_debug
=
False
):
acc_num
=
0
img_num
=
0
total_len
=
preds
.
shape
[
0
]
img_num
=
int
(
total_len
/
max_text_len
)
for
i
in
range
(
img_num
):
cur_label
=
[]
cur_pred
=
[]
for
j
in
range
(
max_text_len
):
if
labels
[
j
+
i
*
max_text_len
]
!=
37
:
#0
cur_label
.
append
(
labels
[
j
+
i
*
max_text_len
][
0
])
else
:
break
for
j
in
range
(
max_text_len
+
1
):
if
j
<
len
(
cur_label
)
and
preds
[
j
+
i
*
max_text_len
][
0
]
!=
cur_label
[
j
]:
break
elif
j
==
len
(
cur_label
)
and
j
==
max_text_len
:
acc_num
+=
1
break
elif
j
==
len
(
cur_label
)
and
preds
[
j
+
i
*
max_text_len
][
0
]
==
37
:
acc_num
+=
1
break
acc
=
acc_num
*
1.0
/
img_num
return
acc
,
acc_num
,
img_num
def
convert_rec_attention_infer_res
(
preds
):
img_num
=
preds
.
shape
[
0
]
target_lod
=
[
0
]
...
...
tools/eval_utils/eval_rec_utils.py
浏览文件 @
612e8014
...
...
@@ -29,7 +29,7 @@ FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
logger
=
logging
.
getLogger
(
__name__
)
from
ppocr.utils.character
import
cal_predicts_accuracy
from
ppocr.utils.character
import
cal_predicts_accuracy
,
cal_predicts_accuracy_srn
from
ppocr.utils.character
import
convert_rec_label_to_lod
from
ppocr.utils.character
import
convert_rec_attention_infer_res
from
ppocr.utils.utility
import
create_module
...
...
@@ -60,22 +60,60 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
for
ino
in
range
(
img_num
):
img_list
.
append
(
data
[
ino
][
0
])
label_list
.
append
(
data
[
ino
][
1
])
if
config
[
'Global'
][
'loss_type'
]
!=
"srn"
:
img_list
=
np
.
concatenate
(
img_list
,
axis
=
0
)
outs
=
exe
.
run
(
eval_info_dict
[
'program'
],
\
feed
=
{
'image'
:
img_list
},
\
fetch_list
=
eval_info_dict
[
'fetch_varname_list'
],
\
return_numpy
=
False
)
preds
=
np
.
array
(
outs
[
0
])
if
preds
.
shape
[
1
]
!=
1
:
if
config
[
'Global'
][
'loss_type'
]
==
"attention"
:
preds
,
preds_lod
=
convert_rec_attention_infer_res
(
preds
)
else
:
preds_lod
=
outs
[
0
].
lod
()[
0
]
labels
,
labels_lod
=
convert_rec_label_to_lod
(
label_list
)
acc
,
acc_num
,
sample_num
=
cal_predicts_accuracy
(
char_ops
,
preds
,
preds_lod
,
labels
,
labels_lod
,
is_remove_duplicate
)
char_ops
,
preds
,
preds_lod
,
labels
,
labels_lod
,
is_remove_duplicate
)
else
:
encoder_word_pos_list
=
[]
gsrm_word_pos_list
=
[]
gsrm_slf_attn_bias1_list
=
[]
gsrm_slf_attn_bias2_list
=
[]
for
ino
in
range
(
img_num
):
encoder_word_pos_list
.
append
(
data
[
ino
][
2
])
gsrm_word_pos_list
.
append
(
data
[
ino
][
3
])
gsrm_slf_attn_bias1_list
.
append
(
data
[
ino
][
4
])
gsrm_slf_attn_bias2_list
.
append
(
data
[
ino
][
5
])
img_list
=
np
.
concatenate
(
img_list
,
axis
=
0
)
label_list
=
np
.
concatenate
(
label_list
,
axis
=
0
)
encoder_word_pos_list
=
np
.
concatenate
(
encoder_word_pos_list
,
axis
=
0
).
astype
(
np
.
int64
)
gsrm_word_pos_list
=
np
.
concatenate
(
gsrm_word_pos_list
,
axis
=
0
).
astype
(
np
.
int64
)
gsrm_slf_attn_bias1_list
=
np
.
concatenate
(
gsrm_slf_attn_bias1_list
,
axis
=
0
).
astype
(
np
.
float32
)
gsrm_slf_attn_bias2_list
=
np
.
concatenate
(
gsrm_slf_attn_bias2_list
,
axis
=
0
).
astype
(
np
.
float32
)
labels
=
label_list
outs
=
exe
.
run
(
eval_info_dict
[
'program'
],
\
feed
=
{
'image'
:
img_list
,
'encoder_word_pos'
:
encoder_word_pos_list
,
'gsrm_word_pos'
:
gsrm_word_pos_list
,
'gsrm_slf_attn_bias1'
:
gsrm_slf_attn_bias1_list
,
'gsrm_slf_attn_bias2'
:
gsrm_slf_attn_bias2_list
},
\
fetch_list
=
eval_info_dict
[
'fetch_varname_list'
],
\
return_numpy
=
False
)
preds
=
np
.
array
(
outs
[
0
])
acc
,
acc_num
,
sample_num
=
cal_predicts_accuracy_srn
(
char_ops
,
preds
,
labels
,
config
[
'Global'
][
'max_text_length'
])
total_acc_num
+=
acc_num
total_sample_num
+=
sample_num
logger
.
info
(
"eval batch id: {}, acc: {}"
.
format
(
total_batch_num
,
acc
))
#
logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
total_batch_num
+=
1
avg_acc
=
total_acc_num
*
1.0
/
total_sample_num
metrics
=
{
'avg_acc'
:
avg_acc
,
"total_acc_num"
:
total_acc_num
,
\
...
...
tools/infer/predict_rec.py
浏览文件 @
612e8014
...
...
@@ -40,7 +40,8 @@ class TextRecognizer(object):
char_ops_params
=
{
"character_type"
:
args
.
rec_char_type
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
"use_space_char"
:
args
.
use_space_char
,
"max_text_length"
:
args
.
max_text_length
}
if
self
.
rec_algorithm
!=
"RARE"
:
char_ops_params
[
'loss_type'
]
=
'ctc'
...
...
tools/infer/utility.py
浏览文件 @
612e8014
...
...
@@ -59,6 +59,7 @@ def parse_args():
parser
.
add_argument
(
"--rec_image_shape"
,
type
=
str
,
default
=
"3, 32, 320"
)
parser
.
add_argument
(
"--rec_char_type"
,
type
=
str
,
default
=
'ch'
)
parser
.
add_argument
(
"--rec_batch_num"
,
type
=
int
,
default
=
30
)
parser
.
add_argument
(
"--max_text_length"
,
type
=
int
,
default
=
25
)
parser
.
add_argument
(
"--rec_char_dict_path"
,
type
=
str
,
...
...
tools/infer_rec.py
浏览文件 @
612e8014
...
...
@@ -64,7 +64,6 @@ def main():
exe
=
fluid
.
Executor
(
place
)
rec_model
=
create_module
(
config
[
'Architecture'
][
'function'
])(
params
=
config
)
startup_prog
=
fluid
.
Program
()
eval_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
eval_prog
,
startup_prog
):
...
...
@@ -86,10 +85,36 @@ def main():
for
i
in
range
(
max_img_num
):
logger
.
info
(
"infer_img:%s"
%
infer_list
[
i
])
img
=
next
(
blobs
)
if
loss_type
!=
"srn"
:
predict
=
exe
.
run
(
program
=
eval_prog
,
feed
=
{
"image"
:
img
},
fetch_list
=
fetch_varname_list
,
return_numpy
=
False
)
else
:
encoder_word_pos_list
=
[]
gsrm_word_pos_list
=
[]
gsrm_slf_attn_bias1_list
=
[]
gsrm_slf_attn_bias2_list
=
[]
encoder_word_pos_list
.
append
(
img
[
1
])
gsrm_word_pos_list
.
append
(
img
[
2
])
gsrm_slf_attn_bias1_list
.
append
(
img
[
3
])
gsrm_slf_attn_bias2_list
.
append
(
img
[
4
])
encoder_word_pos_list
=
np
.
concatenate
(
encoder_word_pos_list
,
axis
=
0
).
astype
(
np
.
int64
)
gsrm_word_pos_list
=
np
.
concatenate
(
gsrm_word_pos_list
,
axis
=
0
).
astype
(
np
.
int64
)
gsrm_slf_attn_bias1_list
=
np
.
concatenate
(
gsrm_slf_attn_bias1_list
,
axis
=
0
).
astype
(
np
.
float32
)
gsrm_slf_attn_bias2_list
=
np
.
concatenate
(
gsrm_slf_attn_bias2_list
,
axis
=
0
).
astype
(
np
.
float32
)
predict
=
exe
.
run
(
program
=
eval_prog
,
\
feed
=
{
'image'
:
img
[
0
],
'encoder_word_pos'
:
encoder_word_pos_list
,
'gsrm_word_pos'
:
gsrm_word_pos_list
,
'gsrm_slf_attn_bias1'
:
gsrm_slf_attn_bias1_list
,
'gsrm_slf_attn_bias2'
:
gsrm_slf_attn_bias2_list
},
\
fetch_list
=
fetch_varname_list
,
\
return_numpy
=
False
)
if
loss_type
==
"ctc"
:
preds
=
np
.
array
(
predict
[
0
])
preds
=
preds
.
reshape
(
-
1
)
...
...
@@ -114,7 +139,18 @@ def main():
score
=
np
.
mean
(
probs
[
0
,
1
:
end_pos
[
1
]])
preds
=
preds
.
reshape
(
-
1
)
preds_text
=
char_ops
.
decode
(
preds
)
elif
loss_type
==
"srn"
:
cur_pred
=
[]
preds
=
np
.
array
(
predict
[
0
])
preds
=
preds
.
reshape
(
-
1
)
probs
=
np
.
array
(
predict
[
1
])
ind
=
np
.
argmax
(
probs
,
axis
=
1
)
valid_ind
=
np
.
where
(
preds
!=
37
)[
0
]
if
len
(
valid_ind
)
==
0
:
continue
score
=
np
.
mean
(
probs
[
valid_ind
,
ind
[
valid_ind
]])
preds
=
preds
[:
valid_ind
[
-
1
]
+
1
]
preds_text
=
char_ops
.
decode
(
preds
)
logger
.
info
(
"
\t
index: {}"
.
format
(
preds
))
logger
.
info
(
"
\t
word : {}"
.
format
(
preds_text
))
logger
.
info
(
"
\t
score: {}"
.
format
(
score
))
...
...
tools/program.py
浏览文件 @
612e8014
...
...
@@ -32,7 +32,8 @@ from eval_utils.eval_det_utils import eval_det_run
from
eval_utils.eval_rec_utils
import
eval_rec_run
from
ppocr.utils.save_load
import
save_model
import
numpy
as
np
from
ppocr.utils.character
import
cal_predicts_accuracy
,
CharacterOps
from
ppocr.utils.character
import
cal_predicts_accuracy
,
cal_predicts_accuracy_srn
,
CharacterOps
class
ArgsParser
(
ArgumentParser
):
def
__init__
(
self
):
...
...
@@ -176,8 +177,16 @@ def build(config, main_prog, startup_prog, mode):
fetch_name_list
=
list
(
outputs
.
keys
())
fetch_varname_list
=
[
outputs
[
v
].
name
for
v
in
fetch_name_list
]
opt_loss_name
=
None
model_average
=
None
img_loss_name
=
None
word_loss_name
=
None
if
mode
==
"train"
:
opt_loss
=
outputs
[
'total_loss'
]
# srn loss
#img_loss = outputs['img_loss']
#word_loss = outputs['word_loss']
#img_loss_name = img_loss.name
#word_loss_name = word_loss.name
opt_params
=
config
[
'Optimizer'
]
optimizer
=
create_module
(
opt_params
[
'function'
])(
opt_params
)
optimizer
.
minimize
(
opt_loss
)
...
...
@@ -185,7 +194,17 @@ def build(config, main_prog, startup_prog, mode):
global_lr
=
optimizer
.
_global_learning_rate
()
fetch_name_list
.
insert
(
0
,
"lr"
)
fetch_varname_list
.
insert
(
0
,
global_lr
.
name
)
return
(
dataloader
,
fetch_name_list
,
fetch_varname_list
,
opt_loss_name
)
if
"loss_type"
in
config
[
"Global"
]:
if
config
[
'Global'
][
"loss_type"
]
==
'srn'
:
model_average
=
fluid
.
optimizer
.
ModelAverage
(
config
[
'Global'
][
'average_window'
],
min_average_window
=
config
[
'Global'
][
'min_average_window'
],
max_average_window
=
config
[
'Global'
][
'max_average_window'
])
return
(
dataloader
,
fetch_name_list
,
fetch_varname_list
,
opt_loss_name
,
model_average
)
def
build_export
(
config
,
main_prog
,
startup_prog
):
...
...
@@ -329,14 +348,20 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
lr
=
np
.
mean
(
np
.
array
(
train_outs
[
fetch_map
[
'lr'
]]))
preds_idx
=
fetch_map
[
'decoded_out'
]
preds
=
np
.
array
(
train_outs
[
preds_idx
])
preds_lod
=
train_outs
[
preds_idx
].
lod
()[
0
]
labels_idx
=
fetch_map
[
'label'
]
labels
=
np
.
array
(
train_outs
[
labels_idx
])
if
config
[
'Global'
][
'loss_type'
]
!=
'srn'
:
preds_lod
=
train_outs
[
preds_idx
].
lod
()[
0
]
labels_lod
=
train_outs
[
labels_idx
].
lod
()[
0
]
acc
,
acc_num
,
img_num
=
cal_predicts_accuracy
(
config
[
'Global'
][
'char_ops'
],
preds
,
preds_lod
,
labels
,
labels_lod
)
else
:
acc
,
acc_num
,
img_num
=
cal_predicts_accuracy_srn
(
config
[
'Global'
][
'char_ops'
],
preds
,
labels
,
config
[
'Global'
][
'max_text_length'
])
t2
=
time
.
time
()
train_batch_elapse
=
t2
-
t1
stats
=
{
'loss'
:
loss
,
'acc'
:
acc
}
...
...
@@ -350,6 +375,9 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
if
train_batch_id
>
0
and
\
train_batch_id
%
eval_batch_step
==
0
:
model_average
=
train_info_dict
[
'model_average'
]
if
model_average
!=
None
:
model_average
.
apply
(
exe
)
metrics
=
eval_rec_run
(
exe
,
config
,
eval_info_dict
,
"eval"
)
eval_acc
=
metrics
[
'avg_acc'
]
eval_sample_num
=
metrics
[
'total_sample_num'
]
...
...
@@ -375,6 +403,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
save_model
(
train_info_dict
[
'train_program'
],
save_path
)
return
def
preprocess
():
FLAGS
=
ArgsParser
().
parse_args
()
config
=
load_config
(
FLAGS
.
config
)
...
...
@@ -386,8 +415,8 @@ def preprocess():
check_gpu
(
use_gpu
)
alg
=
config
[
'Global'
][
'algorithm'
]
assert
alg
in
[
'EAST'
,
'DB'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
]
if
alg
in
[
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
]:
assert
alg
in
[
'EAST'
,
'DB'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
]
if
alg
in
[
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
]:
config
[
'Global'
][
'char_ops'
]
=
CharacterOps
(
config
[
'Global'
])
place
=
fluid
.
CUDAPlace
(
0
)
if
use_gpu
else
fluid
.
CPUPlace
()
...
...
tools/train.py
浏览文件 @
612e8014
...
...
@@ -52,6 +52,7 @@ def main():
train_fetch_name_list
=
train_build_outputs
[
1
]
train_fetch_varname_list
=
train_build_outputs
[
2
]
train_opt_loss_name
=
train_build_outputs
[
3
]
model_average
=
train_build_outputs
[
-
1
]
eval_program
=
fluid
.
Program
()
eval_build_outputs
=
program
.
build
(
...
...
@@ -85,7 +86,8 @@ def main():
'train_program'
:
train_program
,
\
'reader'
:
train_loader
,
\
'fetch_name_list'
:
train_fetch_name_list
,
\
'fetch_varname_list'
:
train_fetch_varname_list
}
'fetch_varname_list'
:
train_fetch_varname_list
,
\
'model_average'
:
model_average
}
eval_info_dict
=
{
'program'
:
eval_program
,
\
'reader'
:
eval_reader
,
\
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录