Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
882ad395
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
882ad395
编写于
11月 11, 2020
作者:
Z
zhoujun
提交者:
GitHub
11月 11, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1146 from WenmuZhou/dygraph_rc
add tps mdule
上级
dc6e724e
367c49df
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
402 addition
and
146 deletion
+402
-146
configs/rec/rec_mv3_none_bilstm_ctc.yml
configs/rec/rec_mv3_none_bilstm_ctc.yml
+1
-1
configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
+100
-0
ppocr/modeling/architectures/base_model.py
ppocr/modeling/architectures/base_model.py
+6
-5
ppocr/modeling/necks/rnn.py
ppocr/modeling/necks/rnn.py
+5
-3
ppocr/modeling/transform/__init__.py
ppocr/modeling/transform/__init__.py
+3
-1
ppocr/modeling/transform/tps.py
ppocr/modeling/transform/tps.py
+287
-0
ppocr/postprocess/db_postprocess_torch.py
ppocr/postprocess/db_postprocess_torch.py
+0
-136
未找到文件。
configs/rec/rec_mv3_none_bilstm_ctc.yml
浏览文件 @
882ad395
...
...
@@ -72,7 +72,7 @@ Train:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
loader
:
shuffle
:
Tru
e
shuffle
:
Fals
e
batch_size_per_card
:
256
drop_last
:
True
num_workers
:
8
...
...
configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
0 → 100644
浏览文件 @
882ad395
Global
:
use_gpu
:
true
epoch_num
:
72
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec/r34_vd_tps_bilstm_ctc/
save_epoch_step
:
3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
0
,
2000
]
# if pretrained_model is saved in static mode, load_static_weights must set to True
cal_metric_during_train
:
True
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path
:
character_type
:
en
max_text_length
:
25
infer_mode
:
False
use_space_char
:
False
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
learning_rate
:
0.0005
regularizer
:
name
:
'
L2'
factor
:
0
Architecture
:
model_type
:
rec
algorithm
:
CRNN
Transform
:
name
:
TPS
num_fiducial
:
20
loc_lr
:
0.1
model_name
:
small
Backbone
:
name
:
ResNet
layers
:
34
Neck
:
name
:
SequenceEncoder
encoder_type
:
rnn
hidden_size
:
256
Head
:
name
:
CTCHead
fc_decay
:
0
Loss
:
name
:
CTCLoss
PostProcess
:
name
:
CTCLabelDecode
Metric
:
name
:
RecMetric
main_indicator
:
acc
Train
:
dataset
:
name
:
LMDBDateSet
data_dir
:
./train_data/data_lmdb_release/training/
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
CTCLabelEncode
:
# Class handling label
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
100
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
loader
:
shuffle
:
True
batch_size_per_card
:
256
drop_last
:
True
num_workers
:
8
Eval
:
dataset
:
name
:
LMDBDateSet
data_dir
:
./train_data/data_lmdb_release/validation/
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
CTCLabelEncode
:
# Class handling label
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
100
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
256
num_workers
:
4
ppocr/modeling/architectures/base_model.py
浏览文件 @
882ad395
...
...
@@ -16,13 +16,14 @@ from __future__ import division
from
__future__
import
print_function
from
paddle
import
nn
from
ppocr.modeling.transform
import
build_transform
from
ppocr.modeling.backbones
import
build_backbone
from
ppocr.modeling.necks
import
build_neck
from
ppocr.modeling.heads
import
build_head
__all__
=
[
'BaseModel'
]
class
BaseModel
(
nn
.
Layer
):
def
__init__
(
self
,
config
):
"""
...
...
@@ -31,7 +32,7 @@ class BaseModel(nn.Layer):
config (dict): the super parameters for module.
"""
super
(
BaseModel
,
self
).
__init__
()
in_channels
=
config
.
get
(
'in_channels'
,
3
)
model_type
=
config
[
'model_type'
]
# build transfrom,
...
...
@@ -50,7 +51,7 @@ class BaseModel(nn.Layer):
config
[
"Backbone"
][
'in_channels'
]
=
in_channels
self
.
backbone
=
build_backbone
(
config
[
"Backbone"
],
model_type
)
in_channels
=
self
.
backbone
.
out_channels
# build neck
# for rec, neck can be cnn,rnn or reshape(None)
# for det, neck can be FPN, BIFPN and so on.
...
...
@@ -62,7 +63,7 @@ class BaseModel(nn.Layer):
config
[
'Neck'
][
'in_channels'
]
=
in_channels
self
.
neck
=
build_neck
(
config
[
'Neck'
])
in_channels
=
self
.
neck
.
out_channels
# # build head, head is need for det, rec and cls
config
[
"Head"
][
'in_channels'
]
=
in_channels
self
.
head
=
build_head
(
config
[
"Head"
])
...
...
@@ -74,4 +75,4 @@ class BaseModel(nn.Layer):
if
self
.
use_neck
:
x
=
self
.
neck
(
x
)
x
=
self
.
head
(
x
)
return
x
\ No newline at end of file
return
x
ppocr/modeling/necks/rnn.py
浏览文件 @
882ad395
...
...
@@ -28,8 +28,9 @@ class Im2Seq(nn.Layer):
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
x
=
x
.
reshape
((
B
,
-
1
,
W
))
x
=
x
.
transpose
((
0
,
2
,
1
))
# (NTC)(batch, width, channels)
assert
H
==
1
x
=
x
.
squeeze
(
axis
=
2
)
x
=
x
.
transpose
([
0
,
2
,
1
])
# (NTC)(batch, width, channels)
return
x
...
...
@@ -76,7 +77,8 @@ class SequenceEncoder(nn.Layer):
'fc'
:
EncoderWithFC
,
'rnn'
:
EncoderWithRNN
}
assert
encoder_type
in
support_encoder_dict
,
'{} must in {}'
.
format
(
encoder_type
,
support_encoder_dict
.
keys
())
assert
encoder_type
in
support_encoder_dict
,
'{} must in {}'
.
format
(
encoder_type
,
support_encoder_dict
.
keys
())
self
.
encoder
=
support_encoder_dict
[
encoder_type
](
self
.
encoder_reshape
.
out_channels
,
hidden_size
)
...
...
ppocr/modeling/transform/__init__.py
浏览文件 @
882ad395
...
...
@@ -16,7 +16,9 @@ __all__ = ['build_transform']
def
build_transform
(
config
):
support_dict
=
[
''
]
from
.tps
import
TPS
support_dict
=
[
'TPS'
]
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
...
...
ppocr/modeling/transform/tps.py
0 → 100644
浏览文件 @
882ad395
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
nn
,
ParamAttr
from
paddle.nn
import
functional
as
F
import
numpy
as
np
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
groups
=
1
,
act
=
None
,
name
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
(
kernel_size
-
1
)
//
2
,
groups
=
groups
,
weight_attr
=
ParamAttr
(
name
=
name
+
"_weights"
),
bias_attr
=
False
)
bn_name
=
"bn_"
+
name
self
.
bn
=
nn
.
BatchNorm
(
out_channels
,
act
=
act
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
),
bias_attr
=
ParamAttr
(
bn_name
+
'_offset'
),
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
return
x
class
LocalizationNetwork
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
num_fiducial
,
loc_lr
,
model_name
):
super
(
LocalizationNetwork
,
self
).
__init__
()
self
.
F
=
num_fiducial
F
=
num_fiducial
if
model_name
==
"large"
:
num_filters_list
=
[
64
,
128
,
256
,
512
]
fc_dim
=
256
else
:
num_filters_list
=
[
16
,
32
,
64
,
128
]
fc_dim
=
64
self
.
block_list
=
[]
for
fno
in
range
(
0
,
len
(
num_filters_list
)):
num_filters
=
num_filters_list
[
fno
]
name
=
"loc_conv%d"
%
fno
conv
=
self
.
add_sublayer
(
name
,
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
num_filters
,
kernel_size
=
3
,
act
=
'relu'
,
name
=
name
))
self
.
block_list
.
append
(
conv
)
if
fno
==
len
(
num_filters_list
)
-
1
:
pool
=
nn
.
AdaptiveAvgPool2D
(
1
)
else
:
pool
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
in_channels
=
num_filters
self
.
block_list
.
append
(
pool
)
name
=
"loc_fc1"
self
.
fc1
=
nn
.
Linear
(
in_channels
,
fc_dim
,
weight_attr
=
ParamAttr
(
learning_rate
=
loc_lr
,
name
=
name
+
"_w"
),
bias_attr
=
ParamAttr
(
name
=
name
+
'.b_0'
),
name
=
name
)
# Init fc2 in LocalizationNetwork
initial_bias
=
self
.
get_initial_fiducials
()
initial_bias
=
initial_bias
.
reshape
(
-
1
)
name
=
"loc_fc2"
param_attr
=
ParamAttr
(
learning_rate
=
loc_lr
,
initializer
=
nn
.
initializer
.
Assign
(
np
.
zeros
([
fc_dim
,
F
*
2
])),
name
=
name
+
"_w"
)
bias_attr
=
ParamAttr
(
learning_rate
=
loc_lr
,
initializer
=
nn
.
initializer
.
Assign
(
initial_bias
),
name
=
name
+
"_b"
)
self
.
fc2
=
nn
.
Linear
(
fc_dim
,
F
*
2
,
weight_attr
=
param_attr
,
bias_attr
=
bias_attr
,
name
=
name
)
self
.
out_channels
=
F
*
2
def
forward
(
self
,
x
):
"""
Estimating parameters of geometric transformation
Args:
image: input
Return:
batch_C_prime: the matrix of the geometric transformation
"""
B
=
x
.
shape
[
0
]
i
=
0
for
block
in
self
.
block_list
:
x
=
block
(
x
)
x
=
x
.
reshape
([
B
,
-
1
])
x
=
self
.
fc1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
fc2
(
x
)
x
=
x
.
reshape
(
shape
=
[
-
1
,
self
.
F
,
2
])
return
x
def
get_initial_fiducials
(
self
):
""" see RARE paper Fig. 6 (a) """
F
=
self
.
F
ctrl_pts_x
=
np
.
linspace
(
-
1.0
,
1.0
,
int
(
F
/
2
))
ctrl_pts_y_top
=
np
.
linspace
(
0.0
,
-
1.0
,
num
=
int
(
F
/
2
))
ctrl_pts_y_bottom
=
np
.
linspace
(
1.0
,
0.0
,
num
=
int
(
F
/
2
))
ctrl_pts_top
=
np
.
stack
([
ctrl_pts_x
,
ctrl_pts_y_top
],
axis
=
1
)
ctrl_pts_bottom
=
np
.
stack
([
ctrl_pts_x
,
ctrl_pts_y_bottom
],
axis
=
1
)
initial_bias
=
np
.
concatenate
([
ctrl_pts_top
,
ctrl_pts_bottom
],
axis
=
0
)
return
initial_bias
class
GridGenerator
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
num_fiducial
):
super
(
GridGenerator
,
self
).
__init__
()
self
.
eps
=
1e-6
self
.
F
=
num_fiducial
name
=
"ex_fc"
initializer
=
nn
.
initializer
.
Constant
(
value
=
0.0
)
param_attr
=
ParamAttr
(
learning_rate
=
0.0
,
initializer
=
initializer
,
name
=
name
+
"_w"
)
bias_attr
=
ParamAttr
(
learning_rate
=
0.0
,
initializer
=
initializer
,
name
=
name
+
"_b"
)
self
.
fc
=
nn
.
Linear
(
in_channels
,
6
,
weight_attr
=
param_attr
,
bias_attr
=
bias_attr
,
name
=
name
)
def
forward
(
self
,
batch_C_prime
,
I_r_size
):
"""
Generate the grid for the grid_sampler.
Args:
batch_C_prime: the matrix of the geometric transformation
I_r_size: the shape of the input image
Return:
batch_P_prime: the grid for the grid_sampler
"""
C
=
self
.
build_C
()
P
=
self
.
build_P
(
I_r_size
)
inv_delta_C
=
self
.
build_inv_delta_C
(
C
).
astype
(
'float32'
)
P_hat
=
self
.
build_P_hat
(
C
,
P
).
astype
(
'float32'
)
inv_delta_C_tensor
=
paddle
.
to_tensor
(
inv_delta_C
)
inv_delta_C_tensor
.
stop_gradient
=
True
P_hat_tensor
=
paddle
.
to_tensor
(
P_hat
)
P_hat_tensor
.
stop_gradient
=
True
batch_C_ex_part_tensor
=
self
.
get_expand_tensor
(
batch_C_prime
)
batch_C_ex_part_tensor
.
stop_gradient
=
True
batch_C_prime_with_zeros
=
paddle
.
concat
(
[
batch_C_prime
,
batch_C_ex_part_tensor
],
axis
=
1
)
batch_T
=
paddle
.
matmul
(
inv_delta_C_tensor
,
batch_C_prime_with_zeros
)
batch_P_prime
=
paddle
.
matmul
(
P_hat_tensor
,
batch_T
)
return
batch_P_prime
def
build_C
(
self
):
""" Return coordinates of fiducial points in I_r; C """
F
=
self
.
F
ctrl_pts_x
=
np
.
linspace
(
-
1.0
,
1.0
,
int
(
F
/
2
))
ctrl_pts_y_top
=
-
1
*
np
.
ones
(
int
(
F
/
2
))
ctrl_pts_y_bottom
=
np
.
ones
(
int
(
F
/
2
))
ctrl_pts_top
=
np
.
stack
([
ctrl_pts_x
,
ctrl_pts_y_top
],
axis
=
1
)
ctrl_pts_bottom
=
np
.
stack
([
ctrl_pts_x
,
ctrl_pts_y_bottom
],
axis
=
1
)
C
=
np
.
concatenate
([
ctrl_pts_top
,
ctrl_pts_bottom
],
axis
=
0
)
return
C
# F x 2
def
build_P
(
self
,
I_r_size
):
I_r_width
,
I_r_height
=
I_r_size
I_r_grid_x
=
(
np
.
arange
(
-
I_r_width
,
I_r_width
,
2
)
+
1.0
)
\
/
I_r_width
# self.I_r_width
I_r_grid_y
=
(
np
.
arange
(
-
I_r_height
,
I_r_height
,
2
)
+
1.0
)
\
/
I_r_height
# self.I_r_height
# P: self.I_r_width x self.I_r_height x 2
P
=
np
.
stack
(
np
.
meshgrid
(
I_r_grid_x
,
I_r_grid_y
),
axis
=
2
)
# n (= self.I_r_width x self.I_r_height) x 2
return
P
.
reshape
([
-
1
,
2
])
def
build_inv_delta_C
(
self
,
C
):
""" Return inv_delta_C which is needed to calculate T """
F
=
self
.
F
hat_C
=
np
.
zeros
((
F
,
F
),
dtype
=
float
)
# F x F
for
i
in
range
(
0
,
F
):
for
j
in
range
(
i
,
F
):
r
=
np
.
linalg
.
norm
(
C
[
i
]
-
C
[
j
])
hat_C
[
i
,
j
]
=
r
hat_C
[
j
,
i
]
=
r
np
.
fill_diagonal
(
hat_C
,
1
)
hat_C
=
(
hat_C
**
2
)
*
np
.
log
(
hat_C
)
# print(C.shape, hat_C.shape)
delta_C
=
np
.
concatenate
(
# F+3 x F+3
[
np
.
concatenate
(
[
np
.
ones
((
F
,
1
)),
C
,
hat_C
],
axis
=
1
),
# F x F+3
np
.
concatenate
(
[
np
.
zeros
((
2
,
3
)),
np
.
transpose
(
C
)],
axis
=
1
),
# 2 x F+3
np
.
concatenate
(
[
np
.
zeros
((
1
,
3
)),
np
.
ones
((
1
,
F
))],
axis
=
1
)
# 1 x F+3
],
axis
=
0
)
inv_delta_C
=
np
.
linalg
.
inv
(
delta_C
)
return
inv_delta_C
# F+3 x F+3
def
build_P_hat
(
self
,
C
,
P
):
F
=
self
.
F
eps
=
self
.
eps
n
=
P
.
shape
[
0
]
# n (= self.I_r_width x self.I_r_height)
# P_tile: n x 2 -> n x 1 x 2 -> n x F x 2
P_tile
=
np
.
tile
(
np
.
expand_dims
(
P
,
axis
=
1
),
(
1
,
F
,
1
))
C_tile
=
np
.
expand_dims
(
C
,
axis
=
0
)
# 1 x F x 2
P_diff
=
P_tile
-
C_tile
# n x F x 2
# rbf_norm: n x F
rbf_norm
=
np
.
linalg
.
norm
(
P_diff
,
ord
=
2
,
axis
=
2
,
keepdims
=
False
)
# rbf: n x F
rbf
=
np
.
multiply
(
np
.
square
(
rbf_norm
),
np
.
log
(
rbf_norm
+
eps
))
P_hat
=
np
.
concatenate
([
np
.
ones
((
n
,
1
)),
P
,
rbf
],
axis
=
1
)
return
P_hat
# n x F+3
def
get_expand_tensor
(
self
,
batch_C_prime
):
B
=
batch_C_prime
.
shape
[
0
]
batch_C_prime
=
batch_C_prime
.
reshape
([
B
,
-
1
])
batch_C_ex_part_tensor
=
self
.
fc
(
batch_C_prime
)
batch_C_ex_part_tensor
=
batch_C_ex_part_tensor
.
reshape
([
-
1
,
3
,
2
])
return
batch_C_ex_part_tensor
class
TPS
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
num_fiducial
,
loc_lr
,
model_name
):
super
(
TPS
,
self
).
__init__
()
self
.
loc_net
=
LocalizationNetwork
(
in_channels
,
num_fiducial
,
loc_lr
,
model_name
)
self
.
grid_generator
=
GridGenerator
(
self
.
loc_net
.
out_channels
,
num_fiducial
)
self
.
out_channels
=
in_channels
def
forward
(
self
,
image
):
image
.
stop_gradient
=
False
I_r_size
=
[
image
.
shape
[
3
],
image
.
shape
[
2
]]
batch_C_prime
=
self
.
loc_net
(
image
)
batch_P_prime
=
self
.
grid_generator
(
batch_C_prime
,
I_r_size
)
batch_P_prime
=
batch_P_prime
.
reshape
(
[
-
1
,
image
.
shape
[
2
],
image
.
shape
[
3
],
2
])
batch_I_r
=
F
.
grid_sample
(
x
=
image
,
grid
=
batch_P_prime
)
return
batch_I_r
ppocr/postprocess/db_postprocess_torch.py
已删除
100644 → 0
浏览文件 @
dc6e724e
import
cv2
import
paddle
import
numpy
as
np
import
pyclipper
from
shapely.geometry
import
Polygon
class
DBPostProcess
():
def
__init__
(
self
,
thresh
=
0.3
,
box_thresh
=
0.7
,
max_candidates
=
1000
,
unclip_ratio
=
1.5
):
self
.
min_size
=
3
self
.
thresh
=
thresh
self
.
box_thresh
=
box_thresh
self
.
max_candidates
=
max_candidates
self
.
unclip_ratio
=
unclip_ratio
def
__call__
(
self
,
pred
,
shape_list
,
is_output_polygon
=
False
):
'''
batch: (image, polygons, ignore_tags
h_w_list: 包含[h,w]的数组
pred:
binary: text region segmentation map, with shape (N, 1,H, W)
'''
if
isinstance
(
pred
,
paddle
.
Tensor
):
pred
=
pred
.
numpy
()
pred
=
pred
[:,
0
,
:,
:]
segmentation
=
self
.
binarize
(
pred
)
batch_out
=
[]
for
batch_index
in
range
(
pred
.
shape
[
0
]):
height
,
width
=
shape_list
[
batch_index
]
boxes
,
scores
=
self
.
post_p
(
pred
[
batch_index
],
segmentation
[
batch_index
],
width
,
height
,
is_output_polygon
=
is_output_polygon
)
batch_out
.
append
({
"points"
:
boxes
})
return
batch_out
def
binarize
(
self
,
pred
):
return
pred
>
self
.
thresh
def
post_p
(
self
,
pred
,
bitmap
,
dest_width
,
dest_height
,
is_output_polygon
=
True
):
'''
_bitmap: single map with shape (H, W),
whose values are binarized as {0, 1}
'''
height
,
width
=
pred
.
shape
boxes
=
[]
new_scores
=
[]
contours
,
_
=
cv2
.
findContours
((
bitmap
*
255
).
astype
(
np
.
uint8
),
cv2
.
RETR_LIST
,
cv2
.
CHAIN_APPROX_SIMPLE
)
for
contour
in
contours
[:
self
.
max_candidates
]:
epsilon
=
0.005
*
cv2
.
arcLength
(
contour
,
True
)
approx
=
cv2
.
approxPolyDP
(
contour
,
epsilon
,
True
)
points
=
approx
.
reshape
((
-
1
,
2
))
if
points
.
shape
[
0
]
<
4
:
continue
score
=
self
.
box_score_fast
(
pred
,
points
.
reshape
(
-
1
,
2
))
if
self
.
box_thresh
>
score
:
continue
if
points
.
shape
[
0
]
>
2
:
box
=
self
.
unclip
(
points
,
unclip_ratio
=
self
.
unclip_ratio
)
if
len
(
box
)
>
1
or
len
(
box
)
==
0
:
continue
else
:
continue
four_point_box
,
sside
=
self
.
get_mini_boxes
(
box
.
reshape
((
-
1
,
1
,
2
)))
if
sside
<
self
.
min_size
+
2
:
continue
if
not
is_output_polygon
:
box
=
np
.
array
(
four_point_box
)
else
:
box
=
box
.
reshape
(
-
1
,
2
)
box
[:,
0
]
=
np
.
clip
(
np
.
round
(
box
[:,
0
]
/
width
*
dest_width
),
0
,
dest_width
)
box
[:,
1
]
=
np
.
clip
(
np
.
round
(
box
[:,
1
]
/
height
*
dest_height
),
0
,
dest_height
)
boxes
.
append
(
box
)
new_scores
.
append
(
score
)
return
boxes
,
new_scores
def
unclip
(
self
,
box
,
unclip_ratio
=
1.5
):
poly
=
Polygon
(
box
)
distance
=
poly
.
area
*
unclip_ratio
/
poly
.
length
offset
=
pyclipper
.
PyclipperOffset
()
offset
.
AddPath
(
box
,
pyclipper
.
JT_ROUND
,
pyclipper
.
ET_CLOSEDPOLYGON
)
expanded
=
np
.
array
(
offset
.
Execute
(
distance
))
return
expanded
def
get_mini_boxes
(
self
,
contour
):
bounding_box
=
cv2
.
minAreaRect
(
contour
)
points
=
sorted
(
list
(
cv2
.
boxPoints
(
bounding_box
)),
key
=
lambda
x
:
x
[
0
])
index_1
,
index_2
,
index_3
,
index_4
=
0
,
1
,
2
,
3
if
points
[
1
][
1
]
>
points
[
0
][
1
]:
index_1
=
0
index_4
=
1
else
:
index_1
=
1
index_4
=
0
if
points
[
3
][
1
]
>
points
[
2
][
1
]:
index_2
=
2
index_3
=
3
else
:
index_2
=
3
index_3
=
2
box
=
[
points
[
index_1
],
points
[
index_2
],
points
[
index_3
],
points
[
index_4
]
]
return
box
,
min
(
bounding_box
[
1
])
def
box_score_fast
(
self
,
bitmap
,
_box
):
h
,
w
=
bitmap
.
shape
[:
2
]
box
=
_box
.
copy
()
xmin
=
np
.
clip
(
np
.
floor
(
box
[:,
0
].
min
()).
astype
(
np
.
int
),
0
,
w
-
1
)
xmax
=
np
.
clip
(
np
.
ceil
(
box
[:,
0
].
max
()).
astype
(
np
.
int
),
0
,
w
-
1
)
ymin
=
np
.
clip
(
np
.
floor
(
box
[:,
1
].
min
()).
astype
(
np
.
int
),
0
,
h
-
1
)
ymax
=
np
.
clip
(
np
.
ceil
(
box
[:,
1
].
max
()).
astype
(
np
.
int
),
0
,
h
-
1
)
mask
=
np
.
zeros
((
ymax
-
ymin
+
1
,
xmax
-
xmin
+
1
),
dtype
=
np
.
uint8
)
box
[:,
0
]
=
box
[:,
0
]
-
xmin
box
[:,
1
]
=
box
[:,
1
]
-
ymin
cv2
.
fillPoly
(
mask
,
box
.
reshape
(
1
,
-
1
,
2
).
astype
(
np
.
int32
),
1
)
return
cv2
.
mean
(
bitmap
[
ymin
:
ymax
+
1
,
xmin
:
xmax
+
1
],
mask
)[
0
]
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录