Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
a14f8da9
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a14f8da9
编写于
9月 28, 2021
作者:
T
tink2123
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish seed code
上级
1effa5f3
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
54 addition
and
73 deletion
+54
-73
configs/rec/rec_resnet_stn_bilstm_att.yml
configs/rec/rec_resnet_stn_bilstm_att.yml
+8
-5
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+0
-2
ppocr/data/imaug/rec_img_aug.py
ppocr/data/imaug/rec_img_aug.py
+13
-29
ppocr/modeling/transforms/__init__.py
ppocr/modeling/transforms/__init__.py
+1
-1
ppocr/modeling/transforms/stn.py
ppocr/modeling/transforms/stn.py
+24
-0
ppocr/modeling/transforms/tps.py
ppocr/modeling/transforms/tps.py
+0
-25
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+0
-1
requirements.txt
requirements.txt
+2
-1
tools/program.py
tools/program.py
+6
-9
未找到文件。
configs/rec/rec_resnet_stn_bilstm_att.yml
浏览文件 @
a14f8da9
...
...
@@ -19,7 +19,6 @@ Global:
max_text_length
:
100
infer_mode
:
False
use_space_char
:
False
eval_filter
:
True
save_res_path
:
./output/rec/predicts_seed.txt
...
...
@@ -37,8 +36,8 @@ Optimizer:
Architecture
:
model_type
:
seed
algorithm
:
ASTER
model_type
:
rec
algorithm
:
seed
Transform
:
name
:
STN_ON
tps_inputsize
:
[
32
,
64
]
...
...
@@ -76,8 +75,10 @@ Train:
img_mode
:
BGR
channel_first
:
False
-
SEEDLabelEncode
:
# Class handling label
-
SEEDResize
:
-
RecResizeImg
:
character_type
:
en
image_shape
:
[
3
,
64
,
256
]
padding
:
False
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
,
'
fast_label'
]
# dataloader will return list in this order
loader
:
...
...
@@ -95,8 +96,10 @@ Eval:
img_mode
:
BGR
channel_first
:
False
-
SEEDLabelEncode
:
# Class handling label
-
SEEDResize
:
-
RecResizeImg
:
character_type
:
en
image_shape
:
[
3
,
64
,
256
]
padding
:
False
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
loader
:
...
...
ppocr/data/imaug/label_ops.py
浏览文件 @
a14f8da9
...
...
@@ -106,7 +106,6 @@ class BaseRecLabelEncode(object):
self
.
max_text_len
=
max_text_length
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
self
.
unknown
=
"UNKNOWN"
if
character_type
==
"en"
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
dict_character
=
list
(
self
.
character_str
)
...
...
@@ -357,7 +356,6 @@ class SEEDLabelEncode(BaseRecLabelEncode):
character_type
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
dict_character
=
dict_character
+
[
self
.
end_str
]
return
dict_character
...
...
ppocr/data/imaug/rec_img_aug.py
浏览文件 @
a14f8da9
...
...
@@ -88,29 +88,19 @@ class RecResizeImg(object):
image_shape
,
infer_mode
=
False
,
character_type
=
'ch'
,
padding
=
True
,
**
kwargs
):
self
.
image_shape
=
image_shape
self
.
infer_mode
=
infer_mode
self
.
character_type
=
character_type
self
.
padding
=
padding
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
if
self
.
infer_mode
and
self
.
character_type
==
"ch"
:
norm_img
=
resize_norm_img_chinese
(
img
,
self
.
image_shape
)
else
:
norm_img
=
resize_norm_img
(
img
,
self
.
image_shape
)
data
[
'image'
]
=
norm_img
return
data
class
SEEDResize
(
object
):
def
__init__
(
self
,
image_shape
,
infer_mode
=
False
,
**
kwargs
):
self
.
image_shape
=
image_shape
self
.
infer_mode
=
infer_mode
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
norm_img
=
resize_no_padding_img
(
img
,
self
.
image_shape
)
norm_img
=
resize_norm_img
(
img
,
self
.
image_shape
,
self
.
padding
)
data
[
'image'
]
=
norm_img
return
data
...
...
@@ -186,16 +176,21 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
return
padding_im
,
resize_shape
,
pad_shape
,
valid_ratio
def
resize_norm_img
(
img
,
image_shape
):
def
resize_norm_img
(
img
,
image_shape
,
padding
=
True
):
imgC
,
imgH
,
imgW
=
image_shape
h
=
img
.
shape
[
0
]
w
=
img
.
shape
[
1
]
ratio
=
w
/
float
(
h
)
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
if
not
padding
:
resized_image
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_w
=
imgW
else
:
resized_w
=
int
(
math
.
ceil
(
imgH
*
ratio
))
resized_image
=
cv2
.
resize
(
img
,
(
resized_w
,
imgH
))
ratio
=
w
/
float
(
h
)
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
resized_w
=
imgW
else
:
resized_w
=
int
(
math
.
ceil
(
imgH
*
ratio
))
resized_image
=
cv2
.
resize
(
img
,
(
resized_w
,
imgH
))
resized_image
=
resized_image
.
astype
(
'float32'
)
if
image_shape
[
0
]
==
1
:
resized_image
=
resized_image
/
255
...
...
@@ -209,17 +204,6 @@ def resize_norm_img(img, image_shape):
return
padding_im
def
resize_no_padding_img
(
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
resized_image
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_image
=
resized_image
.
astype
(
'float32'
)
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
return
resized_image
def
resize_norm_img_chinese
(
img
,
image_shape
):
imgC
,
imgH
,
imgW
=
image_shape
# todo: change to 0 and modified image shape
...
...
ppocr/modeling/transforms/__init__.py
浏览文件 @
a14f8da9
...
...
@@ -17,7 +17,7 @@ __all__ = ['build_transform']
def
build_transform
(
config
):
from
.tps
import
TPS
from
.
tps
import
STN_ON
from
.
stn
import
STN_ON
support_dict
=
[
'TPS'
,
'STN_ON'
]
...
...
ppocr/modeling/transforms/stn.py
浏览文件 @
a14f8da9
...
...
@@ -22,6 +22,8 @@ from paddle import nn, ParamAttr
from
paddle.nn
import
functional
as
F
import
numpy
as
np
from
.tps_spatial_transformer
import
TPSSpatialTransformer
def
conv3x3_block
(
in_channels
,
out_channels
,
stride
=
1
):
n
=
3
*
3
*
out_channels
...
...
@@ -106,3 +108,25 @@ class STN(nn.Layer):
x
=
F
.
sigmoid
(
x
)
x
=
paddle
.
reshape
(
x
,
shape
=
[
-
1
,
self
.
num_ctrlpoints
,
2
])
return
img_feat
,
x
class
STN_ON
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
tps_inputsize
,
tps_outputsize
,
num_control_points
,
tps_margins
,
stn_activation
):
super
(
STN_ON
,
self
).
__init__
()
self
.
tps
=
TPSSpatialTransformer
(
output_image_size
=
tuple
(
tps_outputsize
),
num_control_points
=
num_control_points
,
margins
=
tuple
(
tps_margins
))
self
.
stn_head
=
STN
(
in_channels
=
in_channels
,
num_ctrlpoints
=
num_control_points
,
activation
=
stn_activation
)
self
.
tps_inputsize
=
tps_inputsize
self
.
out_channels
=
in_channels
def
forward
(
self
,
image
):
stn_input
=
paddle
.
nn
.
functional
.
interpolate
(
image
,
self
.
tps_inputsize
,
mode
=
"bilinear"
,
align_corners
=
True
)
stn_img_feat
,
ctrl_points
=
self
.
stn_head
(
stn_input
)
x
,
_
=
self
.
tps
(
image
,
ctrl_points
)
return
x
ppocr/modeling/transforms/tps.py
浏览文件 @
a14f8da9
...
...
@@ -22,9 +22,6 @@ from paddle import nn, ParamAttr
from
paddle.nn
import
functional
as
F
import
numpy
as
np
from
.tps_spatial_transformer
import
TPSSpatialTransformer
from
.stn
import
STN
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
...
...
@@ -305,25 +302,3 @@ class TPS(nn.Layer):
[
-
1
,
image
.
shape
[
2
],
image
.
shape
[
3
],
2
])
batch_I_r
=
F
.
grid_sample
(
x
=
image
,
grid
=
batch_P_prime
)
return
batch_I_r
class
STN_ON
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
tps_inputsize
,
tps_outputsize
,
num_control_points
,
tps_margins
,
stn_activation
):
super
(
STN_ON
,
self
).
__init__
()
self
.
tps
=
TPSSpatialTransformer
(
output_image_size
=
tuple
(
tps_outputsize
),
num_control_points
=
num_control_points
,
margins
=
tuple
(
tps_margins
))
self
.
stn_head
=
STN
(
in_channels
=
in_channels
,
num_ctrlpoints
=
num_control_points
,
activation
=
stn_activation
)
self
.
tps_inputsize
=
tps_inputsize
self
.
out_channels
=
in_channels
def
forward
(
self
,
image
):
stn_input
=
paddle
.
nn
.
functional
.
interpolate
(
image
,
self
.
tps_inputsize
,
mode
=
"bilinear"
,
align_corners
=
True
)
stn_img_feat
,
ctrl_points
=
self
.
stn_head
(
stn_input
)
x
,
_
=
self
.
tps
(
image
,
ctrl_points
)
return
x
ppocr/postprocess/rec_postprocess.py
浏览文件 @
a14f8da9
...
...
@@ -322,7 +322,6 @@ class SEEDLabelDecode(BaseRecLabelDecode):
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
dict_character
=
dict_character
dict_character
=
dict_character
+
[
self
.
end_str
]
return
dict_character
...
...
requirements.txt
浏览文件 @
a14f8da9
...
...
@@ -11,4 +11,5 @@ opencv-contrib-python==4.4.0.46
cython
lxml
premailer
openpyxl
\ No newline at end of file
openpyxl
fasttext
==0.9.1
\ No newline at end of file
tools/program.py
浏览文件 @
a14f8da9
...
...
@@ -186,9 +186,8 @@ def train(config,
model
.
train
()
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
use_nrtr
=
config
[
'Architecture'
][
'algorithm'
]
==
"NRTR"
use_sar
=
config
[
'Architecture'
][
'algorithm'
]
==
'SAR'
use_seed
=
config
[
'Architecture'
][
'algorithm'
]
==
'SEED'
extra_input
=
config
[
'Architecture'
][
'algorithm'
]
in
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
]
try
:
model_type
=
config
[
'Architecture'
][
'model_type'
]
except
:
...
...
@@ -217,7 +216,7 @@ def train(config,
images
=
batch
[
0
]
if
use_srn
:
model_average
=
True
if
use_srn
or
model_type
==
'table'
or
use_nrtr
or
use_sar
or
use_seed
:
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
preds
=
model
(
images
)
...
...
@@ -281,8 +280,7 @@ def train(config,
post_process_class
,
eval_class
,
model_type
,
use_srn
=
use_srn
,
use_sar
=
use_sar
)
extra_input
=
extra_input
)
cur_metric_str
=
'cur metric, {}'
.
format
(
', '
.
join
(
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
cur_metric
.
items
()]))
logger
.
info
(
cur_metric_str
)
...
...
@@ -354,8 +352,7 @@ def eval(model,
post_process_class
,
eval_class
,
model_type
=
None
,
use_srn
=
False
,
use_sar
=
False
):
extra_input
=
False
):
model
.
eval
()
with
paddle
.
no_grad
():
total_frame
=
0.0
...
...
@@ -368,7 +365,7 @@ def eval(model,
break
images
=
batch
[
0
]
start
=
time
.
time
()
if
use_srn
or
model_type
==
'table'
or
use_sar
:
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
preds
=
model
(
images
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录