Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
0002349d
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0002349d
编写于
9月 27, 2022
作者:
z37757
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add text recognition algorithm rflearning
上级
6a8a0eeb
变更
24
隐藏空白更改
内联
并排
Showing
24 changed file
with
1301 addition
and
16 deletion
+1301
-16
configs/rec/rec_resnet_rfl_att.yml
configs/rec/rec_resnet_rfl_att.yml
+113
-0
configs/rec/rec_resnet_rfl_visual.yml
configs/rec/rec_resnet_rfl_visual.yml
+110
-0
ppocr/data/imaug/__init__.py
ppocr/data/imaug/__init__.py
+2
-1
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+56
-0
ppocr/data/imaug/rec_img_aug.py
ppocr/data/imaug/rec_img_aug.py
+39
-3
ppocr/losses/__init__.py
ppocr/losses/__init__.py
+2
-1
ppocr/losses/rec_rfl_loss.py
ppocr/losses/rec_rfl_loss.py
+61
-0
ppocr/metrics/__init__.py
ppocr/metrics/__init__.py
+2
-2
ppocr/metrics/rec_metric.py
ppocr/metrics/rec_metric.py
+39
-1
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+2
-1
ppocr/modeling/backbones/rec_resnet_rfl.py
ppocr/modeling/backbones/rec_resnet_rfl.py
+348
-0
ppocr/modeling/heads/__init__.py
ppocr/modeling/heads/__init__.py
+2
-1
ppocr/modeling/heads/rec_att_head.py
ppocr/modeling/heads/rec_att_head.py
+2
-0
ppocr/modeling/heads/rec_rfl_head.py
ppocr/modeling/heads/rec_rfl_head.py
+109
-0
ppocr/modeling/necks/__init__.py
ppocr/modeling/necks/__init__.py
+3
-1
ppocr/modeling/necks/rf_adaptor.py
ppocr/modeling/necks/rf_adaptor.py
+137
-0
ppocr/optimizer/__init__.py
ppocr/optimizer/__init__.py
+3
-0
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+2
-2
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+86
-0
test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml
test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml
+111
-0
test_tipc/configs/rec_resnet_rfl/train_infer_python.txt
test_tipc/configs/rec_resnet_rfl/train_infer_python.txt
+53
-0
tools/export_model.py
tools/export_model.py
+1
-1
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+16
-0
tools/program.py
tools/program.py
+2
-2
未找到文件。
configs/rec/rec_resnet_rfl_att.yml
0 → 100644
浏览文件 @
0002349d
Global
:
use_gpu
:
True
epoch_num
:
6
log_smooth_window
:
20
print_batch_step
:
50
save_model_dir
:
./output/rec/rec_resnet_rfl_att/
save_epoch_step
:
1
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
0
,
5000
]
cal_metric_during_train
:
True
pretrained_model
:
./pretrain_models/rec_resnet_rfl_visual/best_accuracy.pdparams
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path
:
max_text_length
:
25
infer_mode
:
False
use_space_char
:
False
save_res_path
:
./output/rec/rec_resnet_rfl.txt
Optimizer
:
name
:
AdamW
beta1
:
0.9
beta2
:
0.999
weight_decay
:
0.0
clip_norm_global
:
5.0
lr
:
name
:
Piecewise
decay_epochs
:
[
3
,
4
,
5
]
values
:
[
0.001
,
0.0003
,
0.00009
,
0.000027
]
Architecture
:
model_type
:
rec
algorithm
:
RFL
in_channels
:
1
Transform
:
name
:
TPS
num_fiducial
:
20
loc_lr
:
1.0
model_name
:
large
Backbone
:
name
:
ResNetRFL
use_cnt
:
True
use_seq
:
True
Neck
:
name
:
RFAdaptor
use_v2s
:
True
use_s2v
:
True
Head
:
name
:
RFLHead
in_channels
:
512
hidden_size
:
256
batch_max_legnth
:
25
out_channels
:
38
use_cnt
:
True
use_seq
:
True
Loss
:
name
:
RFLLoss
# ignore_index: 0
PostProcess
:
name
:
RFLLabelDecode
Metric
:
name
:
RecMetric
main_indicator
:
acc
Train
:
dataset
:
name
:
LMDBDataSet
data_dir
:
./train_data/rfl_dataset2/training
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
RFLLabelEncode
:
# Class handling label
-
RFLRecResizeImg
:
image_shape
:
[
1
,
32
,
100
]
padding
:
false
interpolation
:
2
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
,
'
cnt_label'
]
# dataloader will return list in this order
loader
:
shuffle
:
True
batch_size_per_card
:
64
drop_last
:
True
num_workers
:
8
Eval
:
dataset
:
name
:
LMDBDataSet
data_dir
:
./train_data/rfl_dataset2/evaluation
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
RFLLabelEncode
:
# Class handling label
-
RFLRecResizeImg
:
image_shape
:
[
1
,
32
,
100
]
padding
:
false
interpolation
:
2
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
,
'
cnt_label'
]
# dataloader will return list in this order
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
256
num_workers
:
8
configs/rec/rec_resnet_rfl_visual.yml
0 → 100644
浏览文件 @
0002349d
Global
:
use_gpu
:
True
epoch_num
:
6
log_smooth_window
:
20
print_batch_step
:
50
save_model_dir
:
./output/rec/rec_resnet_rfl_visual/
save_epoch_step
:
1
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
0
,
5000
]
cal_metric_during_train
:
False
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path
:
max_text_length
:
25
infer_mode
:
False
use_space_char
:
False
save_res_path
:
./output/rec/rec_resnet_rfl_visual.txt
Optimizer
:
name
:
AdamW
beta1
:
0.9
beta2
:
0.999
weight_decay
:
0.0
clip_norm_global
:
5.0
lr
:
name
:
Piecewise
decay_epochs
:
[
3
,
4
,
5
]
values
:
[
0.001
,
0.0003
,
0.00009
,
0.000027
]
Architecture
:
model_type
:
rec
algorithm
:
RFL
in_channels
:
1
Transform
:
name
:
TPS
num_fiducial
:
20
loc_lr
:
1.0
model_name
:
large
Backbone
:
name
:
ResNetRFL
use_cnt
:
True
use_seq
:
False
Neck
:
name
:
RFAdaptor
use_v2s
:
False
use_s2v
:
False
Head
:
name
:
RFLHead
in_channels
:
512
hidden_size
:
256
batch_max_legnth
:
25
out_channels
:
38
use_cnt
:
True
use_seq
:
False
Loss
:
name
:
RFLLoss
PostProcess
:
name
:
RFLLabelDecode
Metric
:
name
:
CNTMetric
main_indicator
:
acc
Train
:
dataset
:
name
:
LMDBDataSet
data_dir
:
./train_data/rfl_dataset2/training
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
RFLLabelEncode
:
# Class handling label
-
RFLRecResizeImg
:
image_shape
:
[
1
,
32
,
100
]
padding
:
false
interpolation
:
2
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
,
'
cnt_label'
]
# dataloader will return list in this order
loader
:
shuffle
:
True
batch_size_per_card
:
64
drop_last
:
True
num_workers
:
8
Eval
:
dataset
:
name
:
LMDBDataSet
data_dir
:
./train_data/rfl_dataset2/evaluation
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
RFLLabelEncode
:
# Class handling label
-
RFLRecResizeImg
:
image_shape
:
[
1
,
32
,
100
]
padding
:
false
interpolation
:
2
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
,
'
cnt_label'
]
# dataloader will return list in this order
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
256
num_workers
:
8
ppocr/data/imaug/__init__.py
浏览文件 @
0002349d
...
...
@@ -26,7 +26,8 @@ from .make_pse_gt import MakePseGt
from
.rec_img_aug
import
BaseDataAugmentation
,
RecAug
,
RecConAug
,
RecResizeImg
,
ClsResizeImg
,
\
SRNRecResizeImg
,
GrayRecResizeImg
,
SARRecResizeImg
,
PRENResizeImg
,
\
ABINetRecResizeImg
,
SVTRRecResizeImg
,
ABINetRecAug
,
VLRecResizeImg
,
SPINRecResizeImg
,
RobustScannerRecResizeImg
ABINetRecResizeImg
,
SVTRRecResizeImg
,
ABINetRecAug
,
VLRecResizeImg
,
SPINRecResizeImg
,
RobustScannerRecResizeImg
,
\
RFLRecResizeImg
from
.ssl_img_aug
import
SSLRotateResize
from
.randaugment
import
RandAugment
from
.copy_paste
import
CopyPaste
...
...
ppocr/data/imaug/label_ops.py
浏览文件 @
0002349d
...
...
@@ -488,6 +488,62 @@ class AttnLabelEncode(BaseRecLabelEncode):
return
idx
class
RFLLabelEncode
(
BaseRecLabelEncode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
max_text_length
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
RFLLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
dict_character
=
[
self
.
beg_str
]
+
dict_character
+
[
self
.
end_str
]
return
dict_character
def
encode_cnt
(
self
,
text
):
cnt_label
=
[
0.0
]
*
len
(
self
.
character
)
for
char_
in
text
:
cnt_label
[
char_
]
+=
1
return
np
.
array
(
cnt_label
)
def
__call__
(
self
,
data
):
text
=
data
[
'label'
]
text
=
self
.
encode
(
text
)
if
text
is
None
:
return
None
if
len
(
text
)
>=
self
.
max_text_len
:
return
None
cnt_label
=
self
.
encode_cnt
(
text
)
data
[
'length'
]
=
np
.
array
(
len
(
text
))
text
=
[
0
]
+
text
+
[
len
(
self
.
character
)
-
1
]
+
[
0
]
*
(
self
.
max_text_len
-
len
(
text
)
-
2
)
if
len
(
text
)
!=
self
.
max_text_len
:
return
None
data
[
'label'
]
=
np
.
array
(
text
)
data
[
'cnt_label'
]
=
cnt_label
return
data
def
get_ignored_tokens
(
self
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"beg"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"end"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx"
\
%
beg_or_end
return
idx
class
SEEDLabelEncode
(
BaseRecLabelEncode
):
""" Convert between text-label and text-index """
...
...
ppocr/data/imaug/rec_img_aug.py
浏览文件 @
0002349d
...
...
@@ -237,6 +237,33 @@ class VLRecResizeImg(object):
return
data
class
RFLRecResizeImg
(
object
):
def
__init__
(
self
,
image_shape
,
padding
=
True
,
interpolation
=
1
,
**
kwargs
):
self
.
image_shape
=
image_shape
self
.
padding
=
padding
self
.
interpolation
=
interpolation
if
self
.
interpolation
==
0
:
self
.
interpolation
=
cv2
.
INTER_NEAREST
elif
self
.
interpolation
==
1
:
self
.
interpolation
=
cv2
.
INTER_LINEAR
elif
self
.
interpolation
==
2
:
self
.
interpolation
=
cv2
.
INTER_CUBIC
elif
self
.
interpolation
==
3
:
self
.
interpolation
=
cv2
.
INTER_AREA
else
:
raise
Exception
(
"Unsupported interpolation type !!!"
)
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
norm_img
,
valid_ratio
=
resize_norm_img
(
img
,
self
.
image_shape
,
self
.
padding
,
self
.
interpolation
)
data
[
'image'
]
=
norm_img
data
[
'valid_ratio'
]
=
valid_ratio
return
data
class
SRNRecResizeImg
(
object
):
def
__init__
(
self
,
image_shape
,
num_heads
,
max_text_length
,
**
kwargs
):
self
.
image_shape
=
image_shape
...
...
@@ -414,8 +441,13 @@ class SVTRRecResizeImg(object):
data
[
'valid_ratio'
]
=
valid_ratio
return
data
class
RobustScannerRecResizeImg
(
object
):
def
__init__
(
self
,
image_shape
,
max_text_length
,
width_downsample_ratio
=
0.25
,
**
kwargs
):
def
__init__
(
self
,
image_shape
,
max_text_length
,
width_downsample_ratio
=
0.25
,
**
kwargs
):
self
.
image_shape
=
image_shape
self
.
width_downsample_ratio
=
width_downsample_ratio
self
.
max_text_length
=
max_text_length
...
...
@@ -432,6 +464,7 @@ class RobustScannerRecResizeImg(object):
data
[
'word_positons'
]
=
word_positons
return
data
def
resize_norm_img_sar
(
img
,
image_shape
,
width_downsample_ratio
=
0.25
):
imgC
,
imgH
,
imgW_min
,
imgW_max
=
image_shape
h
=
img
.
shape
[
0
]
...
...
@@ -467,13 +500,16 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
return
padding_im
,
resize_shape
,
pad_shape
,
valid_ratio
def
resize_norm_img
(
img
,
image_shape
,
padding
=
True
):
def
resize_norm_img
(
img
,
image_shape
,
padding
=
True
,
interpolation
=
cv2
.
INTER_LINEAR
):
imgC
,
imgH
,
imgW
=
image_shape
h
=
img
.
shape
[
0
]
w
=
img
.
shape
[
1
]
if
not
padding
:
resized_image
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
img
,
(
imgW
,
imgH
),
interpolation
=
interpolation
)
resized_w
=
imgW
else
:
ratio
=
w
/
float
(
h
)
...
...
ppocr/losses/__init__.py
浏览文件 @
0002349d
...
...
@@ -38,6 +38,7 @@ from .rec_pren_loss import PRENLoss
from
.rec_multi_loss
import
MultiLoss
from
.rec_vl_loss
import
VLLoss
from
.rec_spin_att_loss
import
SPINAttentionLoss
from
.rec_rfl_loss
import
RFLLoss
# cls loss
from
.cls_loss
import
ClsLoss
...
...
@@ -69,7 +70,7 @@ def build_loss(config):
'CELoss'
,
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
,
'SDMGRLoss'
,
'VQASerTokenLayoutLMLoss'
,
'LossFromOutput'
,
'PRENLoss'
,
'MultiLoss'
,
'TableMasterLoss'
,
'SPINAttentionLoss'
,
'VLLoss'
,
'StrokeFocusLoss'
,
'SLALoss'
,
'CTLoss'
'SLALoss'
,
'CTLoss'
,
'RFLLoss'
]
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
...
...
ppocr/losses/rec_rfl_loss.py
0 → 100644
浏览文件 @
0002349d
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
nn
from
.basic_loss
import
CELoss
,
DistanceLoss
class
RFLLoss
(
nn
.
Layer
):
def
__init__
(
self
,
ignore_index
=-
100
,
**
kwargs
):
super
().
__init__
()
self
.
cnt_loss
=
nn
.
MSELoss
(
**
kwargs
)
self
.
seq_loss
=
nn
.
CrossEntropyLoss
(
ignore_index
=
ignore_index
)
def
forward
(
self
,
predicts
,
batch
):
self
.
total_loss
=
{}
total_loss
=
0.0
# batch [image, label, length, cnt_label]
if
predicts
[
0
]
is
not
None
:
cnt_loss
=
self
.
cnt_loss
(
predicts
[
0
],
paddle
.
cast
(
batch
[
3
],
paddle
.
float32
))
self
.
total_loss
[
'cnt_loss'
]
=
cnt_loss
total_loss
+=
cnt_loss
if
predicts
[
1
]
is
not
None
:
targets
=
batch
[
1
].
astype
(
"int64"
)
label_lengths
=
batch
[
2
].
astype
(
'int64'
)
batch_size
,
num_steps
,
num_classes
=
predicts
[
1
].
shape
[
0
],
predicts
[
1
].
shape
[
1
],
predicts
[
1
].
shape
[
2
]
assert
len
(
targets
.
shape
)
==
len
(
list
(
predicts
[
1
].
shape
))
-
1
,
\
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
inputs
=
predicts
[
1
][:,
:
-
1
,
:]
targets
=
targets
[:,
1
:]
inputs
=
paddle
.
reshape
(
inputs
,
[
-
1
,
inputs
.
shape
[
-
1
]])
targets
=
paddle
.
reshape
(
targets
,
[
-
1
])
seq_loss
=
self
.
seq_loss
(
inputs
,
targets
)
self
.
total_loss
[
'seq_loss'
]
=
seq_loss
total_loss
+=
seq_loss
self
.
total_loss
[
'loss'
]
=
total_loss
return
self
.
total_loss
ppocr/metrics/__init__.py
浏览文件 @
0002349d
...
...
@@ -22,7 +22,7 @@ import copy
__all__
=
[
"build_metric"
]
from
.det_metric
import
DetMetric
,
DetFCEMetric
from
.rec_metric
import
RecMetric
from
.rec_metric
import
RecMetric
,
CNTMetric
from
.cls_metric
import
ClsMetric
from
.e2e_metric
import
E2EMetric
from
.distillation_metric
import
DistillationMetric
...
...
@@ -38,7 +38,7 @@ def build_metric(config):
support_dict
=
[
"DetMetric"
,
"DetFCEMetric"
,
"RecMetric"
,
"ClsMetric"
,
"E2EMetric"
,
"DistillationMetric"
,
"TableMetric"
,
'KIEMetric'
,
'VQASerTokenMetric'
,
'VQAReTokenMetric'
,
'SRMetric'
,
'CTMetric'
'VQAReTokenMetric'
,
'SRMetric'
,
'CTMetric'
,
'CNTMetric'
]
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/metrics/rec_metric.py
浏览文件 @
0002349d
...
...
@@ -16,7 +16,6 @@ from rapidfuzz.distance import Levenshtein
import
string
class
RecMetric
(
object
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
...
...
@@ -74,3 +73,42 @@ class RecMetric(object):
self
.
correct_num
=
0
self
.
all_num
=
0
self
.
norm_edit_dis
=
0
class
CNTMetric
(
object
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
eps
=
1e-5
self
.
reset
()
def
_normalize_text
(
self
,
text
):
text
=
''
.
join
(
filter
(
lambda
x
:
x
in
(
string
.
digits
+
string
.
ascii_letters
),
text
))
return
text
.
lower
()
def
__call__
(
self
,
pred_label
,
*
args
,
**
kwargs
):
preds
,
labels
=
pred_label
correct_num
=
0
all_num
=
0
for
pred
,
target
in
zip
(
preds
,
labels
):
if
pred
==
target
:
correct_num
+=
1
all_num
+=
1
self
.
correct_num
+=
correct_num
self
.
all_num
+=
all_num
return
{
'acc'
:
correct_num
/
(
all_num
+
self
.
eps
),
}
def
get_metric
(
self
):
"""
return metrics {
'acc': 0,
'norm_edit_dis': 0,
}
"""
acc
=
1.0
*
self
.
correct_num
/
(
self
.
all_num
+
self
.
eps
)
self
.
reset
()
return
{
'acc'
:
acc
}
def
reset
(
self
):
self
.
correct_num
=
0
self
.
all_num
=
0
ppocr/modeling/backbones/__init__.py
浏览文件 @
0002349d
...
...
@@ -42,10 +42,11 @@ def build_backbone(config, model_type):
from
.rec_efficientb3_pren
import
EfficientNetb3_PREN
from
.rec_svtrnet
import
SVTRNet
from
.rec_vitstr
import
ViTSTR
from
.rec_resnet_rfl
import
ResNetRFL
support_dict
=
[
'MobileNetV1Enhance'
,
'MobileNetV3'
,
'ResNet'
,
'ResNetFPN'
,
'MTB'
,
'ResNet31'
,
'ResNet45'
,
'ResNet_ASTER'
,
'MicroNet'
,
'EfficientNetb3_PREN'
,
'SVTRNet'
,
'ViTSTR'
,
'ResNet32'
'EfficientNetb3_PREN'
,
'SVTRNet'
,
'ViTSTR'
,
'ResNet32'
,
'ResNetRFL'
]
elif
model_type
==
'e2e'
:
from
.e2e_resnet_vd_pg
import
ResNet
...
...
ppocr/modeling/backbones/rec_resnet_rfl.py
0 → 100644
浏览文件 @
0002349d
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/backbones/ResNetRFL.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.nn
as
nn
from
paddle.nn.initializer
import
TruncatedNormal
,
Constant
,
Normal
,
KaimingNormal
kaiming_init_
=
KaimingNormal
()
zeros_
=
Constant
(
value
=
0.
)
ones_
=
Constant
(
value
=
1.
)
class
BasicBlock
(
nn
.
Layer
):
"""Res-net Basic Block"""
expansion
=
1
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
,
norm_type
=
'BN'
,
**
kwargs
):
"""
Args:
inplanes (int): input channel
planes (int): channels of the middle feature
stride (int): stride of the convolution
downsample (int): type of the down_sample
norm_type (str): type of the normalization
**kwargs (None): backup parameter
"""
super
(
BasicBlock
,
self
).
__init__
()
self
.
conv1
=
self
.
_conv3x3
(
inplanes
,
planes
)
self
.
bn1
=
nn
.
BatchNorm
(
planes
)
self
.
conv2
=
self
.
_conv3x3
(
planes
,
planes
)
self
.
bn2
=
nn
.
BatchNorm
(
planes
)
self
.
relu
=
nn
.
ReLU
()
self
.
downsample
=
downsample
self
.
stride
=
stride
def
_conv3x3
(
self
,
in_planes
,
out_planes
,
stride
=
1
):
return
nn
.
Conv2D
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias_attr
=
False
)
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
out
+=
residual
out
=
self
.
relu
(
out
)
return
out
class
ResNetRFL
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
=
512
,
use_cnt
=
True
,
use_seq
=
True
):
"""
Args:
in_channels (int): input channel
out_channels (int): output channel
"""
super
(
ResNetRFL
,
self
).
__init__
()
assert
use_cnt
or
use_seq
self
.
use_cnt
,
self
.
use_seq
=
use_cnt
,
use_seq
self
.
backbone
=
RFLBase
(
in_channels
)
self
.
out_channels
=
out_channels
self
.
out_channels_block
=
[
int
(
self
.
out_channels
/
4
),
int
(
self
.
out_channels
/
2
),
self
.
out_channels
,
self
.
out_channels
]
block
=
BasicBlock
layers
=
[
1
,
2
,
5
,
3
]
self
.
inplanes
=
int
(
self
.
out_channels
//
2
)
self
.
relu
=
nn
.
ReLU
()
if
self
.
use_seq
:
self
.
maxpool3
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
(
2
,
1
),
padding
=
(
0
,
1
))
self
.
layer3
=
self
.
_make_layer
(
block
,
self
.
out_channels_block
[
2
],
layers
[
2
],
stride
=
1
)
self
.
conv3
=
nn
.
Conv2D
(
self
.
out_channels_block
[
2
],
self
.
out_channels_block
[
2
],
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn3
=
nn
.
BatchNorm
(
self
.
out_channels_block
[
2
])
self
.
layer4
=
self
.
_make_layer
(
block
,
self
.
out_channels_block
[
3
],
layers
[
3
],
stride
=
1
)
self
.
conv4_1
=
nn
.
Conv2D
(
self
.
out_channels_block
[
3
],
self
.
out_channels_block
[
3
],
kernel_size
=
2
,
stride
=
(
2
,
1
),
padding
=
(
0
,
1
),
bias_attr
=
False
)
self
.
bn4_1
=
nn
.
BatchNorm
(
self
.
out_channels_block
[
3
])
self
.
conv4_2
=
nn
.
Conv2D
(
self
.
out_channels_block
[
3
],
self
.
out_channels_block
[
3
],
kernel_size
=
2
,
stride
=
1
,
padding
=
0
,
bias_attr
=
False
)
self
.
bn4_2
=
nn
.
BatchNorm
(
self
.
out_channels_block
[
3
])
if
self
.
use_cnt
:
self
.
inplanes
=
int
(
self
.
out_channels
//
2
)
self
.
v_maxpool3
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
(
2
,
1
),
padding
=
(
0
,
1
))
self
.
v_layer3
=
self
.
_make_layer
(
block
,
self
.
out_channels_block
[
2
],
layers
[
2
],
stride
=
1
)
self
.
v_conv3
=
nn
.
Conv2D
(
self
.
out_channels_block
[
2
],
self
.
out_channels_block
[
2
],
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias_attr
=
False
)
self
.
v_bn3
=
nn
.
BatchNorm
(
self
.
out_channels_block
[
2
])
self
.
v_layer4
=
self
.
_make_layer
(
block
,
self
.
out_channels_block
[
3
],
layers
[
3
],
stride
=
1
)
self
.
v_conv4_1
=
nn
.
Conv2D
(
self
.
out_channels_block
[
3
],
self
.
out_channels_block
[
3
],
kernel_size
=
2
,
stride
=
(
2
,
1
),
padding
=
(
0
,
1
),
bias_attr
=
False
)
self
.
v_bn4_1
=
nn
.
BatchNorm
(
self
.
out_channels_block
[
3
])
self
.
v_conv4_2
=
nn
.
Conv2D
(
self
.
out_channels_block
[
3
],
self
.
out_channels_block
[
3
],
kernel_size
=
2
,
stride
=
1
,
padding
=
0
,
bias_attr
=
False
)
self
.
v_bn4_2
=
nn
.
BatchNorm
(
self
.
out_channels_block
[
3
])
def
_make_layer
(
self
,
block
,
planes
,
blocks
,
stride
=
1
):
downsample
=
None
if
stride
!=
1
or
self
.
inplanes
!=
planes
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
nn
.
Conv2D
(
self
.
inplanes
,
planes
*
block
.
expansion
,
kernel_size
=
1
,
stride
=
stride
,
bias_attr
=
False
),
nn
.
BatchNorm
(
planes
*
block
.
expansion
),
)
layers
=
list
()
layers
.
append
(
block
(
self
.
inplanes
,
planes
,
stride
,
downsample
))
self
.
inplanes
=
planes
*
block
.
expansion
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
self
.
inplanes
,
planes
))
return
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
inputs
):
x_1
=
self
.
backbone
(
inputs
)
if
self
.
use_cnt
:
v_x
=
self
.
v_maxpool3
(
x_1
)
v_x
=
self
.
v_layer3
(
v_x
)
v_x
=
self
.
v_conv3
(
v_x
)
v_x
=
self
.
v_bn3
(
v_x
)
visual_feature_2
=
self
.
relu
(
v_x
)
v_x
=
self
.
v_layer4
(
visual_feature_2
)
v_x
=
self
.
v_conv4_1
(
v_x
)
v_x
=
self
.
v_bn4_1
(
v_x
)
v_x
=
self
.
relu
(
v_x
)
v_x
=
self
.
v_conv4_2
(
v_x
)
v_x
=
self
.
v_bn4_2
(
v_x
)
visual_feature_3
=
self
.
relu
(
v_x
)
else
:
visual_feature_3
=
None
if
self
.
use_seq
:
x
=
self
.
maxpool3
(
x_1
)
x
=
self
.
layer3
(
x
)
x
=
self
.
conv3
(
x
)
x
=
self
.
bn3
(
x
)
x_2
=
self
.
relu
(
x
)
x
=
self
.
layer4
(
x_2
)
x
=
self
.
conv4_1
(
x
)
x
=
self
.
bn4_1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
conv4_2
(
x
)
x
=
self
.
bn4_2
(
x
)
x_3
=
self
.
relu
(
x
)
else
:
x_3
=
None
return
[
visual_feature_3
,
x_3
]
class
ResNetBase
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
block
,
layers
):
super
(
ResNetBase
,
self
).
__init__
()
self
.
out_channels_block
=
[
int
(
out_channels
/
4
),
int
(
out_channels
/
2
),
out_channels
,
out_channels
]
self
.
inplanes
=
int
(
out_channels
/
8
)
self
.
conv0_1
=
nn
.
Conv2D
(
in_channels
,
int
(
out_channels
/
16
),
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn0_1
=
nn
.
BatchNorm
(
int
(
out_channels
/
16
))
self
.
conv0_2
=
nn
.
Conv2D
(
int
(
out_channels
/
16
),
self
.
inplanes
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn0_2
=
nn
.
BatchNorm
(
self
.
inplanes
)
self
.
relu
=
nn
.
ReLU
()
self
.
maxpool1
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
self
.
layer1
=
self
.
_make_layer
(
block
,
self
.
out_channels_block
[
0
],
layers
[
0
])
self
.
conv1
=
nn
.
Conv2D
(
self
.
out_channels_block
[
0
],
self
.
out_channels_block
[
0
],
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn1
=
nn
.
BatchNorm
(
self
.
out_channels_block
[
0
])
self
.
maxpool2
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
self
.
layer2
=
self
.
_make_layer
(
block
,
self
.
out_channels_block
[
1
],
layers
[
1
],
stride
=
1
)
self
.
conv2
=
nn
.
Conv2D
(
self
.
out_channels_block
[
1
],
self
.
out_channels_block
[
1
],
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn2
=
nn
.
BatchNorm
(
self
.
out_channels_block
[
1
])
def
_make_layer
(
self
,
block
,
planes
,
blocks
,
stride
=
1
):
downsample
=
None
if
stride
!=
1
or
self
.
inplanes
!=
planes
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
nn
.
Conv2D
(
self
.
inplanes
,
planes
*
block
.
expansion
,
kernel_size
=
1
,
stride
=
stride
,
bias_attr
=
False
),
nn
.
BatchNorm
(
planes
*
block
.
expansion
),
)
layers
=
list
()
layers
.
append
(
block
(
self
.
inplanes
,
planes
,
stride
,
downsample
))
self
.
inplanes
=
planes
*
block
.
expansion
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
self
.
inplanes
,
planes
))
return
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
x
=
self
.
conv0_1
(
x
)
x
=
self
.
bn0_1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
conv0_2
(
x
)
x
=
self
.
bn0_2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
maxpool1
(
x
)
x
=
self
.
layer1
(
x
)
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
maxpool2
(
x
)
x
=
self
.
layer2
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
bn2
(
x
)
x
=
self
.
relu
(
x
)
return
x
class
RFLBase
(
nn
.
Layer
):
""" Reciprocal feature learning share backbone network"""
def
__init__
(
self
,
in_channels
,
out_channels
=
512
):
super
(
RFLBase
,
self
).
__init__
()
self
.
ConvNet
=
ResNetBase
(
in_channels
,
out_channels
,
BasicBlock
,
[
1
,
2
,
5
,
3
])
def
forward
(
self
,
inputs
):
return
self
.
ConvNet
(
inputs
)
ppocr/modeling/heads/__init__.py
浏览文件 @
0002349d
...
...
@@ -38,6 +38,7 @@ def build_head(config):
from
.rec_abinet_head
import
ABINetHead
from
.rec_robustscanner_head
import
RobustScannerHead
from
.rec_visionlan_head
import
VLHead
from
.rec_rfl_head
import
RFLHead
# cls head
from
.cls_head
import
ClsHead
...
...
@@ -53,7 +54,7 @@ def build_head(config):
'ClsHead'
,
'AttentionHead'
,
'SRNHead'
,
'PGHead'
,
'Transformer'
,
'TableAttentionHead'
,
'SARHead'
,
'AsterHead'
,
'SDMGRHead'
,
'PRENHead'
,
'MultiHead'
,
'ABINetHead'
,
'TableMasterHead'
,
'SPINAttentionHead'
,
'VLHead'
,
'SLAHead'
,
'RobustScannerHead'
,
'CT_Head'
'VLHead'
,
'SLAHead'
,
'RobustScannerHead'
,
'CT_Head'
,
'RFLHead'
]
#table head
...
...
ppocr/modeling/heads/rec_att_head.py
浏览文件 @
0002349d
...
...
@@ -149,6 +149,8 @@ class AttentionLSTM(nn.Layer):
else
:
targets
=
paddle
.
zeros
(
shape
=
[
batch_size
],
dtype
=
"int32"
)
probs
=
None
char_onehots
=
None
alpha
=
None
for
i
in
range
(
num_steps
):
char_onehots
=
self
.
_char_to_onehot
(
...
...
ppocr/modeling/heads/rec_rfl_head.py
0 → 100644
浏览文件 @
0002349d
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/sequence_heads/counting_head.py
"""
import
paddle
import
paddle.nn
as
nn
from
paddle.nn.initializer
import
TruncatedNormal
,
Constant
,
Normal
,
KaimingNormal
from
.rec_att_head
import
AttentionLSTM
kaiming_init_
=
KaimingNormal
()
zeros_
=
Constant
(
value
=
0.
)
ones_
=
Constant
(
value
=
1.
)
class
CNTHead
(
nn
.
Layer
):
def
__init__
(
self
,
embed_size
=
512
,
encode_length
=
26
,
out_channels
=
38
,
**
kwargs
):
super
(
CNTHead
,
self
).
__init__
()
self
.
out_channels
=
out_channels
self
.
Wv_fusion
=
nn
.
Linear
(
embed_size
,
embed_size
,
bias_attr
=
False
)
self
.
Prediction_visual
=
nn
.
Linear
(
encode_length
*
embed_size
,
self
.
out_channels
)
def
forward
(
self
,
visual_feature
):
b
,
c
,
h
,
w
=
visual_feature
.
shape
visual_feature
=
visual_feature
.
reshape
([
b
,
c
,
h
*
w
]).
transpose
(
[
0
,
2
,
1
])
visual_feature_num
=
self
.
Wv_fusion
(
visual_feature
)
# batch * 26 * 512
b
,
n
,
c
=
visual_feature_num
.
shape
# using visual feature directly calculate the text length
visual_feature_num
=
visual_feature_num
.
reshape
([
b
,
n
*
c
])
prediction_visual
=
self
.
Prediction_visual
(
visual_feature_num
)
return
prediction_visual
class
RFLHead
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
512
,
hidden_size
=
256
,
batch_max_legnth
=
25
,
out_channels
=
38
,
use_cnt
=
True
,
use_seq
=
True
,
**
kwargs
):
super
(
RFLHead
,
self
).
__init__
()
assert
use_cnt
or
use_seq
self
.
use_cnt
=
use_cnt
self
.
use_seq
=
use_seq
if
self
.
use_cnt
:
self
.
cnt_head
=
CNTHead
(
embed_size
=
in_channels
,
encode_length
=
batch_max_legnth
+
1
,
out_channels
=
out_channels
,
**
kwargs
)
if
self
.
use_seq
:
self
.
seq_head
=
AttentionLSTM
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
hidden_size
=
hidden_size
,
**
kwargs
)
self
.
batch_max_legnth
=
batch_max_legnth
self
.
num_class
=
out_channels
self
.
apply
(
self
.
init_weights
)
def
init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
kaiming_init_
(
m
.
weight
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
zeros_
(
m
.
bias
)
def
forward
(
self
,
x
,
targets
=
None
):
cnt_inputs
,
seq_inputs
=
x
if
self
.
use_cnt
:
cnt_outputs
=
self
.
cnt_head
(
cnt_inputs
)
else
:
cnt_outputs
=
None
if
self
.
use_seq
:
if
self
.
training
:
seq_outputs
=
self
.
seq_head
(
seq_inputs
,
targets
[
0
],
self
.
batch_max_legnth
)
else
:
seq_outputs
=
self
.
seq_head
(
seq_inputs
,
None
,
self
.
batch_max_legnth
)
else
:
seq_outputs
=
None
return
cnt_outputs
,
seq_outputs
ppocr/modeling/necks/__init__.py
浏览文件 @
0002349d
...
...
@@ -27,9 +27,11 @@ def build_neck(config):
from
.pren_fpn
import
PRENFPN
from
.csp_pan
import
CSPPAN
from
.ct_fpn
import
CTFPN
from
.rf_adaptor
import
RFAdaptor
support_dict
=
[
'FPN'
,
'FCEFPN'
,
'LKPAN'
,
'DBFPN'
,
'RSEFPN'
,
'EASTFPN'
,
'SASTFPN'
,
'SequenceEncoder'
,
'PGFPN'
,
'TableFPN'
,
'PRENFPN'
,
'CSPPAN'
,
'CTFPN'
'SequenceEncoder'
,
'PGFPN'
,
'TableFPN'
,
'PRENFPN'
,
'CSPPAN'
,
'CTFPN'
,
'RFAdaptor'
]
module_name
=
config
.
pop
(
'name'
)
...
...
ppocr/modeling/necks/rf_adaptor.py
0 → 100644
浏览文件 @
0002349d
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/connects/single_block/RFAdaptor.py
"""
import
paddle
import
paddle.nn
as
nn
from
paddle.nn.initializer
import
TruncatedNormal
,
Constant
,
Normal
,
KaimingNormal
kaiming_init_
=
KaimingNormal
()
zeros_
=
Constant
(
value
=
0.
)
ones_
=
Constant
(
value
=
1.
)
class
S2VAdaptor
(
nn
.
Layer
):
""" Semantic to Visual adaptation module"""
def
__init__
(
self
,
in_channels
=
512
):
super
(
S2VAdaptor
,
self
).
__init__
()
self
.
in_channels
=
in_channels
# 512
# feature strengthen module, channel attention
self
.
channel_inter
=
nn
.
Linear
(
self
.
in_channels
,
self
.
in_channels
,
bias_attr
=
False
)
self
.
channel_bn
=
nn
.
BatchNorm1D
(
self
.
in_channels
)
self
.
channel_act
=
nn
.
ReLU
()
self
.
apply
(
self
.
init_weights
)
def
init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Conv2D
):
kaiming_init_
(
m
.
weight
)
if
isinstance
(
m
,
nn
.
Conv2D
)
and
m
.
bias
is
not
None
:
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
(
nn
.
BatchNorm
,
nn
.
BatchNorm2D
,
nn
.
BatchNorm1D
)):
zeros_
(
m
.
bias
)
ones_
(
m
.
weight
)
def
forward
(
self
,
semantic
):
semantic_source
=
semantic
# batch, channel, height, width
# feature transformation
semantic
=
semantic
.
squeeze
(
2
).
transpose
(
[
0
,
2
,
1
])
# batch, width, channel
channel_att
=
self
.
channel_inter
(
semantic
)
# batch, width, channel
channel_att
=
channel_att
.
transpose
([
0
,
2
,
1
])
# batch, channel, width
channel_bn
=
self
.
channel_bn
(
channel_att
)
# batch, channel, width
channel_att
=
self
.
channel_act
(
channel_bn
)
# batch, channel, width
# Feature enhancement
channel_output
=
semantic_source
*
channel_att
.
unsqueeze
(
-
2
)
# batch, channel, 1, width
return
channel_output
class
V2SAdaptor
(
nn
.
Layer
):
""" Visual to Semantic adaptation module"""
def
__init__
(
self
,
in_channels
=
512
,
return_mask
=
False
):
super
(
V2SAdaptor
,
self
).
__init__
()
# parameter initialization
self
.
in_channels
=
in_channels
self
.
return_mask
=
return_mask
# output transformation
self
.
channel_inter
=
nn
.
Linear
(
self
.
in_channels
,
self
.
in_channels
,
bias_attr
=
False
)
self
.
channel_bn
=
nn
.
BatchNorm1D
(
self
.
in_channels
)
self
.
channel_act
=
nn
.
ReLU
()
def
forward
(
self
,
visual
):
# Feature enhancement
visual
=
visual
.
squeeze
(
2
).
transpose
([
0
,
2
,
1
])
# batch, width, channel
channel_att
=
self
.
channel_inter
(
visual
)
# batch, width, channel
channel_att
=
channel_att
.
transpose
([
0
,
2
,
1
])
# batch, channel, width
channel_bn
=
self
.
channel_bn
(
channel_att
)
# batch, channel, width
channel_att
=
self
.
channel_act
(
channel_bn
)
# batch, channel, width
# size alignment
channel_output
=
channel_att
.
unsqueeze
(
-
2
)
# batch, width, channel
if
self
.
return_mask
:
return
channel_output
,
channel_att
return
channel_output
class
RFAdaptor
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
512
,
use_v2s
=
True
,
use_s2v
=
True
,
**
kwargs
):
super
(
RFAdaptor
,
self
).
__init__
()
if
use_v2s
is
True
:
self
.
neck_v2s
=
V2SAdaptor
(
in_channels
=
in_channels
,
**
kwargs
)
else
:
self
.
neck_v2s
=
None
if
use_s2v
is
True
:
self
.
neck_s2v
=
S2VAdaptor
(
in_channels
=
in_channels
,
**
kwargs
)
else
:
self
.
neck_s2v
=
None
self
.
out_channels
=
in_channels
def
forward
(
self
,
x
):
visual_feature
,
rcg_feature
=
x
if
visual_feature
is
not
None
:
batch
,
source_channels
,
v_source_height
,
v_source_width
=
visual_feature
.
shape
visual_feature
=
visual_feature
.
reshape
(
[
batch
,
source_channels
,
1
,
v_source_height
*
v_source_width
])
if
self
.
neck_v2s
is
not
None
:
v_rcg_feature
=
rcg_feature
*
self
.
neck_v2s
(
visual_feature
)
else
:
v_rcg_feature
=
rcg_feature
if
self
.
neck_s2v
is
not
None
:
v_visual_feature
=
visual_feature
+
self
.
neck_s2v
(
rcg_feature
)
else
:
v_visual_feature
=
visual_feature
if
v_rcg_feature
is
not
None
:
batch
,
source_channels
,
source_height
,
source_width
=
v_rcg_feature
.
shape
v_rcg_feature
=
v_rcg_feature
.
reshape
(
[
batch
,
source_channels
,
1
,
source_height
*
source_width
])
v_rcg_feature
=
v_rcg_feature
.
squeeze
(
2
).
transpose
([
0
,
2
,
1
])
return
v_visual_feature
,
v_rcg_feature
ppocr/optimizer/__init__.py
浏览文件 @
0002349d
...
...
@@ -53,6 +53,9 @@ def build_optimizer(config, epochs, step_each_epoch, model):
if
'clip_norm'
in
config
:
clip_norm
=
config
.
pop
(
'clip_norm'
)
grad_clip
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
clip_norm
)
elif
'clip_norm_global'
in
config
:
clip_norm
=
config
.
pop
(
'clip_norm_global'
)
grad_clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
clip_norm
=
clip_norm
)
else
:
grad_clip
=
None
optim
=
getattr
(
optimizer
,
optim_name
)(
learning_rate
=
lr
,
...
...
ppocr/postprocess/__init__.py
浏览文件 @
0002349d
...
...
@@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
\
DistillationCTCLabelDecode
,
NRTRLabelDecode
,
SARLabelDecode
,
\
SEEDLabelDecode
,
PRENLabelDecode
,
ViTSTRLabelDecode
,
ABINetLabelDecode
,
\
SPINLabelDecode
,
VLLabelDecode
SPINLabelDecode
,
VLLabelDecode
,
RFLLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
from
.vqa_token_ser_layoutlm_postprocess
import
VQASerTokenLayoutLMPostProcess
,
DistillationSerPostProcess
...
...
@@ -49,7 +49,7 @@ def build_post_process(config, global_config=None):
'DistillationSARLabelDecode'
,
'ViTSTRLabelDecode'
,
'ABINetLabelDecode'
,
'TableMasterLabelDecode'
,
'SPINLabelDecode'
,
'DistillationSerPostProcess'
,
'DistillationRePostProcess'
,
'VLLabelDecode'
,
'PicoDetPostProcess'
,
'CTPostProcess'
'VLLabelDecode'
,
'PicoDetPostProcess'
,
'CTPostProcess'
,
'RFLLabelDecode'
]
if
config
[
'name'
]
==
'PSEPostProcess'
:
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
0002349d
...
...
@@ -242,6 +242,92 @@ class AttnLabelDecode(BaseRecLabelDecode):
return
idx
class
RFLLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
RFLLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
dict_character
=
dict_character
dict_character
=
[
self
.
beg_str
]
+
dict_character
+
[
self
.
end_str
]
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
ignored_tokens
=
self
.
get_ignored_tokens
()
[
beg_idx
,
end_idx
]
=
self
.
get_ignored_tokens
()
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
in
ignored_tokens
:
continue
if
int
(
text_index
[
batch_idx
][
idx
])
==
int
(
end_idx
):
break
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
continue
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
,
np
.
mean
(
conf_list
).
tolist
()))
return
result_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
cnt_pred
,
preds
=
preds
if
preds
is
not
None
:
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
return
text
,
label
else
:
cnt_length
=
[]
for
lens
in
cnt_pred
:
length
=
round
(
paddle
.
sum
(
lens
).
item
())
cnt_length
.
append
(
length
)
if
label
is
None
:
return
cnt_length
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
length
=
[
len
(
res
[
0
])
for
res
in
label
]
return
cnt_length
,
length
def
get_ignored_tokens
(
self
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"beg"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
elif
beg_or_end
==
"end"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
end_str
])
else
:
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
\
%
beg_or_end
return
idx
class
SEEDLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
...
...
test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml
0 → 100644
浏览文件 @
0002349d
Global
:
use_gpu
:
True
epoch_num
:
6
log_smooth_window
:
20
print_batch_step
:
50
save_model_dir
:
./output/rec/rec_resnet_rfl/
save_epoch_step
:
1
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
0
,
5000
]
cal_metric_during_train
:
False
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path
:
max_text_length
:
25
infer_mode
:
False
use_space_char
:
False
save_res_path
:
./output/rec/rec_resnet_rfl.txt
Optimizer
:
name
:
AdamW
beta1
:
0.9
beta2
:
0.999
weight_decay
:
0.0
clip_norm_global
:
5.0
lr
:
name
:
Piecewise
decay_epochs
:
[
3
,
4
,
5
]
values
:
[
0.001
,
0.0003
,
0.00009
,
0.000027
]
Architecture
:
model_type
:
rec
algorithm
:
RFL
in_channels
:
1
Transform
:
name
:
TPS
num_fiducial
:
20
loc_lr
:
1.0
model_name
:
large
Backbone
:
name
:
ResNetRFL
use_cnt
:
True
use_seq
:
True
Neck
:
name
:
RFAdaptor
use_v2s
:
True
use_s2v
:
True
Head
:
name
:
RFLHead
in_channels
:
512
hidden_size
:
256
batch_max_legnth
:
25
out_channels
:
38
use_cnt
:
True
use_seq
:
True
Loss
:
name
:
RFLLoss
PostProcess
:
name
:
RFLLabelDecode
Metric
:
name
:
RecMetric
main_indicator
:
acc
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/ic15_data/
label_file_list
:
[
"
./train_data/ic15_data/rec_gt_train.txt"
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
RFLLabelEncode
:
# Class handling label
-
RFLRecResizeImg
:
image_shape
:
[
1
,
32
,
100
]
interpolation
:
2
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
,
'
cnt_label'
]
# dataloader will return list in this order
loader
:
shuffle
:
True
batch_size_per_card
:
64
drop_last
:
True
num_workers
:
8
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/ic15_data
label_file_list
:
[
"
./train_data/ic15_data/rec_gt_test.txt"
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
RFLLabelEncode
:
# Class handling label
-
RFLRecResizeImg
:
image_shape
:
[
1
,
32
,
100
]
interpolation
:
2
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
,
'
cnt_label'
]
# dataloader will return list in this order
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
256
num_workers
:
8
test_tipc/configs/rec_resnet_rfl/train_infer_python.txt
0 → 100644
浏览文件 @
0002349d
===========================train_params===========================
model_name:rec_resnet_rfl
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
##
train_model:./inference/rec_resnet_rfl_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_image_shape="1,32,100" --rec_algorithm="RFL" --min_subgraph_size=5
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
--rec_model_dir:
--image_dir:./inference/rec_inference
--save_log_path:./test/output/
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[1,32,100]}]
tools/export_model.py
浏览文件 @
0002349d
...
...
@@ -99,7 +99,7 @@ def export_single_model(model,
]
# print([None, 3, 32, 128])
model
=
to_static
(
model
,
input_spec
=
other_shape
)
elif
arch_config
[
"algorithm"
]
in
[
"NRTR"
,
"SPIN"
]:
elif
arch_config
[
"algorithm"
]
in
[
"NRTR"
,
"SPIN"
,
'RFL'
]:
other_shape
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
1
,
32
,
100
],
dtype
=
"float32"
),
...
...
tools/infer/predict_rec.py
浏览文件 @
0002349d
...
...
@@ -100,6 +100,12 @@ class TextRecognizer(object):
"use_space_char"
:
args
.
use_space_char
,
"rm_symbol"
:
True
}
elif
self
.
rec_algorithm
==
'RFL'
:
postprocess_params
=
{
'name'
:
'RFLLabelDecode'
,
"character_dict_path"
:
None
,
"use_space_char"
:
args
.
use_space_char
}
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
\
utility
.
create_predictor
(
args
,
'rec'
,
logger
)
...
...
@@ -143,6 +149,16 @@ class TextRecognizer(object):
else
:
norm_img
=
norm_img
.
astype
(
np
.
float32
)
/
128.
-
1.
return
norm_img
elif
self
.
rec_algorithm
==
'RFL'
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
resized_image
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_CUBIC
)
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
/
255
resized_image
=
resized_image
[
np
.
newaxis
,
:]
resized_image
-=
0.5
resized_image
/=
0.5
return
resized_image
assert
imgC
==
img
.
shape
[
2
]
imgW
=
int
((
imgH
*
max_wh_ratio
))
...
...
tools/program.py
浏览文件 @
0002349d
...
...
@@ -217,7 +217,7 @@ def train(config,
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
extra_input_models
=
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
,
"SVTR"
,
"SPIN"
,
"VisionLAN"
,
"RobustScanner"
"RobustScanner"
,
"RFL"
]
extra_input
=
False
if
config
[
'Architecture'
][
'algorithm'
]
==
'Distillation'
:
...
...
@@ -625,7 +625,7 @@ def preprocess(is_train=False):
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'SEED'
,
'SDMGR'
,
'LayoutXLM'
,
'LayoutLM'
,
'LayoutLMv2'
,
'PREN'
,
'FCE'
,
'SVTR'
,
'ViTSTR'
,
'ABINet'
,
'DB++'
,
'TableMaster'
,
'SPIN'
,
'VisionLAN'
,
'Gestalt'
,
'SLANet'
,
'RobustScanner'
,
'CT'
'Gestalt'
,
'SLANet'
,
'RobustScanner'
,
'CT'
,
'RFL'
]
if
use_xpu
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录