Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
bde50863
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bde50863
编写于
4月 27, 2022
作者:
X
xiaoting
提交者:
GitHub
4月 27, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #6061 from Topdu/dygraph
add_rec_svtr
上级
58a408ab
9974b81a
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
333 addition
and
55 deletion
+333
-55
configs/rec/rec_svtrnet.yml
configs/rec/rec_svtrnet.yml
+117
-0
deploy/slim/quantization/quant.py
deploy/slim/quantization/quant.py
+1
-1
ppocr/data/imaug/__init__.py
ppocr/data/imaug/__init__.py
+1
-1
ppocr/data/imaug/rec_img_aug.py
ppocr/data/imaug/rec_img_aug.py
+72
-0
ppocr/modeling/backbones/rec_svtrnet.py
ppocr/modeling/backbones/rec_svtrnet.py
+27
-25
ppocr/modeling/transforms/stn.py
ppocr/modeling/transforms/stn.py
+2
-0
ppocr/modeling/transforms/tps_spatial_transformer.py
ppocr/modeling/transforms/tps_spatial_transformer.py
+2
-2
ppocr/optimizer/__init__.py
ppocr/optimizer/__init__.py
+4
-2
ppocr/optimizer/optimizer.py
ppocr/optimizer/optimizer.py
+38
-14
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+2
-2
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+37
-0
tools/export_model.py
tools/export_model.py
+5
-0
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+24
-7
tools/train.py
tools/train.py
+1
-1
未找到文件。
configs/rec/rec_svtrnet.yml
0 → 100644
浏览文件 @
bde50863
Global
:
use_gpu
:
True
epoch_num
:
20
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec/svtr/
save_epoch_step
:
1
# evaluation is run every 2000 iterations after the 0th iteration
eval_batch_step
:
[
0
,
2000
]
cal_metric_during_train
:
True
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path
:
character_type
:
en
max_text_length
:
25
infer_mode
:
False
use_space_char
:
False
save_res_path
:
./output/rec/predicts_svtr_tiny.txt
Optimizer
:
name
:
AdamW
beta1
:
0.9
beta2
:
0.99
epsilon
:
0.00000008
weight_decay
:
0.05
no_weight_decay_name
:
norm pos_embed
one_dim_param_no_weight_decay
:
true
lr
:
name
:
Cosine
learning_rate
:
0.0005
warmup_epoch
:
2
Architecture
:
model_type
:
rec
algorithm
:
SVTR
Transform
:
name
:
STN_ON
tps_inputsize
:
[
32
,
64
]
tps_outputsize
:
[
32
,
100
]
num_control_points
:
20
tps_margins
:
[
0.05
,
0.05
]
stn_activation
:
none
Backbone
:
name
:
SVTRNet
img_size
:
[
32
,
100
]
out_char_num
:
25
out_channels
:
192
patch_merging
:
'
Conv'
embed_dim
:
[
64
,
128
,
256
]
depth
:
[
3
,
6
,
3
]
num_heads
:
[
2
,
4
,
8
]
mixer
:
[
'
Local'
,
'
Local'
,
'
Local'
,
'
Local'
,
'
Local'
,
'
Local'
,
'
Global'
,
'
Global'
,
'
Global'
,
'
Global'
,
'
Global'
,
'
Global'
]
local_mixer
:
[[
7
,
11
],
[
7
,
11
],
[
7
,
11
]]
last_stage
:
True
prenorm
:
false
Neck
:
name
:
SequenceEncoder
encoder_type
:
reshape
Head
:
name
:
CTCHead
Loss
:
name
:
CTCLoss
PostProcess
:
name
:
CTCLabelDecode
Metric
:
name
:
RecMetric
main_indicator
:
acc
Train
:
dataset
:
name
:
LMDBDataSet
data_dir
:
./train_data/data_lmdb_release/training/
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
CTCLabelEncode
:
# Class handling label
-
RecResizeImg
:
character_dict_path
:
image_shape
:
[
3
,
64
,
256
]
padding
:
False
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
loader
:
shuffle
:
True
batch_size_per_card
:
512
drop_last
:
True
num_workers
:
4
Eval
:
dataset
:
name
:
LMDBDataSet
data_dir
:
./train_data/data_lmdb_release/validation/
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
CTCLabelEncode
:
# Class handling label
-
RecResizeImg
:
character_dict_path
:
image_shape
:
[
3
,
64
,
256
]
padding
:
False
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
256
num_workers
:
2
deploy/slim/quantization/quant.py
浏览文件 @
bde50863
...
...
@@ -137,7 +137,7 @@ def main(config, device, logger, vdl_writer):
config
[
'Optimizer'
],
epochs
=
config
[
'Global'
][
'epoch_num'
],
step_each_epoch
=
len
(
train_dataloader
),
parameters
=
model
.
parameters
()
)
model
=
model
)
# resume PACT training process
if
config
[
"Global"
][
"checkpoints"
]
is
not
None
:
...
...
ppocr/data/imaug/__init__.py
浏览文件 @
bde50863
...
...
@@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from
.make_pse_gt
import
MakePseGt
from
.rec_img_aug
import
RecAug
,
RecConAug
,
RecResizeImg
,
ClsResizeImg
,
\
SRNRecResizeImg
,
NRTRRecResizeImg
,
SARRecResizeImg
,
PRENResizeImg
SRNRecResizeImg
,
NRTRRecResizeImg
,
SARRecResizeImg
,
PRENResizeImg
,
SVTRRecResizeImg
from
.randaugment
import
RandAugment
from
.copy_paste
import
CopyPaste
from
.ColorJitter
import
ColorJitter
...
...
ppocr/data/imaug/rec_img_aug.py
浏览文件 @
bde50863
...
...
@@ -16,6 +16,7 @@ import math
import
cv2
import
numpy
as
np
import
random
import
copy
from
PIL
import
Image
from
.text_image_aug
import
tia_perspective
,
tia_stretch
,
tia_distort
...
...
@@ -206,6 +207,25 @@ class PRENResizeImg(object):
return
data
class
SVTRRecResizeImg
(
object
):
def
__init__
(
self
,
image_shape
,
infer_mode
=
False
,
character_dict_path
=
'./ppocr/utils/ppocr_keys_v1.txt'
,
padding
=
True
,
**
kwargs
):
self
.
image_shape
=
image_shape
self
.
infer_mode
=
infer_mode
self
.
character_dict_path
=
character_dict_path
self
.
padding
=
padding
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
norm_img
=
resize_norm_img_svtr
(
img
,
self
.
image_shape
,
self
.
padding
)
data
[
'image'
]
=
norm_img
return
data
def
resize_norm_img_sar
(
img
,
image_shape
,
width_downsample_ratio
=
0.25
):
imgC
,
imgH
,
imgW_min
,
imgW_max
=
image_shape
h
=
img
.
shape
[
0
]
...
...
@@ -324,6 +344,58 @@ def resize_norm_img_srn(img, image_shape):
return
np
.
reshape
(
img_black
,
(
c
,
row
,
col
)).
astype
(
np
.
float32
)
def
resize_norm_img_svtr
(
img
,
image_shape
,
padding
=
False
):
imgC
,
imgH
,
imgW
=
image_shape
h
=
img
.
shape
[
0
]
w
=
img
.
shape
[
1
]
if
not
padding
:
if
h
>
2.0
*
w
:
image
=
Image
.
fromarray
(
img
)
image1
=
image
.
rotate
(
90
,
expand
=
True
)
image2
=
image
.
rotate
(
-
90
,
expand
=
True
)
img1
=
np
.
array
(
image1
)
img2
=
np
.
array
(
image2
)
else
:
img1
=
copy
.
deepcopy
(
img
)
img2
=
copy
.
deepcopy
(
img
)
resized_image
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_image1
=
cv2
.
resize
(
img1
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_image2
=
cv2
.
resize
(
img2
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_w
=
imgW
else
:
ratio
=
w
/
float
(
h
)
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
resized_w
=
imgW
else
:
resized_w
=
int
(
math
.
ceil
(
imgH
*
ratio
))
resized_image
=
cv2
.
resize
(
img
,
(
resized_w
,
imgH
))
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image1
=
resized_image1
.
astype
(
'float32'
)
resized_image2
=
resized_image2
.
astype
(
'float32'
)
if
image_shape
[
0
]
==
1
:
resized_image
=
resized_image
/
255
resized_image
=
resized_image
[
np
.
newaxis
,
:]
else
:
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image1
=
resized_image1
.
transpose
((
2
,
0
,
1
))
/
255
resized_image2
=
resized_image2
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
resized_image1
-=
0.5
resized_image1
/=
0.5
resized_image2
-=
0.5
resized_image2
/=
0.5
padding_im
=
np
.
zeros
((
3
,
imgC
,
imgH
,
imgW
),
dtype
=
np
.
float32
)
padding_im
[
0
,
:,
:,
0
:
resized_w
]
=
resized_image
padding_im
[
1
,
:,
:,
0
:
resized_w
]
=
resized_image1
padding_im
[
2
,
:,
:,
0
:
resized_w
]
=
resized_image2
return
padding_im
def
srn_other_inputs
(
image_shape
,
num_heads
,
max_text_length
):
imgC
,
imgH
,
imgW
=
image_shape
...
...
ppocr/modeling/backbones/rec_svtrnet.py
浏览文件 @
bde50863
...
...
@@ -296,47 +296,49 @@ class PatchEmbed(nn.Layer):
if
sub_num
==
2
:
self
.
proj
=
nn
.
Sequential
(
ConvBNLayer
(
in_channels
,
embed_dim
//
2
,
3
,
2
,
1
,
in_channels
=
in_channels
,
out_channels
=
embed_dim
//
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
act
=
nn
.
GELU
,
bias_attr
=
None
),
ConvBNLayer
(
embed_dim
//
2
,
embed_dim
,
3
,
2
,
1
,
in_channels
=
embed_dim
//
2
,
out_channels
=
embed_dim
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
act
=
nn
.
GELU
,
bias_attr
=
None
))
if
sub_num
==
3
:
self
.
proj
=
nn
.
Sequential
(
ConvBNLayer
(
in_channels
,
embed_dim
//
4
,
3
,
2
,
1
,
in_channels
=
in_channels
,
out_channels
=
embed_dim
//
4
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
act
=
nn
.
GELU
,
bias_attr
=
None
),
ConvBNLayer
(
embed_dim
//
4
,
embed_dim
//
2
,
3
,
2
,
1
,
in_channels
=
embed_dim
//
4
,
out_channels
=
embed_dim
//
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
act
=
nn
.
GELU
,
bias_attr
=
None
),
ConvBNLayer
(
embed_dim
//
2
,
embed_dim
,
3
,
2
,
1
,
in_channels
=
embed_dim
//
2
,
out_channels
=
embed_dim
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
act
=
nn
.
GELU
,
bias_attr
=
None
)
,
)
bias_attr
=
None
))
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
...
...
@@ -455,7 +457,7 @@ class SVTRNet(nn.Layer):
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
act_layer
=
nn
.
Swish
,
act_layer
=
eval
(
act
)
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
0
:
depth
[
0
]][
i
],
norm_layer
=
norm_layer
,
...
...
ppocr/modeling/transforms/stn.py
浏览文件 @
bde50863
...
...
@@ -128,6 +128,8 @@ class STN_ON(nn.Layer):
self
.
out_channels
=
in_channels
def
forward
(
self
,
image
):
if
len
(
image
.
shape
)
==
5
:
image
=
image
.
reshape
([
0
,
image
.
shape
[
-
3
],
image
.
shape
[
-
2
],
image
.
shape
[
-
1
]])
stn_input
=
paddle
.
nn
.
functional
.
interpolate
(
image
,
self
.
tps_inputsize
,
mode
=
"bilinear"
,
align_corners
=
True
)
stn_img_feat
,
ctrl_points
=
self
.
stn_head
(
stn_input
)
...
...
ppocr/modeling/transforms/tps_spatial_transformer.py
浏览文件 @
bde50863
...
...
@@ -138,9 +138,9 @@ class TPSSpatialTransformer(nn.Layer):
assert
source_control_points
.
shape
[
2
]
==
2
batch_size
=
paddle
.
shape
(
source_control_points
)[
0
]
self
.
padding_matrix
=
paddle
.
expand
(
padding_matrix
=
paddle
.
expand
(
self
.
padding_matrix
,
shape
=
[
batch_size
,
3
,
2
])
Y
=
paddle
.
concat
([
source_control_points
,
self
.
padding_matrix
],
1
)
Y
=
paddle
.
concat
([
source_control_points
,
padding_matrix
],
1
)
mapping_matrix
=
paddle
.
matmul
(
self
.
inverse_kernel
,
Y
)
source_coordinate
=
paddle
.
matmul
(
self
.
target_coordinate_repr
,
mapping_matrix
)
...
...
ppocr/optimizer/__init__.py
浏览文件 @
bde50863
...
...
@@ -30,7 +30,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
return
lr
def
build_optimizer
(
config
,
epochs
,
step_each_epoch
,
parameters
):
def
build_optimizer
(
config
,
epochs
,
step_each_epoch
,
model
):
from
.
import
regularizer
,
optimizer
config
=
copy
.
deepcopy
(
config
)
# step1 build lr
...
...
@@ -43,6 +43,8 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
if
not
hasattr
(
regularizer
,
reg_name
):
reg_name
+=
'Decay'
reg
=
getattr
(
regularizer
,
reg_name
)(
**
reg_config
)()
elif
'weight_decay'
in
config
:
reg
=
config
.
pop
(
'weight_decay'
)
else
:
reg
=
None
...
...
@@ -57,4 +59,4 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
weight_decay
=
reg
,
grad_clip
=
grad_clip
,
**
config
)
return
optim
(
parameters
),
lr
return
optim
(
model
),
lr
ppocr/optimizer/optimizer.py
浏览文件 @
bde50863
...
...
@@ -42,13 +42,13 @@ class Momentum(object):
self
.
weight_decay
=
weight_decay
self
.
grad_clip
=
grad_clip
def
__call__
(
self
,
parameters
):
def
__call__
(
self
,
model
):
opt
=
optim
.
Momentum
(
learning_rate
=
self
.
learning_rate
,
momentum
=
self
.
momentum
,
weight_decay
=
self
.
weight_decay
,
grad_clip
=
self
.
grad_clip
,
parameters
=
parameters
)
parameters
=
model
.
parameters
()
)
return
opt
...
...
@@ -75,7 +75,7 @@ class Adam(object):
self
.
name
=
name
self
.
lazy_mode
=
lazy_mode
def
__call__
(
self
,
parameters
):
def
__call__
(
self
,
model
):
opt
=
optim
.
Adam
(
learning_rate
=
self
.
learning_rate
,
beta1
=
self
.
beta1
,
...
...
@@ -85,7 +85,7 @@ class Adam(object):
grad_clip
=
self
.
grad_clip
,
name
=
self
.
name
,
lazy_mode
=
self
.
lazy_mode
,
parameters
=
parameters
)
parameters
=
model
.
parameters
()
)
return
opt
...
...
@@ -117,7 +117,7 @@ class RMSProp(object):
self
.
weight_decay
=
weight_decay
self
.
grad_clip
=
grad_clip
def
__call__
(
self
,
parameters
):
def
__call__
(
self
,
model
):
opt
=
optim
.
RMSProp
(
learning_rate
=
self
.
learning_rate
,
momentum
=
self
.
momentum
,
...
...
@@ -125,7 +125,7 @@ class RMSProp(object):
epsilon
=
self
.
epsilon
,
weight_decay
=
self
.
weight_decay
,
grad_clip
=
self
.
grad_clip
,
parameters
=
parameters
)
parameters
=
model
.
parameters
()
)
return
opt
...
...
@@ -148,7 +148,7 @@ class Adadelta(object):
self
.
grad_clip
=
grad_clip
self
.
name
=
name
def
__call__
(
self
,
parameters
):
def
__call__
(
self
,
model
):
opt
=
optim
.
Adadelta
(
learning_rate
=
self
.
learning_rate
,
epsilon
=
self
.
epsilon
,
...
...
@@ -156,7 +156,7 @@ class Adadelta(object):
weight_decay
=
self
.
weight_decay
,
grad_clip
=
self
.
grad_clip
,
name
=
self
.
name
,
parameters
=
parameters
)
parameters
=
model
.
parameters
()
)
return
opt
...
...
@@ -165,31 +165,55 @@ class AdamW(object):
learning_rate
=
0.001
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-
0
8
,
epsilon
=
1e-8
,
weight_decay
=
0.01
,
multi_precision
=
False
,
grad_clip
=
None
,
no_weight_decay_name
=
None
,
one_dim_param_no_weight_decay
=
False
,
name
=
None
,
lazy_mode
=
False
,
**
kwargs
):
**
args
):
super
().
__init__
()
self
.
learning_rate
=
learning_rate
self
.
beta1
=
beta1
self
.
beta2
=
beta2
self
.
epsilon
=
epsilon
self
.
learning_rate
=
learning_rate
self
.
grad_clip
=
grad_clip
self
.
weight_decay
=
0.01
if
weight_decay
is
None
else
weight_decay
self
.
grad_clip
=
grad_clip
self
.
name
=
name
self
.
lazy_mode
=
lazy_mode
self
.
multi_precision
=
multi_precision
self
.
no_weight_decay_name_list
=
no_weight_decay_name
.
split
(
)
if
no_weight_decay_name
else
[]
self
.
one_dim_param_no_weight_decay
=
one_dim_param_no_weight_decay
def
__call__
(
self
,
model
):
parameters
=
model
.
parameters
()
self
.
no_weight_decay_param_name_list
=
[
p
.
name
for
n
,
p
in
model
.
named_parameters
()
if
any
(
nd
in
n
for
nd
in
self
.
no_weight_decay_name_list
)
]
if
self
.
one_dim_param_no_weight_decay
:
self
.
no_weight_decay_param_name_list
+=
[
p
.
name
for
n
,
p
in
model
.
named_parameters
()
if
len
(
p
.
shape
)
==
1
]
def
__call__
(
self
,
parameters
):
opt
=
optim
.
AdamW
(
learning_rate
=
self
.
learning_rate
,
beta1
=
self
.
beta1
,
beta2
=
self
.
beta2
,
epsilon
=
self
.
epsilon
,
parameters
=
parameters
,
weight_decay
=
self
.
weight_decay
,
multi_precision
=
self
.
multi_precision
,
grad_clip
=
self
.
grad_clip
,
name
=
self
.
name
,
lazy_mode
=
self
.
lazy_mode
,
parameters
=
parameters
)
apply_decay_param_fun
=
self
.
_apply_decay_param_fun
)
return
opt
def
_apply_decay_param_fun
(
self
,
name
):
return
name
not
in
self
.
no_weight_decay_param_name_list
\ No newline at end of file
ppocr/postprocess/__init__.py
浏览文件 @
bde50863
...
...
@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
from
.fce_postprocess
import
FCEPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
\
DistillationCTCLabelDecode
,
TableLabelDecode
,
NRTRLabelDecode
,
SARLabelDecode
,
\
SEEDLabelDecode
,
PRENLabelDecode
SEEDLabelDecode
,
PRENLabelDecode
,
SVTRLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
from
.vqa_token_ser_layoutlm_postprocess
import
VQASerTokenLayoutLMPostProcess
...
...
@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'SEEDLabelDecode'
,
'VQASerTokenLayoutLMPostProcess'
,
'VQAReTokenLayoutLMPostProcess'
,
'PRENLabelDecode'
,
'DistillationSARLabelDecode'
'DistillationSARLabelDecode'
,
'SVTRLabelDecode'
]
if
config
[
'name'
]
==
'PSEPostProcess'
:
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
bde50863
...
...
@@ -752,3 +752,40 @@ class PRENLabelDecode(BaseRecLabelDecode):
return
text
label
=
self
.
decode
(
label
)
return
text
,
label
class
SVTRLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
SVTRLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
isinstance
(
preds
,
tuple
):
preds
=
preds
[
-
1
]
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=-
1
)
preds_prob
=
preds
.
max
(
axis
=-
1
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
True
)
return_text
=
[]
for
i
in
range
(
0
,
len
(
text
),
3
):
text0
=
text
[
i
]
text1
=
text
[
i
+
1
]
text2
=
text
[
i
+
2
]
text_pred
=
[
text0
[
0
],
text1
[
0
],
text2
[
0
]]
text_prob
=
[
text0
[
1
],
text1
[
1
],
text2
[
1
]]
id_max
=
text_prob
.
index
(
max
(
text_prob
))
return_text
.
append
((
text_pred
[
id_max
],
text_prob
[
id_max
]))
if
label
is
None
:
return
return_text
label
=
self
.
decode
(
label
)
return
return_text
,
label
def
add_special_char
(
self
,
dict_character
):
dict_character
=
[
'blank'
]
+
dict_character
return
dict_character
\ No newline at end of file
tools/export_model.py
浏览文件 @
bde50863
...
...
@@ -61,6 +61,11 @@ def export_single_model(model, arch_config, save_path, logger):
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
48
,
-
1
],
dtype
=
"float32"
),
]
else
:
other_shape
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
64
,
256
],
dtype
=
"float32"
),
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
elif
arch_config
[
"algorithm"
]
==
"PREN"
:
other_shape
=
[
...
...
tools/infer/predict_rec.py
浏览文件 @
bde50863
...
...
@@ -132,6 +132,17 @@ class TextRecognizer(object):
padding_im
[:,
:,
0
:
resized_w
]
=
resized_image
return
padding_im
def
resize_norm_img_svtr
(
self
,
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
resized_image
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
return
resized_image
def
resize_norm_img_srn
(
self
,
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
...
...
@@ -263,12 +274,8 @@ class TextRecognizer(object):
wh_ratio
=
w
*
1.0
/
h
max_wh_ratio
=
max
(
max_wh_ratio
,
wh_ratio
)
for
ino
in
range
(
beg_img_no
,
end_img_no
):
if
self
.
rec_algorithm
!=
"SRN"
and
self
.
rec_algorithm
!=
"SAR"
:
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]],
max_wh_ratio
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
elif
self
.
rec_algorithm
==
"SAR"
:
if
self
.
rec_algorithm
==
"SAR"
:
norm_img
,
_
,
_
,
valid_ratio
=
self
.
resize_norm_img_sar
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
...
...
@@ -276,7 +283,7 @@ class TextRecognizer(object):
valid_ratios
=
[]
valid_ratios
.
append
(
valid_ratio
)
norm_img_batch
.
append
(
norm_img
)
el
se
:
el
if
self
.
rec_algorithm
==
"SRN"
:
norm_img
=
self
.
process_image_srn
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
,
8
,
25
)
encoder_word_pos_list
=
[]
...
...
@@ -288,6 +295,16 @@ class TextRecognizer(object):
gsrm_slf_attn_bias1_list
.
append
(
norm_img
[
3
])
gsrm_slf_attn_bias2_list
.
append
(
norm_img
[
4
])
norm_img_batch
.
append
(
norm_img
[
0
])
elif
self
.
rec_algorithm
==
"SVTR"
:
norm_img
=
self
.
resize_norm_img_svtr
(
img_list
[
indices
[
ino
]],
self
.
rec_image_shape
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
else
:
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]],
max_wh_ratio
)
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
norm_img_batch
=
np
.
concatenate
(
norm_img_batch
)
norm_img_batch
=
norm_img_batch
.
copy
()
if
self
.
benchmark
:
...
...
tools/train.py
浏览文件 @
bde50863
...
...
@@ -129,7 +129,7 @@ def main(config, device, logger, vdl_writer):
config
[
'Optimizer'
],
epochs
=
config
[
'Global'
][
'epoch_num'
],
step_each_epoch
=
len
(
train_dataloader
),
parameters
=
model
.
parameters
()
)
model
=
model
)
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录