Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
c1fd4664
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
1 年多 前同步成功
通知
1533
Star
32963
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c1fd4664
编写于
12月 30, 2020
作者:
T
tink2123
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add srn for dygraph
上级
de3e2e7c
变更
28
隐藏空白更改
内联
并排
Showing
28 changed file
with
1594 addition
and
70 deletion
+1594
-70
configs/rec/rec_mv3_none_bilstm_ctc.yml
configs/rec/rec_mv3_none_bilstm_ctc.yml
+3
-3
configs/rec/rec_mv3_none_none_ctc.yml
configs/rec/rec_mv3_none_none_ctc.yml
+2
-2
configs/rec/rec_mv3_tps_bilstm_ctc.yml
configs/rec/rec_mv3_tps_bilstm_ctc.yml
+2
-2
configs/rec/rec_r34_vd_none_bilstm_ctc.yml
configs/rec/rec_r34_vd_none_bilstm_ctc.yml
+2
-2
configs/rec/rec_r34_vd_none_none_ctc.yml
configs/rec/rec_r34_vd_none_none_ctc.yml
+2
-2
configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
+2
-2
configs/rec/rec_r50_fpn_srn.yml
configs/rec/rec_r50_fpn_srn.yml
+106
-0
ppocr/data/__init__.py
ppocr/data/__init__.py
+2
-2
ppocr/data/imaug/__init__.py
ppocr/data/imaug/__init__.py
+1
-1
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+48
-0
ppocr/data/imaug/rec_img_aug.py
ppocr/data/imaug/rec_img_aug.py
+75
-15
ppocr/data/lmdb_dataset.py
ppocr/data/lmdb_dataset.py
+2
-2
ppocr/losses/__init__.py
ppocr/losses/__init__.py
+4
-1
ppocr/losses/rec_srn_loss.py
ppocr/losses/rec_srn_loss.py
+47
-0
ppocr/metrics/__init__.py
ppocr/metrics/__init__.py
+1
-0
ppocr/metrics/rec_metric.py
ppocr/metrics/rec_metric.py
+1
-3
ppocr/modeling/architectures/base_model.py
ppocr/modeling/architectures/base_model.py
+5
-2
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+2
-1
ppocr/modeling/backbones/rec_resnet_fpn.py
ppocr/modeling/backbones/rec_resnet_fpn.py
+307
-0
ppocr/modeling/heads/__init__.py
ppocr/modeling/heads/__init__.py
+4
-1
ppocr/modeling/heads/rec_srn_head.py
ppocr/modeling/heads/rec_srn_head.py
+279
-0
ppocr/modeling/heads/self_attention.py
ppocr/modeling/heads/self_attention.py
+408
-0
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+3
-2
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+83
-1
tools/export_model.py
tools/export_model.py
+34
-7
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+136
-13
tools/infer_rec.py
tools/infer_rec.py
+23
-2
tools/program.py
tools/program.py
+10
-4
未找到文件。
configs/rec/rec_mv3_none_bilstm_ctc.yml
浏览文件 @
c1fd4664
Global
:
Global
:
use_gpu
:
t
rue
use_gpu
:
T
rue
epoch_num
:
72
epoch_num
:
72
log_smooth_window
:
20
log_smooth_window
:
20
print_batch_step
:
10
print_batch_step
:
10
...
@@ -59,7 +59,7 @@ Metric:
...
@@ -59,7 +59,7 @@ Metric:
Train
:
Train
:
dataset
:
dataset
:
name
:
LMDBDat
e
Set
name
:
LMDBDat
a
Set
data_dir
:
./train_data/data_lmdb_release/training/
data_dir
:
./train_data/data_lmdb_release/training/
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
...
@@ -78,7 +78,7 @@ Train:
...
@@ -78,7 +78,7 @@ Train:
Eval
:
Eval
:
dataset
:
dataset
:
name
:
LMDBDat
e
Set
name
:
LMDBDat
a
Set
data_dir
:
./train_data/data_lmdb_release/validation/
data_dir
:
./train_data/data_lmdb_release/validation/
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
...
...
configs/rec/rec_mv3_none_none_ctc.yml
浏览文件 @
c1fd4664
...
@@ -58,7 +58,7 @@ Metric:
...
@@ -58,7 +58,7 @@ Metric:
Train
:
Train
:
dataset
:
dataset
:
name
:
LMDBDat
e
Set
name
:
LMDBDat
a
Set
data_dir
:
./train_data/data_lmdb_release/training/
data_dir
:
./train_data/data_lmdb_release/training/
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
...
@@ -77,7 +77,7 @@ Train:
...
@@ -77,7 +77,7 @@ Train:
Eval
:
Eval
:
dataset
:
dataset
:
name
:
LMDBDat
e
Set
name
:
LMDBDat
a
Set
data_dir
:
./train_data/data_lmdb_release/validation/
data_dir
:
./train_data/data_lmdb_release/validation/
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
...
...
configs/rec/rec_mv3_tps_bilstm_ctc.yml
浏览文件 @
c1fd4664
...
@@ -63,7 +63,7 @@ Metric:
...
@@ -63,7 +63,7 @@ Metric:
Train
:
Train
:
dataset
:
dataset
:
name
:
LMDBDat
e
Set
name
:
LMDBDat
a
Set
data_dir
:
./train_data/data_lmdb_release/training/
data_dir
:
./train_data/data_lmdb_release/training/
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
...
@@ -82,7 +82,7 @@ Train:
...
@@ -82,7 +82,7 @@ Train:
Eval
:
Eval
:
dataset
:
dataset
:
name
:
LMDBDat
e
Set
name
:
LMDBDat
a
Set
data_dir
:
./train_data/data_lmdb_release/validation/
data_dir
:
./train_data/data_lmdb_release/validation/
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
...
...
configs/rec/rec_r34_vd_none_bilstm_ctc.yml
浏览文件 @
c1fd4664
...
@@ -58,7 +58,7 @@ Metric:
...
@@ -58,7 +58,7 @@ Metric:
Train
:
Train
:
dataset
:
dataset
:
name
:
LMDBDat
e
Set
name
:
LMDBDat
a
Set
data_dir
:
./train_data/data_lmdb_release/training/
data_dir
:
./train_data/data_lmdb_release/training/
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
...
@@ -77,7 +77,7 @@ Train:
...
@@ -77,7 +77,7 @@ Train:
Eval
:
Eval
:
dataset
:
dataset
:
name
:
LMDBDat
e
Set
name
:
LMDBDat
a
Set
data_dir
:
./train_data/data_lmdb_release/validation/
data_dir
:
./train_data/data_lmdb_release/validation/
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
...
...
configs/rec/rec_r34_vd_none_none_ctc.yml
浏览文件 @
c1fd4664
...
@@ -56,7 +56,7 @@ Metric:
...
@@ -56,7 +56,7 @@ Metric:
Train
:
Train
:
dataset
:
dataset
:
name
:
LMDBDat
e
Set
name
:
LMDBDat
a
Set
data_dir
:
./train_data/data_lmdb_release/training/
data_dir
:
./train_data/data_lmdb_release/training/
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
...
@@ -75,7 +75,7 @@ Train:
...
@@ -75,7 +75,7 @@ Train:
Eval
:
Eval
:
dataset
:
dataset
:
name
:
LMDBDat
e
Set
name
:
LMDBDat
a
Set
data_dir
:
./train_data/data_lmdb_release/validation/
data_dir
:
./train_data/data_lmdb_release/validation/
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
...
...
configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
浏览文件 @
c1fd4664
...
@@ -62,7 +62,7 @@ Metric:
...
@@ -62,7 +62,7 @@ Metric:
Train
:
Train
:
dataset
:
dataset
:
name
:
LMDBDat
e
Set
name
:
LMDBDat
a
Set
data_dir
:
./train_data/data_lmdb_release/training/
data_dir
:
./train_data/data_lmdb_release/training/
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
...
@@ -81,7 +81,7 @@ Train:
...
@@ -81,7 +81,7 @@ Train:
Eval
:
Eval
:
dataset
:
dataset
:
name
:
LMDBDat
e
Set
name
:
LMDBDat
a
Set
data_dir
:
./train_data/data_lmdb_release/validation/
data_dir
:
./train_data/data_lmdb_release/validation/
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
...
...
configs/rec/rec_r50_fpn_srn.yml
0 → 100644
浏览文件 @
c1fd4664
Global
:
use_gpu
:
True
epoch_num
:
72
log_smooth_window
:
20
print_batch_step
:
5
save_model_dir
:
./output/rec/srn
save_epoch_step
:
3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
0
,
5000
]
# if pretrained_model is saved in static mode, load_static_weights must set to True
cal_metric_during_train
:
True
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path
:
character_type
:
en
max_text_length
:
25
num_heads
:
8
infer_mode
:
False
use_space_char
:
False
Optimizer
:
name
:
Adam
lr
:
name
:
Cosine
learning_rate
:
0.0001
Architecture
:
model_type
:
rec
algorithm
:
SRN
in_channels
:
1
Transform
:
Backbone
:
name
:
ResNetFPN
Head
:
name
:
SRNHead
max_text_length
:
25
num_heads
:
8
num_encoder_TUs
:
2
num_decoder_TUs
:
4
hidden_dims
:
512
Loss
:
name
:
SRNLoss
PostProcess
:
name
:
SRNLabelDecode
Metric
:
name
:
RecMetric
main_indicator
:
acc
Train
:
dataset
:
name
:
LMDBDataSet
data_dir
:
./train_data/srn_train_data_duiqi
#label_file_list: ["./train_data/ic15_data/1.txt"]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
SRNLabelEncode
:
# Class handling label
-
SRNRecResizeImg
:
image_shape
:
[
1
,
64
,
256
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
,
'
encoder_word_pos'
,
'
gsrm_word_pos'
,
'
gsrm_slf_attn_bias1'
,
'
gsrm_slf_attn_bias2'
]
# dataloader will return list in this order
loader
:
shuffle
:
False
batch_size_per_card
:
64
drop_last
:
True
num_workers
:
4
Eval
:
dataset
:
name
:
LMDBDataSet
data_dir
:
./train_data/data_lmdb_release/evaluation
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
SRNLabelEncode
:
# Class handling label
-
SRNRecResizeImg
:
image_shape
:
[
1
,
64
,
256
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
,
'
encoder_word_pos'
,
'
gsrm_word_pos'
,
'
gsrm_slf_attn_bias1'
,
'
gsrm_slf_attn_bias2'
]
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
32
num_workers
:
4
ppocr/data/__init__.py
浏览文件 @
c1fd4664
...
@@ -33,7 +33,7 @@ import paddle.distributed as dist
...
@@ -33,7 +33,7 @@ import paddle.distributed as dist
from
ppocr.data.imaug
import
transform
,
create_operators
from
ppocr.data.imaug
import
transform
,
create_operators
from
ppocr.data.simple_dataset
import
SimpleDataSet
from
ppocr.data.simple_dataset
import
SimpleDataSet
from
ppocr.data.lmdb_dataset
import
LMDBDat
e
Set
from
ppocr.data.lmdb_dataset
import
LMDBDat
a
Set
__all__
=
[
'build_dataloader'
,
'transform'
,
'create_operators'
]
__all__
=
[
'build_dataloader'
,
'transform'
,
'create_operators'
]
...
@@ -54,7 +54,7 @@ signal.signal(signal.SIGTERM, term_mp)
...
@@ -54,7 +54,7 @@ signal.signal(signal.SIGTERM, term_mp)
def
build_dataloader
(
config
,
mode
,
device
,
logger
):
def
build_dataloader
(
config
,
mode
,
device
,
logger
):
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
support_dict
=
[
'SimpleDataSet'
,
'LMDBDat
e
Set'
]
support_dict
=
[
'SimpleDataSet'
,
'LMDBDat
a
Set'
]
module_name
=
config
[
mode
][
'dataset'
][
'name'
]
module_name
=
config
[
mode
][
'dataset'
][
'name'
]
assert
module_name
in
support_dict
,
Exception
(
assert
module_name
in
support_dict
,
Exception
(
'DataSet only support {}'
.
format
(
support_dict
))
'DataSet only support {}'
.
format
(
support_dict
))
...
...
ppocr/data/imaug/__init__.py
浏览文件 @
c1fd4664
...
@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
...
@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
from
.make_shrink_map
import
MakeShrinkMap
from
.make_shrink_map
import
MakeShrinkMap
from
.random_crop_data
import
EastRandomCropData
,
PSERandomCrop
from
.random_crop_data
import
EastRandomCropData
,
PSERandomCrop
from
.rec_img_aug
import
RecAug
,
RecResizeImg
,
ClsResizeImg
from
.rec_img_aug
import
RecAug
,
RecResizeImg
,
ClsResizeImg
,
SRNRecResizeImg
from
.randaugment
import
RandAugment
from
.randaugment
import
RandAugment
from
.operators
import
*
from
.operators
import
*
from
.label_ops
import
*
from
.label_ops
import
*
...
...
ppocr/data/imaug/label_ops.py
浏览文件 @
c1fd4664
...
@@ -98,6 +98,8 @@ class BaseRecLabelEncode(object):
...
@@ -98,6 +98,8 @@ class BaseRecLabelEncode(object):
support_character_type
,
character_type
)
support_character_type
,
character_type
)
self
.
max_text_len
=
max_text_length
self
.
max_text_len
=
max_text_length
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
if
character_type
==
"en"
:
if
character_type
==
"en"
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
dict_character
=
list
(
self
.
character_str
)
dict_character
=
list
(
self
.
character_str
)
...
@@ -213,3 +215,49 @@ class AttnLabelEncode(BaseRecLabelEncode):
...
@@ -213,3 +215,49 @@ class AttnLabelEncode(BaseRecLabelEncode):
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx"
\
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx"
\
%
beg_or_end
%
beg_or_end
return
idx
return
idx
class
SRNLabelEncode
(
BaseRecLabelEncode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
max_text_length
=
25
,
character_dict_path
=
None
,
character_type
=
'en'
,
use_space_char
=
False
,
**
kwargs
):
super
(
SRNLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
character_type
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
dict_character
=
dict_character
+
[
self
.
beg_str
,
self
.
end_str
]
return
dict_character
def
__call__
(
self
,
data
):
text
=
data
[
'label'
]
text
=
self
.
encode
(
text
)
char_num
=
len
(
self
.
character_str
)
if
text
is
None
:
return
None
if
len
(
text
)
>
self
.
max_text_len
:
return
None
data
[
'length'
]
=
np
.
array
(
len
(
text
))
text
=
text
+
[
char_num
]
*
(
self
.
max_text_len
-
len
(
text
))
data
[
'label'
]
=
np
.
array
(
text
)
return
data
def
get_ignored_tokens
(
self
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"beg"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"end"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx"
\
%
beg_or_end
return
idx
ppocr/data/imaug/rec_img_aug.py
浏览文件 @
c1fd4664
...
@@ -12,20 +12,6 @@
...
@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
math
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
...
@@ -77,6 +63,26 @@ class RecResizeImg(object):
...
@@ -77,6 +63,26 @@ class RecResizeImg(object):
return
data
return
data
class
SRNRecResizeImg
(
object
):
def
__init__
(
self
,
image_shape
,
num_heads
,
max_text_length
,
**
kwargs
):
self
.
image_shape
=
image_shape
self
.
num_heads
=
num_heads
self
.
max_text_length
=
max_text_length
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
norm_img
=
resize_norm_img_srn
(
img
,
self
.
image_shape
)
data
[
'image'
]
=
norm_img
[
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
=
\
srn_other_inputs
(
self
.
image_shape
,
self
.
num_heads
,
self
.
max_text_length
)
data
[
'encoder_word_pos'
]
=
encoder_word_pos
data
[
'gsrm_word_pos'
]
=
gsrm_word_pos
data
[
'gsrm_slf_attn_bias1'
]
=
gsrm_slf_attn_bias1
data
[
'gsrm_slf_attn_bias2'
]
=
gsrm_slf_attn_bias2
return
data
def
resize_norm_img
(
img
,
image_shape
):
def
resize_norm_img
(
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
imgC
,
imgH
,
imgW
=
image_shape
h
=
img
.
shape
[
0
]
h
=
img
.
shape
[
0
]
...
@@ -103,7 +109,7 @@ def resize_norm_img(img, image_shape):
...
@@ -103,7 +109,7 @@ def resize_norm_img(img, image_shape):
def
resize_norm_img_chinese
(
img
,
image_shape
):
def
resize_norm_img_chinese
(
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
imgC
,
imgH
,
imgW
=
image_shape
# todo: change to 0 and modified image shape
# todo: change to 0 and modified image shape
max_wh_ratio
=
0
max_wh_ratio
=
imgW
*
1.0
/
imgH
h
,
w
=
img
.
shape
[
0
],
img
.
shape
[
1
]
h
,
w
=
img
.
shape
[
0
],
img
.
shape
[
1
]
ratio
=
w
*
1.0
/
h
ratio
=
w
*
1.0
/
h
max_wh_ratio
=
max
(
max_wh_ratio
,
ratio
)
max_wh_ratio
=
max
(
max_wh_ratio
,
ratio
)
...
@@ -126,6 +132,60 @@ def resize_norm_img_chinese(img, image_shape):
...
@@ -126,6 +132,60 @@ def resize_norm_img_chinese(img, image_shape):
return
padding_im
return
padding_im
def
resize_norm_img_srn
(
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
img_black
=
np
.
zeros
((
imgH
,
imgW
))
im_hei
=
img
.
shape
[
0
]
im_wid
=
img
.
shape
[
1
]
if
im_wid
<=
im_hei
*
1
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
1
,
imgH
))
elif
im_wid
<=
im_hei
*
2
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
2
,
imgH
))
elif
im_wid
<=
im_hei
*
3
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
3
,
imgH
))
else
:
img_new
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
))
img_np
=
np
.
asarray
(
img_new
)
img_np
=
cv2
.
cvtColor
(
img_np
,
cv2
.
COLOR_BGR2GRAY
)
img_black
[:,
0
:
img_np
.
shape
[
1
]]
=
img_np
img_black
=
img_black
[:,
:,
np
.
newaxis
]
row
,
col
,
c
=
img_black
.
shape
c
=
1
return
np
.
reshape
(
img_black
,
(
c
,
row
,
col
)).
astype
(
np
.
float32
)
def
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
):
imgC
,
imgH
,
imgW
=
image_shape
feature_dim
=
int
((
imgH
/
8
)
*
(
imgW
/
8
))
encoder_word_pos
=
np
.
array
(
range
(
0
,
feature_dim
)).
reshape
(
(
feature_dim
,
1
)).
astype
(
'int64'
)
gsrm_word_pos
=
np
.
array
(
range
(
0
,
max_text_length
)).
reshape
(
(
max_text_length
,
1
)).
astype
(
'int64'
)
gsrm_attn_bias_data
=
np
.
ones
((
1
,
max_text_length
,
max_text_length
))
gsrm_slf_attn_bias1
=
np
.
triu
(
gsrm_attn_bias_data
,
1
).
reshape
(
[
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias1
=
np
.
tile
(
gsrm_slf_attn_bias1
,
[
num_heads
,
1
,
1
])
*
[
-
1e9
]
gsrm_slf_attn_bias2
=
np
.
tril
(
gsrm_attn_bias_data
,
-
1
).
reshape
(
[
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias2
=
np
.
tile
(
gsrm_slf_attn_bias2
,
[
num_heads
,
1
,
1
])
*
[
-
1e9
]
return
[
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
def
flag
():
def
flag
():
"""
"""
flag
flag
...
...
ppocr/data/lmdb_dataset.py
浏览文件 @
c1fd4664
...
@@ -20,9 +20,9 @@ import cv2
...
@@ -20,9 +20,9 @@ import cv2
from
.imaug
import
transform
,
create_operators
from
.imaug
import
transform
,
create_operators
class
LMDBDat
e
Set
(
Dataset
):
class
LMDBDat
a
Set
(
Dataset
):
def
__init__
(
self
,
config
,
mode
,
logger
):
def
__init__
(
self
,
config
,
mode
,
logger
):
super
(
LMDBDat
e
Set
,
self
).
__init__
()
super
(
LMDBDat
a
Set
,
self
).
__init__
()
global_config
=
config
[
'Global'
]
global_config
=
config
[
'Global'
]
dataset_config
=
config
[
mode
][
'dataset'
]
dataset_config
=
config
[
mode
][
'dataset'
]
...
...
ppocr/losses/__init__.py
浏览文件 @
c1fd4664
...
@@ -23,11 +23,14 @@ def build_loss(config):
...
@@ -23,11 +23,14 @@ def build_loss(config):
# rec loss
# rec loss
from
.rec_ctc_loss
import
CTCLoss
from
.rec_ctc_loss
import
CTCLoss
from
.rec_srn_loss
import
SRNLoss
# cls loss
# cls loss
from
.cls_loss
import
ClsLoss
from
.cls_loss
import
ClsLoss
support_dict
=
[
'DBLoss'
,
'EASTLoss'
,
'SASTLoss'
,
'CTCLoss'
,
'ClsLoss'
]
support_dict
=
[
'DBLoss'
,
'EASTLoss'
,
'SASTLoss'
,
'CTCLoss'
,
'ClsLoss'
,
'SRNLoss'
]
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
module_name
=
config
.
pop
(
'name'
)
...
...
ppocr/losses/rec_srn_loss.py
0 → 100644
浏览文件 @
c1fd4664
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
nn
class
SRNLoss
(
nn
.
Layer
):
def
__init__
(
self
,
**
kwargs
):
super
(
SRNLoss
,
self
).
__init__
()
self
.
loss_func
=
paddle
.
nn
.
loss
.
CrossEntropyLoss
(
reduction
=
"sum"
)
def
forward
(
self
,
predicts
,
batch
):
predict
=
predicts
[
'predict'
]
word_predict
=
predicts
[
'word_out'
]
gsrm_predict
=
predicts
[
'gsrm_out'
]
label
=
batch
[
1
]
casted_label
=
paddle
.
cast
(
x
=
label
,
dtype
=
'int64'
)
casted_label
=
paddle
.
reshape
(
x
=
casted_label
,
shape
=
[
-
1
,
1
])
cost_word
=
self
.
loss_func
(
word_predict
,
label
=
casted_label
)
cost_gsrm
=
self
.
loss_func
(
gsrm_predict
,
label
=
casted_label
)
cost_vsfd
=
self
.
loss_func
(
predict
,
label
=
casted_label
)
cost_word
=
paddle
.
reshape
(
x
=
paddle
.
sum
(
cost_word
),
shape
=
[
1
])
cost_gsrm
=
paddle
.
reshape
(
x
=
paddle
.
sum
(
cost_gsrm
),
shape
=
[
1
])
cost_vsfd
=
paddle
.
reshape
(
x
=
paddle
.
sum
(
cost_vsfd
),
shape
=
[
1
])
sum_cost
=
cost_word
+
cost_vsfd
*
2.0
+
cost_gsrm
*
0.15
return
{
'loss'
:
sum_cost
,
'word_loss'
:
cost_word
,
'img_loss'
:
cost_vsfd
}
ppocr/metrics/__init__.py
浏览文件 @
c1fd4664
...
@@ -26,6 +26,7 @@ def build_metric(config):
...
@@ -26,6 +26,7 @@ def build_metric(config):
from
.det_metric
import
DetMetric
from
.det_metric
import
DetMetric
from
.rec_metric
import
RecMetric
from
.rec_metric
import
RecMetric
from
.cls_metric
import
ClsMetric
from
.cls_metric
import
ClsMetric
from
.rec_metric
import
RecMetric
support_dict
=
[
'DetMetric'
,
'RecMetric'
,
'ClsMetric'
]
support_dict
=
[
'DetMetric'
,
'RecMetric'
,
'ClsMetric'
]
...
...
ppocr/metrics/rec_metric.py
浏览文件 @
c1fd4664
...
@@ -31,8 +31,6 @@ class RecMetric(object):
...
@@ -31,8 +31,6 @@ class RecMetric(object):
if
pred
==
target
:
if
pred
==
target
:
correct_num
+=
1
correct_num
+=
1
all_num
+=
1
all_num
+=
1
# if all_num < 10 and kwargs.get('show_str', False):
# print('{} -> {}'.format(pred, target))
self
.
correct_num
+=
correct_num
self
.
correct_num
+=
correct_num
self
.
all_num
+=
all_num
self
.
all_num
+=
all_num
self
.
norm_edit_dis
+=
norm_edit_dis
self
.
norm_edit_dis
+=
norm_edit_dis
...
@@ -48,7 +46,7 @@ class RecMetric(object):
...
@@ -48,7 +46,7 @@ class RecMetric(object):
'norm_edit_dis': 0,
'norm_edit_dis': 0,
}
}
"""
"""
acc
=
self
.
correct_num
/
self
.
all_num
acc
=
1.0
*
self
.
correct_num
/
self
.
all_num
norm_edit_dis
=
1
-
self
.
norm_edit_dis
/
self
.
all_num
norm_edit_dis
=
1
-
self
.
norm_edit_dis
/
self
.
all_num
self
.
reset
()
self
.
reset
()
return
{
'acc'
:
acc
,
'norm_edit_dis'
:
norm_edit_dis
}
return
{
'acc'
:
acc
,
'norm_edit_dis'
:
norm_edit_dis
}
...
...
ppocr/modeling/architectures/base_model.py
浏览文件 @
c1fd4664
...
@@ -68,11 +68,14 @@ class BaseModel(nn.Layer):
...
@@ -68,11 +68,14 @@ class BaseModel(nn.Layer):
config
[
"Head"
][
'in_channels'
]
=
in_channels
config
[
"Head"
][
'in_channels'
]
=
in_channels
self
.
head
=
build_head
(
config
[
"Head"
])
self
.
head
=
build_head
(
config
[
"Head"
])
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
data
=
None
):
if
self
.
use_transform
:
if
self
.
use_transform
:
x
=
self
.
transform
(
x
)
x
=
self
.
transform
(
x
)
x
=
self
.
backbone
(
x
)
x
=
self
.
backbone
(
x
)
if
self
.
use_neck
:
if
self
.
use_neck
:
x
=
self
.
neck
(
x
)
x
=
self
.
neck
(
x
)
x
=
self
.
head
(
x
)
if
data
is
None
:
x
=
self
.
head
(
x
)
else
:
x
=
self
.
head
(
x
,
data
)
return
x
return
x
ppocr/modeling/backbones/__init__.py
浏览文件 @
c1fd4664
...
@@ -24,7 +24,8 @@ def build_backbone(config, model_type):
...
@@ -24,7 +24,8 @@ def build_backbone(config, model_type):
elif
model_type
==
'rec'
or
model_type
==
'cls'
:
elif
model_type
==
'rec'
or
model_type
==
'cls'
:
from
.rec_mobilenet_v3
import
MobileNetV3
from
.rec_mobilenet_v3
import
MobileNetV3
from
.rec_resnet_vd
import
ResNet
from
.rec_resnet_vd
import
ResNet
support_dict
=
[
'MobileNetV3'
,
'ResNet'
,
'ResNet_FPN'
]
from
.rec_resnet_fpn
import
ResNetFPN
support_dict
=
[
'MobileNetV3'
,
'ResNet'
,
'ResNetFPN'
]
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
...
ppocr/modeling/backbones/rec_resnet_fpn.py
0 → 100644
浏览文件 @
c1fd4664
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
paddle
import
nn
,
ParamAttr
from
paddle.nn
import
functional
as
F
import
paddle.fluid
as
fluid
import
paddle
import
numpy
as
np
__all__
=
[
"ResNetFPN"
]
class
ResNetFPN
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
1
,
layers
=
50
,
**
kwargs
):
super
(
ResNetFPN
,
self
).
__init__
()
supported_layers
=
{
18
:
{
'depth'
:
[
2
,
2
,
2
,
2
],
'block_class'
:
BasicBlock
},
34
:
{
'depth'
:
[
3
,
4
,
6
,
3
],
'block_class'
:
BasicBlock
},
50
:
{
'depth'
:
[
3
,
4
,
6
,
3
],
'block_class'
:
BottleneckBlock
},
101
:
{
'depth'
:
[
3
,
4
,
23
,
3
],
'block_class'
:
BottleneckBlock
},
152
:
{
'depth'
:
[
3
,
8
,
36
,
3
],
'block_class'
:
BottleneckBlock
}
}
stride_list
=
[(
2
,
2
),
(
2
,
2
),
(
1
,
1
),
(
1
,
1
)]
num_filters
=
[
64
,
128
,
256
,
512
]
self
.
depth
=
supported_layers
[
layers
][
'depth'
]
self
.
F
=
[]
self
.
conv
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
64
,
kernel_size
=
7
,
stride
=
2
,
act
=
"relu"
,
name
=
"conv1"
)
self
.
block_list
=
[]
in_ch
=
64
if
layers
>=
50
:
for
block
in
range
(
len
(
self
.
depth
)):
for
i
in
range
(
self
.
depth
[
block
]):
if
layers
in
[
101
,
152
]
and
block
==
2
:
if
i
==
0
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"a"
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"b"
+
str
(
i
)
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
block_list
=
self
.
add_sublayer
(
"bottleneckBlock_{}_{}"
.
format
(
block
,
i
),
BottleneckBlock
(
in_channels
=
in_ch
,
out_channels
=
num_filters
[
block
],
stride
=
stride_list
[
block
]
if
i
==
0
else
1
,
name
=
conv_name
))
in_ch
=
num_filters
[
block
]
*
4
self
.
block_list
.
append
(
block_list
)
self
.
F
.
append
(
block_list
)
else
:
for
block
in
range
(
len
(
self
.
depth
)):
for
i
in
range
(
self
.
depth
[
block
]):
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
if
i
==
0
and
block
!=
0
:
stride
=
(
2
,
1
)
else
:
stride
=
(
1
,
1
)
basic_block
=
self
.
add_sublayer
(
conv_name
,
BasicBlock
(
in_channels
=
in_ch
,
out_channels
=
num_filters
[
block
],
stride
=
stride_list
[
block
]
if
i
==
0
else
1
,
is_first
=
block
==
i
==
0
,
name
=
conv_name
))
in_ch
=
basic_block
.
out_channels
self
.
block_list
.
append
(
basic_block
)
out_ch_list
=
[
in_ch
//
4
,
in_ch
//
2
,
in_ch
]
self
.
base_block
=
[]
self
.
conv_trans
=
[]
self
.
bn_block
=
[]
for
i
in
[
-
2
,
-
3
]:
in_channels
=
out_ch_list
[
i
+
1
]
+
out_ch_list
[
i
]
self
.
base_block
.
append
(
self
.
add_sublayer
(
"F_{}_base_block_0"
.
format
(
i
),
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_ch_list
[
i
],
kernel_size
=
1
,
weight_attr
=
ParamAttr
(
trainable
=
True
),
bias_attr
=
ParamAttr
(
trainable
=
True
))))
self
.
base_block
.
append
(
self
.
add_sublayer
(
"F_{}_base_block_1"
.
format
(
i
),
nn
.
Conv2D
(
in_channels
=
out_ch_list
[
i
],
out_channels
=
out_ch_list
[
i
],
kernel_size
=
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
trainable
=
True
),
bias_attr
=
ParamAttr
(
trainable
=
True
))))
self
.
base_block
.
append
(
self
.
add_sublayer
(
"F_{}_base_block_2"
.
format
(
i
),
nn
.
BatchNorm
(
num_channels
=
out_ch_list
[
i
],
act
=
"relu"
,
param_attr
=
ParamAttr
(
trainable
=
True
),
bias_attr
=
ParamAttr
(
trainable
=
True
))))
self
.
base_block
.
append
(
self
.
add_sublayer
(
"F_{}_base_block_3"
.
format
(
i
),
nn
.
Conv2D
(
in_channels
=
out_ch_list
[
i
],
out_channels
=
512
,
kernel_size
=
1
,
bias_attr
=
ParamAttr
(
trainable
=
True
),
weight_attr
=
ParamAttr
(
trainable
=
True
))))
self
.
out_channels
=
512
def
__call__
(
self
,
x
):
x
=
self
.
conv
(
x
)
fpn_list
=
[]
F
=
[]
for
i
in
range
(
len
(
self
.
depth
)):
fpn_list
.
append
(
np
.
sum
(
self
.
depth
[:
i
+
1
]))
for
i
,
block
in
enumerate
(
self
.
block_list
):
x
=
block
(
x
)
for
number
in
fpn_list
:
if
i
+
1
==
number
:
F
.
append
(
x
)
base
=
F
[
-
1
]
j
=
0
for
i
,
block
in
enumerate
(
self
.
base_block
):
if
i
%
3
==
0
and
i
<
6
:
j
=
j
+
1
b
,
c
,
w
,
h
=
F
[
-
j
-
1
].
shape
if
[
w
,
h
]
==
list
(
base
.
shape
[
2
:]):
base
=
base
else
:
base
=
self
.
conv_trans
[
j
-
1
](
base
)
base
=
self
.
bn_block
[
j
-
1
](
base
)
base
=
paddle
.
concat
([
base
,
F
[
-
j
-
1
]],
axis
=
1
)
base
=
block
(
base
)
return
base
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
groups
=
1
,
act
=
None
,
name
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
2
if
stride
==
(
1
,
1
)
else
kernel_size
,
dilation
=
2
if
stride
==
(
1
,
1
)
else
1
,
stride
=
stride
,
padding
=
(
kernel_size
-
1
)
//
2
,
groups
=
groups
,
weight_attr
=
ParamAttr
(
name
=
name
+
'.conv2d.output.1.w_0'
),
bias_attr
=
False
,
)
if
name
==
"conv1"
:
bn_name
=
"bn_"
+
name
else
:
bn_name
=
"bn"
+
name
[
3
:]
self
.
bn
=
nn
.
BatchNorm
(
num_channels
=
out_channels
,
act
=
act
,
param_attr
=
ParamAttr
(
name
=
name
+
'.output.1.w_0'
),
bias_attr
=
ParamAttr
(
name
=
name
+
'.output.1.b_0'
),
moving_mean_name
=
bn_name
+
"_mean"
,
moving_variance_name
=
bn_name
+
"_variance"
)
def
__call__
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
return
x
class
ShortCut
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
,
name
,
is_first
=
False
):
super
(
ShortCut
,
self
).
__init__
()
self
.
use_conv
=
True
if
in_channels
!=
out_channels
or
stride
!=
1
or
is_first
==
True
:
if
stride
==
(
1
,
1
):
self
.
conv
=
ConvBNLayer
(
in_channels
,
out_channels
,
1
,
1
,
name
=
name
)
else
:
# stride==(2,2)
self
.
conv
=
ConvBNLayer
(
in_channels
,
out_channels
,
1
,
stride
,
name
=
name
)
else
:
self
.
use_conv
=
False
def
forward
(
self
,
x
):
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
BottleneckBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
,
name
):
super
(
BottleneckBlock
,
self
).
__init__
()
self
.
conv0
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
act
=
'relu'
,
name
=
name
+
"_branch2a"
)
self
.
conv1
=
ConvBNLayer
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
stride
=
stride
,
act
=
'relu'
,
name
=
name
+
"_branch2b"
)
self
.
conv2
=
ConvBNLayer
(
in_channels
=
out_channels
,
out_channels
=
out_channels
*
4
,
kernel_size
=
1
,
act
=
None
,
name
=
name
+
"_branch2c"
)
self
.
short
=
ShortCut
(
in_channels
=
in_channels
,
out_channels
=
out_channels
*
4
,
stride
=
stride
,
is_first
=
False
,
name
=
name
+
"_branch1"
)
self
.
out_channels
=
out_channels
*
4
def
forward
(
self
,
x
):
y
=
self
.
conv0
(
x
)
y
=
self
.
conv1
(
y
)
y
=
self
.
conv2
(
y
)
y
=
y
+
self
.
short
(
x
)
y
=
F
.
relu
(
y
)
return
y
class
BasicBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
,
name
,
is_first
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
conv0
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
act
=
'relu'
,
stride
=
stride
,
name
=
name
+
"_branch2a"
)
self
.
conv1
=
ConvBNLayer
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
act
=
None
,
name
=
name
+
"_branch2b"
)
self
.
short
=
ShortCut
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
stride
=
stride
,
is_first
=
is_first
,
name
=
name
+
"_branch1"
)
self
.
out_channels
=
out_channels
def
forward
(
self
,
x
):
y
=
self
.
conv0
(
x
)
y
=
self
.
conv1
(
y
)
y
=
y
+
self
.
short
(
x
)
return
F
.
relu
(
y
)
ppocr/modeling/heads/__init__.py
浏览文件 @
c1fd4664
...
@@ -23,10 +23,13 @@ def build_head(config):
...
@@ -23,10 +23,13 @@ def build_head(config):
# rec head
# rec head
from
.rec_ctc_head
import
CTCHead
from
.rec_ctc_head
import
CTCHead
from
.rec_srn_head
import
SRNHead
# cls head
# cls head
from
.cls_head
import
ClsHead
from
.cls_head
import
ClsHead
support_dict
=
[
'DBHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'ClsHead'
]
support_dict
=
[
'DBHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'ClsHead'
,
'SRNHead'
]
module_name
=
config
.
pop
(
'name'
)
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
'head only support {}'
.
format
(
assert
module_name
in
support_dict
,
Exception
(
'head only support {}'
.
format
(
...
...
ppocr/modeling/heads/rec_srn_head.py
0 → 100644
浏览文件 @
c1fd4664
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
from
paddle
import
nn
,
ParamAttr
from
paddle.nn
import
functional
as
F
import
paddle.fluid
as
fluid
import
numpy
as
np
from
.self_attention
import
WrapEncoderForFeature
from
.self_attention
import
WrapEncoder
from
paddle.static
import
Program
from
ppocr.modeling.backbones.rec_resnet_fpn
import
ResNetFPN
import
paddle.fluid.framework
as
framework
from
collections
import
OrderedDict
gradient_clip
=
10
class
PVAM
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
char_num
,
max_text_length
,
num_heads
,
num_encoder_tus
,
hidden_dims
):
super
(
PVAM
,
self
).
__init__
()
self
.
char_num
=
char_num
self
.
max_length
=
max_text_length
self
.
num_heads
=
num_heads
self
.
num_encoder_TUs
=
num_encoder_tus
self
.
hidden_dims
=
hidden_dims
# Transformer encoder
t
=
256
c
=
512
self
.
wrap_encoder_for_feature
=
WrapEncoderForFeature
(
src_vocab_size
=
1
,
max_length
=
t
,
n_layer
=
self
.
num_encoder_TUs
,
n_head
=
self
.
num_heads
,
d_key
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_value
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_model
=
self
.
hidden_dims
,
d_inner_hid
=
self
.
hidden_dims
,
prepostprocess_dropout
=
0.1
,
attention_dropout
=
0.1
,
relu_dropout
=
0.1
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
,
weight_sharing
=
True
)
# PVAM
self
.
flatten0
=
paddle
.
nn
.
Flatten
(
start_axis
=
0
,
stop_axis
=
1
)
self
.
fc0
=
paddle
.
nn
.
Linear
(
in_features
=
in_channels
,
out_features
=
in_channels
,
)
self
.
emb
=
paddle
.
nn
.
Embedding
(
num_embeddings
=
self
.
max_length
,
embedding_dim
=
in_channels
)
self
.
flatten1
=
paddle
.
nn
.
Flatten
(
start_axis
=
0
,
stop_axis
=
2
)
self
.
fc1
=
paddle
.
nn
.
Linear
(
in_features
=
in_channels
,
out_features
=
1
,
bias_attr
=
False
)
def
forward
(
self
,
inputs
,
encoder_word_pos
,
gsrm_word_pos
):
b
,
c
,
h
,
w
=
inputs
.
shape
conv_features
=
paddle
.
reshape
(
inputs
,
shape
=
[
-
1
,
c
,
h
*
w
])
conv_features
=
paddle
.
transpose
(
conv_features
,
perm
=
[
0
,
2
,
1
])
# transformer encoder
b
,
t
,
c
=
conv_features
.
shape
enc_inputs
=
[
conv_features
,
encoder_word_pos
,
None
]
word_features
=
self
.
wrap_encoder_for_feature
(
enc_inputs
)
# pvam
b
,
t
,
c
=
word_features
.
shape
word_features
=
self
.
fc0
(
word_features
)
word_features_
=
paddle
.
reshape
(
word_features
,
[
-
1
,
1
,
t
,
c
])
word_features_
=
paddle
.
tile
(
word_features_
,
[
1
,
self
.
max_length
,
1
,
1
])
word_pos_feature
=
self
.
emb
(
gsrm_word_pos
)
word_pos_feature_
=
paddle
.
reshape
(
word_pos_feature
,
[
-
1
,
self
.
max_length
,
1
,
c
])
word_pos_feature_
=
paddle
.
tile
(
word_pos_feature_
,
[
1
,
1
,
t
,
1
])
y
=
word_pos_feature_
+
word_features_
y
=
F
.
tanh
(
y
)
attention_weight
=
self
.
fc1
(
y
)
attention_weight
=
paddle
.
reshape
(
attention_weight
,
shape
=
[
-
1
,
self
.
max_length
,
t
])
attention_weight
=
F
.
softmax
(
attention_weight
,
axis
=-
1
)
pvam_features
=
paddle
.
matmul
(
attention_weight
,
word_features
)
#[b, max_length, c]
return
pvam_features
class
GSRM
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
char_num
,
max_text_length
,
num_heads
,
num_encoder_tus
,
num_decoder_tus
,
hidden_dims
):
super
(
GSRM
,
self
).
__init__
()
self
.
char_num
=
char_num
self
.
max_length
=
max_text_length
self
.
num_heads
=
num_heads
self
.
num_encoder_TUs
=
num_encoder_tus
self
.
num_decoder_TUs
=
num_decoder_tus
self
.
hidden_dims
=
hidden_dims
self
.
fc0
=
paddle
.
nn
.
Linear
(
in_features
=
in_channels
,
out_features
=
self
.
char_num
)
self
.
wrap_encoder0
=
WrapEncoder
(
src_vocab_size
=
self
.
char_num
+
1
,
max_length
=
self
.
max_length
,
n_layer
=
self
.
num_decoder_TUs
,
n_head
=
self
.
num_heads
,
d_key
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_value
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_model
=
self
.
hidden_dims
,
d_inner_hid
=
self
.
hidden_dims
,
prepostprocess_dropout
=
0.1
,
attention_dropout
=
0.1
,
relu_dropout
=
0.1
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
,
weight_sharing
=
True
)
self
.
wrap_encoder1
=
WrapEncoder
(
src_vocab_size
=
self
.
char_num
+
1
,
max_length
=
self
.
max_length
,
n_layer
=
self
.
num_decoder_TUs
,
n_head
=
self
.
num_heads
,
d_key
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_value
=
int
(
self
.
hidden_dims
/
self
.
num_heads
),
d_model
=
self
.
hidden_dims
,
d_inner_hid
=
self
.
hidden_dims
,
prepostprocess_dropout
=
0.1
,
attention_dropout
=
0.1
,
relu_dropout
=
0.1
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
,
weight_sharing
=
True
)
self
.
mul
=
lambda
x
:
paddle
.
matmul
(
x
=
x
,
y
=
self
.
wrap_encoder0
.
prepare_decoder
.
emb0
.
weight
,
transpose_y
=
True
)
def
forward
(
self
,
inputs
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
):
# ===== GSRM Visual-to-semantic embedding block =====
b
,
t
,
c
=
inputs
.
shape
pvam_features
=
paddle
.
reshape
(
inputs
,
[
-
1
,
c
])
word_out
=
self
.
fc0
(
pvam_features
)
word_ids
=
paddle
.
argmax
(
F
.
softmax
(
word_out
),
axis
=
1
)
word_ids
=
paddle
.
reshape
(
x
=
word_ids
,
shape
=
[
-
1
,
t
,
1
])
#===== GSRM Semantic reasoning block =====
"""
This module is achieved through bi-transformers,
ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
"""
pad_idx
=
self
.
char_num
word1
=
paddle
.
cast
(
word_ids
,
"float32"
)
word1
=
F
.
pad
(
word1
,
[
1
,
0
],
value
=
1.0
*
pad_idx
,
data_format
=
"NLC"
)
word1
=
paddle
.
cast
(
word1
,
"int64"
)
word1
=
word1
[:,
:
-
1
,
:]
word2
=
word_ids
enc_inputs_1
=
[
word1
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
]
enc_inputs_2
=
[
word2
,
gsrm_word_pos
,
gsrm_slf_attn_bias2
]
gsrm_feature1
=
self
.
wrap_encoder0
(
enc_inputs_1
)
gsrm_feature2
=
self
.
wrap_encoder1
(
enc_inputs_2
)
gsrm_feature2
=
F
.
pad
(
gsrm_feature2
,
[
0
,
1
],
value
=
0.
,
data_format
=
"NLC"
)
gsrm_feature2
=
gsrm_feature2
[:,
1
:,
]
gsrm_features
=
gsrm_feature1
+
gsrm_feature2
gsrm_out
=
self
.
mul
(
gsrm_features
)
b
,
t
,
c
=
gsrm_out
.
shape
gsrm_out
=
paddle
.
reshape
(
gsrm_out
,
[
-
1
,
c
])
return
gsrm_features
,
word_out
,
gsrm_out
class
VSFD
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
512
,
pvam_ch
=
512
,
char_num
=
38
):
super
(
VSFD
,
self
).
__init__
()
self
.
char_num
=
char_num
self
.
fc0
=
paddle
.
nn
.
Linear
(
in_features
=
in_channels
*
2
,
out_features
=
pvam_ch
)
self
.
fc1
=
paddle
.
nn
.
Linear
(
in_features
=
pvam_ch
,
out_features
=
self
.
char_num
)
def
forward
(
self
,
pvam_feature
,
gsrm_feature
):
b
,
t
,
c1
=
pvam_feature
.
shape
b
,
t
,
c2
=
gsrm_feature
.
shape
combine_feature_
=
paddle
.
concat
([
pvam_feature
,
gsrm_feature
],
axis
=
2
)
img_comb_feature_
=
paddle
.
reshape
(
combine_feature_
,
shape
=
[
-
1
,
c1
+
c2
])
img_comb_feature_map
=
self
.
fc0
(
img_comb_feature_
)
img_comb_feature_map
=
F
.
sigmoid
(
img_comb_feature_map
)
img_comb_feature_map
=
paddle
.
reshape
(
img_comb_feature_map
,
shape
=
[
-
1
,
t
,
c1
])
combine_feature
=
img_comb_feature_map
*
pvam_feature
+
(
1.0
-
img_comb_feature_map
)
*
gsrm_feature
img_comb_feature
=
paddle
.
reshape
(
combine_feature
,
shape
=
[
-
1
,
c1
])
out
=
self
.
fc1
(
img_comb_feature
)
return
out
class
SRNHead
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
max_text_length
,
num_heads
,
num_encoder_TUs
,
num_decoder_TUs
,
hidden_dims
,
**
kwargs
):
super
(
SRNHead
,
self
).
__init__
()
self
.
char_num
=
out_channels
self
.
max_length
=
max_text_length
self
.
num_heads
=
num_heads
self
.
num_encoder_TUs
=
num_encoder_TUs
self
.
num_decoder_TUs
=
num_decoder_TUs
self
.
hidden_dims
=
hidden_dims
self
.
pvam
=
PVAM
(
in_channels
=
in_channels
,
char_num
=
self
.
char_num
,
max_text_length
=
self
.
max_length
,
num_heads
=
self
.
num_heads
,
num_encoder_tus
=
self
.
num_encoder_TUs
,
hidden_dims
=
self
.
hidden_dims
)
self
.
gsrm
=
GSRM
(
in_channels
=
in_channels
,
char_num
=
self
.
char_num
,
max_text_length
=
self
.
max_length
,
num_heads
=
self
.
num_heads
,
num_encoder_tus
=
self
.
num_encoder_TUs
,
num_decoder_tus
=
self
.
num_decoder_TUs
,
hidden_dims
=
self
.
hidden_dims
)
self
.
vsfd
=
VSFD
(
in_channels
=
in_channels
)
self
.
gsrm
.
wrap_encoder1
.
prepare_decoder
.
emb0
=
self
.
gsrm
.
wrap_encoder0
.
prepare_decoder
.
emb0
def
forward
(
self
,
inputs
,
others
):
encoder_word_pos
=
others
[
0
]
gsrm_word_pos
=
others
[
1
]
gsrm_slf_attn_bias1
=
others
[
2
]
gsrm_slf_attn_bias2
=
others
[
3
]
pvam_feature
=
self
.
pvam
(
inputs
,
encoder_word_pos
,
gsrm_word_pos
)
gsrm_feature
,
word_out
,
gsrm_out
=
self
.
gsrm
(
pvam_feature
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
)
final_out
=
self
.
vsfd
(
pvam_feature
,
gsrm_feature
)
if
not
self
.
training
:
final_out
=
F
.
softmax
(
final_out
,
axis
=
1
)
_
,
decoded_out
=
paddle
.
topk
(
final_out
,
k
=
1
)
predicts
=
OrderedDict
([
(
'predict'
,
final_out
),
(
'pvam_feature'
,
pvam_feature
),
(
'decoded_out'
,
decoded_out
),
(
'word_out'
,
word_out
),
(
'gsrm_out'
,
gsrm_out
),
])
return
predicts
ppocr/modeling/heads/self_attention.py
0 → 100644
浏览文件 @
c1fd4664
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
from
paddle
import
ParamAttr
,
nn
from
paddle
import
nn
,
ParamAttr
from
paddle.nn
import
functional
as
F
import
paddle.fluid
as
fluid
import
numpy
as
np
gradient_clip
=
10
class
WrapEncoderForFeature
(
nn
.
Layer
):
def
__init__
(
self
,
src_vocab_size
,
max_length
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
bos_idx
=
0
):
super
(
WrapEncoderForFeature
,
self
).
__init__
()
self
.
prepare_encoder
=
PrepareEncoder
(
src_vocab_size
,
d_model
,
max_length
,
prepostprocess_dropout
,
bos_idx
=
bos_idx
,
word_emb_param_name
=
"src_word_emb_table"
)
self
.
encoder
=
Encoder
(
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
)
def
forward
(
self
,
enc_inputs
):
conv_features
,
src_pos
,
src_slf_attn_bias
=
enc_inputs
enc_input
=
self
.
prepare_encoder
(
conv_features
,
src_pos
)
enc_output
=
self
.
encoder
(
enc_input
,
src_slf_attn_bias
)
return
enc_output
class
WrapEncoder
(
nn
.
Layer
):
"""
embedder + encoder
"""
def
__init__
(
self
,
src_vocab_size
,
max_length
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
bos_idx
=
0
):
super
(
WrapEncoder
,
self
).
__init__
()
self
.
prepare_decoder
=
PrepareDecoder
(
src_vocab_size
,
d_model
,
max_length
,
prepostprocess_dropout
,
bos_idx
=
bos_idx
)
self
.
encoder
=
Encoder
(
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
)
def
forward
(
self
,
enc_inputs
):
src_word
,
src_pos
,
src_slf_attn_bias
=
enc_inputs
enc_input
=
self
.
prepare_decoder
(
src_word
,
src_pos
)
enc_output
=
self
.
encoder
(
enc_input
,
src_slf_attn_bias
)
return
enc_output
class
Encoder
(
nn
.
Layer
):
"""
encoder
"""
def
__init__
(
self
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
):
super
(
Encoder
,
self
).
__init__
()
self
.
encoder_layers
=
list
()
for
i
in
range
(
n_layer
):
self
.
encoder_layers
.
append
(
self
.
add_sublayer
(
"layer_%d"
%
i
,
EncoderLayer
(
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
)))
self
.
processer
=
PrePostProcessLayer
(
preprocess_cmd
,
d_model
,
prepostprocess_dropout
)
def
forward
(
self
,
enc_input
,
attn_bias
):
for
encoder_layer
in
self
.
encoder_layers
:
enc_output
=
encoder_layer
(
enc_input
,
attn_bias
)
enc_input
=
enc_output
enc_output
=
self
.
processer
(
enc_output
)
return
enc_output
class
EncoderLayer
(
nn
.
Layer
):
"""
EncoderLayer
"""
def
__init__
(
self
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
):
super
(
EncoderLayer
,
self
).
__init__
()
self
.
preprocesser1
=
PrePostProcessLayer
(
preprocess_cmd
,
d_model
,
prepostprocess_dropout
)
self
.
self_attn
=
MultiHeadAttention
(
d_key
,
d_value
,
d_model
,
n_head
,
attention_dropout
)
self
.
postprocesser1
=
PrePostProcessLayer
(
postprocess_cmd
,
d_model
,
prepostprocess_dropout
)
self
.
preprocesser2
=
PrePostProcessLayer
(
preprocess_cmd
,
d_model
,
prepostprocess_dropout
)
self
.
ffn
=
FFN
(
d_inner_hid
,
d_model
,
relu_dropout
)
self
.
postprocesser2
=
PrePostProcessLayer
(
postprocess_cmd
,
d_model
,
prepostprocess_dropout
)
def
forward
(
self
,
enc_input
,
attn_bias
):
attn_output
=
self
.
self_attn
(
self
.
preprocesser1
(
enc_input
),
None
,
None
,
attn_bias
)
attn_output
=
self
.
postprocesser1
(
attn_output
,
enc_input
)
ffn_output
=
self
.
ffn
(
self
.
preprocesser2
(
attn_output
))
ffn_output
=
self
.
postprocesser2
(
ffn_output
,
attn_output
)
return
ffn_output
class
MultiHeadAttention
(
nn
.
Layer
):
"""
Multi-Head Attention
"""
def
__init__
(
self
,
d_key
,
d_value
,
d_model
,
n_head
=
1
,
dropout_rate
=
0.
):
super
(
MultiHeadAttention
,
self
).
__init__
()
self
.
n_head
=
n_head
self
.
d_key
=
d_key
self
.
d_value
=
d_value
self
.
d_model
=
d_model
self
.
dropout_rate
=
dropout_rate
self
.
q_fc
=
paddle
.
nn
.
Linear
(
in_features
=
d_model
,
out_features
=
d_key
*
n_head
,
bias_attr
=
False
)
self
.
k_fc
=
paddle
.
nn
.
Linear
(
in_features
=
d_model
,
out_features
=
d_key
*
n_head
,
bias_attr
=
False
)
self
.
v_fc
=
paddle
.
nn
.
Linear
(
in_features
=
d_model
,
out_features
=
d_value
*
n_head
,
bias_attr
=
False
)
self
.
proj_fc
=
paddle
.
nn
.
Linear
(
in_features
=
d_value
*
n_head
,
out_features
=
d_model
,
bias_attr
=
False
)
def
_prepare_qkv
(
self
,
queries
,
keys
,
values
,
cache
=
None
):
if
keys
is
None
:
# self-attention
keys
,
values
=
queries
,
queries
static_kv
=
False
else
:
# cross-attention
static_kv
=
True
q
=
self
.
q_fc
(
queries
)
q
=
paddle
.
reshape
(
x
=
q
,
shape
=
[
0
,
0
,
self
.
n_head
,
self
.
d_key
])
q
=
paddle
.
transpose
(
x
=
q
,
perm
=
[
0
,
2
,
1
,
3
])
if
cache
is
not
None
and
static_kv
and
"static_k"
in
cache
:
# for encoder-decoder attention in inference and has cached
k
=
cache
[
"static_k"
]
v
=
cache
[
"static_v"
]
else
:
k
=
self
.
k_fc
(
keys
)
v
=
self
.
v_fc
(
values
)
k
=
paddle
.
reshape
(
x
=
k
,
shape
=
[
0
,
0
,
self
.
n_head
,
self
.
d_key
])
k
=
paddle
.
transpose
(
x
=
k
,
perm
=
[
0
,
2
,
1
,
3
])
v
=
paddle
.
reshape
(
x
=
v
,
shape
=
[
0
,
0
,
self
.
n_head
,
self
.
d_value
])
v
=
paddle
.
transpose
(
x
=
v
,
perm
=
[
0
,
2
,
1
,
3
])
if
cache
is
not
None
:
if
static_kv
and
not
"static_k"
in
cache
:
# for encoder-decoder attention in inference and has not cached
cache
[
"static_k"
],
cache
[
"static_v"
]
=
k
,
v
elif
not
static_kv
:
# for decoder self-attention in inference
cache_k
,
cache_v
=
cache
[
"k"
],
cache
[
"v"
]
k
=
paddle
.
concat
([
cache_k
,
k
],
axis
=
2
)
v
=
paddle
.
concat
([
cache_v
,
v
],
axis
=
2
)
cache
[
"k"
],
cache
[
"v"
]
=
k
,
v
return
q
,
k
,
v
def
forward
(
self
,
queries
,
keys
,
values
,
attn_bias
,
cache
=
None
):
# compute q ,k ,v
keys
=
queries
if
keys
is
None
else
keys
values
=
keys
if
values
is
None
else
values
q
,
k
,
v
=
self
.
_prepare_qkv
(
queries
,
keys
,
values
,
cache
)
# scale dot product attention
product
=
paddle
.
matmul
(
x
=
q
,
y
=
k
,
transpose_y
=
True
)
product
=
product
*
self
.
d_model
**-
0.5
if
attn_bias
is
not
None
:
product
+=
attn_bias
weights
=
F
.
softmax
(
product
)
if
self
.
dropout_rate
:
weights
=
F
.
dropout
(
weights
,
p
=
self
.
dropout_rate
,
mode
=
"downscale_in_infer"
)
out
=
paddle
.
matmul
(
weights
,
v
)
# combine heads
out
=
paddle
.
transpose
(
out
,
perm
=
[
0
,
2
,
1
,
3
])
out
=
paddle
.
reshape
(
x
=
out
,
shape
=
[
0
,
0
,
out
.
shape
[
2
]
*
out
.
shape
[
3
]])
# project to output
out
=
self
.
proj_fc
(
out
)
return
out
class
PrePostProcessLayer
(
nn
.
Layer
):
"""
PrePostProcessLayer
"""
def
__init__
(
self
,
process_cmd
,
d_model
,
dropout_rate
):
super
(
PrePostProcessLayer
,
self
).
__init__
()
self
.
process_cmd
=
process_cmd
self
.
functors
=
[]
for
cmd
in
self
.
process_cmd
:
if
cmd
==
"a"
:
# add residual connection
self
.
functors
.
append
(
lambda
x
,
y
:
x
+
y
if
y
is
not
None
else
x
)
elif
cmd
==
"n"
:
# add layer normalization
self
.
functors
.
append
(
self
.
add_sublayer
(
"layer_norm_%d"
%
len
(
self
.
sublayers
(
include_sublayers
=
False
)),
paddle
.
nn
.
LayerNorm
(
normalized_shape
=
d_model
,
weight_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
1.
)),
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
0.
)))))
elif
cmd
==
"d"
:
# add dropout
self
.
functors
.
append
(
lambda
x
:
F
.
dropout
(
x
,
p
=
dropout_rate
,
mode
=
"downscale_in_infer"
)
if
dropout_rate
else
x
)
def
forward
(
self
,
x
,
residual
=
None
):
for
i
,
cmd
in
enumerate
(
self
.
process_cmd
):
if
cmd
==
"a"
:
x
=
self
.
functors
[
i
](
x
,
residual
)
else
:
x
=
self
.
functors
[
i
](
x
)
return
x
class
PrepareEncoder
(
nn
.
Layer
):
def
__init__
(
self
,
src_vocab_size
,
src_emb_dim
,
src_max_len
,
dropout_rate
=
0
,
bos_idx
=
0
,
word_emb_param_name
=
None
,
pos_enc_param_name
=
None
):
super
(
PrepareEncoder
,
self
).
__init__
()
self
.
src_emb_dim
=
src_emb_dim
self
.
src_max_len
=
src_max_len
self
.
emb
=
paddle
.
nn
.
Embedding
(
num_embeddings
=
self
.
src_max_len
,
embedding_dim
=
self
.
src_emb_dim
,
sparse
=
True
)
self
.
dropout_rate
=
dropout_rate
def
forward
(
self
,
src_word
,
src_pos
):
src_word_emb
=
src_word
src_word_emb
=
fluid
.
layers
.
cast
(
src_word_emb
,
'float32'
)
src_word_emb
=
paddle
.
scale
(
x
=
src_word_emb
,
scale
=
self
.
src_emb_dim
**
0.5
)
src_pos
=
paddle
.
squeeze
(
src_pos
,
axis
=-
1
)
src_pos_enc
=
self
.
emb
(
src_pos
)
src_pos_enc
.
stop_gradient
=
True
enc_input
=
src_word_emb
+
src_pos_enc
if
self
.
dropout_rate
:
out
=
F
.
dropout
(
x
=
enc_input
,
p
=
self
.
dropout_rate
,
mode
=
"downscale_in_infer"
)
else
:
out
=
enc_input
return
out
class
PrepareDecoder
(
nn
.
Layer
):
def
__init__
(
self
,
src_vocab_size
,
src_emb_dim
,
src_max_len
,
dropout_rate
=
0
,
bos_idx
=
0
,
word_emb_param_name
=
None
,
pos_enc_param_name
=
None
):
super
(
PrepareDecoder
,
self
).
__init__
()
self
.
src_emb_dim
=
src_emb_dim
"""
self.emb0 = Embedding(num_embeddings=src_vocab_size,
embedding_dim=src_emb_dim)
"""
self
.
emb0
=
paddle
.
nn
.
Embedding
(
num_embeddings
=
src_vocab_size
,
embedding_dim
=
self
.
src_emb_dim
,
weight_attr
=
paddle
.
ParamAttr
(
name
=
word_emb_param_name
,
initializer
=
nn
.
initializer
.
Normal
(
0.
,
src_emb_dim
**-
0.5
)))
self
.
emb1
=
paddle
.
nn
.
Embedding
(
num_embeddings
=
src_max_len
,
embedding_dim
=
self
.
src_emb_dim
,
weight_attr
=
paddle
.
ParamAttr
(
name
=
pos_enc_param_name
))
self
.
dropout_rate
=
dropout_rate
def
forward
(
self
,
src_word
,
src_pos
):
src_word
=
fluid
.
layers
.
cast
(
src_word
,
'int64'
)
src_word
=
paddle
.
squeeze
(
src_word
,
axis
=-
1
)
src_word_emb
=
self
.
emb0
(
src_word
)
src_word_emb
=
paddle
.
scale
(
x
=
src_word_emb
,
scale
=
self
.
src_emb_dim
**
0.5
)
src_pos
=
paddle
.
squeeze
(
src_pos
,
axis
=-
1
)
src_pos_enc
=
self
.
emb1
(
src_pos
)
src_pos_enc
.
stop_gradient
=
True
enc_input
=
src_word_emb
+
src_pos_enc
if
self
.
dropout_rate
:
out
=
F
.
dropout
(
x
=
enc_input
,
p
=
self
.
dropout_rate
,
mode
=
"downscale_in_infer"
)
else
:
out
=
enc_input
return
out
class
FFN
(
nn
.
Layer
):
"""
Feed-Forward Network
"""
def
__init__
(
self
,
d_inner_hid
,
d_model
,
dropout_rate
):
super
(
FFN
,
self
).
__init__
()
self
.
dropout_rate
=
dropout_rate
self
.
fc1
=
paddle
.
nn
.
Linear
(
in_features
=
d_model
,
out_features
=
d_inner_hid
)
self
.
fc2
=
paddle
.
nn
.
Linear
(
in_features
=
d_inner_hid
,
out_features
=
d_model
)
def
forward
(
self
,
x
):
hidden
=
self
.
fc1
(
x
)
hidden
=
F
.
relu
(
hidden
)
if
self
.
dropout_rate
:
hidden
=
F
.
dropout
(
hidden
,
p
=
self
.
dropout_rate
,
mode
=
"downscale_in_infer"
)
out
=
self
.
fc2
(
hidden
)
return
out
ppocr/postprocess/__init__.py
浏览文件 @
c1fd4664
...
@@ -26,11 +26,12 @@ def build_post_process(config, global_config=None):
...
@@ -26,11 +26,12 @@ def build_post_process(config, global_config=None):
from
.db_postprocess
import
DBPostProcess
from
.db_postprocess
import
DBPostProcess
from
.east_postprocess
import
EASTPostProcess
from
.east_postprocess
import
EASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.cls_postprocess
import
ClsPostProcess
support_dict
=
[
support_dict
=
[
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
]
]
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
c1fd4664
...
@@ -29,6 +29,9 @@ class BaseRecLabelDecode(object):
...
@@ -29,6 +29,9 @@ class BaseRecLabelDecode(object):
assert
character_type
in
support_character_type
,
"Only {} are supported now but get {}"
.
format
(
assert
character_type
in
support_character_type
,
"Only {} are supported now but get {}"
.
format
(
support_character_type
,
character_type
)
support_character_type
,
character_type
)
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
if
character_type
==
"en"
:
if
character_type
==
"en"
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
dict_character
=
list
(
self
.
character_str
)
dict_character
=
list
(
self
.
character_str
)
...
@@ -104,7 +107,6 @@ class CTCLabelDecode(BaseRecLabelDecode):
...
@@ -104,7 +107,6 @@ class CTCLabelDecode(BaseRecLabelDecode):
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
isinstance
(
preds
,
paddle
.
Tensor
):
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
)
...
@@ -153,3 +155,83 @@ class AttnLabelDecode(BaseRecLabelDecode):
...
@@ -153,3 +155,83 @@ class AttnLabelDecode(BaseRecLabelDecode):
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
\
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
\
%
beg_or_end
%
beg_or_end
return
idx
return
idx
class
SRNLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
character_type
=
'en'
,
use_space_char
=
False
,
**
kwargs
):
super
(
SRNLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
pred
=
preds
[
'predict'
]
char_num
=
len
(
self
.
character_str
)
+
2
if
isinstance
(
pred
,
paddle
.
Tensor
):
pred
=
pred
.
numpy
()
pred
=
np
.
reshape
(
pred
,
[
-
1
,
char_num
])
preds_idx
=
np
.
argmax
(
pred
,
axis
=
1
)
preds_prob
=
np
.
max
(
pred
,
axis
=
1
)
preds_idx
=
np
.
reshape
(
preds_idx
,
[
-
1
,
25
])
preds_prob
=
np
.
reshape
(
preds_prob
,
[
-
1
,
25
])
text
=
self
.
decode
(
preds_idx
,
preds_prob
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
return
text
,
label
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
True
):
""" convert text-index into text-label. """
result_list
=
[]
ignored_tokens
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
in
ignored_tokens
:
continue
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
)))
return
result_list
def
add_special_char
(
self
,
dict_character
):
dict_character
=
dict_character
+
[
self
.
beg_str
,
self
.
end_str
]
return
dict_character
def
get_ignored_tokens
(
self
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"beg"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"end"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
\
%
beg_or_end
return
idx
tools/export_model.py
浏览文件 @
c1fd4664
...
@@ -31,6 +31,14 @@ from ppocr.utils.logging import get_logger
...
@@ -31,6 +31,14 @@ from ppocr.utils.logging import get_logger
from
tools.program
import
load_config
,
merge_config
,
ArgsParser
from
tools.program
import
load_config
,
merge_config
,
ArgsParser
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-c"
,
"--config"
,
help
=
"configuration file to use"
)
parser
.
add_argument
(
"-o"
,
"--output_path"
,
type
=
str
,
default
=
'./output/infer/'
)
return
parser
.
parse_args
()
def
main
():
def
main
():
FLAGS
=
ArgsParser
().
parse_args
()
FLAGS
=
ArgsParser
().
parse_args
()
config
=
load_config
(
FLAGS
.
config
)
config
=
load_config
(
FLAGS
.
config
)
...
@@ -51,14 +59,33 @@ def main():
...
@@ -51,14 +59,33 @@ def main():
model
.
eval
()
model
.
eval
()
save_path
=
'{}/inference'
.
format
(
config
[
'Global'
][
'save_inference_dir'
])
save_path
=
'{}/inference'
.
format
(
config
[
'Global'
][
'save_inference_dir'
])
infer_shape
=
[
3
,
32
,
100
]
if
config
[
'Architecture'
][
'model_type'
]
!=
"det"
else
[
3
,
640
,
640
]
if
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
:
model
=
to_static
(
other_shape
=
[
model
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
paddle
.
static
.
InputSpec
(
shape
=
[
None
]
+
infer_shape
,
dtype
=
'float32'
)
shape
=
[
None
,
1
,
64
,
256
],
dtype
=
'float32'
),
[
])
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
256
,
1
],
dtype
=
"int64"
),
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
25
,
1
],
dtype
=
"int64"
),
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
8
,
25
,
25
],
dtype
=
"int64"
),
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
8
,
25
,
25
],
dtype
=
"int64"
)
]
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
else
:
infer_shape
=
[
3
,
32
,
100
]
if
config
[
'Architecture'
][
'model_type'
]
!=
"det"
else
[
3
,
640
,
640
]
model
=
to_static
(
model
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
]
+
infer_shape
,
dtype
=
'float32'
)
])
paddle
.
jit
.
save
(
model
,
save_path
)
paddle
.
jit
.
save
(
model
,
save_path
)
logger
.
info
(
'inference model is saved to {}'
.
format
(
save_path
))
logger
.
info
(
'inference model is saved to {}'
.
format
(
save_path
))
...
...
tools/infer/predict_rec.py
浏览文件 @
c1fd4664
...
@@ -25,6 +25,7 @@ import numpy as np
...
@@ -25,6 +25,7 @@ import numpy as np
import
math
import
math
import
time
import
time
import
traceback
import
traceback
import
paddle
import
tools.infer.utility
as
utility
import
tools.infer.utility
as
utility
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
...
@@ -46,6 +47,13 @@ class TextRecognizer(object):
...
@@ -46,6 +47,13 @@ class TextRecognizer(object):
"character_dict_path"
:
args
.
rec_char_dict_path
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
"use_space_char"
:
args
.
use_space_char
}
}
if
self
.
rec_algorithm
==
"SRN"
:
postprocess_params
=
{
'name'
:
'SRNLabelDecode'
,
"character_type"
:
args
.
rec_char_type
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
}
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
utility
.
create_predictor
(
args
,
'rec'
,
logger
)
utility
.
create_predictor
(
args
,
'rec'
,
logger
)
...
@@ -70,6 +78,78 @@ class TextRecognizer(object):
...
@@ -70,6 +78,78 @@ class TextRecognizer(object):
padding_im
[:,
:,
0
:
resized_w
]
=
resized_image
padding_im
[:,
:,
0
:
resized_w
]
=
resized_image
return
padding_im
return
padding_im
def
resize_norm_img_srn
(
self
,
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
img_black
=
np
.
zeros
((
imgH
,
imgW
))
im_hei
=
img
.
shape
[
0
]
im_wid
=
img
.
shape
[
1
]
if
im_wid
<=
im_hei
*
1
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
1
,
imgH
))
elif
im_wid
<=
im_hei
*
2
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
2
,
imgH
))
elif
im_wid
<=
im_hei
*
3
:
img_new
=
cv2
.
resize
(
img
,
(
imgH
*
3
,
imgH
))
else
:
img_new
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
))
img_np
=
np
.
asarray
(
img_new
)
img_np
=
cv2
.
cvtColor
(
img_np
,
cv2
.
COLOR_BGR2GRAY
)
img_black
[:,
0
:
img_np
.
shape
[
1
]]
=
img_np
img_black
=
img_black
[:,
:,
np
.
newaxis
]
row
,
col
,
c
=
img_black
.
shape
c
=
1
return
np
.
reshape
(
img_black
,
(
c
,
row
,
col
)).
astype
(
np
.
float32
)
def
srn_other_inputs
(
self
,
image_shape
,
num_heads
,
max_text_length
):
imgC
,
imgH
,
imgW
=
image_shape
feature_dim
=
int
((
imgH
/
8
)
*
(
imgW
/
8
))
encoder_word_pos
=
np
.
array
(
range
(
0
,
feature_dim
)).
reshape
(
(
feature_dim
,
1
)).
astype
(
'int64'
)
gsrm_word_pos
=
np
.
array
(
range
(
0
,
max_text_length
)).
reshape
(
(
max_text_length
,
1
)).
astype
(
'int64'
)
gsrm_attn_bias_data
=
np
.
ones
((
1
,
max_text_length
,
max_text_length
))
gsrm_slf_attn_bias1
=
np
.
triu
(
gsrm_attn_bias_data
,
1
).
reshape
(
[
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias1
=
np
.
tile
(
gsrm_slf_attn_bias1
,
[
1
,
num_heads
,
1
,
1
]).
astype
(
'float32'
)
*
[
-
1e9
]
gsrm_slf_attn_bias2
=
np
.
tril
(
gsrm_attn_bias_data
,
-
1
).
reshape
(
[
-
1
,
1
,
max_text_length
,
max_text_length
])
gsrm_slf_attn_bias2
=
np
.
tile
(
gsrm_slf_attn_bias2
,
[
1
,
num_heads
,
1
,
1
]).
astype
(
'float32'
)
*
[
-
1e9
]
encoder_word_pos
=
encoder_word_pos
[
np
.
newaxis
,
:]
gsrm_word_pos
=
gsrm_word_pos
[
np
.
newaxis
,
:]
return
[
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
def
process_image_srn
(
self
,
img
,
image_shape
,
num_heads
,
max_text_length
):
norm_img
=
self
.
resize_norm_img_srn
(
img
,
image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
[
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
]
=
\
self
.
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
)
gsrm_slf_attn_bias1
=
gsrm_slf_attn_bias1
.
astype
(
np
.
float32
)
gsrm_slf_attn_bias2
=
gsrm_slf_attn_bias2
.
astype
(
np
.
float32
)
encoder_word_pos
=
encoder_word_pos
.
astype
(
np
.
int64
)
gsrm_word_pos
=
gsrm_word_pos
.
astype
(
np
.
int64
)
return
(
norm_img
,
encoder_word_pos
,
gsrm_word_pos
,
gsrm_slf_attn_bias1
,
gsrm_slf_attn_bias2
)
def
__call__
(
self
,
img_list
):
def
__call__
(
self
,
img_list
):
img_num
=
len
(
img_list
)
img_num
=
len
(
img_list
)
# Calculate the aspect ratio of all text bars
# Calculate the aspect ratio of all text bars
...
@@ -93,21 +173,64 @@ class TextRecognizer(object):
...
@@ -93,21 +173,64 @@ class TextRecognizer(object):
wh_ratio
=
w
*
1.0
/
h
wh_ratio
=
w
*
1.0
/
h
max_wh_ratio
=
max
(
max_wh_ratio
,
wh_ratio
)
max_wh_ratio
=
max
(
max_wh_ratio
,
wh_ratio
)
for
ino
in
range
(
beg_img_no
,
end_img_no
):
for
ino
in
range
(
beg_img_no
,
end_img_no
):
# norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
if
self
.
rec_algorithm
!=
"SRN"
:
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]],
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]],
max_wh_ratio
)
max_wh_ratio
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
norm_img_batch
.
append
(
norm_img
)
else
:
norm_img
=
self
.
process_image_srn
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
,
8
,
25
)
encoder_word_pos_list
=
[]
gsrm_word_pos_list
=
[]
gsrm_slf_attn_bias1_list
=
[]
gsrm_slf_attn_bias2_list
=
[]
encoder_word_pos_list
.
append
(
norm_img
[
1
])
gsrm_word_pos_list
.
append
(
norm_img
[
2
])
gsrm_slf_attn_bias1_list
.
append
(
norm_img
[
3
])
gsrm_slf_attn_bias2_list
.
append
(
norm_img
[
4
])
norm_img_batch
.
append
(
norm_img
[
0
])
norm_img_batch
=
np
.
concatenate
(
norm_img_batch
)
norm_img_batch
=
np
.
concatenate
(
norm_img_batch
)
norm_img_batch
=
norm_img_batch
.
copy
()
norm_img_batch
=
norm_img_batch
.
copy
()
starttime
=
time
.
time
()
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
if
self
.
rec_algorithm
==
"SRN"
:
self
.
predictor
.
run
()
starttime
=
time
.
time
()
outputs
=
[]
encoder_word_pos_list
=
np
.
concatenate
(
encoder_word_pos_list
)
for
output_tensor
in
self
.
output_tensors
:
gsrm_word_pos_list
=
np
.
concatenate
(
gsrm_word_pos_list
)
output
=
output_tensor
.
copy_to_cpu
()
gsrm_slf_attn_bias1_list
=
np
.
concatenate
(
outputs
.
append
(
output
)
gsrm_slf_attn_bias1_list
)
preds
=
outputs
[
0
]
gsrm_slf_attn_bias2_list
=
np
.
concatenate
(
gsrm_slf_attn_bias2_list
)
inputs
=
[
norm_img_batch
,
encoder_word_pos_list
,
gsrm_word_pos_list
,
gsrm_slf_attn_bias1_list
,
gsrm_slf_attn_bias2_list
,
]
input_names
=
self
.
predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
input_tensor
=
self
.
predictor
.
get_input_handle
(
input_names
[
i
])
input_tensor
.
copy_from_cpu
(
inputs
[
i
])
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
preds
=
{
"predict"
:
outputs
[
2
]}
else
:
starttime
=
time
.
time
()
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
preds
=
outputs
[
0
]
rec_result
=
self
.
postprocess_op
(
preds
)
rec_result
=
self
.
postprocess_op
(
preds
)
for
rno
in
range
(
len
(
rec_result
)):
for
rno
in
range
(
len
(
rec_result
)):
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
rec_result
[
rno
]
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
rec_result
[
rno
]
...
...
tools/infer_rec.py
浏览文件 @
c1fd4664
...
@@ -62,7 +62,13 @@ def main():
...
@@ -62,7 +62,13 @@ def main():
elif
op_name
in
[
'RecResizeImg'
]:
elif
op_name
in
[
'RecResizeImg'
]:
op
[
op_name
][
'infer_mode'
]
=
True
op
[
op_name
][
'infer_mode'
]
=
True
elif
op_name
==
'KeepKeys'
:
elif
op_name
==
'KeepKeys'
:
op
[
op_name
][
'keep_keys'
]
=
[
'image'
]
if
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
:
op
[
op_name
][
'keep_keys'
]
=
[
'image'
,
'encoder_word_pos'
,
'gsrm_word_pos'
,
'gsrm_slf_attn_bias1'
,
'gsrm_slf_attn_bias2'
]
else
:
op
[
op_name
][
'keep_keys'
]
=
[
'image'
]
transforms
.
append
(
op
)
transforms
.
append
(
op
)
global_config
[
'infer_mode'
]
=
True
global_config
[
'infer_mode'
]
=
True
ops
=
create_operators
(
transforms
,
global_config
)
ops
=
create_operators
(
transforms
,
global_config
)
...
@@ -74,10 +80,25 @@ def main():
...
@@ -74,10 +80,25 @@ def main():
img
=
f
.
read
()
img
=
f
.
read
()
data
=
{
'image'
:
img
}
data
=
{
'image'
:
img
}
batch
=
transform
(
data
,
ops
)
batch
=
transform
(
data
,
ops
)
if
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
:
encoder_word_pos_list
=
np
.
expand_dims
(
batch
[
1
],
axis
=
0
)
gsrm_word_pos_list
=
np
.
expand_dims
(
batch
[
2
],
axis
=
0
)
gsrm_slf_attn_bias1_list
=
np
.
expand_dims
(
batch
[
3
],
axis
=
0
)
gsrm_slf_attn_bias2_list
=
np
.
expand_dims
(
batch
[
4
],
axis
=
0
)
others
=
[
paddle
.
to_tensor
(
encoder_word_pos_list
),
paddle
.
to_tensor
(
gsrm_word_pos_list
),
paddle
.
to_tensor
(
gsrm_slf_attn_bias1_list
),
paddle
.
to_tensor
(
gsrm_slf_attn_bias2_list
)
]
images
=
np
.
expand_dims
(
batch
[
0
],
axis
=
0
)
images
=
np
.
expand_dims
(
batch
[
0
],
axis
=
0
)
images
=
paddle
.
to_tensor
(
images
)
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
)
if
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
:
preds
=
model
(
images
,
others
)
else
:
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
)
post_result
=
post_process_class
(
preds
)
for
rec_reuslt
in
post_result
:
for
rec_reuslt
in
post_result
:
logger
.
info
(
'
\t
result: {}'
.
format
(
rec_reuslt
))
logger
.
info
(
'
\t
result: {}'
.
format
(
rec_reuslt
))
...
...
tools/program.py
浏览文件 @
c1fd4664
...
@@ -179,9 +179,9 @@ def train(config,
...
@@ -179,9 +179,9 @@ def train(config,
if
'start_epoch'
in
best_model_dict
:
if
'start_epoch'
in
best_model_dict
:
start_epoch
=
best_model_dict
[
'start_epoch'
]
start_epoch
=
best_model_dict
[
'start_epoch'
]
else
:
else
:
start_epoch
=
1
start_epoch
=
0
for
epoch
in
range
(
start_epoch
,
epoch_num
+
1
):
for
epoch
in
range
(
start_epoch
,
epoch_num
):
if
epoch
>
0
:
if
epoch
>
0
:
train_dataloader
=
build_dataloader
(
config
,
'Train'
,
device
,
logger
)
train_dataloader
=
build_dataloader
(
config
,
'Train'
,
device
,
logger
)
train_batch_cost
=
0.0
train_batch_cost
=
0.0
...
@@ -194,7 +194,11 @@ def train(config,
...
@@ -194,7 +194,11 @@ def train(config,
break
break
lr
=
optimizer
.
get_lr
()
lr
=
optimizer
.
get_lr
()
images
=
batch
[
0
]
images
=
batch
[
0
]
preds
=
model
(
images
)
if
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
:
others
=
batch
[
-
4
:]
preds
=
model
(
images
,
others
)
else
:
preds
=
model
(
images
)
loss
=
loss_class
(
preds
,
batch
)
loss
=
loss_class
(
preds
,
batch
)
avg_loss
=
loss
[
'loss'
]
avg_loss
=
loss
[
'loss'
]
avg_loss
.
backward
()
avg_loss
.
backward
()
...
@@ -212,6 +216,7 @@ def train(config,
...
@@ -212,6 +216,7 @@ def train(config,
stats
[
'lr'
]
=
lr
stats
[
'lr'
]
=
lr
train_stats
.
update
(
stats
)
train_stats
.
update
(
stats
)
#cal_metric_during_train = False
if
cal_metric_during_train
:
# onlt rec and cls need
if
cal_metric_during_train
:
# onlt rec and cls need
batch
=
[
item
.
numpy
()
for
item
in
batch
]
batch
=
[
item
.
numpy
()
for
item
in
batch
]
post_result
=
post_process_class
(
preds
,
batch
[
1
])
post_result
=
post_process_class
(
preds
,
batch
[
1
])
...
@@ -312,8 +317,9 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
...
@@ -312,8 +317,9 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
if
idx
>=
len
(
valid_dataloader
):
if
idx
>=
len
(
valid_dataloader
):
break
break
images
=
batch
[
0
]
images
=
batch
[
0
]
others
=
batch
[
-
4
:]
start
=
time
.
time
()
start
=
time
.
time
()
preds
=
model
(
images
)
preds
=
model
(
images
,
others
)
batch
=
[
item
.
numpy
()
for
item
in
batch
]
batch
=
[
item
.
numpy
()
for
item
in
batch
]
# Obtain usable results from post-processing methods
# Obtain usable results from post-processing methods
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录