Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
bde8cad0
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bde8cad0
编写于
8月 08, 2022
作者:
T
topduke
提交者:
GitHub
8月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add svtr ch model (#7134)
上级
05b6d296
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
191 addition
and
9 deletion
+191
-9
configs/rec/rec_svtrnet.yml
configs/rec/rec_svtrnet.yml
+3
-5
configs/rec/rec_svtrnet_ch.yml
configs/rec/rec_svtrnet_ch.yml
+155
-0
ppocr/data/imaug/__init__.py
ppocr/data/imaug/__init__.py
+2
-1
ppocr/data/imaug/rec_img_aug.py
ppocr/data/imaug/rec_img_aug.py
+15
-0
tools/export_model.py
tools/export_model.py
+16
-3
未找到文件。
configs/rec/rec_svtrnet.yml
浏览文件 @
bde8cad0
...
...
@@ -83,8 +83,7 @@ Train:
img_mode
:
BGR
channel_first
:
False
-
CTCLabelEncode
:
# Class handling label
-
RecResizeImg
:
character_dict_path
:
-
SVTRRecResizeImg
:
image_shape
:
[
3
,
64
,
256
]
padding
:
False
-
KeepKeys
:
...
...
@@ -98,14 +97,13 @@ Train:
Eval
:
dataset
:
name
:
LMDBDataSet
data_dir
:
./train_data/data_lmdb_release/
valid
ation/
data_dir
:
./train_data/data_lmdb_release/
evalu
ation/
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
CTCLabelEncode
:
# Class handling label
-
RecResizeImg
:
character_dict_path
:
-
SVTRRecResizeImg
:
image_shape
:
[
3
,
64
,
256
]
padding
:
False
-
KeepKeys
:
...
...
configs/rec/rec_svtrnet_ch.yml
0 → 100644
浏览文件 @
bde8cad0
Global
:
use_gpu
:
true
epoch_num
:
100
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec/svtr_ch_all/
save_epoch_step
:
10
eval_batch_step
:
-
0
-
2000
cal_metric_during_train
:
true
pretrained_model
:
null
checkpoints
:
null
save_inference_dir
:
null
use_visualdl
:
false
infer_img
:
doc/imgs_words/ch/word_1.jpg
character_dict_path
:
ppocr/utils/ppocr_keys_v1.txt
max_text_length
:
25
infer_mode
:
false
use_space_char
:
true
save_res_path
:
./output/rec/predicts_svtr_tiny_ch_all.txt
Optimizer
:
name
:
AdamW
beta1
:
0.9
beta2
:
0.99
epsilon
:
8.0e-08
weight_decay
:
0.05
no_weight_decay_name
:
norm pos_embed
one_dim_param_no_weight_decay
:
true
lr
:
name
:
Cosine
learning_rate
:
0.0005
warmup_epoch
:
2
Architecture
:
model_type
:
rec
algorithm
:
SVTR
Transform
:
null
Backbone
:
name
:
SVTRNet
img_size
:
-
32
-
320
out_char_num
:
40
out_channels
:
96
patch_merging
:
Conv
embed_dim
:
-
64
-
128
-
256
depth
:
-
3
-
6
-
3
num_heads
:
-
2
-
4
-
8
mixer
:
-
Local
-
Local
-
Local
-
Local
-
Local
-
Local
-
Global
-
Global
-
Global
-
Global
-
Global
-
Global
local_mixer
:
-
-
7
-
11
-
-
7
-
11
-
-
7
-
11
last_stage
:
true
prenorm
:
false
Neck
:
name
:
SequenceEncoder
encoder_type
:
reshape
Head
:
name
:
CTCHead
Loss
:
name
:
CTCLoss
PostProcess
:
name
:
CTCLabelDecode
Metric
:
name
:
RecMetric
main_indicator
:
acc
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data
label_file_list
:
-
./train_data/train_list.txt
ext_op_transform_idx
:
1
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
RecConAug
:
prob
:
0.5
ext_data_num
:
2
image_shape
:
-
32
-
320
-
3
-
RecAug
:
null
-
CTCLabelEncode
:
null
-
SVTRRecResizeImg
:
image_shape
:
-
3
-
32
-
320
padding
:
true
-
KeepKeys
:
keep_keys
:
-
image
-
label
-
length
loader
:
shuffle
:
true
batch_size_per_card
:
256
drop_last
:
true
num_workers
:
8
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data
label_file_list
:
-
./train_data/val_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
CTCLabelEncode
:
null
-
SVTRRecResizeImg
:
image_shape
:
-
3
-
32
-
320
padding
:
true
-
KeepKeys
:
keep_keys
:
-
image
-
label
-
length
loader
:
shuffle
:
false
drop_last
:
false
batch_size_per_card
:
256
num_workers
:
2
profiler_options
:
null
ppocr/data/imaug/__init__.py
浏览文件 @
bde8cad0
...
...
@@ -23,7 +23,8 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from
.make_pse_gt
import
MakePseGt
from
.rec_img_aug
import
RecAug
,
RecConAug
,
RecResizeImg
,
ClsResizeImg
,
\
SRNRecResizeImg
,
NRTRRecResizeImg
,
SARRecResizeImg
,
PRENResizeImg
SRNRecResizeImg
,
NRTRRecResizeImg
,
SARRecResizeImg
,
PRENResizeImg
,
\
SVTRRecResizeImg
from
.ssl_img_aug
import
SSLRotateResize
from
.randaugment
import
RandAugment
from
.copy_paste
import
CopyPaste
...
...
ppocr/data/imaug/rec_img_aug.py
浏览文件 @
bde8cad0
...
...
@@ -207,6 +207,21 @@ class PRENResizeImg(object):
return
data
class
SVTRRecResizeImg
(
object
):
def
__init__
(
self
,
image_shape
,
padding
=
True
,
**
kwargs
):
self
.
image_shape
=
image_shape
self
.
padding
=
padding
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
norm_img
,
valid_ratio
=
resize_norm_img
(
img
,
self
.
image_shape
,
self
.
padding
)
data
[
'image'
]
=
norm_img
data
[
'valid_ratio'
]
=
valid_ratio
return
data
def
resize_norm_img_sar
(
img
,
image_shape
,
width_downsample_ratio
=
0.25
):
imgC
,
imgH
,
imgW_min
,
imgW_max
=
image_shape
h
=
img
.
shape
[
0
]
...
...
tools/export_model.py
浏览文件 @
bde8cad0
...
...
@@ -31,7 +31,12 @@ from ppocr.utils.logging import get_logger
from
tools.program
import
load_config
,
merge_config
,
ArgsParser
def
export_single_model
(
model
,
arch_config
,
save_path
,
logger
,
quanter
=
None
):
def
export_single_model
(
model
,
arch_config
,
save_path
,
logger
,
input_shape
=
None
,
quanter
=
None
):
if
arch_config
[
"algorithm"
]
==
"SRN"
:
max_text_length
=
arch_config
[
"Head"
][
"max_text_length"
]
other_shape
=
[
...
...
@@ -64,7 +69,7 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
else
:
other_shape
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
64
,
256
]
,
dtype
=
"float32"
),
shape
=
[
None
]
+
input_shape
,
dtype
=
"float32"
),
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
elif
arch_config
[
"algorithm"
]
==
"PREN"
:
...
...
@@ -157,6 +162,13 @@ def main():
arch_config
=
config
[
"Architecture"
]
if
arch_config
[
"algorithm"
]
==
"SVTR"
and
arch_config
[
"Head"
][
"name"
]
!=
'MultiHead'
:
input_shape
=
config
[
"Eval"
][
"dataset"
][
"transforms"
][
-
2
][
'SVTRRecResizeImg'
][
'image_shape'
]
else
:
input_shape
=
None
if
arch_config
[
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
archs
=
list
(
arch_config
[
"Models"
].
values
())
for
idx
,
name
in
enumerate
(
model
.
model_name_list
):
...
...
@@ -165,7 +177,8 @@ def main():
sub_model_save_path
,
logger
)
else
:
save_path
=
os
.
path
.
join
(
save_path
,
"inference"
)
export_single_model
(
model
,
arch_config
,
save_path
,
logger
)
export_single_model
(
model
,
arch_config
,
save_path
,
logger
,
input_shape
=
input_shape
)
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录