Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
612e8014
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
612e8014
编写于
8月 16, 2020
作者:
M
MissPenguin
提交者:
GitHub
8月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #537 from tink2123/add_srn
Add SRN
上级
bad9f6cd
9c893102
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
1962 addition
and
48 deletion
+1962
-48
README_cn.md
README_cn.md
+5
-1
configs/rec/rec_r50fpn_vd_none_srn.yml
configs/rec/rec_r50fpn_vd_none_srn.yml
+49
-0
doc/doc_ch/config.md
doc/doc_ch/config.md
+3
-0
ppocr/data/rec/dataset_traversal.py
ppocr/data/rec/dataset_traversal.py
+36
-15
ppocr/data/rec/img_tools.py
ppocr/data/rec/img_tools.py
+81
-0
ppocr/modeling/architectures/rec_model.py
ppocr/modeling/architectures/rec_model.py
+100
-4
ppocr/modeling/backbones/rec_resnet50_fpn.py
ppocr/modeling/backbones/rec_resnet50_fpn.py
+172
-0
ppocr/modeling/heads/rec_srn_all_head.py
ppocr/modeling/heads/rec_srn_all_head.py
+230
-0
ppocr/modeling/heads/self_attention/__init__.py
ppocr/modeling/heads/self_attention/__init__.py
+0
-0
ppocr/modeling/heads/self_attention/model.py
ppocr/modeling/heads/self_attention/model.py
+1058
-0
ppocr/modeling/losses/rec_srn_loss.py
ppocr/modeling/losses/rec_srn_loss.py
+55
-0
ppocr/utils/character.py
ppocr/utils/character.py
+38
-0
tools/eval_utils/eval_rec_utils.py
tools/eval_utils/eval_rec_utils.py
+49
-11
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+2
-1
tools/infer/utility.py
tools/infer/utility.py
+1
-0
tools/infer_rec.py
tools/infer_rec.py
+42
-6
tools/program.py
tools/program.py
+38
-9
tools/train.py
tools/train.py
+3
-1
未找到文件。
README_cn.md
浏览文件 @
612e8014
...
@@ -122,7 +122,10 @@ PaddleOCR开源的文本识别算法列表:
...
@@ -122,7 +122,10 @@ PaddleOCR开源的文本识别算法列表:
-
[
x] Rosetta([paper
](
https://arxiv.org/abs/1910.05085
)
)
-
[
x] Rosetta([paper
](
https://arxiv.org/abs/1910.05085
)
)
-
[
x] STAR-Net([paper
](
http://www.bmva.org/bmvc/2016/papers/paper043/index.html
)
)
-
[
x] STAR-Net([paper
](
http://www.bmva.org/bmvc/2016/papers/paper043/index.html
)
)
-
[
x] RARE([paper
](
https://arxiv.org/abs/1603.03915v1
)
)
-
[
x] RARE([paper
](
https://arxiv.org/abs/1603.03915v1
)
)
-
[
] SRN([paper
](
https://arxiv.org/abs/2003.12294
)
)(百度自研, coming soon)
-
[
x] SRN([paper
](
https://arxiv.org/abs/2003.12294
)
)(百度自研)
*备注:*
SRN模型使用了数据扰动方法对上述提到对两个训练集进行增广,增广后的数据可以在
[
百度网盘
](
todo
)
上下载。
原始论文使用两阶段训练平均精度为89.74%,PaddleOCR中使用one-stage训练,平均精度为88.33%。两种预训练权重均在
[
下载链接
](
https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar
)
中。
参考
[
DTRB
](
https://arxiv.org/abs/1904.01906
)
文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
参考
[
DTRB
](
https://arxiv.org/abs/1904.01906
)
文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
...
@@ -136,6 +139,7 @@ PaddleOCR开源的文本识别算法列表:
...
@@ -136,6 +139,7 @@ PaddleOCR开源的文本识别算法列表:
|STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|
[
下载链接
](
https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_ctc.tar
)
|
|STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|
[
下载链接
](
https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_ctc.tar
)
|
|RARE|Resnet34_vd|84.90%|rec_r34_vd_tps_bilstm_attn|
[
下载链接
](
https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_attn.tar
)
|
|RARE|Resnet34_vd|84.90%|rec_r34_vd_tps_bilstm_attn|
[
下载链接
](
https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_attn.tar
)
|
|RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|
[
下载链接
](
https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar
)
|
|RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|
[
下载链接
](
https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar
)
|
|SRN|Resnet50_vd_fpn|88.33%|rec_r50fpn_vd_none_srn|
[
下载链接
](
https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar
)
|
使用
[
LSVT
](
https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/datasets.md#1icdar2019-lsvt
)
街景数据集根据真值将图crop出来30w数据,进行位置校准。此外基于LSVT语料生成500w合成数据训练中文模型,相关配置和预训练文件如下:
使用
[
LSVT
](
https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/datasets.md#1icdar2019-lsvt
)
街景数据集根据真值将图crop出来30w数据,进行位置校准。此外基于LSVT语料生成500w合成数据训练中文模型,相关配置和预训练文件如下:
...
...
configs/rec/rec_r50fpn_vd_none_srn.yml
0 → 100755
浏览文件 @
612e8014
Global
:
algorithm
:
SRN
use_gpu
:
true
epoch_num
:
72
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
output/rec_pvam_withrotate
save_epoch_step
:
1
eval_batch_step
:
8000
train_batch_size_per_card
:
64
test_batch_size_per_card
:
1
image_shape
:
[
1
,
64
,
256
]
max_text_length
:
25
character_type
:
en
loss_type
:
srn
num_heads
:
8
average_window
:
0.15
max_average_window
:
15625
min_average_window
:
10000
reader_yml
:
./configs/rec/rec_benchmark_reader.yml
pretrain_weights
:
checkpoints
:
save_inference_dir
:
infer_img
:
Architecture
:
function
:
ppocr.modeling.architectures.rec_model,RecModel
Backbone
:
function
:
ppocr.modeling.backbones.rec_resnet50_fpn,ResNet
layers
:
50
Head
:
function
:
ppocr.modeling.heads.rec_srn_all_head,SRNPredict
encoder_type
:
rnn
num_encoder_TUs
:
2
num_decoder_TUs
:
4
hidden_dims
:
512
SeqRNN
:
hidden_size
:
256
Loss
:
function
:
ppocr.modeling.losses.rec_srn_loss,SRNLoss
Optimizer
:
function
:
ppocr.optimizer,AdamDecay
base_lr
:
0.0001
beta1
:
0.9
beta2
:
0.999
doc/doc_ch/config.md
浏览文件 @
612e8014
...
@@ -32,6 +32,9 @@
...
@@ -32,6 +32,9 @@
| loss_type | 设置 loss 类型 | ctc | 支持两种loss: ctc / attention |
| loss_type | 设置 loss 类型 | ctc | 支持两种loss: ctc / attention |
| distort | 设置是否使用数据增强 | false | 设置为true时,将在训练时随机进行扰动,支持的扰动操作可阅读
[
img_tools.py
](
https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py
)
|
| distort | 设置是否使用数据增强 | false | 设置为true时,将在训练时随机进行扰动,支持的扰动操作可阅读
[
img_tools.py
](
https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py
)
|
| use_space_char | 设置是否识别空格 | false | 仅在 character_type=ch 时支持空格 |
| use_space_char | 设置是否识别空格 | false | 仅在 character_type=ch 时支持空格 |
| average_window | ModelAverage优化器中的窗口长度计算比例 | 0.15 | 目前仅应用与SRN |
| max_average_window | 平均值计算窗口长度的最大值 | 15625 | 推荐设置为一轮训练中mini-batchs的数目|
| min_average_window | 平均值计算窗口长度的最小值 | 10000 |
\
|
| reader_yml | 设置reader配置文件 | ./configs/rec/rec_icdar15_reader.yml |
\
|
| reader_yml | 设置reader配置文件 | ./configs/rec/rec_icdar15_reader.yml |
\
|
| pretrain_weights | 加载预训练模型路径 | ./pretrain_models/CRNN/best_accuracy |
\
|
| pretrain_weights | 加载预训练模型路径 | ./pretrain_models/CRNN/best_accuracy |
\
|
| checkpoints | 加载模型参数路径 | None | 用于中断后加载参数继续训练 |
| checkpoints | 加载模型参数路径 | None | 用于中断后加载参数继续训练 |
...
...
ppocr/data/rec/dataset_traversal.py
浏览文件 @
612e8014
...
@@ -26,7 +26,7 @@ from ppocr.utils.utility import initial_logger
...
@@ -26,7 +26,7 @@ from ppocr.utils.utility import initial_logger
from
ppocr.utils.utility
import
get_image_file_list
from
ppocr.utils.utility
import
get_image_file_list
logger
=
initial_logger
()
logger
=
initial_logger
()
from
.img_tools
import
process_image
,
get_img_data
from
.img_tools
import
process_image
,
process_image_srn
,
get_img_data
class
LMDBReader
(
object
):
class
LMDBReader
(
object
):
...
@@ -43,6 +43,9 @@ class LMDBReader(object):
...
@@ -43,6 +43,9 @@ class LMDBReader(object):
self
.
mode
=
params
[
'mode'
]
self
.
mode
=
params
[
'mode'
]
self
.
drop_last
=
False
self
.
drop_last
=
False
self
.
use_tps
=
False
self
.
use_tps
=
False
self
.
num_heads
=
None
if
"num_heads"
in
params
:
self
.
num_heads
=
params
[
'num_heads'
]
if
"tps"
in
params
:
if
"tps"
in
params
:
self
.
ues_tps
=
True
self
.
ues_tps
=
True
self
.
use_distort
=
False
self
.
use_distort
=
False
...
@@ -119,12 +122,19 @@ class LMDBReader(object):
...
@@ -119,12 +122,19 @@ class LMDBReader(object):
img
=
cv2
.
imread
(
single_img
)
img
=
cv2
.
imread
(
single_img
)
if
img
.
shape
[
-
1
]
==
1
or
len
(
list
(
img
.
shape
))
==
2
:
if
img
.
shape
[
-
1
]
==
1
or
len
(
list
(
img
.
shape
))
==
2
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
norm_img
=
process_image
(
if
self
.
loss_type
==
'srn'
:
img
=
img
,
norm_img
=
process_image_srn
(
image_shape
=
self
.
image_shape
,
img
=
img
,
char_ops
=
self
.
char_ops
,
image_shape
=
self
.
image_shape
,
tps
=
self
.
use_tps
,
num_heads
=
self
.
num_heads
,
infer_mode
=
True
)
max_text_length
=
self
.
max_text_length
)
else
:
norm_img
=
process_image
(
img
=
img
,
image_shape
=
self
.
image_shape
,
char_ops
=
self
.
char_ops
,
tps
=
self
.
use_tps
,
infer_mode
=
True
)
yield
norm_img
yield
norm_img
else
:
else
:
lmdb_sets
=
self
.
load_hierarchical_lmdb_dataset
()
lmdb_sets
=
self
.
load_hierarchical_lmdb_dataset
()
...
@@ -144,14 +154,25 @@ class LMDBReader(object):
...
@@ -144,14 +154,25 @@ class LMDBReader(object):
if
sample_info
is
None
:
if
sample_info
is
None
:
continue
continue
img
,
label
=
sample_info
img
,
label
=
sample_info
outs
=
process_image
(
outs
=
[]
img
=
img
,
if
self
.
loss_type
==
"srn"
:
image_shape
=
self
.
image_shape
,
outs
=
process_image_srn
(
label
=
label
,
img
=
img
,
char_ops
=
self
.
char_ops
,
image_shape
=
self
.
image_shape
,
loss_type
=
self
.
loss_type
,
num_heads
=
self
.
num_heads
,
max_text_length
=
self
.
max_text_length
,
max_text_length
=
self
.
max_text_length
,
distort
=
self
.
use_distort
)
label
=
label
,
char_ops
=
self
.
char_ops
,
loss_type
=
self
.
loss_type
)
else
:
outs
=
process_image
(
img
=
img
,
image_shape
=
self
.
image_shape
,
label
=
label
,
char_ops
=
self
.
char_ops
,
loss_type
=
self
.
loss_type
,
max_text_length
=
self
.
max_text_length
)
if
outs
is
None
:
if
outs
is
None
:
continue
continue
yield
outs
yield
outs
...
...
ppocr/data/rec/img_tools.py
浏览文件 @
612e8014
...
@@ -381,3 +381,84 @@ def process_image(img,
...
@@ -381,3 +381,84 @@ def process_image(img,
assert
False
,
"Unsupport loss_type %s in process_image"
\
assert
False
,
"Unsupport loss_type %s in process_image"
\
%
loss_type
%
loss_type
return
(
norm_img
)
return
(
norm_img
)
def
resize_norm_img_srn
(
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
img_black
=
np
.
zeros
((
imgH
,
imgW
))
im_hei
=
img
.
shape
[
0
]
im_wid
=
img
.
shape
[
1
]
if
im_wid
<=
im_hei
*
1
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
1
,
imgH
))
elif
im_wid
<=
im_hei
*
2
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
2
,
imgH
))
elif
im_wid
<=
im_hei
*
3
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
3
,
imgH
))
else
:
img_new
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
))
img_np
=
np
.
asarray
(
img_new
)
img_np
=
cv2
.
cvtColor
(
img_np
,
cv2
.
COLOR_BGR2GRAY
)
img_black
[:,
0
:
img_np
.
shape
[
1
]]
=
img_np
img_black
=
img_black
[:,
:,
np
.
newaxis
]
row
,
col
,
c
=
img_black
.
shape
c
=
1
return
np
.
reshape
(
img_black
,
(
c
,
row
,
col
)).
astype
(
np
.
float32
)
def
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
):
imgC
,
imgH
,
imgW
=
image_shape
feature_dim
=
int
((
imgH
/
8
)
*
(
imgW
/
8
))
encoder_word_pos
=
np
.
array
(
range
(
0
,
feature_dim
)).
reshape
((
feature_dim
,
1
)).
astype
(
'int64'
)
gsrm_word_pos
=
np
.
array
(
range
(
0
,
max_text_length
)).
reshape
((
max_text_length
,
1
)).
astype
(
'int64'
)
lbl_weight
=
np
.
array
([
37
]
*
max_text_length
).
reshape
((
-
1
,
1
)).
astype
(
'int64'
)
gsrm_attn_bias_data
=
np
.
ones
((
1
,
max_text_length
,
max_text_length
))
gsrm_slf_attn_bias1
=
np
.
triu
(
gsrm_attn_bias_data
,
1
).
reshape
([
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias1
=
np
.
tile
(
gsrm_slf_attn_bias1
,
[
1
,
num_heads
,
1
,
1
])
*
[
-
1e9
]
gsrm_slf_attn_bias2
=
np
.
tril
(
gsrm_attn_bias_data
,
-
1
).
reshape
([
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias2
=
np
.
tile
(
gsrm_slf_attn_bias2
,
[
1
,
num_heads
,
1
,
1
])
*
[
-
1e9
]
encoder_word_pos
=
encoder_word_pos
[
np
.
newaxis
,
:]
gsrm_word_pos
=
gsrm_word_pos
[
np
.
newaxis
,
:]
return
[
lbl_weight
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
def
process_image_srn
(
img
,
image_shape
,
num_heads
,
max_text_length
,
label
=
None
,
char_ops
=
None
,
loss_type
=
None
):
norm_img
=
resize_norm_img_srn
(
img
,
image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
[
lbl_weight
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
=
\
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
)
if
label
is
not
None
:
char_num
=
char_ops
.
get_char_num
()
text
=
char_ops
.
encode
(
label
)
if
len
(
text
)
==
0
or
len
(
text
)
>
max_text_length
:
return
None
else
:
if
loss_type
==
"srn"
:
text_padded
=
[
37
]
*
max_text_length
for
i
in
range
(
len
(
text
)):
text_padded
[
i
]
=
text
[
i
]
lbl_weight
[
i
]
=
[
1.0
]
text_padded
=
np
.
array
(
text_padded
)
text
=
text_padded
.
reshape
(
-
1
,
1
)
return
(
norm_img
,
text
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
,
lbl_weight
)
else
:
assert
False
,
"Unsupport loss_type %s in process_image"
\
%
loss_type
return
(
norm_img
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
)
ppocr/modeling/architectures/rec_model.py
浏览文件 @
612e8014
...
@@ -58,6 +58,10 @@ class RecModel(object):
...
@@ -58,6 +58,10 @@ class RecModel(object):
self
.
loss_type
=
global_params
[
'loss_type'
]
self
.
loss_type
=
global_params
[
'loss_type'
]
self
.
image_shape
=
global_params
[
'image_shape'
]
self
.
image_shape
=
global_params
[
'image_shape'
]
self
.
max_text_length
=
global_params
[
'max_text_length'
]
self
.
max_text_length
=
global_params
[
'max_text_length'
]
if
"num_heads"
in
params
:
self
.
num_heads
=
global_params
[
"num_heads"
]
else
:
self
.
num_heads
=
None
def
create_feed
(
self
,
mode
):
def
create_feed
(
self
,
mode
):
image_shape
=
deepcopy
(
self
.
image_shape
)
image_shape
=
deepcopy
(
self
.
image_shape
)
...
@@ -77,6 +81,48 @@ class RecModel(object):
...
@@ -77,6 +81,48 @@ class RecModel(object):
lod_level
=
1
)
lod_level
=
1
)
feed_list
=
[
image
,
label_in
,
label_out
]
feed_list
=
[
image
,
label_in
,
label_out
]
labels
=
{
'label_in'
:
label_in
,
'label_out'
:
label_out
}
labels
=
{
'label_in'
:
label_in
,
'label_out'
:
label_out
}
elif
self
.
loss_type
==
"srn"
:
encoder_word_pos
=
fluid
.
data
(
name
=
"encoder_word_pos"
,
shape
=
[
-
1
,
int
((
image_shape
[
-
2
]
/
8
)
*
(
image_shape
[
-
1
]
/
8
)),
1
],
dtype
=
"int64"
)
gsrm_word_pos
=
fluid
.
data
(
name
=
"gsrm_word_pos"
,
shape
=
[
-
1
,
self
.
max_text_length
,
1
],
dtype
=
"int64"
)
gsrm_slf_attn_bias1
=
fluid
.
data
(
name
=
"gsrm_slf_attn_bias1"
,
shape
=
[
-
1
,
self
.
num_heads
,
self
.
max_text_length
,
self
.
max_text_length
],
dtype
=
"float32"
)
gsrm_slf_attn_bias2
=
fluid
.
data
(
name
=
"gsrm_slf_attn_bias2"
,
shape
=
[
-
1
,
self
.
num_heads
,
self
.
max_text_length
,
self
.
max_text_length
],
dtype
=
"float32"
)
lbl_weight
=
fluid
.
layers
.
data
(
name
=
"lbl_weight"
,
shape
=
[
-
1
,
1
],
dtype
=
'int64'
)
label
=
fluid
.
data
(
name
=
'label'
,
shape
=
[
-
1
,
1
],
dtype
=
'int32'
,
lod_level
=
1
)
feed_list
=
[
image
,
label
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
,
lbl_weight
]
labels
=
{
'label'
:
label
,
'encoder_word_pos'
:
encoder_word_pos
,
'gsrm_word_pos'
:
gsrm_word_pos
,
'gsrm_slf_attn_bias1'
:
gsrm_slf_attn_bias1
,
'gsrm_slf_attn_bias2'
:
gsrm_slf_attn_bias2
,
'lbl_weight'
:
lbl_weight
}
else
:
else
:
label
=
fluid
.
data
(
label
=
fluid
.
data
(
name
=
'label'
,
shape
=
[
None
,
1
],
dtype
=
'int32'
,
lod_level
=
1
)
name
=
'label'
,
shape
=
[
None
,
1
],
dtype
=
'int32'
,
lod_level
=
1
)
...
@@ -88,6 +134,8 @@ class RecModel(object):
...
@@ -88,6 +134,8 @@ class RecModel(object):
use_double_buffer
=
True
,
use_double_buffer
=
True
,
iterable
=
False
)
iterable
=
False
)
else
:
else
:
labels
=
None
loader
=
None
if
self
.
char_type
==
"ch"
and
self
.
infer_img
:
if
self
.
char_type
==
"ch"
and
self
.
infer_img
:
image_shape
[
-
1
]
=
-
1
image_shape
[
-
1
]
=
-
1
if
self
.
tps
!=
None
:
if
self
.
tps
!=
None
:
...
@@ -98,8 +146,42 @@ class RecModel(object):
...
@@ -98,8 +146,42 @@ class RecModel(object):
)
)
image_shape
=
deepcopy
(
self
.
image_shape
)
image_shape
=
deepcopy
(
self
.
image_shape
)
image
=
fluid
.
data
(
name
=
'image'
,
shape
=
image_shape
,
dtype
=
'float32'
)
image
=
fluid
.
data
(
name
=
'image'
,
shape
=
image_shape
,
dtype
=
'float32'
)
labels
=
None
if
self
.
loss_type
==
"srn"
:
loader
=
None
encoder_word_pos
=
fluid
.
data
(
name
=
"encoder_word_pos"
,
shape
=
[
-
1
,
int
((
image_shape
[
-
2
]
/
8
)
*
(
image_shape
[
-
1
]
/
8
)),
1
],
dtype
=
"int64"
)
gsrm_word_pos
=
fluid
.
data
(
name
=
"gsrm_word_pos"
,
shape
=
[
-
1
,
self
.
max_text_length
,
1
],
dtype
=
"int64"
)
gsrm_slf_attn_bias1
=
fluid
.
data
(
name
=
"gsrm_slf_attn_bias1"
,
shape
=
[
-
1
,
self
.
num_heads
,
self
.
max_text_length
,
self
.
max_text_length
],
dtype
=
"float32"
)
gsrm_slf_attn_bias2
=
fluid
.
data
(
name
=
"gsrm_slf_attn_bias2"
,
shape
=
[
-
1
,
self
.
num_heads
,
self
.
max_text_length
,
self
.
max_text_length
],
dtype
=
"float32"
)
feed_list
=
[
image
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
labels
=
{
'encoder_word_pos'
:
encoder_word_pos
,
'gsrm_word_pos'
:
gsrm_word_pos
,
'gsrm_slf_attn_bias1'
:
gsrm_slf_attn_bias1
,
'gsrm_slf_attn_bias2'
:
gsrm_slf_attn_bias2
}
return
image
,
labels
,
loader
return
image
,
labels
,
loader
def
__call__
(
self
,
mode
):
def
__call__
(
self
,
mode
):
...
@@ -117,13 +199,27 @@ class RecModel(object):
...
@@ -117,13 +199,27 @@ class RecModel(object):
label
=
labels
[
'label_out'
]
label
=
labels
[
'label_out'
]
else
:
else
:
label
=
labels
[
'label'
]
label
=
labels
[
'label'
]
outputs
=
{
'total_loss'
:
loss
,
'decoded_out'
:
\
if
self
.
loss_type
==
'srn'
:
decoded_out
,
'label'
:
label
}
total_loss
,
img_loss
,
word_loss
=
self
.
loss
(
predicts
,
labels
)
outputs
=
{
'total_loss'
:
total_loss
,
'img_loss'
:
img_loss
,
'word_loss'
:
word_loss
,
'decoded_out'
:
decoded_out
,
'label'
:
label
}
else
:
outputs
=
{
'total_loss'
:
loss
,
'decoded_out'
:
\
decoded_out
,
'label'
:
label
}
return
loader
,
outputs
return
loader
,
outputs
elif
mode
==
"export"
:
elif
mode
==
"export"
:
predict
=
predicts
[
'predict'
]
predict
=
predicts
[
'predict'
]
if
self
.
loss_type
==
"ctc"
:
if
self
.
loss_type
==
"ctc"
:
predict
=
fluid
.
layers
.
softmax
(
predict
)
predict
=
fluid
.
layers
.
softmax
(
predict
)
if
self
.
loss_type
==
"srn"
:
raise
Exception
(
"Warning! SRN does not support export model currently"
)
return
[
image
,
{
'decoded_out'
:
decoded_out
,
'predicts'
:
predict
}]
return
[
image
,
{
'decoded_out'
:
decoded_out
,
'predicts'
:
predict
}]
else
:
else
:
predict
=
predicts
[
'predict'
]
predict
=
predicts
[
'predict'
]
...
...
ppocr/modeling/backbones/rec_resnet50_fpn.py
0 → 100755
浏览文件 @
612e8014
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
__all__
=
[
"ResNet"
,
"ResNet18"
,
"ResNet34"
,
"ResNet50"
,
"ResNet101"
,
"ResNet152"
]
Trainable
=
True
w_nolr
=
fluid
.
ParamAttr
(
trainable
=
Trainable
)
train_parameters
=
{
"input_size"
:
[
3
,
224
,
224
],
"input_mean"
:
[
0.485
,
0.456
,
0.406
],
"input_std"
:
[
0.229
,
0.224
,
0.225
],
"learning_strategy"
:
{
"name"
:
"piecewise_decay"
,
"batch_size"
:
256
,
"epochs"
:
[
30
,
60
,
90
],
"steps"
:
[
0.1
,
0.01
,
0.001
,
0.0001
]
}
}
class
ResNet
():
def
__init__
(
self
,
params
):
self
.
layers
=
params
[
'layers'
]
self
.
params
=
train_parameters
def
__call__
(
self
,
input
):
layers
=
self
.
layers
supported_layers
=
[
18
,
34
,
50
,
101
,
152
]
assert
layers
in
supported_layers
,
\
"supported layers are {} but input layer is {}"
.
format
(
supported_layers
,
layers
)
if
layers
==
18
:
depth
=
[
2
,
2
,
2
,
2
]
elif
layers
==
34
or
layers
==
50
:
depth
=
[
3
,
4
,
6
,
3
]
elif
layers
==
101
:
depth
=
[
3
,
4
,
23
,
3
]
elif
layers
==
152
:
depth
=
[
3
,
8
,
36
,
3
]
stride_list
=
[(
2
,
2
),(
2
,
2
),(
1
,
1
),(
1
,
1
)]
num_filters
=
[
64
,
128
,
256
,
512
]
conv
=
self
.
conv_bn_layer
(
input
=
input
,
num_filters
=
64
,
filter_size
=
7
,
stride
=
2
,
act
=
'relu'
,
name
=
"conv1"
)
F
=
[]
if
layers
>=
50
:
for
block
in
range
(
len
(
depth
)):
for
i
in
range
(
depth
[
block
]):
if
layers
in
[
101
,
152
]
and
block
==
2
:
if
i
==
0
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"a"
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"b"
+
str
(
i
)
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
conv
=
self
.
bottleneck_block
(
input
=
conv
,
num_filters
=
num_filters
[
block
],
stride
=
stride_list
[
block
]
if
i
==
0
else
1
,
name
=
conv_name
)
F
.
append
(
conv
)
base
=
F
[
-
1
]
for
i
in
[
-
2
,
-
3
]:
b
,
c
,
w
,
h
=
F
[
i
].
shape
if
(
w
,
h
)
==
base
.
shape
[
2
:]:
base
=
base
else
:
base
=
fluid
.
layers
.
conv2d_transpose
(
input
=
base
,
num_filters
=
c
,
filter_size
=
4
,
stride
=
2
,
padding
=
1
,
act
=
None
,
param_attr
=
w_nolr
,
bias_attr
=
w_nolr
)
base
=
fluid
.
layers
.
batch_norm
(
base
,
act
=
"relu"
,
param_attr
=
w_nolr
,
bias_attr
=
w_nolr
)
base
=
fluid
.
layers
.
concat
([
base
,
F
[
i
]],
axis
=
1
)
base
=
fluid
.
layers
.
conv2d
(
base
,
num_filters
=
c
,
filter_size
=
1
,
param_attr
=
w_nolr
,
bias_attr
=
w_nolr
)
base
=
fluid
.
layers
.
conv2d
(
base
,
num_filters
=
c
,
filter_size
=
3
,
padding
=
1
,
param_attr
=
w_nolr
,
bias_attr
=
w_nolr
)
base
=
fluid
.
layers
.
batch_norm
(
base
,
act
=
"relu"
,
param_attr
=
w_nolr
,
bias_attr
=
w_nolr
)
base
=
fluid
.
layers
.
conv2d
(
base
,
num_filters
=
512
,
filter_size
=
1
,
bias_attr
=
w_nolr
,
param_attr
=
w_nolr
)
return
base
def
conv_bn_layer
(
self
,
input
,
num_filters
,
filter_size
,
stride
=
1
,
groups
=
1
,
act
=
None
,
name
=
None
):
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
2
if
stride
==
(
1
,
1
)
else
filter_size
,
dilation
=
2
if
stride
==
(
1
,
1
)
else
1
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
groups
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
name
+
"_weights"
,
trainable
=
Trainable
),
bias_attr
=
False
,
name
=
name
+
'.conv2d.output.1'
)
if
name
==
"conv1"
:
bn_name
=
"bn_"
+
name
else
:
bn_name
=
"bn"
+
name
[
3
:]
return
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
act
,
name
=
bn_name
+
'.output.1'
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
,
trainable
=
Trainable
),
bias_attr
=
ParamAttr
(
bn_name
+
'_offset'
,
trainable
=
Trainable
),
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
,
)
def
shortcut
(
self
,
input
,
ch_out
,
stride
,
is_first
,
name
):
ch_in
=
input
.
shape
[
1
]
if
ch_in
!=
ch_out
or
stride
!=
1
or
is_first
==
True
:
if
stride
==
(
1
,
1
):
return
self
.
conv_bn_layer
(
input
,
ch_out
,
1
,
1
,
name
=
name
)
else
:
#stride == (2,2)
return
self
.
conv_bn_layer
(
input
,
ch_out
,
1
,
stride
,
name
=
name
)
else
:
return
input
def
bottleneck_block
(
self
,
input
,
num_filters
,
stride
,
name
):
conv0
=
self
.
conv_bn_layer
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
1
,
act
=
'relu'
,
name
=
name
+
"_branch2a"
)
conv1
=
self
.
conv_bn_layer
(
input
=
conv0
,
num_filters
=
num_filters
,
filter_size
=
3
,
stride
=
stride
,
act
=
'relu'
,
name
=
name
+
"_branch2b"
)
conv2
=
self
.
conv_bn_layer
(
input
=
conv1
,
num_filters
=
num_filters
*
4
,
filter_size
=
1
,
act
=
None
,
name
=
name
+
"_branch2c"
)
short
=
self
.
shortcut
(
input
,
num_filters
*
4
,
stride
,
is_first
=
False
,
name
=
name
+
"_branch1"
)
return
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
conv2
,
act
=
'relu'
,
name
=
name
+
".add.output.5"
)
def
basic_block
(
self
,
input
,
num_filters
,
stride
,
is_first
,
name
):
conv0
=
self
.
conv_bn_layer
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
3
,
act
=
'relu'
,
stride
=
stride
,
name
=
name
+
"_branch2a"
)
conv1
=
self
.
conv_bn_layer
(
input
=
conv0
,
num_filters
=
num_filters
,
filter_size
=
3
,
act
=
None
,
name
=
name
+
"_branch2b"
)
short
=
self
.
shortcut
(
input
,
num_filters
,
stride
,
is_first
,
name
=
name
+
"_branch1"
)
return
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
conv1
,
act
=
'relu'
)
ppocr/modeling/heads/rec_srn_all_head.py
0 → 100755
浏览文件 @
612e8014
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
import
numpy
as
np
from
.self_attention.model
import
wrap_encoder
from
.self_attention.model
import
wrap_encoder_forFeature
gradient_clip
=
10
class
SRNPredict
(
object
):
def
__init__
(
self
,
params
):
super
(
SRNPredict
,
self
).
__init__
()
self
.
char_num
=
params
[
'char_num'
]
self
.
max_length
=
params
[
'max_text_length'
]
self
.
num_heads
=
params
[
'num_heads'
]
self
.
num_encoder_TUs
=
params
[
'num_encoder_TUs'
]
self
.
num_decoder_TUs
=
params
[
'num_decoder_TUs'
]
self
.
hidden_dims
=
params
[
'hidden_dims'
]
def
pvam
(
self
,
inputs
,
others
):
b
,
c
,
h
,
w
=
inputs
.
shape
conv_features
=
fluid
.
layers
.
reshape
(
x
=
inputs
,
shape
=
[
-
1
,
c
,
h
*
w
])
conv_features
=
fluid
.
layers
.
transpose
(
x
=
conv_features
,
perm
=
[
0
,
2
,
1
])
#===== Transformer encoder =====
b
,
t
,
c
=
conv_features
.
shape
encoder_word_pos
=
others
[
"encoder_word_pos"
]
gsrm_word_pos
=
others
[
"gsrm_word_pos"
]
enc_inputs
=
[
conv_features
,
encoder_word_pos
,
None
]
word_features
=
wrap_encoder_forFeature
(
src_vocab_size
=-
1
,
max_length
=
t
,
n_layer
=
self
.
num_encoder_TUs
,
n_head
=
self
.
num_heads
,
d_key
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_value
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_model
=
self
.
hidden_dims
,
d_inner_hid
=
self
.
hidden_dims
,
prepostprocess_dropout
=
0.1
,
attention_dropout
=
0.1
,
relu_dropout
=
0.1
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
,
weight_sharing
=
True
,
enc_inputs
=
enc_inputs
,
)
fluid
.
clip
.
set_gradient_clip
(
fluid
.
clip
.
GradientClipByValue
(
gradient_clip
))
#===== Parallel Visual Attention Module =====
b
,
t
,
c
=
word_features
.
shape
word_features
=
fluid
.
layers
.
fc
(
word_features
,
c
,
num_flatten_dims
=
2
)
word_features_
=
fluid
.
layers
.
reshape
(
word_features
,
[
-
1
,
1
,
t
,
c
])
word_features_
=
fluid
.
layers
.
expand
(
word_features_
,
[
1
,
self
.
max_length
,
1
,
1
])
word_pos_feature
=
fluid
.
layers
.
embedding
(
gsrm_word_pos
,
[
self
.
max_length
,
c
])
word_pos_
=
fluid
.
layers
.
reshape
(
word_pos_feature
,
[
-
1
,
self
.
max_length
,
1
,
c
])
word_pos_
=
fluid
.
layers
.
expand
(
word_pos_
,
[
1
,
1
,
t
,
1
])
temp
=
fluid
.
layers
.
elementwise_add
(
word_features_
,
word_pos_
,
act
=
'tanh'
)
attention_weight
=
fluid
.
layers
.
fc
(
input
=
temp
,
size
=
1
,
num_flatten_dims
=
3
,
bias_attr
=
False
)
attention_weight
=
fluid
.
layers
.
reshape
(
x
=
attention_weight
,
shape
=
[
-
1
,
self
.
max_length
,
t
])
attention_weight
=
fluid
.
layers
.
softmax
(
input
=
attention_weight
,
axis
=-
1
)
pvam_features
=
fluid
.
layers
.
matmul
(
attention_weight
,
word_features
)
#[b, max_length, c]
return
pvam_features
def
gsrm
(
self
,
pvam_features
,
others
):
#===== GSRM Visual-to-semantic embedding block =====
b
,
t
,
c
=
pvam_features
.
shape
word_out
=
fluid
.
layers
.
fc
(
input
=
fluid
.
layers
.
reshape
(
pvam_features
,
[
-
1
,
c
]),
size
=
self
.
char_num
,
act
=
"softmax"
)
#word_out.stop_gradient = True
word_ids
=
fluid
.
layers
.
argmax
(
word_out
,
axis
=
1
)
word_ids
.
stop_gradient
=
True
word_ids
=
fluid
.
layers
.
reshape
(
x
=
word_ids
,
shape
=
[
-
1
,
t
,
1
])
#===== GSRM Semantic reasoning block =====
"""
This module is achieved through bi-transformers,
ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
"""
pad_idx
=
self
.
char_num
gsrm_word_pos
=
others
[
"gsrm_word_pos"
]
gsrm_slf_attn_bias1
=
others
[
"gsrm_slf_attn_bias1"
]
gsrm_slf_attn_bias2
=
others
[
"gsrm_slf_attn_bias2"
]
def
prepare_bi
(
word_ids
):
"""
prepare bi for gsrm
word1 for forward; word2 for backward
"""
word1
=
fluid
.
layers
.
cast
(
word_ids
,
"float32"
)
word1
=
fluid
.
layers
.
pad
(
word1
,
[
0
,
0
,
1
,
0
,
0
,
0
],
pad_value
=
1.0
*
pad_idx
)
word1
=
fluid
.
layers
.
cast
(
word1
,
"int64"
)
word1
=
word1
[:,
:
-
1
,
:]
word2
=
word_ids
return
word1
,
word2
word1
,
word2
=
prepare_bi
(
word_ids
)
word1
.
stop_gradient
=
True
word2
.
stop_gradient
=
True
enc_inputs_1
=
[
word1
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
]
enc_inputs_2
=
[
word2
,
gsrm_word_pos
,
gsrm_slf_attn_bias2
]
gsrm_feature1
=
wrap_encoder
(
src_vocab_size
=
self
.
char_num
+
1
,
max_length
=
self
.
max_length
,
n_layer
=
self
.
num_decoder_TUs
,
n_head
=
self
.
num_heads
,
d_key
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_value
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_model
=
self
.
hidden_dims
,
d_inner_hid
=
self
.
hidden_dims
,
prepostprocess_dropout
=
0.1
,
attention_dropout
=
0.1
,
relu_dropout
=
0.1
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
,
weight_sharing
=
True
,
enc_inputs
=
enc_inputs_1
,
)
gsrm_feature2
=
wrap_encoder
(
src_vocab_size
=
self
.
char_num
+
1
,
max_length
=
self
.
max_length
,
n_layer
=
self
.
num_decoder_TUs
,
n_head
=
self
.
num_heads
,
d_key
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_value
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_model
=
self
.
hidden_dims
,
d_inner_hid
=
self
.
hidden_dims
,
prepostprocess_dropout
=
0.1
,
attention_dropout
=
0.1
,
relu_dropout
=
0.1
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
,
weight_sharing
=
True
,
enc_inputs
=
enc_inputs_2
,
)
gsrm_feature2
=
fluid
.
layers
.
pad
(
gsrm_feature2
,
[
0
,
0
,
0
,
1
,
0
,
0
],
pad_value
=
0.
)
gsrm_feature2
=
gsrm_feature2
[:,
1
:,
]
gsrm_features
=
gsrm_feature1
+
gsrm_feature2
b
,
t
,
c
=
gsrm_features
.
shape
gsrm_out
=
fluid
.
layers
.
matmul
(
x
=
gsrm_features
,
y
=
fluid
.
default_main_program
().
global_block
().
var
(
"src_word_emb_table"
),
transpose_y
=
True
)
b
,
t
,
c
=
gsrm_out
.
shape
gsrm_out
=
fluid
.
layers
.
softmax
(
input
=
fluid
.
layers
.
reshape
(
gsrm_out
,
[
-
1
,
c
]))
return
gsrm_features
,
word_out
,
gsrm_out
def
vsfd
(
self
,
pvam_features
,
gsrm_features
):
#===== Visual-Semantic Fusion Decoder Module =====
b
,
t
,
c1
=
pvam_features
.
shape
b
,
t
,
c2
=
gsrm_features
.
shape
combine_features_
=
fluid
.
layers
.
concat
(
[
pvam_features
,
gsrm_features
],
axis
=
2
)
img_comb_features_
=
fluid
.
layers
.
reshape
(
x
=
combine_features_
,
shape
=
[
-
1
,
c1
+
c2
])
img_comb_features_map
=
fluid
.
layers
.
fc
(
input
=
img_comb_features_
,
size
=
c1
,
act
=
"sigmoid"
)
img_comb_features_map
=
fluid
.
layers
.
reshape
(
x
=
img_comb_features_map
,
shape
=
[
-
1
,
t
,
c1
])
combine_features
=
img_comb_features_map
*
pvam_features
+
(
1.0
-
img_comb_features_map
)
*
gsrm_features
img_comb_features
=
fluid
.
layers
.
reshape
(
x
=
combine_features
,
shape
=
[
-
1
,
c1
])
fc_out
=
fluid
.
layers
.
fc
(
input
=
img_comb_features
,
size
=
self
.
char_num
,
act
=
"softmax"
)
return
fc_out
def
__call__
(
self
,
inputs
,
others
,
mode
=
None
):
pvam_features
=
self
.
pvam
(
inputs
,
others
)
gsrm_features
,
word_out
,
gsrm_out
=
self
.
gsrm
(
pvam_features
,
others
)
final_out
=
self
.
vsfd
(
pvam_features
,
gsrm_features
)
_
,
decoded_out
=
fluid
.
layers
.
topk
(
input
=
final_out
,
k
=
1
)
predicts
=
{
'predict'
:
final_out
,
'decoded_out'
:
decoded_out
,
'word_out'
:
word_out
,
'gsrm_out'
:
gsrm_out
}
return
predicts
ppocr/modeling/heads/self_attention/__init__.py
0 → 100644
浏览文件 @
612e8014
ppocr/modeling/heads/self_attention/model.py
0 → 100644
浏览文件 @
612e8014
from
functools
import
partial
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
layers
# Set seed for CE
dropout_seed
=
None
def
wrap_layer_with_block
(
layer
,
block_idx
):
"""
Make layer define support indicating block, by which we can add layers
to other blocks within current block. This will make it easy to define
cache among while loop.
"""
class
BlockGuard
(
object
):
"""
BlockGuard class.
BlockGuard class is used to switch to the given block in a program by
using the Python `with` keyword.
"""
def
__init__
(
self
,
block_idx
=
None
,
main_program
=
None
):
self
.
main_program
=
fluid
.
default_main_program
(
)
if
main_program
is
None
else
main_program
self
.
old_block_idx
=
self
.
main_program
.
current_block
().
idx
self
.
new_block_idx
=
block_idx
def
__enter__
(
self
):
self
.
main_program
.
current_block_idx
=
self
.
new_block_idx
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
main_program
.
current_block_idx
=
self
.
old_block_idx
if
exc_type
is
not
None
:
return
False
# re-raise exception
return
True
def
layer_wrapper
(
*
args
,
**
kwargs
):
with
BlockGuard
(
block_idx
):
return
layer
(
*
args
,
**
kwargs
)
return
layer_wrapper
def
position_encoding_init
(
n_position
,
d_pos_vec
):
"""
Generate the initial values for the sinusoid position encoding table.
"""
channels
=
d_pos_vec
position
=
np
.
arange
(
n_position
)
num_timescales
=
channels
//
2
log_timescale_increment
=
(
np
.
log
(
float
(
1e4
)
/
float
(
1
))
/
(
num_timescales
-
1
))
inv_timescales
=
np
.
exp
(
np
.
arange
(
num_timescales
))
*
-
log_timescale_increment
scaled_time
=
np
.
expand_dims
(
position
,
1
)
*
np
.
expand_dims
(
inv_timescales
,
0
)
signal
=
np
.
concatenate
([
np
.
sin
(
scaled_time
),
np
.
cos
(
scaled_time
)],
axis
=
1
)
signal
=
np
.
pad
(
signal
,
[[
0
,
0
],
[
0
,
np
.
mod
(
channels
,
2
)]],
'constant'
)
position_enc
=
signal
return
position_enc
.
astype
(
"float32"
)
def
multi_head_attention
(
queries
,
keys
,
values
,
attn_bias
,
d_key
,
d_value
,
d_model
,
n_head
=
1
,
dropout_rate
=
0.
,
cache
=
None
,
gather_idx
=
None
,
static_kv
=
False
):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
"""
keys
=
queries
if
keys
is
None
else
keys
values
=
keys
if
values
is
None
else
values
if
not
(
len
(
queries
.
shape
)
==
len
(
keys
.
shape
)
==
len
(
values
.
shape
)
==
3
):
raise
ValueError
(
"Inputs: quries, keys and values should all be 3-D tensors."
)
def
__compute_qkv
(
queries
,
keys
,
values
,
n_head
,
d_key
,
d_value
):
"""
Add linear projection to queries, keys, and values.
"""
q
=
layers
.
fc
(
input
=
queries
,
size
=
d_key
*
n_head
,
bias_attr
=
False
,
num_flatten_dims
=
2
)
# For encoder-decoder attention in inference, insert the ops and vars
# into global block to use as cache among beam search.
fc_layer
=
wrap_layer_with_block
(
layers
.
fc
,
fluid
.
default_main_program
().
current_block
()
.
parent_idx
)
if
cache
is
not
None
and
static_kv
else
layers
.
fc
k
=
fc_layer
(
input
=
keys
,
size
=
d_key
*
n_head
,
bias_attr
=
False
,
num_flatten_dims
=
2
)
v
=
fc_layer
(
input
=
values
,
size
=
d_value
*
n_head
,
bias_attr
=
False
,
num_flatten_dims
=
2
)
return
q
,
k
,
v
def
__split_heads_qkv
(
queries
,
keys
,
values
,
n_head
,
d_key
,
d_value
):
"""
Reshape input tensors at the last dimension to split multi-heads
and then transpose. Specifically, transform the input tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] to the output tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped_q
=
layers
.
reshape
(
x
=
queries
,
shape
=
[
0
,
0
,
n_head
,
d_key
],
inplace
=
True
)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
q
=
layers
.
transpose
(
x
=
reshaped_q
,
perm
=
[
0
,
2
,
1
,
3
])
# For encoder-decoder attention in inference, insert the ops and vars
# into global block to use as cache among beam search.
reshape_layer
=
wrap_layer_with_block
(
layers
.
reshape
,
fluid
.
default_main_program
().
current_block
()
.
parent_idx
)
if
cache
is
not
None
and
static_kv
else
layers
.
reshape
transpose_layer
=
wrap_layer_with_block
(
layers
.
transpose
,
fluid
.
default_main_program
().
current_block
().
parent_idx
)
if
cache
is
not
None
and
static_kv
else
layers
.
transpose
reshaped_k
=
reshape_layer
(
x
=
keys
,
shape
=
[
0
,
0
,
n_head
,
d_key
],
inplace
=
True
)
k
=
transpose_layer
(
x
=
reshaped_k
,
perm
=
[
0
,
2
,
1
,
3
])
reshaped_v
=
reshape_layer
(
x
=
values
,
shape
=
[
0
,
0
,
n_head
,
d_value
],
inplace
=
True
)
v
=
transpose_layer
(
x
=
reshaped_v
,
perm
=
[
0
,
2
,
1
,
3
])
if
cache
is
not
None
:
# only for faster inference
if
static_kv
:
# For encoder-decoder attention in inference
cache_k
,
cache_v
=
cache
[
"static_k"
],
cache
[
"static_v"
]
# To init the static_k and static_v in cache.
# Maybe we can use condition_op(if_else) to do these at the first
# step in while loop to replace these, however it might be less
# efficient.
static_cache_init
=
wrap_layer_with_block
(
layers
.
assign
,
fluid
.
default_main_program
().
current_block
().
parent_idx
)
static_cache_init
(
k
,
cache_k
)
static_cache_init
(
v
,
cache_v
)
else
:
# For decoder self-attention in inference
cache_k
,
cache_v
=
cache
[
"k"
],
cache
[
"v"
]
# gather cell states corresponding to selected parent
select_k
=
layers
.
gather
(
cache_k
,
index
=
gather_idx
)
select_v
=
layers
.
gather
(
cache_v
,
index
=
gather_idx
)
if
not
static_kv
:
# For self attention in inference, use cache and concat time steps.
select_k
=
layers
.
concat
([
select_k
,
k
],
axis
=
2
)
select_v
=
layers
.
concat
([
select_v
,
v
],
axis
=
2
)
# update cell states(caches) cached in global block
layers
.
assign
(
select_k
,
cache_k
)
layers
.
assign
(
select_v
,
cache_v
)
return
q
,
select_k
,
select_v
return
q
,
k
,
v
def
__combine_heads
(
x
):
"""
Transpose and then reshape the last two dimensions of inpunt tensor x
so that it becomes one dimension, which is reverse to __split_heads.
"""
if
len
(
x
.
shape
)
!=
4
:
raise
ValueError
(
"Input(x) should be a 4-D Tensor."
)
trans_x
=
layers
.
transpose
(
x
,
perm
=
[
0
,
2
,
1
,
3
])
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return
layers
.
reshape
(
x
=
trans_x
,
shape
=
[
0
,
0
,
trans_x
.
shape
[
2
]
*
trans_x
.
shape
[
3
]],
inplace
=
True
)
def
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_bias
,
d_key
,
dropout_rate
):
"""
Scaled Dot-Product Attention
"""
# print(q)
# print(k)
product
=
layers
.
matmul
(
x
=
q
,
y
=
k
,
transpose_y
=
True
,
alpha
=
d_key
**-
0.5
)
if
attn_bias
:
product
+=
attn_bias
weights
=
layers
.
softmax
(
product
)
if
dropout_rate
:
weights
=
layers
.
dropout
(
weights
,
dropout_prob
=
dropout_rate
,
seed
=
dropout_seed
,
is_test
=
False
)
out
=
layers
.
matmul
(
weights
,
v
)
return
out
q
,
k
,
v
=
__compute_qkv
(
queries
,
keys
,
values
,
n_head
,
d_key
,
d_value
)
q
,
k
,
v
=
__split_heads_qkv
(
q
,
k
,
v
,
n_head
,
d_key
,
d_value
)
ctx_multiheads
=
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_bias
,
d_model
,
dropout_rate
)
out
=
__combine_heads
(
ctx_multiheads
)
# Project back to the model size.
proj_out
=
layers
.
fc
(
input
=
out
,
size
=
d_model
,
bias_attr
=
False
,
num_flatten_dims
=
2
)
return
proj_out
def
positionwise_feed_forward
(
x
,
d_inner_hid
,
d_hid
,
dropout_rate
):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden
=
layers
.
fc
(
input
=
x
,
size
=
d_inner_hid
,
num_flatten_dims
=
2
,
act
=
"relu"
)
if
dropout_rate
:
hidden
=
layers
.
dropout
(
hidden
,
dropout_prob
=
dropout_rate
,
seed
=
dropout_seed
,
is_test
=
False
)
out
=
layers
.
fc
(
input
=
hidden
,
size
=
d_hid
,
num_flatten_dims
=
2
)
return
out
def
pre_post_process_layer
(
prev_out
,
out
,
process_cmd
,
dropout_rate
=
0.
):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
for
cmd
in
process_cmd
:
if
cmd
==
"a"
:
# add residual connection
out
=
out
+
prev_out
if
prev_out
else
out
elif
cmd
==
"n"
:
# add layer normalization
out
=
layers
.
layer_norm
(
out
,
begin_norm_axis
=
len
(
out
.
shape
)
-
1
,
param_attr
=
fluid
.
initializer
.
Constant
(
1.
),
bias_attr
=
fluid
.
initializer
.
Constant
(
0.
))
elif
cmd
==
"d"
:
# add dropout
if
dropout_rate
:
out
=
layers
.
dropout
(
out
,
dropout_prob
=
dropout_rate
,
seed
=
dropout_seed
,
is_test
=
False
)
return
out
pre_process_layer
=
partial
(
pre_post_process_layer
,
None
)
post_process_layer
=
pre_post_process_layer
def
prepare_encoder
(
src_word
,
#[b,t,c]
src_pos
,
src_vocab_size
,
src_emb_dim
,
src_max_len
,
dropout_rate
=
0.
,
bos_idx
=
0
,
word_emb_param_name
=
None
,
pos_enc_param_name
=
None
):
"""Add word embeddings and position encodings.
The output tensor has a shape of:
[batch_size, max_src_length_in_batch, d_model].
This module is used at the bottom of the encoder stacks.
"""
src_word_emb
=
src_word
#layers.concat(res,axis=1)
src_word_emb
=
layers
.
cast
(
src_word_emb
,
'float32'
)
# print("src_word_emb",src_word_emb)
src_word_emb
=
layers
.
scale
(
x
=
src_word_emb
,
scale
=
src_emb_dim
**
0.5
)
src_pos_enc
=
layers
.
embedding
(
src_pos
,
size
=
[
src_max_len
,
src_emb_dim
],
param_attr
=
fluid
.
ParamAttr
(
name
=
pos_enc_param_name
,
trainable
=
False
))
src_pos_enc
.
stop_gradient
=
True
enc_input
=
src_word_emb
+
src_pos_enc
return
layers
.
dropout
(
enc_input
,
dropout_prob
=
dropout_rate
,
seed
=
dropout_seed
,
is_test
=
False
)
if
dropout_rate
else
enc_input
def
prepare_decoder
(
src_word
,
src_pos
,
src_vocab_size
,
src_emb_dim
,
src_max_len
,
dropout_rate
=
0.
,
bos_idx
=
0
,
word_emb_param_name
=
None
,
pos_enc_param_name
=
None
):
"""Add word embeddings and position encodings.
The output tensor has a shape of:
[batch_size, max_src_length_in_batch, d_model].
This module is used at the bottom of the encoder stacks.
"""
src_word_emb
=
layers
.
embedding
(
src_word
,
size
=
[
src_vocab_size
,
src_emb_dim
],
padding_idx
=
bos_idx
,
# set embedding of bos to 0
param_attr
=
fluid
.
ParamAttr
(
name
=
word_emb_param_name
,
initializer
=
fluid
.
initializer
.
Normal
(
0.
,
src_emb_dim
**-
0.5
)))
# print("target_word_emb",src_word_emb)
src_word_emb
=
layers
.
scale
(
x
=
src_word_emb
,
scale
=
src_emb_dim
**
0.5
)
src_pos_enc
=
layers
.
embedding
(
src_pos
,
size
=
[
src_max_len
,
src_emb_dim
],
param_attr
=
fluid
.
ParamAttr
(
name
=
pos_enc_param_name
,
trainable
=
False
))
src_pos_enc
.
stop_gradient
=
True
enc_input
=
src_word_emb
+
src_pos_enc
return
layers
.
dropout
(
enc_input
,
dropout_prob
=
dropout_rate
,
seed
=
dropout_seed
,
is_test
=
False
)
if
dropout_rate
else
enc_input
# prepare_encoder = partial(
# prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[0])
# prepare_decoder = partial(
# prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[1])
def
encoder_layer
(
enc_input
,
attn_bias
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output
=
multi_head_attention
(
pre_process_layer
(
enc_input
,
preprocess_cmd
,
prepostprocess_dropout
),
None
,
None
,
attn_bias
,
d_key
,
d_value
,
d_model
,
n_head
,
attention_dropout
)
attn_output
=
post_process_layer
(
enc_input
,
attn_output
,
postprocess_cmd
,
prepostprocess_dropout
)
ffd_output
=
positionwise_feed_forward
(
pre_process_layer
(
attn_output
,
preprocess_cmd
,
prepostprocess_dropout
),
d_inner_hid
,
d_model
,
relu_dropout
)
return
post_process_layer
(
attn_output
,
ffd_output
,
postprocess_cmd
,
prepostprocess_dropout
)
def
encoder
(
enc_input
,
attn_bias
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
for
i
in
range
(
n_layer
):
enc_output
=
encoder_layer
(
enc_input
,
attn_bias
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
)
enc_input
=
enc_output
enc_output
=
pre_process_layer
(
enc_output
,
preprocess_cmd
,
prepostprocess_dropout
)
return
enc_output
def
decoder_layer
(
dec_input
,
enc_output
,
slf_attn_bias
,
dec_enc_attn_bias
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
cache
=
None
,
gather_idx
=
None
):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention.
"""
slf_attn_output
=
multi_head_attention
(
pre_process_layer
(
dec_input
,
preprocess_cmd
,
prepostprocess_dropout
),
None
,
None
,
slf_attn_bias
,
d_key
,
d_value
,
d_model
,
n_head
,
attention_dropout
,
cache
=
cache
,
gather_idx
=
gather_idx
)
slf_attn_output
=
post_process_layer
(
dec_input
,
slf_attn_output
,
postprocess_cmd
,
prepostprocess_dropout
,
)
enc_attn_output
=
multi_head_attention
(
pre_process_layer
(
slf_attn_output
,
preprocess_cmd
,
prepostprocess_dropout
),
enc_output
,
enc_output
,
dec_enc_attn_bias
,
d_key
,
d_value
,
d_model
,
n_head
,
attention_dropout
,
cache
=
cache
,
gather_idx
=
gather_idx
,
static_kv
=
True
)
enc_attn_output
=
post_process_layer
(
slf_attn_output
,
enc_attn_output
,
postprocess_cmd
,
prepostprocess_dropout
,
)
ffd_output
=
positionwise_feed_forward
(
pre_process_layer
(
enc_attn_output
,
preprocess_cmd
,
prepostprocess_dropout
),
d_inner_hid
,
d_model
,
relu_dropout
,
)
dec_output
=
post_process_layer
(
enc_attn_output
,
ffd_output
,
postprocess_cmd
,
prepostprocess_dropout
,
)
return
dec_output
def
decoder
(
dec_input
,
enc_output
,
dec_slf_attn_bias
,
dec_enc_attn_bias
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
caches
=
None
,
gather_idx
=
None
):
"""
The decoder is composed of a stack of identical decoder_layer layers.
"""
for
i
in
range
(
n_layer
):
dec_output
=
decoder_layer
(
dec_input
,
enc_output
,
dec_slf_attn_bias
,
dec_enc_attn_bias
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
cache
=
None
if
caches
is
None
else
caches
[
i
],
gather_idx
=
gather_idx
)
dec_input
=
dec_output
dec_output
=
pre_process_layer
(
dec_output
,
preprocess_cmd
,
prepostprocess_dropout
)
return
dec_output
def
make_all_inputs
(
input_fields
):
"""
Define the input data layers for the transformer model.
"""
inputs
=
[]
for
input_field
in
input_fields
:
input_var
=
layers
.
data
(
name
=
input_field
,
shape
=
input_descs
[
input_field
][
0
],
dtype
=
input_descs
[
input_field
][
1
],
lod_level
=
input_descs
[
input_field
][
2
]
if
len
(
input_descs
[
input_field
])
==
3
else
0
,
append_batch_size
=
False
)
inputs
.
append
(
input_var
)
return
inputs
def
make_all_py_reader_inputs
(
input_fields
,
is_test
=
False
):
reader
=
layers
.
py_reader
(
capacity
=
20
,
name
=
"test_reader"
if
is_test
else
"train_reader"
,
shapes
=
[
input_descs
[
input_field
][
0
]
for
input_field
in
input_fields
],
dtypes
=
[
input_descs
[
input_field
][
1
]
for
input_field
in
input_fields
],
lod_levels
=
[
input_descs
[
input_field
][
2
]
if
len
(
input_descs
[
input_field
])
==
3
else
0
for
input_field
in
input_fields
])
return
layers
.
read_file
(
reader
),
reader
def
transformer
(
src_vocab_size
,
trg_vocab_size
,
max_length
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
label_smooth_eps
,
bos_idx
=
0
,
use_py_reader
=
False
,
is_test
=
False
):
if
weight_sharing
:
assert
src_vocab_size
==
trg_vocab_size
,
(
"Vocabularies in source and target should be same for weight sharing."
)
data_input_names
=
encoder_data_input_fields
+
\
decoder_data_input_fields
[:
-
1
]
+
label_data_input_fields
if
use_py_reader
:
all_inputs
,
reader
=
make_all_py_reader_inputs
(
data_input_names
,
is_test
)
else
:
all_inputs
=
make_all_inputs
(
data_input_names
)
# print("all inputs",all_inputs)
enc_inputs_len
=
len
(
encoder_data_input_fields
)
dec_inputs_len
=
len
(
decoder_data_input_fields
[:
-
1
])
enc_inputs
=
all_inputs
[
0
:
enc_inputs_len
]
dec_inputs
=
all_inputs
[
enc_inputs_len
:
enc_inputs_len
+
dec_inputs_len
]
label
=
all_inputs
[
-
2
]
weights
=
all_inputs
[
-
1
]
enc_output
=
wrap_encoder
(
src_vocab_size
,
64
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
enc_inputs
)
predict
=
wrap_decoder
(
trg_vocab_size
,
max_length
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
dec_inputs
,
enc_output
,
)
# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
if
label_smooth_eps
:
label
=
layers
.
label_smooth
(
label
=
layers
.
one_hot
(
input
=
label
,
depth
=
trg_vocab_size
),
epsilon
=
label_smooth_eps
)
cost
=
layers
.
softmax_with_cross_entropy
(
logits
=
predict
,
label
=
label
,
soft_label
=
True
if
label_smooth_eps
else
False
)
weighted_cost
=
cost
*
weights
sum_cost
=
layers
.
reduce_sum
(
weighted_cost
)
token_num
=
layers
.
reduce_sum
(
weights
)
token_num
.
stop_gradient
=
True
avg_cost
=
sum_cost
/
token_num
return
sum_cost
,
avg_cost
,
predict
,
token_num
,
reader
if
use_py_reader
else
None
def
wrap_encoder_forFeature
(
src_vocab_size
,
max_length
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
enc_inputs
=
None
,
bos_idx
=
0
):
"""
The wrapper assembles together all needed layers for the encoder.
img, src_pos, src_slf_attn_bias = enc_inputs
img
"""
if
enc_inputs
is
None
:
# This is used to implement independent encoder program in inference.
conv_features
,
src_pos
,
src_slf_attn_bias
=
make_all_inputs
(
encoder_data_input_fields
)
else
:
conv_features
,
src_pos
,
src_slf_attn_bias
=
enc_inputs
#
b
,
t
,
c
=
conv_features
.
shape
#"""
# insert cnn
#"""
#import basemodel
# feat = basemodel.resnet_50(img)
# mycrnn = basemodel.CRNN()
# feat = mycrnn.ocr_convs(img,use_cudnn=TrainTaskConfig.use_gpu)
# b, c, w, h = feat.shape
# src_word = layers.reshape(feat, shape=[-1, c, w * h])
#myconv8 = basemodel.conv8()
#feat = myconv8.net(img )
#b , c, h, w = feat.shape#h=6
#print(feat)
#layers.Print(feat,message="conv_feat",summarize=10)
#feat =layers.conv2d(feat,c,filter_size =[4 , 1],act="relu")
#feat = layers.pool2d(feat,pool_stride=(3,1),pool_size=(3,1))
#src_word = layers.squeeze(feat,axes=[2]) #src_word [-1,c,ww]
#feat = layers.transpose(feat, [0,3,1,2])
#src_word = layers.reshape(feat,[-1,w, c*h])
#src_word = layers.im2sequence(
# input=feat,
# stride=[1, 1],
# filter_size=[feat.shape[2], 1])
#layers.Print(src_word,message="src_word",summarize=10)
# print('feat',feat)
#print("src_word",src_word)
enc_input
=
prepare_encoder
(
conv_features
,
src_pos
,
src_vocab_size
,
d_model
,
max_length
,
prepostprocess_dropout
,
bos_idx
=
bos_idx
,
word_emb_param_name
=
"src_word_emb_table"
)
enc_output
=
encoder
(
enc_input
,
src_slf_attn_bias
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
)
return
enc_output
def
wrap_encoder
(
src_vocab_size
,
max_length
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
enc_inputs
=
None
,
bos_idx
=
0
):
"""
The wrapper assembles together all needed layers for the encoder.
img, src_pos, src_slf_attn_bias = enc_inputs
img
"""
if
enc_inputs
is
None
:
# This is used to implement independent encoder program in inference.
src_word
,
src_pos
,
src_slf_attn_bias
=
make_all_inputs
(
encoder_data_input_fields
)
else
:
src_word
,
src_pos
,
src_slf_attn_bias
=
enc_inputs
#
#"""
# insert cnn
#"""
#import basemodel
# feat = basemodel.resnet_50(img)
# mycrnn = basemodel.CRNN()
# feat = mycrnn.ocr_convs(img,use_cudnn=TrainTaskConfig.use_gpu)
# b, c, w, h = feat.shape
# src_word = layers.reshape(feat, shape=[-1, c, w * h])
#myconv8 = basemodel.conv8()
#feat = myconv8.net(img )
#b , c, h, w = feat.shape#h=6
#print(feat)
#layers.Print(feat,message="conv_feat",summarize=10)
#feat =layers.conv2d(feat,c,filter_size =[4 , 1],act="relu")
#feat = layers.pool2d(feat,pool_stride=(3,1),pool_size=(3,1))
#src_word = layers.squeeze(feat,axes=[2]) #src_word [-1,c,ww]
#feat = layers.transpose(feat, [0,3,1,2])
#src_word = layers.reshape(feat,[-1,w, c*h])
#src_word = layers.im2sequence(
# input=feat,
# stride=[1, 1],
# filter_size=[feat.shape[2], 1])
#layers.Print(src_word,message="src_word",summarize=10)
# print('feat',feat)
#print("src_word",src_word)
enc_input
=
prepare_decoder
(
src_word
,
src_pos
,
src_vocab_size
,
d_model
,
max_length
,
prepostprocess_dropout
,
bos_idx
=
bos_idx
,
word_emb_param_name
=
"src_word_emb_table"
)
enc_output
=
encoder
(
enc_input
,
src_slf_attn_bias
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
)
return
enc_output
def
wrap_decoder
(
trg_vocab_size
,
max_length
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
dec_inputs
=
None
,
enc_output
=
None
,
caches
=
None
,
gather_idx
=
None
,
bos_idx
=
0
):
"""
The wrapper assembles together all needed layers for the decoder.
"""
if
dec_inputs
is
None
:
# This is used to implement independent decoder program in inference.
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
enc_output
=
\
make_all_inputs
(
decoder_data_input_fields
)
else
:
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
=
dec_inputs
dec_input
=
prepare_decoder
(
trg_word
,
trg_pos
,
trg_vocab_size
,
d_model
,
max_length
,
prepostprocess_dropout
,
bos_idx
=
bos_idx
,
word_emb_param_name
=
"src_word_emb_table"
if
weight_sharing
else
"trg_word_emb_table"
)
dec_output
=
decoder
(
dec_input
,
enc_output
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
caches
=
caches
,
gather_idx
=
gather_idx
)
return
dec_output
# Reshape to 2D tensor to use GEMM instead of BatchedGEMM
dec_output
=
layers
.
reshape
(
dec_output
,
shape
=
[
-
1
,
dec_output
.
shape
[
-
1
]],
inplace
=
True
)
if
weight_sharing
:
predict
=
layers
.
matmul
(
x
=
dec_output
,
y
=
fluid
.
default_main_program
().
global_block
().
var
(
"trg_word_emb_table"
),
transpose_y
=
True
)
else
:
predict
=
layers
.
fc
(
input
=
dec_output
,
size
=
trg_vocab_size
,
bias_attr
=
False
)
if
dec_inputs
is
None
:
# Return probs for independent decoder program.
predict
=
layers
.
softmax
(
predict
)
return
predict
def
fast_decode
(
src_vocab_size
,
trg_vocab_size
,
max_in_len
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
beam_size
,
max_out_len
,
bos_idx
,
eos_idx
,
use_py_reader
=
False
):
"""
Use beam search to decode. Caches will be used to store states of history
steps which can make the decoding faster.
"""
data_input_names
=
encoder_data_input_fields
+
fast_decoder_data_input_fields
if
use_py_reader
:
all_inputs
,
reader
=
make_all_py_reader_inputs
(
data_input_names
)
else
:
all_inputs
=
make_all_inputs
(
data_input_names
)
enc_inputs_len
=
len
(
encoder_data_input_fields
)
dec_inputs_len
=
len
(
fast_decoder_data_input_fields
)
enc_inputs
=
all_inputs
[
0
:
enc_inputs_len
]
#enc_inputs tensor
dec_inputs
=
all_inputs
[
enc_inputs_len
:
enc_inputs_len
+
dec_inputs_len
]
#dec_inputs tensor
enc_output
=
wrap_encoder
(
src_vocab_size
,
64
,
##to do !!!!!????
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
enc_inputs
,
bos_idx
=
bos_idx
)
start_tokens
,
init_scores
,
parent_idx
,
trg_src_attn_bias
=
dec_inputs
def
beam_search
():
max_len
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
start_tokens
.
dtype
,
value
=
max_out_len
,
force_cpu
=
True
)
step_idx
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
start_tokens
.
dtype
,
value
=
0
,
force_cpu
=
True
)
cond
=
layers
.
less_than
(
x
=
step_idx
,
y
=
max_len
)
# default force_cpu=True
while_op
=
layers
.
While
(
cond
)
# array states will be stored for each step.
ids
=
layers
.
array_write
(
layers
.
reshape
(
start_tokens
,
(
-
1
,
1
)),
step_idx
)
scores
=
layers
.
array_write
(
init_scores
,
step_idx
)
# cell states will be overwrited at each step.
# caches contains states of history steps in decoder self-attention
# and static encoder output projections in encoder-decoder attention
# to reduce redundant computation.
caches
=
[
{
"k"
:
# for self attention
layers
.
fill_constant_batch_size_like
(
input
=
start_tokens
,
shape
=
[
-
1
,
n_head
,
0
,
d_key
],
dtype
=
enc_output
.
dtype
,
value
=
0
),
"v"
:
# for self attention
layers
.
fill_constant_batch_size_like
(
input
=
start_tokens
,
shape
=
[
-
1
,
n_head
,
0
,
d_value
],
dtype
=
enc_output
.
dtype
,
value
=
0
),
"static_k"
:
# for encoder-decoder attention
layers
.
create_tensor
(
dtype
=
enc_output
.
dtype
),
"static_v"
:
# for encoder-decoder attention
layers
.
create_tensor
(
dtype
=
enc_output
.
dtype
)
}
for
i
in
range
(
n_layer
)
]
with
while_op
.
block
():
pre_ids
=
layers
.
array_read
(
array
=
ids
,
i
=
step_idx
)
# Since beam_search_op dosen't enforce pre_ids' shape, we can do
# inplace reshape here which actually change the shape of pre_ids.
pre_ids
=
layers
.
reshape
(
pre_ids
,
(
-
1
,
1
,
1
),
inplace
=
True
)
pre_scores
=
layers
.
array_read
(
array
=
scores
,
i
=
step_idx
)
# gather cell states corresponding to selected parent
pre_src_attn_bias
=
layers
.
gather
(
trg_src_attn_bias
,
index
=
parent_idx
)
pre_pos
=
layers
.
elementwise_mul
(
x
=
layers
.
fill_constant_batch_size_like
(
input
=
pre_src_attn_bias
,
# cann't use lod tensor here
value
=
1
,
shape
=
[
-
1
,
1
,
1
],
dtype
=
pre_ids
.
dtype
),
y
=
step_idx
,
axis
=
0
)
logits
=
wrap_decoder
(
trg_vocab_size
,
max_in_len
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
dec_inputs
=
(
pre_ids
,
pre_pos
,
None
,
pre_src_attn_bias
),
enc_output
=
enc_output
,
caches
=
caches
,
gather_idx
=
parent_idx
,
bos_idx
=
bos_idx
)
# intra-beam topK
topk_scores
,
topk_indices
=
layers
.
topk
(
input
=
layers
.
softmax
(
logits
),
k
=
beam_size
)
accu_scores
=
layers
.
elementwise_add
(
x
=
layers
.
log
(
topk_scores
),
y
=
pre_scores
,
axis
=
0
)
# beam_search op uses lod to differentiate branches.
accu_scores
=
layers
.
lod_reset
(
accu_scores
,
pre_ids
)
# topK reduction across beams, also contain special handle of
# end beams and end sentences(batch reduction)
selected_ids
,
selected_scores
,
gather_idx
=
layers
.
beam_search
(
pre_ids
=
pre_ids
,
pre_scores
=
pre_scores
,
ids
=
topk_indices
,
scores
=
accu_scores
,
beam_size
=
beam_size
,
end_id
=
eos_idx
,
return_parent_idx
=
True
)
layers
.
increment
(
x
=
step_idx
,
value
=
1.0
,
in_place
=
True
)
# cell states(caches) have been updated in wrap_decoder,
# only need to update beam search states here.
layers
.
array_write
(
selected_ids
,
i
=
step_idx
,
array
=
ids
)
layers
.
array_write
(
selected_scores
,
i
=
step_idx
,
array
=
scores
)
layers
.
assign
(
gather_idx
,
parent_idx
)
layers
.
assign
(
pre_src_attn_bias
,
trg_src_attn_bias
)
length_cond
=
layers
.
less_than
(
x
=
step_idx
,
y
=
max_len
)
finish_cond
=
layers
.
logical_not
(
layers
.
is_empty
(
x
=
selected_ids
))
layers
.
logical_and
(
x
=
length_cond
,
y
=
finish_cond
,
out
=
cond
)
finished_ids
,
finished_scores
=
layers
.
beam_search_decode
(
ids
,
scores
,
beam_size
=
beam_size
,
end_id
=
eos_idx
)
return
finished_ids
,
finished_scores
finished_ids
,
finished_scores
=
beam_search
()
return
finished_ids
,
finished_scores
,
reader
if
use_py_reader
else
None
ppocr/modeling/losses/rec_srn_loss.py
0 → 100755
浏览文件 @
612e8014
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
import
paddle.fluid
as
fluid
class
SRNLoss
(
object
):
def
__init__
(
self
,
params
):
super
(
SRNLoss
,
self
).
__init__
()
self
.
char_num
=
params
[
'char_num'
]
def
__call__
(
self
,
predicts
,
others
):
predict
=
predicts
[
'predict'
]
word_predict
=
predicts
[
'word_out'
]
gsrm_predict
=
predicts
[
'gsrm_out'
]
label
=
others
[
'label'
]
lbl_weight
=
others
[
'lbl_weight'
]
casted_label
=
fluid
.
layers
.
cast
(
x
=
label
,
dtype
=
'int64'
)
cost_word
=
fluid
.
layers
.
cross_entropy
(
input
=
word_predict
,
label
=
casted_label
)
cost_gsrm
=
fluid
.
layers
.
cross_entropy
(
input
=
gsrm_predict
,
label
=
casted_label
)
cost_vsfd
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
casted_label
)
cost_word
=
fluid
.
layers
.
reshape
(
x
=
fluid
.
layers
.
reduce_sum
(
cost_word
),
shape
=
[
1
])
cost_gsrm
=
fluid
.
layers
.
reshape
(
x
=
fluid
.
layers
.
reduce_sum
(
cost_gsrm
),
shape
=
[
1
])
cost_vsfd
=
fluid
.
layers
.
reshape
(
x
=
fluid
.
layers
.
reduce_sum
(
cost_vsfd
),
shape
=
[
1
])
sum_cost
=
fluid
.
layers
.
sum
(
[
cost_word
,
cost_vsfd
*
2.0
,
cost_gsrm
*
0.15
])
return
[
sum_cost
,
cost_vsfd
,
cost_word
]
ppocr/utils/character.py
浏览文件 @
612e8014
...
@@ -25,6 +25,9 @@ class CharacterOps(object):
...
@@ -25,6 +25,9 @@ class CharacterOps(object):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
self
.
character_type
=
config
[
'character_type'
]
self
.
character_type
=
config
[
'character_type'
]
self
.
loss_type
=
config
[
'loss_type'
]
self
.
loss_type
=
config
[
'loss_type'
]
self
.
max_text_len
=
config
[
'max_text_length'
]
if
self
.
loss_type
==
"srn"
and
self
.
character_type
!=
"en"
:
raise
Exception
(
"SRN can only support in character_type == en"
)
if
self
.
character_type
==
"en"
:
if
self
.
character_type
==
"en"
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
dict_character
=
list
(
self
.
character_str
)
dict_character
=
list
(
self
.
character_str
)
...
@@ -54,6 +57,8 @@ class CharacterOps(object):
...
@@ -54,6 +57,8 @@ class CharacterOps(object):
self
.
end_str
=
"eos"
self
.
end_str
=
"eos"
if
self
.
loss_type
==
"attention"
:
if
self
.
loss_type
==
"attention"
:
dict_character
=
[
self
.
beg_str
,
self
.
end_str
]
+
dict_character
dict_character
=
[
self
.
beg_str
,
self
.
end_str
]
+
dict_character
elif
self
.
loss_type
==
"srn"
:
dict_character
=
dict_character
+
[
self
.
beg_str
,
self
.
end_str
]
self
.
dict
=
{}
self
.
dict
=
{}
for
i
,
char
in
enumerate
(
dict_character
):
for
i
,
char
in
enumerate
(
dict_character
):
self
.
dict
[
char
]
=
i
self
.
dict
[
char
]
=
i
...
@@ -147,6 +152,39 @@ def cal_predicts_accuracy(char_ops,
...
@@ -147,6 +152,39 @@ def cal_predicts_accuracy(char_ops,
return
acc
,
acc_num
,
img_num
return
acc
,
acc_num
,
img_num
def
cal_predicts_accuracy_srn
(
char_ops
,
preds
,
labels
,
max_text_len
,
is_debug
=
False
):
acc_num
=
0
img_num
=
0
total_len
=
preds
.
shape
[
0
]
img_num
=
int
(
total_len
/
max_text_len
)
for
i
in
range
(
img_num
):
cur_label
=
[]
cur_pred
=
[]
for
j
in
range
(
max_text_len
):
if
labels
[
j
+
i
*
max_text_len
]
!=
37
:
#0
cur_label
.
append
(
labels
[
j
+
i
*
max_text_len
][
0
])
else
:
break
for
j
in
range
(
max_text_len
+
1
):
if
j
<
len
(
cur_label
)
and
preds
[
j
+
i
*
max_text_len
][
0
]
!=
cur_label
[
j
]:
break
elif
j
==
len
(
cur_label
)
and
j
==
max_text_len
:
acc_num
+=
1
break
elif
j
==
len
(
cur_label
)
and
preds
[
j
+
i
*
max_text_len
][
0
]
==
37
:
acc_num
+=
1
break
acc
=
acc_num
*
1.0
/
img_num
return
acc
,
acc_num
,
img_num
def
convert_rec_attention_infer_res
(
preds
):
def
convert_rec_attention_infer_res
(
preds
):
img_num
=
preds
.
shape
[
0
]
img_num
=
preds
.
shape
[
0
]
target_lod
=
[
0
]
target_lod
=
[
0
]
...
...
tools/eval_utils/eval_rec_utils.py
浏览文件 @
612e8014
...
@@ -29,7 +29,7 @@ FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
...
@@ -29,7 +29,7 @@ FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
from
ppocr.utils.character
import
cal_predicts_accuracy
from
ppocr.utils.character
import
cal_predicts_accuracy
,
cal_predicts_accuracy_srn
from
ppocr.utils.character
import
convert_rec_label_to_lod
from
ppocr.utils.character
import
convert_rec_label_to_lod
from
ppocr.utils.character
import
convert_rec_attention_infer_res
from
ppocr.utils.character
import
convert_rec_attention_infer_res
from
ppocr.utils.utility
import
create_module
from
ppocr.utils.utility
import
create_module
...
@@ -60,22 +60,60 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
...
@@ -60,22 +60,60 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
for
ino
in
range
(
img_num
):
for
ino
in
range
(
img_num
):
img_list
.
append
(
data
[
ino
][
0
])
img_list
.
append
(
data
[
ino
][
0
])
label_list
.
append
(
data
[
ino
][
1
])
label_list
.
append
(
data
[
ino
][
1
])
img_list
=
np
.
concatenate
(
img_list
,
axis
=
0
)
outs
=
exe
.
run
(
eval_info_dict
[
'program'
],
\
if
config
[
'Global'
][
'loss_type'
]
!=
"srn"
:
img_list
=
np
.
concatenate
(
img_list
,
axis
=
0
)
outs
=
exe
.
run
(
eval_info_dict
[
'program'
],
\
feed
=
{
'image'
:
img_list
},
\
feed
=
{
'image'
:
img_list
},
\
fetch_list
=
eval_info_dict
[
'fetch_varname_list'
],
\
fetch_list
=
eval_info_dict
[
'fetch_varname_list'
],
\
return_numpy
=
False
)
return_numpy
=
False
)
preds
=
np
.
array
(
outs
[
0
])
preds
=
np
.
array
(
outs
[
0
])
if
preds
.
shape
[
1
]
!=
1
:
preds
,
preds_lod
=
convert_rec_attention_infer_res
(
preds
)
if
config
[
'Global'
][
'loss_type'
]
==
"attention"
:
preds
,
preds_lod
=
convert_rec_attention_infer_res
(
preds
)
else
:
preds_lod
=
outs
[
0
].
lod
()[
0
]
labels
,
labels_lod
=
convert_rec_label_to_lod
(
label_list
)
acc
,
acc_num
,
sample_num
=
cal_predicts_accuracy
(
char_ops
,
preds
,
preds_lod
,
labels
,
labels_lod
,
is_remove_duplicate
)
else
:
else
:
preds_lod
=
outs
[
0
].
lod
()[
0
]
encoder_word_pos_list
=
[]
labels
,
labels_lod
=
convert_rec_label_to_lod
(
label_list
)
gsrm_word_pos_list
=
[]
acc
,
acc_num
,
sample_num
=
cal_predicts_accuracy
(
gsrm_slf_attn_bias1_list
=
[]
char_ops
,
preds
,
preds_lod
,
labels
,
labels_lod
,
is_remove_duplicate
)
gsrm_slf_attn_bias2_list
=
[]
for
ino
in
range
(
img_num
):
encoder_word_pos_list
.
append
(
data
[
ino
][
2
])
gsrm_word_pos_list
.
append
(
data
[
ino
][
3
])
gsrm_slf_attn_bias1_list
.
append
(
data
[
ino
][
4
])
gsrm_slf_attn_bias2_list
.
append
(
data
[
ino
][
5
])
img_list
=
np
.
concatenate
(
img_list
,
axis
=
0
)
label_list
=
np
.
concatenate
(
label_list
,
axis
=
0
)
encoder_word_pos_list
=
np
.
concatenate
(
encoder_word_pos_list
,
axis
=
0
).
astype
(
np
.
int64
)
gsrm_word_pos_list
=
np
.
concatenate
(
gsrm_word_pos_list
,
axis
=
0
).
astype
(
np
.
int64
)
gsrm_slf_attn_bias1_list
=
np
.
concatenate
(
gsrm_slf_attn_bias1_list
,
axis
=
0
).
astype
(
np
.
float32
)
gsrm_slf_attn_bias2_list
=
np
.
concatenate
(
gsrm_slf_attn_bias2_list
,
axis
=
0
).
astype
(
np
.
float32
)
labels
=
label_list
outs
=
exe
.
run
(
eval_info_dict
[
'program'
],
\
feed
=
{
'image'
:
img_list
,
'encoder_word_pos'
:
encoder_word_pos_list
,
'gsrm_word_pos'
:
gsrm_word_pos_list
,
'gsrm_slf_attn_bias1'
:
gsrm_slf_attn_bias1_list
,
'gsrm_slf_attn_bias2'
:
gsrm_slf_attn_bias2_list
},
\
fetch_list
=
eval_info_dict
[
'fetch_varname_list'
],
\
return_numpy
=
False
)
preds
=
np
.
array
(
outs
[
0
])
acc
,
acc_num
,
sample_num
=
cal_predicts_accuracy_srn
(
char_ops
,
preds
,
labels
,
config
[
'Global'
][
'max_text_length'
])
total_acc_num
+=
acc_num
total_acc_num
+=
acc_num
total_sample_num
+=
sample_num
total_sample_num
+=
sample_num
logger
.
info
(
"eval batch id: {}, acc: {}"
.
format
(
total_batch_num
,
acc
))
#
logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
total_batch_num
+=
1
total_batch_num
+=
1
avg_acc
=
total_acc_num
*
1.0
/
total_sample_num
avg_acc
=
total_acc_num
*
1.0
/
total_sample_num
metrics
=
{
'avg_acc'
:
avg_acc
,
"total_acc_num"
:
total_acc_num
,
\
metrics
=
{
'avg_acc'
:
avg_acc
,
"total_acc_num"
:
total_acc_num
,
\
...
...
tools/infer/predict_rec.py
浏览文件 @
612e8014
...
@@ -40,7 +40,8 @@ class TextRecognizer(object):
...
@@ -40,7 +40,8 @@ class TextRecognizer(object):
char_ops_params
=
{
char_ops_params
=
{
"character_type"
:
args
.
rec_char_type
,
"character_type"
:
args
.
rec_char_type
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
"use_space_char"
:
args
.
use_space_char
,
"max_text_length"
:
args
.
max_text_length
}
}
if
self
.
rec_algorithm
!=
"RARE"
:
if
self
.
rec_algorithm
!=
"RARE"
:
char_ops_params
[
'loss_type'
]
=
'ctc'
char_ops_params
[
'loss_type'
]
=
'ctc'
...
...
tools/infer/utility.py
浏览文件 @
612e8014
...
@@ -59,6 +59,7 @@ def parse_args():
...
@@ -59,6 +59,7 @@ def parse_args():
parser
.
add_argument
(
"--rec_image_shape"
,
type
=
str
,
default
=
"3, 32, 320"
)
parser
.
add_argument
(
"--rec_image_shape"
,
type
=
str
,
default
=
"3, 32, 320"
)
parser
.
add_argument
(
"--rec_char_type"
,
type
=
str
,
default
=
'ch'
)
parser
.
add_argument
(
"--rec_char_type"
,
type
=
str
,
default
=
'ch'
)
parser
.
add_argument
(
"--rec_batch_num"
,
type
=
int
,
default
=
30
)
parser
.
add_argument
(
"--rec_batch_num"
,
type
=
int
,
default
=
30
)
parser
.
add_argument
(
"--max_text_length"
,
type
=
int
,
default
=
25
)
parser
.
add_argument
(
parser
.
add_argument
(
"--rec_char_dict_path"
,
"--rec_char_dict_path"
,
type
=
str
,
type
=
str
,
...
...
tools/infer_rec.py
浏览文件 @
612e8014
...
@@ -64,7 +64,6 @@ def main():
...
@@ -64,7 +64,6 @@ def main():
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
rec_model
=
create_module
(
config
[
'Architecture'
][
'function'
])(
params
=
config
)
rec_model
=
create_module
(
config
[
'Architecture'
][
'function'
])(
params
=
config
)
startup_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
eval_prog
=
fluid
.
Program
()
eval_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
eval_prog
,
startup_prog
):
with
fluid
.
program_guard
(
eval_prog
,
startup_prog
):
...
@@ -86,10 +85,36 @@ def main():
...
@@ -86,10 +85,36 @@ def main():
for
i
in
range
(
max_img_num
):
for
i
in
range
(
max_img_num
):
logger
.
info
(
"infer_img:%s"
%
infer_list
[
i
])
logger
.
info
(
"infer_img:%s"
%
infer_list
[
i
])
img
=
next
(
blobs
)
img
=
next
(
blobs
)
predict
=
exe
.
run
(
program
=
eval_prog
,
if
loss_type
!=
"srn"
:
feed
=
{
"image"
:
img
},
predict
=
exe
.
run
(
program
=
eval_prog
,
fetch_list
=
fetch_varname_list
,
feed
=
{
"image"
:
img
},
return_numpy
=
False
)
fetch_list
=
fetch_varname_list
,
return_numpy
=
False
)
else
:
encoder_word_pos_list
=
[]
gsrm_word_pos_list
=
[]
gsrm_slf_attn_bias1_list
=
[]
gsrm_slf_attn_bias2_list
=
[]
encoder_word_pos_list
.
append
(
img
[
1
])
gsrm_word_pos_list
.
append
(
img
[
2
])
gsrm_slf_attn_bias1_list
.
append
(
img
[
3
])
gsrm_slf_attn_bias2_list
.
append
(
img
[
4
])
encoder_word_pos_list
=
np
.
concatenate
(
encoder_word_pos_list
,
axis
=
0
).
astype
(
np
.
int64
)
gsrm_word_pos_list
=
np
.
concatenate
(
gsrm_word_pos_list
,
axis
=
0
).
astype
(
np
.
int64
)
gsrm_slf_attn_bias1_list
=
np
.
concatenate
(
gsrm_slf_attn_bias1_list
,
axis
=
0
).
astype
(
np
.
float32
)
gsrm_slf_attn_bias2_list
=
np
.
concatenate
(
gsrm_slf_attn_bias2_list
,
axis
=
0
).
astype
(
np
.
float32
)
predict
=
exe
.
run
(
program
=
eval_prog
,
\
feed
=
{
'image'
:
img
[
0
],
'encoder_word_pos'
:
encoder_word_pos_list
,
'gsrm_word_pos'
:
gsrm_word_pos_list
,
'gsrm_slf_attn_bias1'
:
gsrm_slf_attn_bias1_list
,
'gsrm_slf_attn_bias2'
:
gsrm_slf_attn_bias2_list
},
\
fetch_list
=
fetch_varname_list
,
\
return_numpy
=
False
)
if
loss_type
==
"ctc"
:
if
loss_type
==
"ctc"
:
preds
=
np
.
array
(
predict
[
0
])
preds
=
np
.
array
(
predict
[
0
])
preds
=
preds
.
reshape
(
-
1
)
preds
=
preds
.
reshape
(
-
1
)
...
@@ -114,7 +139,18 @@ def main():
...
@@ -114,7 +139,18 @@ def main():
score
=
np
.
mean
(
probs
[
0
,
1
:
end_pos
[
1
]])
score
=
np
.
mean
(
probs
[
0
,
1
:
end_pos
[
1
]])
preds
=
preds
.
reshape
(
-
1
)
preds
=
preds
.
reshape
(
-
1
)
preds_text
=
char_ops
.
decode
(
preds
)
preds_text
=
char_ops
.
decode
(
preds
)
elif
loss_type
==
"srn"
:
cur_pred
=
[]
preds
=
np
.
array
(
predict
[
0
])
preds
=
preds
.
reshape
(
-
1
)
probs
=
np
.
array
(
predict
[
1
])
ind
=
np
.
argmax
(
probs
,
axis
=
1
)
valid_ind
=
np
.
where
(
preds
!=
37
)[
0
]
if
len
(
valid_ind
)
==
0
:
continue
score
=
np
.
mean
(
probs
[
valid_ind
,
ind
[
valid_ind
]])
preds
=
preds
[:
valid_ind
[
-
1
]
+
1
]
preds_text
=
char_ops
.
decode
(
preds
)
logger
.
info
(
"
\t
index: {}"
.
format
(
preds
))
logger
.
info
(
"
\t
index: {}"
.
format
(
preds
))
logger
.
info
(
"
\t
word : {}"
.
format
(
preds_text
))
logger
.
info
(
"
\t
word : {}"
.
format
(
preds_text
))
logger
.
info
(
"
\t
score: {}"
.
format
(
score
))
logger
.
info
(
"
\t
score: {}"
.
format
(
score
))
...
...
tools/program.py
浏览文件 @
612e8014
...
@@ -32,7 +32,8 @@ from eval_utils.eval_det_utils import eval_det_run
...
@@ -32,7 +32,8 @@ from eval_utils.eval_det_utils import eval_det_run
from
eval_utils.eval_rec_utils
import
eval_rec_run
from
eval_utils.eval_rec_utils
import
eval_rec_run
from
ppocr.utils.save_load
import
save_model
from
ppocr.utils.save_load
import
save_model
import
numpy
as
np
import
numpy
as
np
from
ppocr.utils.character
import
cal_predicts_accuracy
,
CharacterOps
from
ppocr.utils.character
import
cal_predicts_accuracy
,
cal_predicts_accuracy_srn
,
CharacterOps
class
ArgsParser
(
ArgumentParser
):
class
ArgsParser
(
ArgumentParser
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -176,8 +177,16 @@ def build(config, main_prog, startup_prog, mode):
...
@@ -176,8 +177,16 @@ def build(config, main_prog, startup_prog, mode):
fetch_name_list
=
list
(
outputs
.
keys
())
fetch_name_list
=
list
(
outputs
.
keys
())
fetch_varname_list
=
[
outputs
[
v
].
name
for
v
in
fetch_name_list
]
fetch_varname_list
=
[
outputs
[
v
].
name
for
v
in
fetch_name_list
]
opt_loss_name
=
None
opt_loss_name
=
None
model_average
=
None
img_loss_name
=
None
word_loss_name
=
None
if
mode
==
"train"
:
if
mode
==
"train"
:
opt_loss
=
outputs
[
'total_loss'
]
opt_loss
=
outputs
[
'total_loss'
]
# srn loss
#img_loss = outputs['img_loss']
#word_loss = outputs['word_loss']
#img_loss_name = img_loss.name
#word_loss_name = word_loss.name
opt_params
=
config
[
'Optimizer'
]
opt_params
=
config
[
'Optimizer'
]
optimizer
=
create_module
(
opt_params
[
'function'
])(
opt_params
)
optimizer
=
create_module
(
opt_params
[
'function'
])(
opt_params
)
optimizer
.
minimize
(
opt_loss
)
optimizer
.
minimize
(
opt_loss
)
...
@@ -185,7 +194,17 @@ def build(config, main_prog, startup_prog, mode):
...
@@ -185,7 +194,17 @@ def build(config, main_prog, startup_prog, mode):
global_lr
=
optimizer
.
_global_learning_rate
()
global_lr
=
optimizer
.
_global_learning_rate
()
fetch_name_list
.
insert
(
0
,
"lr"
)
fetch_name_list
.
insert
(
0
,
"lr"
)
fetch_varname_list
.
insert
(
0
,
global_lr
.
name
)
fetch_varname_list
.
insert
(
0
,
global_lr
.
name
)
return
(
dataloader
,
fetch_name_list
,
fetch_varname_list
,
opt_loss_name
)
if
"loss_type"
in
config
[
"Global"
]:
if
config
[
'Global'
][
"loss_type"
]
==
'srn'
:
model_average
=
fluid
.
optimizer
.
ModelAverage
(
config
[
'Global'
][
'average_window'
],
min_average_window
=
config
[
'Global'
][
'min_average_window'
],
max_average_window
=
config
[
'Global'
][
'max_average_window'
])
return
(
dataloader
,
fetch_name_list
,
fetch_varname_list
,
opt_loss_name
,
model_average
)
def
build_export
(
config
,
main_prog
,
startup_prog
):
def
build_export
(
config
,
main_prog
,
startup_prog
):
...
@@ -329,14 +348,20 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
...
@@ -329,14 +348,20 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
lr
=
np
.
mean
(
np
.
array
(
train_outs
[
fetch_map
[
'lr'
]]))
lr
=
np
.
mean
(
np
.
array
(
train_outs
[
fetch_map
[
'lr'
]]))
preds_idx
=
fetch_map
[
'decoded_out'
]
preds_idx
=
fetch_map
[
'decoded_out'
]
preds
=
np
.
array
(
train_outs
[
preds_idx
])
preds
=
np
.
array
(
train_outs
[
preds_idx
])
preds_lod
=
train_outs
[
preds_idx
].
lod
()[
0
]
labels_idx
=
fetch_map
[
'label'
]
labels_idx
=
fetch_map
[
'label'
]
labels
=
np
.
array
(
train_outs
[
labels_idx
])
labels
=
np
.
array
(
train_outs
[
labels_idx
])
labels_lod
=
train_outs
[
labels_idx
].
lod
()[
0
]
acc
,
acc_num
,
img_num
=
cal_predicts_accuracy
(
if
config
[
'Global'
][
'loss_type'
]
!=
'srn'
:
config
[
'Global'
][
'char_ops'
],
preds
,
preds_lod
,
labels
,
preds_lod
=
train_outs
[
preds_idx
].
lod
()[
0
]
labels_lod
)
labels_lod
=
train_outs
[
labels_idx
].
lod
()[
0
]
acc
,
acc_num
,
img_num
=
cal_predicts_accuracy
(
config
[
'Global'
][
'char_ops'
],
preds
,
preds_lod
,
labels
,
labels_lod
)
else
:
acc
,
acc_num
,
img_num
=
cal_predicts_accuracy_srn
(
config
[
'Global'
][
'char_ops'
],
preds
,
labels
,
config
[
'Global'
][
'max_text_length'
])
t2
=
time
.
time
()
t2
=
time
.
time
()
train_batch_elapse
=
t2
-
t1
train_batch_elapse
=
t2
-
t1
stats
=
{
'loss'
:
loss
,
'acc'
:
acc
}
stats
=
{
'loss'
:
loss
,
'acc'
:
acc
}
...
@@ -350,6 +375,9 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
...
@@ -350,6 +375,9 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
if
train_batch_id
>
0
and
\
if
train_batch_id
>
0
and
\
train_batch_id
%
eval_batch_step
==
0
:
train_batch_id
%
eval_batch_step
==
0
:
model_average
=
train_info_dict
[
'model_average'
]
if
model_average
!=
None
:
model_average
.
apply
(
exe
)
metrics
=
eval_rec_run
(
exe
,
config
,
eval_info_dict
,
"eval"
)
metrics
=
eval_rec_run
(
exe
,
config
,
eval_info_dict
,
"eval"
)
eval_acc
=
metrics
[
'avg_acc'
]
eval_acc
=
metrics
[
'avg_acc'
]
eval_sample_num
=
metrics
[
'total_sample_num'
]
eval_sample_num
=
metrics
[
'total_sample_num'
]
...
@@ -375,6 +403,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
...
@@ -375,6 +403,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
save_model
(
train_info_dict
[
'train_program'
],
save_path
)
save_model
(
train_info_dict
[
'train_program'
],
save_path
)
return
return
def
preprocess
():
def
preprocess
():
FLAGS
=
ArgsParser
().
parse_args
()
FLAGS
=
ArgsParser
().
parse_args
()
config
=
load_config
(
FLAGS
.
config
)
config
=
load_config
(
FLAGS
.
config
)
...
@@ -386,8 +415,8 @@ def preprocess():
...
@@ -386,8 +415,8 @@ def preprocess():
check_gpu
(
use_gpu
)
check_gpu
(
use_gpu
)
alg
=
config
[
'Global'
][
'algorithm'
]
alg
=
config
[
'Global'
][
'algorithm'
]
assert
alg
in
[
'EAST'
,
'DB'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
]
assert
alg
in
[
'EAST'
,
'DB'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
]
if
alg
in
[
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
]:
if
alg
in
[
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
]:
config
[
'Global'
][
'char_ops'
]
=
CharacterOps
(
config
[
'Global'
])
config
[
'Global'
][
'char_ops'
]
=
CharacterOps
(
config
[
'Global'
])
place
=
fluid
.
CUDAPlace
(
0
)
if
use_gpu
else
fluid
.
CPUPlace
()
place
=
fluid
.
CUDAPlace
(
0
)
if
use_gpu
else
fluid
.
CPUPlace
()
...
...
tools/train.py
浏览文件 @
612e8014
...
@@ -52,6 +52,7 @@ def main():
...
@@ -52,6 +52,7 @@ def main():
train_fetch_name_list
=
train_build_outputs
[
1
]
train_fetch_name_list
=
train_build_outputs
[
1
]
train_fetch_varname_list
=
train_build_outputs
[
2
]
train_fetch_varname_list
=
train_build_outputs
[
2
]
train_opt_loss_name
=
train_build_outputs
[
3
]
train_opt_loss_name
=
train_build_outputs
[
3
]
model_average
=
train_build_outputs
[
-
1
]
eval_program
=
fluid
.
Program
()
eval_program
=
fluid
.
Program
()
eval_build_outputs
=
program
.
build
(
eval_build_outputs
=
program
.
build
(
...
@@ -85,7 +86,8 @@ def main():
...
@@ -85,7 +86,8 @@ def main():
'train_program'
:
train_program
,
\
'train_program'
:
train_program
,
\
'reader'
:
train_loader
,
\
'reader'
:
train_loader
,
\
'fetch_name_list'
:
train_fetch_name_list
,
\
'fetch_name_list'
:
train_fetch_name_list
,
\
'fetch_varname_list'
:
train_fetch_varname_list
}
'fetch_varname_list'
:
train_fetch_varname_list
,
\
'model_average'
:
model_average
}
eval_info_dict
=
{
'program'
:
eval_program
,
\
eval_info_dict
=
{
'program'
:
eval_program
,
\
'reader'
:
eval_reader
,
\
'reader'
:
eval_reader
,
\
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录