Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
f6532a0e
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f6532a0e
编写于
4月 26, 2022
作者:
A
andyjpaddle
提交者:
GitHub
4月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ppocrv3 rec (#6033)
* add ppocrv3 rec
上级
6902d160
变更
30
隐藏空白更改
内联
并排
Showing
30 changed file
with
1548 addition
and
55 deletion
+1548
-55
configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec.yml
configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec.yml
+131
-0
configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
+205
-0
ppocr/data/imaug/__init__.py
ppocr/data/imaug/__init__.py
+1
-1
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+32
-0
ppocr/data/imaug/rec_img_aug.py
ppocr/data/imaug/rec_img_aug.py
+53
-5
ppocr/data/simple_dataset.py
ppocr/data/simple_dataset.py
+7
-3
ppocr/losses/__init__.py
ppocr/losses/__init__.py
+2
-1
ppocr/losses/basic_loss.py
ppocr/losses/basic_loss.py
+2
-2
ppocr/losses/combined_loss.py
ppocr/losses/combined_loss.py
+2
-0
ppocr/losses/distillation_loss.py
ppocr/losses/distillation_loss.py
+55
-3
ppocr/losses/rec_multi_loss.py
ppocr/losses/rec_multi_loss.py
+58
-0
ppocr/losses/rec_sar_loss.py
ppocr/losses/rec_sar_loss.py
+2
-1
ppocr/metrics/rec_metric.py
ppocr/metrics/rec_metric.py
+9
-3
ppocr/modeling/architectures/base_model.py
ppocr/modeling/architectures/base_model.py
+5
-1
ppocr/modeling/architectures/distillation_model.py
ppocr/modeling/architectures/distillation_model.py
+2
-2
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+3
-1
ppocr/modeling/backbones/rec_mv1_enhance.py
ppocr/modeling/backbones/rec_mv1_enhance.py
+11
-4
ppocr/modeling/backbones/rec_svtrnet.py
ppocr/modeling/backbones/rec_svtrnet.py
+595
-0
ppocr/modeling/heads/__init__.py
ppocr/modeling/heads/__init__.py
+3
-1
ppocr/modeling/heads/rec_multi_head.py
ppocr/modeling/heads/rec_multi_head.py
+73
-0
ppocr/modeling/heads/rec_sar_head.py
ppocr/modeling/heads/rec_sar_head.py
+11
-3
ppocr/modeling/necks/rnn.py
ppocr/modeling/necks/rnn.py
+106
-7
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+2
-1
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+38
-0
tools/eval.py
tools/eval.py
+28
-4
tools/export_model.py
tools/export_model.py
+32
-2
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+4
-2
tools/infer_rec.py
tools/infer_rec.py
+22
-2
tools/program.py
tools/program.py
+14
-4
tools/train.py
tools/train.py
+40
-2
未找到文件。
configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec.yml
0 → 100644
浏览文件 @
f6532a0e
Global
:
debug
:
false
use_gpu
:
true
epoch_num
:
500
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec_ppocr_v3
save_epoch_step
:
3
eval_batch_step
:
[
0
,
2000
]
cal_metric_during_train
:
true
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
false
infer_img
:
doc/imgs_words/ch/word_1.jpg
character_dict_path
:
ppocr/utils/ppocr_keys_v1.txt
max_text_length
:
&max_text_length
25
infer_mode
:
false
use_space_char
:
true
distributed
:
true
save_res_path
:
./output/rec/predicts_ppocrv3.txt
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
Cosine
learning_rate
:
0.001
warmup_epoch
:
5
regularizer
:
name
:
L2
factor
:
3.0e-05
Architecture
:
model_type
:
rec
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Head
:
name
:
MultiHead
head_list
:
-
CTCHead
:
Neck
:
name
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
fc_decay
:
0.00001
-
SARHead
:
enc_dim
:
512
max_text_length
:
*max_text_length
Loss
:
name
:
MultiLoss
loss_config_list
:
-
CTCLoss
:
-
SARLoss
:
PostProcess
:
name
:
CTCLabelDecode
Metric
:
name
:
RecMetric
main_indicator
:
acc
ignore_space
:
True
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/
ext_op_transform_idx
:
1
label_file_list
:
-
./train_data/train_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
RecConAug
:
prob
:
0.5
ext_data_num
:
2
image_shape
:
[
48
,
320
,
3
]
-
RecAug
:
-
MultiLabelEncode
:
-
RecResizeImg
:
image_shape
:
[
3
,
48
,
320
]
-
KeepKeys
:
keep_keys
:
-
image
-
label_ctc
-
label_sar
-
length
-
valid_ratio
loader
:
shuffle
:
true
batch_size_per_card
:
128
drop_last
:
true
num_workers
:
4
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data
label_file_list
:
-
./train_data/val_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
MultiLabelEncode
:
-
RecResizeImg
:
image_shape
:
[
3
,
48
,
320
]
-
KeepKeys
:
keep_keys
:
-
image
-
label_ctc
-
label_sar
-
length
-
valid_ratio
loader
:
shuffle
:
false
drop_last
:
false
batch_size_per_card
:
128
num_workers
:
4
configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
0 → 100644
浏览文件 @
f6532a0e
Global
:
debug
:
false
use_gpu
:
true
epoch_num
:
800
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec_ppocr_v3_distillation
save_epoch_step
:
3
eval_batch_step
:
[
0
,
2000
]
cal_metric_during_train
:
true
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
false
infer_img
:
doc/imgs_words/ch/word_1.jpg
character_dict_path
:
ppocr/utils/ppocr_keys_v1.txt
max_text_length
:
&max_text_length
25
infer_mode
:
false
use_space_char
:
true
distributed
:
true
save_res_path
:
./output/rec/predicts_ppocrv3_distillation.txt
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
Piecewise
decay_epochs
:
[
700
,
800
]
values
:
[
0.0005
,
0.00005
]
warmup_epoch
:
5
regularizer
:
name
:
L2
factor
:
3.0e-05
Architecture
:
model_type
:
&model_type
"
rec"
name
:
DistillationModel
algorithm
:
Distillation
Models
:
Teacher
:
pretrained
:
freeze_params
:
false
return_all_feats
:
true
model_type
:
*model_type
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Head
:
name
:
MultiHead
head_list
:
-
CTCHead
:
Neck
:
name
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
fc_decay
:
0.00001
-
SARHead
:
enc_dim
:
512
max_text_length
:
*max_text_length
Student
:
pretrained
:
freeze_params
:
false
return_all_feats
:
true
model_type
:
*model_type
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Head
:
name
:
MultiHead
head_list
:
-
CTCHead
:
Neck
:
name
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
fc_decay
:
0.00001
-
SARHead
:
enc_dim
:
512
max_text_length
:
*max_text_length
Loss
:
name
:
CombinedLoss
loss_config_list
:
-
DistillationDMLLoss
:
weight
:
1.0
act
:
"
softmax"
use_log
:
true
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
key
:
head_out
multi_head
:
True
dis_head
:
ctc
name
:
dml_ctc
-
DistillationDMLLoss
:
weight
:
0.5
act
:
"
softmax"
use_log
:
true
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
key
:
head_out
multi_head
:
True
dis_head
:
sar
name
:
dml_sar
-
DistillationDistanceLoss
:
weight
:
1.0
mode
:
"
l2"
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
key
:
backbone_out
-
DistillationCTCLoss
:
weight
:
1.0
model_name_list
:
[
"
Student"
,
"
Teacher"
]
key
:
head_out
multi_head
:
True
-
DistillationSARLoss
:
weight
:
1.0
model_name_list
:
[
"
Student"
,
"
Teacher"
]
key
:
head_out
multi_head
:
True
PostProcess
:
name
:
DistillationCTCLabelDecode
model_name
:
[
"
Student"
,
"
Teacher"
]
key
:
head_out
multi_head
:
True
Metric
:
name
:
DistillationMetric
base_metric_name
:
RecMetric
main_indicator
:
acc
key
:
"
Student"
ignore_space
:
True
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/
ext_op_transform_idx
:
1
label_file_list
:
-
./train_data/train_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
RecConAug
:
prob
:
0.5
ext_data_num
:
2
image_shape
:
[
48
,
320
,
3
]
-
RecAug
:
-
MultiLabelEncode
:
-
RecResizeImg
:
image_shape
:
[
3
,
48
,
320
]
-
KeepKeys
:
keep_keys
:
-
image
-
label_ctc
-
label_sar
-
length
-
valid_ratio
loader
:
shuffle
:
true
batch_size_per_card
:
128
drop_last
:
true
num_workers
:
4
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data
label_file_list
:
-
./train_data/val_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
MultiLabelEncode
:
-
RecResizeImg
:
image_shape
:
[
3
,
48
,
320
]
-
KeepKeys
:
keep_keys
:
-
image
-
label_ctc
-
label_sar
-
length
-
valid_ratio
loader
:
shuffle
:
false
drop_last
:
false
batch_size_per_card
:
128
num_workers
:
4
ppocr/data/imaug/__init__.py
浏览文件 @
f6532a0e
...
...
@@ -22,7 +22,7 @@ from .make_shrink_map import MakeShrinkMap
from
.random_crop_data
import
EastRandomCropData
,
RandomCropImgMask
from
.make_pse_gt
import
MakePseGt
from
.rec_img_aug
import
RecAug
,
RecResizeImg
,
ClsResizeImg
,
\
from
.rec_img_aug
import
RecAug
,
Rec
ConAug
,
Rec
ResizeImg
,
ClsResizeImg
,
\
SRNRecResizeImg
,
NRTRRecResizeImg
,
SARRecResizeImg
,
PRENResizeImg
from
.randaugment
import
RandAugment
from
.copy_paste
import
CopyPaste
...
...
ppocr/data/imaug/label_ops.py
浏览文件 @
f6532a0e
...
...
@@ -22,6 +22,7 @@ import numpy as np
import
string
from
shapely.geometry
import
LineString
,
Point
,
Polygon
import
json
import
copy
from
ppocr.utils.logging
import
get_logger
...
...
@@ -1007,3 +1008,34 @@ class VQATokenLabelEncode(object):
gt_label
.
extend
([
self
.
label2id_map
[(
"i-"
+
label
).
upper
()]]
*
(
len
(
encode_res
[
"input_ids"
])
-
1
))
return
gt_label
class
MultiLabelEncode
(
BaseRecLabelEncode
):
def
__init__
(
self
,
max_text_length
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
MultiLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
use_space_char
)
self
.
ctc_encode
=
CTCLabelEncode
(
max_text_length
,
character_dict_path
,
use_space_char
,
**
kwargs
)
self
.
sar_encode
=
SARLabelEncode
(
max_text_length
,
character_dict_path
,
use_space_char
,
**
kwargs
)
def
__call__
(
self
,
data
):
data_ctc
=
copy
.
deepcopy
(
data
)
data_sar
=
copy
.
deepcopy
(
data
)
data_out
=
dict
()
data_out
[
'img_path'
]
=
data
.
get
(
'img_path'
,
None
)
data_out
[
'image'
]
=
data
[
'image'
]
ctc
=
self
.
ctc_encode
.
__call__
(
data_ctc
)
sar
=
self
.
sar_encode
.
__call__
(
data_sar
)
if
ctc
is
None
or
sar
is
None
:
return
None
data_out
[
'label_ctc'
]
=
ctc
[
'label'
]
data_out
[
'label_sar'
]
=
sar
[
'label'
]
data_out
[
'length'
]
=
ctc
[
'length'
]
return
data_out
ppocr/data/imaug/rec_img_aug.py
浏览文件 @
f6532a0e
...
...
@@ -32,6 +32,49 @@ class RecAug(object):
return
data
class
RecConAug
(
object
):
def
__init__
(
self
,
prob
=
0.5
,
image_shape
=
(
32
,
320
,
3
),
max_text_length
=
25
,
ext_data_num
=
1
,
**
kwargs
):
self
.
ext_data_num
=
ext_data_num
self
.
prob
=
prob
self
.
max_text_length
=
max_text_length
self
.
image_shape
=
image_shape
self
.
max_wh_ratio
=
self
.
image_shape
[
1
]
/
self
.
image_shape
[
0
]
def
merge_ext_data
(
self
,
data
,
ext_data
):
ori_w
=
round
(
data
[
'image'
].
shape
[
1
]
/
data
[
'image'
].
shape
[
0
]
*
self
.
image_shape
[
0
])
ext_w
=
round
(
ext_data
[
'image'
].
shape
[
1
]
/
ext_data
[
'image'
].
shape
[
0
]
*
self
.
image_shape
[
0
])
data
[
'image'
]
=
cv2
.
resize
(
data
[
'image'
],
(
ori_w
,
self
.
image_shape
[
0
]))
ext_data
[
'image'
]
=
cv2
.
resize
(
ext_data
[
'image'
],
(
ext_w
,
self
.
image_shape
[
0
]))
data
[
'image'
]
=
np
.
concatenate
(
[
data
[
'image'
],
ext_data
[
'image'
]],
axis
=
1
)
data
[
"label"
]
+=
ext_data
[
"label"
]
return
data
def
__call__
(
self
,
data
):
rnd_num
=
random
.
random
()
if
rnd_num
>
self
.
prob
:
return
data
for
idx
,
ext_data
in
enumerate
(
data
[
"ext_data"
]):
if
len
(
data
[
"label"
])
+
len
(
ext_data
[
"label"
])
>
self
.
max_text_length
:
break
concat_ratio
=
data
[
'image'
].
shape
[
1
]
/
data
[
'image'
].
shape
[
0
]
+
ext_data
[
'image'
].
shape
[
1
]
/
ext_data
[
'image'
].
shape
[
0
]
if
concat_ratio
>
self
.
max_wh_ratio
:
break
data
=
self
.
merge_ext_data
(
data
,
ext_data
)
data
.
pop
(
"ext_data"
)
return
data
class
ClsResizeImg
(
object
):
def
__init__
(
self
,
image_shape
,
**
kwargs
):
self
.
image_shape
=
image_shape
...
...
@@ -98,10 +141,13 @@ class RecResizeImg(object):
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
if
self
.
infer_mode
and
self
.
character_dict_path
is
not
None
:
norm_img
=
resize_norm_img_chinese
(
img
,
self
.
image_shape
)
norm_img
,
valid_ratio
=
resize_norm_img_chinese
(
img
,
self
.
image_shape
)
else
:
norm_img
=
resize_norm_img
(
img
,
self
.
image_shape
,
self
.
padding
)
norm_img
,
valid_ratio
=
resize_norm_img
(
img
,
self
.
image_shape
,
self
.
padding
)
data
[
'image'
]
=
norm_img
data
[
'valid_ratio'
]
=
valid_ratio
return
data
...
...
@@ -220,7 +266,8 @@ def resize_norm_img(img, image_shape, padding=True):
resized_image
/=
0.5
padding_im
=
np
.
zeros
((
imgC
,
imgH
,
imgW
),
dtype
=
np
.
float32
)
padding_im
[:,
:,
0
:
resized_w
]
=
resized_image
return
padding_im
valid_ratio
=
min
(
1.0
,
float
(
resized_w
/
imgW
))
return
padding_im
,
valid_ratio
def
resize_norm_img_chinese
(
img
,
image_shape
):
...
...
@@ -230,7 +277,7 @@ def resize_norm_img_chinese(img, image_shape):
h
,
w
=
img
.
shape
[
0
],
img
.
shape
[
1
]
ratio
=
w
*
1.0
/
h
max_wh_ratio
=
max
(
max_wh_ratio
,
ratio
)
imgW
=
int
(
32
*
max_wh_ratio
)
imgW
=
int
(
imgH
*
max_wh_ratio
)
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
resized_w
=
imgW
else
:
...
...
@@ -246,7 +293,8 @@ def resize_norm_img_chinese(img, image_shape):
resized_image
/=
0.5
padding_im
=
np
.
zeros
((
imgC
,
imgH
,
imgW
),
dtype
=
np
.
float32
)
padding_im
[:,
:,
0
:
resized_w
]
=
resized_image
return
padding_im
valid_ratio
=
min
(
1.0
,
float
(
resized_w
/
imgW
))
return
padding_im
,
valid_ratio
def
resize_norm_img_srn
(
img
,
image_shape
):
...
...
ppocr/data/simple_dataset.py
浏览文件 @
f6532a0e
...
...
@@ -49,7 +49,8 @@ class SimpleDataSet(Dataset):
if
self
.
mode
==
"train"
and
self
.
do_shuffle
:
self
.
shuffle_data_random
()
self
.
ops
=
create_operators
(
dataset_config
[
'transforms'
],
global_config
)
self
.
ext_op_transform_idx
=
dataset_config
.
get
(
"ext_op_transform_idx"
,
2
)
self
.
need_reset
=
True
in
[
x
<
1
for
x
in
ratio_list
]
def
get_image_info_list
(
self
,
file_list
,
ratio_list
):
...
...
@@ -87,7 +88,7 @@ class SimpleDataSet(Dataset):
if
hasattr
(
op
,
'ext_data_num'
):
ext_data_num
=
getattr
(
op
,
'ext_data_num'
)
break
load_data_ops
=
self
.
ops
[:
2
]
load_data_ops
=
self
.
ops
[:
self
.
ext_op_transform_idx
]
ext_data
=
[]
while
len
(
ext_data
)
<
ext_data_num
:
...
...
@@ -108,8 +109,11 @@ class SimpleDataSet(Dataset):
data
[
'image'
]
=
img
data
=
transform
(
data
,
load_data_ops
)
if
data
is
None
or
data
[
'polys'
].
shape
[
1
]
!=
4
:
if
data
is
None
:
continue
if
'polys'
in
data
.
keys
():
if
data
[
'polys'
].
shape
[
1
]
!=
4
:
continue
ext_data
.
append
(
data
)
return
ext_data
...
...
ppocr/losses/__init__.py
浏览文件 @
f6532a0e
...
...
@@ -34,6 +34,7 @@ from .rec_nrtr_loss import NRTRLoss
from
.rec_sar_loss
import
SARLoss
from
.rec_aster_loss
import
AsterLoss
from
.rec_pren_loss
import
PRENLoss
from
.rec_multi_loss
import
MultiLoss
# cls loss
from
.cls_loss
import
ClsLoss
...
...
@@ -60,7 +61,7 @@ def build_loss(config):
'DBLoss'
,
'PSELoss'
,
'EASTLoss'
,
'SASTLoss'
,
'FCELoss'
,
'CTCLoss'
,
'ClsLoss'
,
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'NRTRLoss'
,
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
,
'SDMGRLoss'
,
'VQASerTokenLayoutLMLoss'
,
'LossFromOutput'
,
'PRENLoss'
'VQASerTokenLayoutLMLoss'
,
'LossFromOutput'
,
'PRENLoss'
,
'MultiLoss'
]
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
...
...
ppocr/losses/basic_loss.py
浏览文件 @
f6532a0e
...
...
@@ -106,8 +106,8 @@ class DMLLoss(nn.Layer):
def
forward
(
self
,
out1
,
out2
):
if
self
.
act
is
not
None
:
out1
=
self
.
act
(
out1
)
out2
=
self
.
act
(
out2
)
out1
=
self
.
act
(
out1
)
+
1e-10
out2
=
self
.
act
(
out2
)
+
1e-10
if
self
.
use_log
:
# for recognition distillation, log is needed for feature map
log_out1
=
paddle
.
log
(
out1
)
...
...
ppocr/losses/combined_loss.py
浏览文件 @
f6532a0e
...
...
@@ -18,8 +18,10 @@ import paddle.nn as nn
from
.rec_ctc_loss
import
CTCLoss
from
.center_loss
import
CenterLoss
from
.ace_loss
import
ACELoss
from
.rec_sar_loss
import
SARLoss
from
.distillation_loss
import
DistillationCTCLoss
from
.distillation_loss
import
DistillationSARLoss
from
.distillation_loss
import
DistillationDMLLoss
from
.distillation_loss
import
DistillationDistanceLoss
,
DistillationDBLoss
,
DistillationDilaDBLoss
...
...
ppocr/losses/distillation_loss.py
浏览文件 @
f6532a0e
...
...
@@ -18,6 +18,7 @@ import numpy as np
import
cv2
from
.rec_ctc_loss
import
CTCLoss
from
.rec_sar_loss
import
SARLoss
from
.basic_loss
import
DMLLoss
from
.basic_loss
import
DistanceLoss
from
.det_db_loss
import
DBLoss
...
...
@@ -46,11 +47,15 @@ class DistillationDMLLoss(DMLLoss):
act
=
None
,
use_log
=
False
,
key
=
None
,
multi_head
=
False
,
dis_head
=
'ctc'
,
maps_name
=
None
,
name
=
"dml"
):
super
().
__init__
(
act
=
act
,
use_log
=
use_log
)
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
multi_head
=
multi_head
self
.
dis_head
=
dis_head
self
.
model_name_pairs
=
self
.
_check_model_name_pairs
(
model_name_pairs
)
self
.
name
=
name
self
.
maps_name
=
self
.
_check_maps_name
(
maps_name
)
...
...
@@ -97,7 +102,11 @@ class DistillationDMLLoss(DMLLoss):
out2
=
out2
[
self
.
key
]
if
self
.
maps_name
is
None
:
loss
=
super
().
forward
(
out1
,
out2
)
if
self
.
multi_head
:
loss
=
super
().
forward
(
out1
[
self
.
dis_head
],
out2
[
self
.
dis_head
])
else
:
loss
=
super
().
forward
(
out1
,
out2
)
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
loss_dict
[
"{}_{}_{}_{}"
.
format
(
key
,
pair
[
0
],
pair
[
1
],
...
...
@@ -123,11 +132,16 @@ class DistillationDMLLoss(DMLLoss):
class
DistillationCTCLoss
(
CTCLoss
):
def
__init__
(
self
,
model_name_list
=
[],
key
=
None
,
name
=
"loss_ctc"
):
def
__init__
(
self
,
model_name_list
=
[],
key
=
None
,
multi_head
=
False
,
name
=
"loss_ctc"
):
super
().
__init__
()
self
.
model_name_list
=
model_name_list
self
.
key
=
key
self
.
name
=
name
self
.
multi_head
=
multi_head
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
...
...
@@ -135,7 +149,45 @@ class DistillationCTCLoss(CTCLoss):
out
=
predicts
[
model_name
]
if
self
.
key
is
not
None
:
out
=
out
[
self
.
key
]
loss
=
super
().
forward
(
out
,
batch
)
if
self
.
multi_head
:
assert
'ctc'
in
out
,
'multi head has multi out'
loss
=
super
().
forward
(
out
[
'ctc'
],
batch
[:
2
]
+
batch
[
3
:])
else
:
loss
=
super
().
forward
(
out
,
batch
)
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
model_name
,
idx
)]
=
loss
[
key
]
else
:
loss_dict
[
"{}_{}"
.
format
(
self
.
name
,
model_name
)]
=
loss
return
loss_dict
class
DistillationSARLoss
(
SARLoss
):
def
__init__
(
self
,
model_name_list
=
[],
key
=
None
,
multi_head
=
False
,
name
=
"loss_sar"
,
**
kwargs
):
ignore_index
=
kwargs
.
get
(
'ignore_index'
,
92
)
super
().
__init__
(
ignore_index
=
ignore_index
)
self
.
model_name_list
=
model_name_list
self
.
key
=
key
self
.
name
=
name
self
.
multi_head
=
multi_head
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
model_name
in
enumerate
(
self
.
model_name_list
):
out
=
predicts
[
model_name
]
if
self
.
key
is
not
None
:
out
=
out
[
self
.
key
]
if
self
.
multi_head
:
assert
'sar'
in
out
,
'multi head has multi out'
loss
=
super
().
forward
(
out
[
'sar'
],
batch
[:
1
]
+
batch
[
2
:])
else
:
loss
=
super
().
forward
(
out
,
batch
)
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
model_name
,
...
...
ppocr/losses/rec_multi_loss.py
0 → 100644
浏览文件 @
f6532a0e
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
nn
from
.rec_ctc_loss
import
CTCLoss
from
.rec_sar_loss
import
SARLoss
class
MultiLoss
(
nn
.
Layer
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
()
self
.
loss_funcs
=
{}
self
.
loss_list
=
kwargs
.
pop
(
'loss_config_list'
)
self
.
weight_1
=
kwargs
.
get
(
'weight_1'
,
1.0
)
self
.
weight_2
=
kwargs
.
get
(
'weight_2'
,
1.0
)
self
.
gtc_loss
=
kwargs
.
get
(
'gtc_loss'
,
'sar'
)
for
loss_info
in
self
.
loss_list
:
for
name
,
param
in
loss_info
.
items
():
if
param
is
not
None
:
kwargs
.
update
(
param
)
loss
=
eval
(
name
)(
**
kwargs
)
self
.
loss_funcs
[
name
]
=
loss
def
forward
(
self
,
predicts
,
batch
):
self
.
total_loss
=
{}
total_loss
=
0.0
# batch [image, label_ctc, label_sar, length, valid_ratio]
for
name
,
loss_func
in
self
.
loss_funcs
.
items
():
if
name
==
'CTCLoss'
:
loss
=
loss_func
(
predicts
[
'ctc'
],
batch
[:
2
]
+
batch
[
3
:])[
'loss'
]
*
self
.
weight_1
elif
name
==
'SARLoss'
:
loss
=
loss_func
(
predicts
[
'sar'
],
batch
[:
1
]
+
batch
[
2
:])[
'loss'
]
*
self
.
weight_2
else
:
raise
NotImplementedError
(
'{} is not supported in MultiLoss yet'
.
format
(
name
))
self
.
total_loss
[
name
]
=
loss
total_loss
+=
loss
self
.
total_loss
[
'loss'
]
=
total_loss
return
self
.
total_loss
ppocr/losses/rec_sar_loss.py
浏览文件 @
f6532a0e
...
...
@@ -9,8 +9,9 @@ from paddle import nn
class
SARLoss
(
nn
.
Layer
):
def
__init__
(
self
,
**
kwargs
):
super
(
SARLoss
,
self
).
__init__
()
ignore_index
=
kwargs
.
get
(
'ignore_index'
,
92
)
# 6626
self
.
loss_func
=
paddle
.
nn
.
loss
.
CrossEntropyLoss
(
reduction
=
"mean"
,
ignore_index
=
92
)
reduction
=
"mean"
,
ignore_index
=
ignore_index
)
def
forward
(
self
,
predicts
,
batch
):
predict
=
predicts
[:,
:
...
...
ppocr/metrics/rec_metric.py
浏览文件 @
f6532a0e
...
...
@@ -17,9 +17,14 @@ import string
class
RecMetric
(
object
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
is_filter
=
False
,
**
kwargs
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
is_filter
=
False
,
ignore_space
=
True
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
is_filter
=
is_filter
self
.
ignore_space
=
ignore_space
self
.
eps
=
1e-5
self
.
reset
()
...
...
@@ -34,8 +39,9 @@ class RecMetric(object):
all_num
=
0
norm_edit_dis
=
0.0
for
(
pred
,
pred_conf
),
(
target
,
_
)
in
zip
(
preds
,
labels
):
pred
=
pred
.
replace
(
" "
,
""
)
target
=
target
.
replace
(
" "
,
""
)
if
self
.
ignore_space
:
pred
=
pred
.
replace
(
" "
,
""
)
target
=
target
.
replace
(
" "
,
""
)
if
self
.
is_filter
:
pred
=
self
.
_normalize_text
(
pred
)
target
=
self
.
_normalize_text
(
target
)
...
...
ppocr/modeling/architectures/base_model.py
浏览文件 @
f6532a0e
...
...
@@ -83,7 +83,11 @@ class BaseModel(nn.Layer):
y
[
"neck_out"
]
=
x
if
self
.
use_head
:
x
=
self
.
head
(
x
,
targets
=
data
)
if
isinstance
(
x
,
dict
):
# for multi head, save ctc neck out for udml
if
isinstance
(
x
,
dict
)
and
'ctc_neck'
in
x
.
keys
():
y
[
"neck_out"
]
=
x
[
"ctc_neck"
]
y
[
"head_out"
]
=
x
elif
isinstance
(
x
,
dict
):
y
.
update
(
x
)
else
:
y
[
"head_out"
]
=
x
...
...
ppocr/modeling/architectures/distillation_model.py
浏览文件 @
f6532a0e
...
...
@@ -53,8 +53,8 @@ class DistillationModel(nn.Layer):
self
.
model_list
.
append
(
self
.
add_sublayer
(
key
,
model
))
self
.
model_name_list
.
append
(
key
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
data
=
None
):
result_dict
=
dict
()
for
idx
,
model_name
in
enumerate
(
self
.
model_name_list
):
result_dict
[
model_name
]
=
self
.
model_list
[
idx
](
x
)
result_dict
[
model_name
]
=
self
.
model_list
[
idx
](
x
,
data
)
return
result_dict
ppocr/modeling/backbones/__init__.py
浏览文件 @
f6532a0e
...
...
@@ -31,9 +31,11 @@ def build_backbone(config, model_type):
from
.rec_resnet_aster
import
ResNet_ASTER
from
.rec_micronet
import
MicroNet
from
.rec_efficientb3_pren
import
EfficientNetb3_PREN
from
.rec_svtrnet
import
SVTRNet
support_dict
=
[
'MobileNetV1Enhance'
,
'MobileNetV3'
,
'ResNet'
,
'ResNetFPN'
,
'MTB'
,
"ResNet31"
,
"ResNet_ASTER"
,
'MicroNet'
,
'EfficientNetb3_PREN'
"ResNet31"
,
"ResNet_ASTER"
,
'MicroNet'
,
'EfficientNetb3_PREN'
,
'SVTRNet'
]
elif
model_type
==
"e2e"
:
from
.e2e_resnet_vd_pg
import
ResNet
...
...
ppocr/modeling/backbones/rec_mv1_enhance.py
浏览文件 @
f6532a0e
...
...
@@ -103,7 +103,12 @@ class DepthwiseSeparable(nn.Layer):
class
MobileNetV1Enhance
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
3
,
scale
=
0.5
,
**
kwargs
):
def
__init__
(
self
,
in_channels
=
3
,
scale
=
0.5
,
last_conv_stride
=
1
,
last_pool_type
=
'max'
,
**
kwargs
):
super
().
__init__
()
self
.
scale
=
scale
self
.
block_list
=
[]
...
...
@@ -200,7 +205,7 @@ class MobileNetV1Enhance(nn.Layer):
num_filters1
=
1024
,
num_filters2
=
1024
,
num_groups
=
1024
,
stride
=
1
,
stride
=
last_conv_stride
,
dw_size
=
5
,
padding
=
2
,
use_se
=
True
,
...
...
@@ -208,8 +213,10 @@ class MobileNetV1Enhance(nn.Layer):
self
.
block_list
.
append
(
conv6
)
self
.
block_list
=
nn
.
Sequential
(
*
self
.
block_list
)
self
.
pool
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
if
last_pool_type
==
'avg'
:
self
.
pool
=
nn
.
AvgPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
else
:
self
.
pool
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
self
.
out_channels
=
int
(
1024
*
scale
)
def
forward
(
self
,
inputs
):
...
...
ppocr/modeling/backbones/rec_svtrnet.py
0 → 100644
浏览文件 @
f6532a0e
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
Callable
from
paddle
import
ParamAttr
from
paddle.nn.initializer
import
KaimingNormal
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
from
paddle.nn.initializer
import
TruncatedNormal
,
Constant
,
Normal
trunc_normal_
=
TruncatedNormal
(
std
=
.
02
)
normal_
=
Normal
zeros_
=
Constant
(
value
=
0.
)
ones_
=
Constant
(
value
=
1.
)
def
drop_path
(
x
,
drop_prob
=
0.
,
training
=
False
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
if
drop_prob
==
0.
or
not
training
:
return
x
keep_prob
=
paddle
.
to_tensor
(
1
-
drop_prob
)
shape
=
(
paddle
.
shape
(
x
)[
0
],
)
+
(
1
,
)
*
(
x
.
ndim
-
1
)
random_tensor
=
keep_prob
+
paddle
.
rand
(
shape
,
dtype
=
x
.
dtype
)
random_tensor
=
paddle
.
floor
(
random_tensor
)
# binarize
output
=
x
.
divide
(
keep_prob
)
*
random_tensor
return
output
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
0
,
bias_attr
=
False
,
groups
=
1
,
act
=
nn
.
GELU
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
()),
bias_attr
=
bias_attr
)
self
.
norm
=
nn
.
BatchNorm2D
(
out_channels
)
self
.
act
=
act
()
def
forward
(
self
,
inputs
):
out
=
self
.
conv
(
inputs
)
out
=
self
.
norm
(
out
)
out
=
self
.
act
(
out
)
return
out
class
DropPath
(
nn
.
Layer
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def
__init__
(
self
,
drop_prob
=
None
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
class
Identity
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
Identity
,
self
).
__init__
()
def
forward
(
self
,
input
):
return
input
class
Mlp
(
nn
.
Layer
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
ConvMixer
(
nn
.
Layer
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
HW
=
[
8
,
25
],
local_k
=
[
3
,
3
],
):
super
().
__init__
()
self
.
HW
=
HW
self
.
dim
=
dim
self
.
local_mixer
=
nn
.
Conv2D
(
dim
,
dim
,
local_k
,
1
,
[
local_k
[
0
]
//
2
,
local_k
[
1
]
//
2
],
groups
=
num_heads
,
weight_attr
=
ParamAttr
(
initializer
=
KaimingNormal
()))
def
forward
(
self
,
x
):
h
=
self
.
HW
[
0
]
w
=
self
.
HW
[
1
]
x
=
x
.
transpose
([
0
,
2
,
1
]).
reshape
([
0
,
self
.
dim
,
h
,
w
])
x
=
self
.
local_mixer
(
x
)
x
=
x
.
flatten
(
2
).
transpose
([
0
,
2
,
1
])
return
x
class
Attention
(
nn
.
Layer
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
mixer
=
'Global'
,
HW
=
[
8
,
25
],
local_k
=
[
7
,
11
],
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
().
__init__
()
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**-
0.5
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias_attr
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
HW
=
HW
if
HW
is
not
None
:
H
=
HW
[
0
]
W
=
HW
[
1
]
self
.
N
=
H
*
W
self
.
C
=
dim
if
mixer
==
'Local'
and
HW
is
not
None
:
hk
=
local_k
[
0
]
wk
=
local_k
[
1
]
mask
=
np
.
ones
([
H
*
W
,
H
*
W
])
for
h
in
range
(
H
):
for
w
in
range
(
W
):
for
kh
in
range
(
-
(
hk
//
2
),
(
hk
//
2
)
+
1
):
for
kw
in
range
(
-
(
wk
//
2
),
(
wk
//
2
)
+
1
):
if
H
>
(
h
+
kh
)
>=
0
and
W
>
(
w
+
kw
)
>=
0
:
mask
[
h
*
W
+
w
][(
h
+
kh
)
*
W
+
(
w
+
kw
)]
=
0
mask_paddle
=
paddle
.
to_tensor
(
mask
,
dtype
=
'float32'
)
mask_inf
=
paddle
.
full
([
H
*
W
,
H
*
W
],
'-inf'
,
dtype
=
'float32'
)
mask
=
paddle
.
where
(
mask_paddle
<
1
,
mask_paddle
,
mask_inf
)
self
.
mask
=
mask
.
unsqueeze
([
0
,
1
])
self
.
mixer
=
mixer
def
forward
(
self
,
x
):
if
self
.
HW
is
not
None
:
N
=
self
.
N
C
=
self
.
C
else
:
_
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
((
0
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
)).
transpose
((
2
,
0
,
3
,
1
,
4
))
q
,
k
,
v
=
qkv
[
0
]
*
self
.
scale
,
qkv
[
1
],
qkv
[
2
]
attn
=
(
q
.
matmul
(
k
.
transpose
((
0
,
1
,
3
,
2
))))
if
self
.
mixer
==
'Local'
:
attn
+=
self
.
mask
attn
=
nn
.
functional
.
softmax
(
attn
,
axis
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
.
matmul
(
v
)).
transpose
((
0
,
2
,
1
,
3
)).
reshape
((
0
,
N
,
C
))
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
Block
(
nn
.
Layer
):
def
__init__
(
self
,
dim
,
num_heads
,
mixer
=
'Global'
,
local_mixer
=
[
7
,
11
],
HW
=
[
8
,
25
],
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
'nn.LayerNorm'
,
epsilon
=
1e-6
,
prenorm
=
True
):
super
().
__init__
()
if
isinstance
(
norm_layer
,
str
):
self
.
norm1
=
eval
(
norm_layer
)(
dim
,
epsilon
=
epsilon
)
elif
isinstance
(
norm_layer
,
Callable
):
self
.
norm1
=
norm_layer
(
dim
)
else
:
raise
TypeError
(
"The norm_layer must be str or paddle.nn.layer.Layer class"
)
if
mixer
==
'Global'
or
mixer
==
'Local'
:
self
.
mixer
=
Attention
(
dim
,
num_heads
=
num_heads
,
mixer
=
mixer
,
HW
=
HW
,
local_k
=
local_mixer
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
)
elif
mixer
==
'Conv'
:
self
.
mixer
=
ConvMixer
(
dim
,
num_heads
=
num_heads
,
HW
=
HW
,
local_k
=
local_mixer
)
else
:
raise
TypeError
(
"The mixer must be one of [Global, Local, Conv]"
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
Identity
()
if
isinstance
(
norm_layer
,
str
):
self
.
norm2
=
eval
(
norm_layer
)(
dim
,
epsilon
=
epsilon
)
elif
isinstance
(
norm_layer
,
Callable
):
self
.
norm2
=
norm_layer
(
dim
)
else
:
raise
TypeError
(
"The norm_layer must be str or paddle.nn.layer.Layer class"
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp_ratio
=
mlp_ratio
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
self
.
prenorm
=
prenorm
def
forward
(
self
,
x
):
if
self
.
prenorm
:
x
=
self
.
norm1
(
x
+
self
.
drop_path
(
self
.
mixer
(
x
)))
x
=
self
.
norm2
(
x
+
self
.
drop_path
(
self
.
mlp
(
x
)))
else
:
x
=
x
+
self
.
drop_path
(
self
.
mixer
(
self
.
norm1
(
x
)))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
class
PatchEmbed
(
nn
.
Layer
):
""" Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
[
32
,
100
],
in_channels
=
3
,
embed_dim
=
768
,
sub_num
=
2
):
super
().
__init__
()
num_patches
=
(
img_size
[
1
]
//
(
2
**
sub_num
))
*
\
(
img_size
[
0
]
//
(
2
**
sub_num
))
self
.
img_size
=
img_size
self
.
num_patches
=
num_patches
self
.
embed_dim
=
embed_dim
self
.
norm
=
None
if
sub_num
==
2
:
self
.
proj
=
nn
.
Sequential
(
ConvBNLayer
(
in_channels
,
embed_dim
//
2
,
3
,
2
,
1
,
act
=
nn
.
GELU
,
bias_attr
=
None
),
ConvBNLayer
(
embed_dim
//
2
,
embed_dim
,
3
,
2
,
1
,
act
=
nn
.
GELU
,
bias_attr
=
None
))
if
sub_num
==
3
:
self
.
proj
=
nn
.
Sequential
(
ConvBNLayer
(
in_channels
,
embed_dim
//
4
,
3
,
2
,
1
,
act
=
nn
.
GELU
,
bias_attr
=
None
),
ConvBNLayer
(
embed_dim
//
4
,
embed_dim
//
2
,
3
,
2
,
1
,
act
=
nn
.
GELU
,
bias_attr
=
None
),
ConvBNLayer
(
embed_dim
//
2
,
embed_dim
,
3
,
2
,
1
,
act
=
nn
.
GELU
,
bias_attr
=
None
),
)
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
((
0
,
2
,
1
))
return
x
class
SubSample
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
types
=
'Pool'
,
stride
=
[
2
,
1
],
sub_norm
=
'nn.LayerNorm'
,
act
=
None
):
super
().
__init__
()
self
.
types
=
types
if
types
==
'Pool'
:
self
.
avgpool
=
nn
.
AvgPool2D
(
kernel_size
=
[
3
,
5
],
stride
=
stride
,
padding
=
[
1
,
2
])
self
.
maxpool
=
nn
.
MaxPool2D
(
kernel_size
=
[
3
,
5
],
stride
=
stride
,
padding
=
[
1
,
2
])
self
.
proj
=
nn
.
Linear
(
in_channels
,
out_channels
)
else
:
self
.
conv
=
nn
.
Conv2D
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
KaimingNormal
()))
self
.
norm
=
eval
(
sub_norm
)(
out_channels
)
if
act
is
not
None
:
self
.
act
=
act
()
else
:
self
.
act
=
None
def
forward
(
self
,
x
):
if
self
.
types
==
'Pool'
:
x1
=
self
.
avgpool
(
x
)
x2
=
self
.
maxpool
(
x
)
x
=
(
x1
+
x2
)
*
0.5
out
=
self
.
proj
(
x
.
flatten
(
2
).
transpose
((
0
,
2
,
1
)))
else
:
x
=
self
.
conv
(
x
)
out
=
x
.
flatten
(
2
).
transpose
((
0
,
2
,
1
))
out
=
self
.
norm
(
out
)
if
self
.
act
is
not
None
:
out
=
self
.
act
(
out
)
return
out
class
SVTRNet
(
nn
.
Layer
):
def
__init__
(
self
,
img_size
=
[
32
,
100
],
in_channels
=
3
,
embed_dim
=
[
64
,
128
,
256
],
depth
=
[
3
,
6
,
3
],
num_heads
=
[
2
,
4
,
8
],
mixer
=
[
'Local'
]
*
6
+
[
'Global'
]
*
6
,
# Local atten, Global atten, Conv
local_mixer
=
[[
7
,
11
],
[
7
,
11
],
[
7
,
11
]],
patch_merging
=
'Conv'
,
# Conv, Pool, None
mlp_ratio
=
4
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop_rate
=
0.
,
last_drop
=
0.1
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.1
,
norm_layer
=
'nn.LayerNorm'
,
sub_norm
=
'nn.LayerNorm'
,
epsilon
=
1e-6
,
out_channels
=
192
,
out_char_num
=
25
,
block_unit
=
'Block'
,
act
=
'nn.GELU'
,
last_stage
=
True
,
sub_num
=
2
,
prenorm
=
True
,
use_lenhead
=
False
,
**
kwargs
):
super
().
__init__
()
self
.
img_size
=
img_size
self
.
embed_dim
=
embed_dim
self
.
out_channels
=
out_channels
self
.
prenorm
=
prenorm
patch_merging
=
None
if
patch_merging
!=
'Conv'
and
patch_merging
!=
'Pool'
else
patch_merging
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
in_channels
=
in_channels
,
embed_dim
=
embed_dim
[
0
],
sub_num
=
sub_num
)
num_patches
=
self
.
patch_embed
.
num_patches
self
.
HW
=
[
img_size
[
0
]
//
(
2
**
sub_num
),
img_size
[
1
]
//
(
2
**
sub_num
)]
self
.
pos_embed
=
self
.
create_parameter
(
shape
=
[
1
,
num_patches
,
embed_dim
[
0
]],
default_initializer
=
zeros_
)
self
.
add_parameter
(
"pos_embed"
,
self
.
pos_embed
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
Block_unit
=
eval
(
block_unit
)
dpr
=
np
.
linspace
(
0
,
drop_path_rate
,
sum
(
depth
))
self
.
blocks1
=
nn
.
LayerList
([
Block_unit
(
dim
=
embed_dim
[
0
],
num_heads
=
num_heads
[
0
],
mixer
=
mixer
[
0
:
depth
[
0
]][
i
],
HW
=
self
.
HW
,
local_mixer
=
local_mixer
[
0
],
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
act_layer
=
nn
.
Swish
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
0
:
depth
[
0
]][
i
],
norm_layer
=
norm_layer
,
epsilon
=
epsilon
,
prenorm
=
prenorm
)
for
i
in
range
(
depth
[
0
])
])
if
patch_merging
is
not
None
:
self
.
sub_sample1
=
SubSample
(
embed_dim
[
0
],
embed_dim
[
1
],
sub_norm
=
sub_norm
,
stride
=
[
2
,
1
],
types
=
patch_merging
)
HW
=
[
self
.
HW
[
0
]
//
2
,
self
.
HW
[
1
]]
else
:
HW
=
self
.
HW
self
.
patch_merging
=
patch_merging
self
.
blocks2
=
nn
.
LayerList
([
Block_unit
(
dim
=
embed_dim
[
1
],
num_heads
=
num_heads
[
1
],
mixer
=
mixer
[
depth
[
0
]:
depth
[
0
]
+
depth
[
1
]][
i
],
HW
=
HW
,
local_mixer
=
local_mixer
[
1
],
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
act_layer
=
eval
(
act
),
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
depth
[
0
]:
depth
[
0
]
+
depth
[
1
]][
i
],
norm_layer
=
norm_layer
,
epsilon
=
epsilon
,
prenorm
=
prenorm
)
for
i
in
range
(
depth
[
1
])
])
if
patch_merging
is
not
None
:
self
.
sub_sample2
=
SubSample
(
embed_dim
[
1
],
embed_dim
[
2
],
sub_norm
=
sub_norm
,
stride
=
[
2
,
1
],
types
=
patch_merging
)
HW
=
[
self
.
HW
[
0
]
//
4
,
self
.
HW
[
1
]]
else
:
HW
=
self
.
HW
self
.
blocks3
=
nn
.
LayerList
([
Block_unit
(
dim
=
embed_dim
[
2
],
num_heads
=
num_heads
[
2
],
mixer
=
mixer
[
depth
[
0
]
+
depth
[
1
]:][
i
],
HW
=
HW
,
local_mixer
=
local_mixer
[
2
],
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
act_layer
=
eval
(
act
),
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
depth
[
0
]
+
depth
[
1
]:][
i
],
norm_layer
=
norm_layer
,
epsilon
=
epsilon
,
prenorm
=
prenorm
)
for
i
in
range
(
depth
[
2
])
])
self
.
last_stage
=
last_stage
if
last_stage
:
self
.
avg_pool
=
nn
.
AdaptiveAvgPool2D
([
1
,
out_char_num
])
self
.
last_conv
=
nn
.
Conv2D
(
in_channels
=
embed_dim
[
2
],
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias_attr
=
False
)
self
.
hardswish
=
nn
.
Hardswish
()
self
.
dropout
=
nn
.
Dropout
(
p
=
last_drop
,
mode
=
"downscale_in_infer"
)
if
not
prenorm
:
self
.
norm
=
eval
(
norm_layer
)(
embed_dim
[
-
1
],
epsilon
=
epsilon
)
self
.
use_lenhead
=
use_lenhead
if
use_lenhead
:
self
.
len_conv
=
nn
.
Linear
(
embed_dim
[
2
],
self
.
out_channels
)
self
.
hardswish_len
=
nn
.
Hardswish
()
self
.
dropout_len
=
nn
.
Dropout
(
p
=
last_drop
,
mode
=
"downscale_in_infer"
)
trunc_normal_
(
self
.
pos_embed
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
zeros_
(
m
.
bias
)
ones_
(
m
.
weight
)
def
forward_features
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
x
=
x
+
self
.
pos_embed
x
=
self
.
pos_drop
(
x
)
for
blk
in
self
.
blocks1
:
x
=
blk
(
x
)
if
self
.
patch_merging
is
not
None
:
x
=
self
.
sub_sample1
(
x
.
transpose
([
0
,
2
,
1
]).
reshape
(
[
0
,
self
.
embed_dim
[
0
],
self
.
HW
[
0
],
self
.
HW
[
1
]]))
for
blk
in
self
.
blocks2
:
x
=
blk
(
x
)
if
self
.
patch_merging
is
not
None
:
x
=
self
.
sub_sample2
(
x
.
transpose
([
0
,
2
,
1
]).
reshape
(
[
0
,
self
.
embed_dim
[
1
],
self
.
HW
[
0
]
//
2
,
self
.
HW
[
1
]]))
for
blk
in
self
.
blocks3
:
x
=
blk
(
x
)
if
not
self
.
prenorm
:
x
=
self
.
norm
(
x
)
return
x
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
)
if
self
.
use_lenhead
:
len_x
=
self
.
len_conv
(
x
.
mean
(
1
))
len_x
=
self
.
dropout_len
(
self
.
hardswish_len
(
len_x
))
if
self
.
last_stage
:
if
self
.
patch_merging
is
not
None
:
h
=
self
.
HW
[
0
]
//
4
else
:
h
=
self
.
HW
[
0
]
x
=
self
.
avg_pool
(
x
.
transpose
([
0
,
2
,
1
]).
reshape
(
[
0
,
self
.
embed_dim
[
2
],
h
,
self
.
HW
[
1
]]))
x
=
self
.
last_conv
(
x
)
x
=
self
.
hardswish
(
x
)
x
=
self
.
dropout
(
x
)
if
self
.
use_lenhead
:
return
x
,
len_x
return
x
ppocr/modeling/heads/__init__.py
浏览文件 @
f6532a0e
...
...
@@ -32,6 +32,7 @@ def build_head(config):
from
.rec_sar_head
import
SARHead
from
.rec_aster_head
import
AsterHead
from
.rec_pren_head
import
PRENHead
from
.rec_multi_head
import
MultiHead
# cls head
from
.cls_head
import
ClsHead
...
...
@@ -44,7 +45,8 @@ def build_head(config):
support_dict
=
[
'DBHead'
,
'PSEHead'
,
'FCEHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'ClsHead'
,
'AttentionHead'
,
'SRNHead'
,
'PGHead'
,
'Transformer'
,
'TableAttentionHead'
,
'SARHead'
,
'AsterHead'
,
'SDMGRHead'
,
'PRENHead'
'TableAttentionHead'
,
'SARHead'
,
'AsterHead'
,
'SDMGRHead'
,
'PRENHead'
,
'MultiHead'
]
#table head
...
...
ppocr/modeling/heads/rec_multi_head.py
0 → 100644
浏览文件 @
f6532a0e
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
ppocr.modeling.necks.rnn
import
Im2Seq
,
EncoderWithRNN
,
EncoderWithFC
,
SequenceEncoder
,
EncoderWithSVTR
from
.rec_ctc_head
import
CTCHead
from
.rec_sar_head
import
SARHead
class
MultiHead
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels_list
,
**
kwargs
):
super
().
__init__
()
self
.
head_list
=
kwargs
.
pop
(
'head_list'
)
self
.
gtc_head
=
'sar'
assert
len
(
self
.
head_list
)
>=
2
for
idx
,
head_name
in
enumerate
(
self
.
head_list
):
name
=
list
(
head_name
)[
0
]
if
name
==
'SARHead'
:
# sar head
sar_args
=
self
.
head_list
[
idx
][
name
]
self
.
sar_head
=
eval
(
name
)(
in_channels
=
in_channels
,
\
out_channels
=
out_channels_list
[
'SARLabelDecode'
],
**
sar_args
)
elif
name
==
'CTCHead'
:
# ctc neck
self
.
encoder_reshape
=
Im2Seq
(
in_channels
)
neck_args
=
self
.
head_list
[
idx
][
name
][
'Neck'
]
encoder_type
=
neck_args
.
pop
(
'name'
)
self
.
encoder
=
encoder_type
self
.
ctc_encoder
=
SequenceEncoder
(
in_channels
=
in_channels
,
\
encoder_type
=
encoder_type
,
**
neck_args
)
# ctc head
head_args
=
self
.
head_list
[
idx
][
name
][
'Head'
]
self
.
ctc_head
=
eval
(
name
)(
in_channels
=
self
.
ctc_encoder
.
out_channels
,
\
out_channels
=
out_channels_list
[
'CTCLabelDecode'
],
**
head_args
)
else
:
raise
NotImplementedError
(
'{} is not supported in MultiHead yet'
.
format
(
name
))
def
forward
(
self
,
x
,
targets
=
None
):
ctc_encoder
=
self
.
ctc_encoder
(
x
)
ctc_out
=
self
.
ctc_head
(
ctc_encoder
,
targets
)
head_out
=
dict
()
head_out
[
'ctc'
]
=
ctc_out
head_out
[
'ctc_neck'
]
=
ctc_encoder
# eval mode
if
not
self
.
training
:
return
ctc_out
if
self
.
gtc_head
==
'sar'
:
sar_out
=
self
.
sar_head
(
x
,
targets
[
1
:])
head_out
[
'sar'
]
=
sar_out
return
head_out
else
:
return
head_out
ppocr/modeling/heads/rec_sar_head.py
浏览文件 @
f6532a0e
...
...
@@ -349,7 +349,10 @@ class ParallelSARDecoder(BaseDecoder):
class
SARHead
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
enc_dim
=
512
,
max_text_length
=
30
,
enc_bi_rnn
=
False
,
enc_drop_rnn
=
0.1
,
enc_gru
=
False
,
...
...
@@ -358,14 +361,17 @@ class SARHead(nn.Layer):
dec_gru
=
False
,
d_k
=
512
,
pred_dropout
=
0.1
,
max_text_length
=
30
,
pred_concat
=
True
,
**
kwargs
):
super
(
SARHead
,
self
).
__init__
()
# encoder module
self
.
encoder
=
SAREncoder
(
enc_bi_rnn
=
enc_bi_rnn
,
enc_drop_rnn
=
enc_drop_rnn
,
enc_gru
=
enc_gru
)
enc_bi_rnn
=
enc_bi_rnn
,
enc_drop_rnn
=
enc_drop_rnn
,
enc_gru
=
enc_gru
,
d_model
=
in_channels
,
d_enc
=
enc_dim
)
# decoder module
self
.
decoder
=
ParallelSARDecoder
(
...
...
@@ -374,6 +380,8 @@ class SARHead(nn.Layer):
dec_bi_rnn
=
dec_bi_rnn
,
dec_drop_rnn
=
dec_drop_rnn
,
dec_gru
=
dec_gru
,
d_model
=
in_channels
,
d_enc
=
enc_dim
,
d_k
=
d_k
,
pred_dropout
=
pred_dropout
,
max_text_length
=
max_text_length
,
...
...
@@ -390,7 +398,7 @@ class SARHead(nn.Layer):
label
=
paddle
.
to_tensor
(
label
,
dtype
=
'int64'
)
final_out
=
self
.
decoder
(
feat
,
holistic_feat
,
label
,
img_metas
=
targets
)
if
not
self
.
training
:
else
:
final_out
=
self
.
decoder
(
feat
,
holistic_feat
,
...
...
ppocr/modeling/necks/rnn.py
浏览文件 @
f6532a0e
...
...
@@ -16,9 +16,11 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
nn
from
ppocr.modeling.heads.rec_ctc_head
import
get_para_bias_attr
from
ppocr.modeling.backbones.rec_svtrnet
import
Block
,
ConvBNLayer
,
trunc_normal_
,
zeros_
,
ones_
class
Im2Seq
(
nn
.
Layer
):
...
...
@@ -64,29 +66,126 @@ class EncoderWithFC(nn.Layer):
return
x
class
EncoderWithSVTR
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
dims
=
64
,
# XS
depth
=
2
,
hidden_dims
=
120
,
use_guide
=
False
,
num_heads
=
8
,
qkv_bias
=
True
,
mlp_ratio
=
2.0
,
drop_rate
=
0.1
,
attn_drop_rate
=
0.1
,
drop_path
=
0.
,
qk_scale
=
None
):
super
(
EncoderWithSVTR
,
self
).
__init__
()
self
.
depth
=
depth
self
.
use_guide
=
use_guide
self
.
conv1
=
ConvBNLayer
(
in_channels
,
in_channels
//
8
,
padding
=
1
,
act
=
nn
.
Swish
)
self
.
conv2
=
ConvBNLayer
(
in_channels
//
8
,
hidden_dims
,
kernel_size
=
1
,
act
=
nn
.
Swish
)
self
.
svtr_block
=
nn
.
LayerList
([
Block
(
dim
=
hidden_dims
,
num_heads
=
num_heads
,
mixer
=
'Global'
,
HW
=
None
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
act_layer
=
nn
.
Swish
,
attn_drop
=
attn_drop_rate
,
drop_path
=
drop_path
,
norm_layer
=
'nn.LayerNorm'
,
epsilon
=
1e-05
,
prenorm
=
False
)
for
i
in
range
(
depth
)
])
self
.
norm
=
nn
.
LayerNorm
(
hidden_dims
,
epsilon
=
1e-6
)
self
.
conv3
=
ConvBNLayer
(
hidden_dims
,
in_channels
,
kernel_size
=
1
,
act
=
nn
.
Swish
)
# last conv-nxn, the input is concat of input tensor and conv3 output tensor
self
.
conv4
=
ConvBNLayer
(
2
*
in_channels
,
in_channels
//
8
,
padding
=
1
,
act
=
nn
.
Swish
)
self
.
conv1x1
=
ConvBNLayer
(
in_channels
//
8
,
dims
,
kernel_size
=
1
,
act
=
nn
.
Swish
)
self
.
out_channels
=
dims
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
zeros_
(
m
.
bias
)
ones_
(
m
.
weight
)
def
forward
(
self
,
x
):
# for use guide
if
self
.
use_guide
:
z
=
x
.
clone
()
z
.
stop_gradient
=
True
else
:
z
=
x
# for short cut
h
=
z
# reduce dim
z
=
self
.
conv1
(
z
)
z
=
self
.
conv2
(
z
)
# SVTR global block
B
,
C
,
H
,
W
=
z
.
shape
z
=
z
.
flatten
(
2
).
transpose
([
0
,
2
,
1
])
for
blk
in
self
.
svtr_block
:
z
=
blk
(
z
)
z
=
self
.
norm
(
z
)
# last stage
z
=
z
.
reshape
([
0
,
H
,
W
,
C
]).
transpose
([
0
,
3
,
1
,
2
])
z
=
self
.
conv3
(
z
)
z
=
paddle
.
concat
((
h
,
z
),
axis
=
1
)
z
=
self
.
conv1x1
(
self
.
conv4
(
z
))
return
z
class
SequenceEncoder
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
encoder_type
,
hidden_size
=
48
,
**
kwargs
):
super
(
SequenceEncoder
,
self
).
__init__
()
self
.
encoder_reshape
=
Im2Seq
(
in_channels
)
self
.
out_channels
=
self
.
encoder_reshape
.
out_channels
self
.
encoder_type
=
encoder_type
if
encoder_type
==
'reshape'
:
self
.
only_reshape
=
True
else
:
support_encoder_dict
=
{
'reshape'
:
Im2Seq
,
'fc'
:
EncoderWithFC
,
'rnn'
:
EncoderWithRNN
'rnn'
:
EncoderWithRNN
,
'svtr'
:
EncoderWithSVTR
}
assert
encoder_type
in
support_encoder_dict
,
'{} must in {}'
.
format
(
encoder_type
,
support_encoder_dict
.
keys
())
self
.
encoder
=
support_encoder_dict
[
encoder_type
](
self
.
encoder_reshape
.
out_channels
,
hidden_size
)
if
encoder_type
==
"svtr"
:
self
.
encoder
=
support_encoder_dict
[
encoder_type
](
self
.
encoder_reshape
.
out_channels
,
**
kwargs
)
else
:
self
.
encoder
=
support_encoder_dict
[
encoder_type
](
self
.
encoder_reshape
.
out_channels
,
hidden_size
)
self
.
out_channels
=
self
.
encoder
.
out_channels
self
.
only_reshape
=
False
def
forward
(
self
,
x
):
x
=
self
.
encoder_reshape
(
x
)
if
not
self
.
only_reshape
:
if
self
.
encoder_type
!=
'svtr'
:
x
=
self
.
encoder_reshape
(
x
)
if
not
self
.
only_reshape
:
x
=
self
.
encoder
(
x
)
return
x
else
:
x
=
self
.
encoder
(
x
)
return
x
x
=
self
.
encoder_reshape
(
x
)
return
x
ppocr/postprocess/__init__.py
浏览文件 @
f6532a0e
...
...
@@ -41,7 +41,8 @@ def build_post_process(config, global_config=None):
'PGPostProcess'
,
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'DistillationDBPostProcess'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'SEEDLabelDecode'
,
'VQASerTokenLayoutLMPostProcess'
,
'VQAReTokenLayoutLMPostProcess'
,
'PRENLabelDecode'
'VQAReTokenLayoutLMPostProcess'
,
'PRENLabelDecode'
,
'DistillationSARLabelDecode'
]
if
config
[
'name'
]
==
'PSEPostProcess'
:
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
f6532a0e
...
...
@@ -117,6 +117,7 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
use_space_char
=
False
,
model_name
=
[
"student"
],
key
=
None
,
multi_head
=
False
,
**
kwargs
):
super
(
DistillationCTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
...
...
@@ -125,6 +126,7 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
self
.
model_name
=
model_name
self
.
key
=
key
self
.
multi_head
=
multi_head
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
output
=
dict
()
...
...
@@ -132,6 +134,8 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
pred
=
preds
[
name
]
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
if
self
.
multi_head
and
isinstance
(
pred
,
dict
):
pred
=
pred
[
'ctc'
]
output
[
name
]
=
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
return
output
...
...
@@ -656,6 +660,40 @@ class SARLabelDecode(BaseRecLabelDecode):
return
[
self
.
padding_idx
]
class
DistillationSARLabelDecode
(
SARLabelDecode
):
"""
Convert
Convert between text-label and text-index
"""
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
model_name
=
[
"student"
],
key
=
None
,
multi_head
=
False
,
**
kwargs
):
super
(
DistillationSARLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
if
not
isinstance
(
model_name
,
list
):
model_name
=
[
model_name
]
self
.
model_name
=
model_name
self
.
key
=
key
self
.
multi_head
=
multi_head
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
output
=
dict
()
for
name
in
self
.
model_name
:
pred
=
preds
[
name
]
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
if
self
.
multi_head
and
isinstance
(
pred
,
dict
):
pred
=
pred
[
'sar'
]
output
[
name
]
=
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
return
output
class
PRENLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
...
...
tools/eval.py
浏览文件 @
f6532a0e
...
...
@@ -47,14 +47,38 @@ def main():
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
for
key
in
config
[
'Architecture'
][
"Models"
]:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
'out_channels'
]
=
char_num
if
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'name'
]
==
'MultiHead'
:
# for multi head
out_channels_list
=
{}
if
config
[
'PostProcess'
][
'name'
]
==
'DistillationSARLabelDecode'
:
char_num
=
char_num
-
2
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
'out_channels'
]
=
char_num
elif
config
[
'Architecture'
][
'Head'
][
'name'
]
==
'MultiHead'
:
# for multi head
out_channels_list
=
{}
if
config
[
'PostProcess'
][
'name'
]
==
'SARLabelDecode'
:
char_num
=
char_num
-
2
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
# base rec model
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
model
=
build_model
(
config
[
'Architecture'
])
extra_input
=
config
[
'Architecture'
][
'algorithm'
]
in
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
]
extra_input_models
=
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
,
"SVTR"
]
if
config
[
'Architecture'
][
'algorithm'
]
==
'Distillation'
:
extra_input
=
config
[
'Architecture'
][
'Models'
][
'Teacher'
][
'algorithm'
]
in
extra_input_models
else
:
extra_input
=
config
[
'Architecture'
][
'algorithm'
]
in
extra_input_models
if
"model_type"
in
config
[
'Architecture'
].
keys
():
model_type
=
config
[
'Architecture'
][
'model_type'
]
else
:
...
...
tools/export_model.py
浏览文件 @
f6532a0e
...
...
@@ -55,6 +55,13 @@ def export_single_model(model, arch_config, save_path, logger):
shape
=
[
None
,
3
,
48
,
160
],
dtype
=
"float32"
),
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
elif
arch_config
[
"algorithm"
]
==
"SVTR"
:
if
arch_config
[
"Head"
][
"name"
]
==
'MultiHead'
:
other_shape
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
48
,
-
1
],
dtype
=
"float32"
),
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
elif
arch_config
[
"algorithm"
]
==
"PREN"
:
other_shape
=
[
paddle
.
static
.
InputSpec
(
...
...
@@ -105,13 +112,36 @@ def main():
if
config
[
"Architecture"
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
for
key
in
config
[
"Architecture"
][
"Models"
]:
config
[
"Architecture"
][
"Models"
][
key
][
"Head"
][
"out_channels"
]
=
char_num
if
config
[
"Architecture"
][
"Models"
][
key
][
"Head"
][
"name"
]
==
'MultiHead'
:
# multi head
out_channels_list
=
{}
if
config
[
'PostProcess'
][
'name'
]
==
'DistillationSARLabelDecode'
:
char_num
=
char_num
-
2
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
loss_list
=
config
[
'Loss'
][
'loss_config_list'
]
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
config
[
"Architecture"
][
"Models"
][
key
][
"Head"
][
"out_channels"
]
=
char_num
# just one final tensor needs to to exported for inference
config
[
"Architecture"
][
"Models"
][
key
][
"return_all_feats"
]
=
False
elif
config
[
'Architecture'
][
'Head'
][
'name'
]
==
'MultiHead'
:
# multi head
out_channels_list
=
{}
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
if
config
[
'PostProcess'
][
'name'
]
==
'SARLabelDecode'
:
char_num
=
char_num
-
2
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
# base rec model
config
[
"Architecture"
][
"Head"
][
"out_channels"
]
=
char_num
model
=
build_model
(
config
[
"Architecture"
])
load_model
(
config
,
model
)
model
.
eval
()
...
...
tools/infer/predict_rec.py
浏览文件 @
f6532a0e
...
...
@@ -107,7 +107,7 @@ class TextRecognizer(object):
return
norm_img
.
astype
(
np
.
float32
)
/
128.
-
1.
assert
imgC
==
img
.
shape
[
2
]
imgW
=
int
((
32
*
max_wh_ratio
))
imgW
=
int
((
imgH
*
max_wh_ratio
))
if
self
.
use_onnx
:
w
=
self
.
input_tensor
.
shape
[
3
:][
0
]
if
w
is
not
None
and
w
>
0
:
...
...
@@ -255,7 +255,9 @@ class TextRecognizer(object):
for
beg_img_no
in
range
(
0
,
img_num
,
batch_num
):
end_img_no
=
min
(
img_num
,
beg_img_no
+
batch_num
)
norm_img_batch
=
[]
max_wh_ratio
=
0
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
max_wh_ratio
=
imgW
/
imgH
# max_wh_ratio = 0
for
ino
in
range
(
beg_img_no
,
end_img_no
):
h
,
w
=
img_list
[
indices
[
ino
]].
shape
[
0
:
2
]
wh_ratio
=
w
*
1.0
/
h
...
...
tools/infer_rec.py
浏览文件 @
f6532a0e
...
...
@@ -51,8 +51,28 @@ def main():
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
for
key
in
config
[
'Architecture'
][
"Models"
]:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
'out_channels'
]
=
char_num
if
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'name'
]
==
'MultiHead'
:
# for multi head
out_channels_list
=
{}
if
config
[
'PostProcess'
][
'name'
]
==
'DistillationSARLabelDecode'
:
char_num
=
char_num
-
2
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
'out_channels'
]
=
char_num
elif
config
[
'Architecture'
][
'Head'
][
'name'
]
==
'MultiHead'
:
# for multi head loss
out_channels_list
=
{}
if
config
[
'PostProcess'
][
'name'
]
==
'SARLabelDecode'
:
char_num
=
char_num
-
2
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
# base rec model
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
...
...
tools/program.py
浏览文件 @
f6532a0e
...
...
@@ -201,12 +201,17 @@ def train(config,
model
.
train
()
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
extra_input
=
config
[
'Architecture'
][
'algorithm'
]
in
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
]
extra_input_models
=
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
,
"SVTR"
]
if
config
[
'Architecture'
][
'algorithm'
]
==
'Distillation'
:
extra_input
=
config
[
'Architecture'
][
'Models'
][
'Teacher'
][
'algorithm'
]
in
extra_input_models
else
:
extra_input
=
config
[
'Architecture'
][
'algorithm'
]
in
extra_input_models
try
:
model_type
=
config
[
'Architecture'
][
'model_type'
]
except
:
model_type
=
None
algorithm
=
config
[
'Architecture'
][
'algorithm'
]
start_epoch
=
best_model_dict
[
...
...
@@ -269,7 +274,12 @@ def train(config,
if
model_type
in
[
'table'
,
'kie'
]:
eval_class
(
preds
,
batch
)
else
:
post_result
=
post_process_class
(
preds
,
batch
[
1
])
if
config
[
'Loss'
][
'name'
]
in
[
'MultiLoss'
,
'MultiLoss_v2'
]:
# for multi head loss
post_result
=
post_process_class
(
preds
[
'ctc'
],
batch
[
1
])
# for CTC head out
else
:
post_result
=
post_process_class
(
preds
,
batch
[
1
])
eval_class
(
post_result
,
batch
)
metric
=
eval_class
.
get_metric
()
train_stats
.
update
(
metric
)
...
...
@@ -541,7 +551,7 @@ def preprocess(is_train=False):
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'SEED'
,
'SDMGR'
,
'LayoutXLM'
,
'LayoutLM'
,
'PREN'
,
'FCE'
'SEED'
,
'SDMGR'
,
'LayoutXLM'
,
'LayoutLM'
,
'PREN'
,
'FCE'
,
'SVTR'
]
device
=
'cpu'
...
...
tools/train.py
浏览文件 @
f6532a0e
...
...
@@ -74,11 +74,49 @@ def main(config, device, logger, vdl_writer):
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
for
key
in
config
[
'Architecture'
][
"Models"
]:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
'out_channels'
]
=
char_num
if
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'name'
]
==
'MultiHead'
:
# for multi head
if
config
[
'PostProcess'
][
'name'
]
==
'DistillationSARLabelDecode'
:
char_num
=
char_num
-
2
# update SARLoss params
assert
list
(
config
[
'Loss'
][
'loss_config_list'
][
-
1
].
keys
())[
0
]
==
'DistillationSARLoss'
config
[
'Loss'
][
'loss_config_list'
][
-
1
][
'DistillationSARLoss'
][
'ignore_index'
]
=
char_num
+
1
out_channels_list
=
{}
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
'out_channels'
]
=
char_num
elif
config
[
'Architecture'
][
'Head'
][
'name'
]
==
'MultiHead'
:
# for multi head
if
config
[
'PostProcess'
][
'name'
]
==
'SARLabelDecode'
:
char_num
=
char_num
-
2
# update SARLoss params
assert
list
(
config
[
'Loss'
][
'loss_config_list'
][
1
].
keys
())[
0
]
==
'SARLoss'
if
config
[
'Loss'
][
'loss_config_list'
][
1
][
'SARLoss'
]
is
None
:
config
[
'Loss'
][
'loss_config_list'
][
1
][
'SARLoss'
]
=
{
'ignore_index'
:
char_num
+
1
}
else
:
config
[
'Loss'
][
'loss_config_list'
][
1
][
'SARLoss'
][
'ignore_index'
]
=
char_num
+
1
out_channels_list
=
{}
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
# base rec model
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
if
config
[
'PostProcess'
][
'name'
]
==
'SARLabelDecode'
:
# for SAR model
config
[
'Loss'
][
'ignore_index'
]
=
char_num
-
1
model
=
build_model
(
config
[
'Architecture'
])
if
config
[
'Global'
][
'distributed'
]:
model
=
paddle
.
DataParallel
(
model
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录