Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
bb49e1a5
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bb49e1a5
编写于
3月 08, 2021
作者:
J
Jethong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ADD PGNet_v2
上级
1f76f449
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
227 addition
and
1226 deletion
+227
-1226
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+2
-1
ppocr/metrics/e2e_metric.py
ppocr/metrics/e2e_metric.py
+0
-17
ppocr/modeling/necks/pg_fpn.py
ppocr/modeling/necks/pg_fpn.py
+160
-120
ppocr/utils/e2e_metric/Deteval.py
ppocr/utils/e2e_metric/Deteval.py
+13
-17
ppocr/utils/e2e_metric/polygon_fast.py
ppocr/utils/e2e_metric/polygon_fast.py
+13
-1
ppocr/utils/e2e_metric/tttt.py
ppocr/utils/e2e_metric/tttt.py
+0
-881
ppocr/utils/e2e_utils/extract_textpoint.py
ppocr/utils/e2e_utils/extract_textpoint.py
+13
-0
ppocr/utils/e2e_utils/ski_thin.py
ppocr/utils/e2e_utils/ski_thin.py
+13
-3
ppocr/utils/e2e_utils/visual.py
ppocr/utils/e2e_utils/visual.py
+13
-185
tools/program.py
tools/program.py
+0
-1
未找到文件。
ppocr/data/imaug/label_ops.py
浏览文件 @
bb49e1a5
...
@@ -37,6 +37,7 @@ class ClsLabelEncode(object):
...
@@ -37,6 +37,7 @@ class ClsLabelEncode(object):
class
E2ELabelEncode
(
object
):
class
E2ELabelEncode
(
object
):
def
__init__
(
self
,
label_list
,
**
kwargs
):
def
__init__
(
self
,
label_list
,
**
kwargs
):
self
.
label_list
=
label_list
self
.
label_list
=
label_list
self
.
max_len
=
50
def
__call__
(
self
,
data
):
def
__call__
(
self
,
data
):
text_label_index_list
,
temp_text
=
[],
[]
text_label_index_list
,
temp_text
=
[],
[]
...
@@ -47,7 +48,7 @@ class E2ELabelEncode(object):
...
@@ -47,7 +48,7 @@ class E2ELabelEncode(object):
for
c_
in
text
:
for
c_
in
text
:
if
c_
in
self
.
label_list
:
if
c_
in
self
.
label_list
:
temp_text
.
append
(
self
.
label_list
.
index
(
c_
))
temp_text
.
append
(
self
.
label_list
.
index
(
c_
))
temp_text
=
temp_text
+
[
36
]
*
(
50
-
len
(
temp_text
))
temp_text
=
temp_text
+
[
36
]
*
(
self
.
max_len
-
len
(
temp_text
))
text_label_index_list
.
append
(
temp_text
)
text_label_index_list
.
append
(
temp_text
)
data
[
'strs'
]
=
np
.
array
(
text_label_index_list
)
data
[
'strs'
]
=
np
.
array
(
text_label_index_list
)
return
data
return
data
...
...
ppocr/metrics/e2e_metric.py
浏览文件 @
bb49e1a5
...
@@ -32,16 +32,6 @@ class E2EMetric(object):
...
@@ -32,16 +32,6 @@ class E2EMetric(object):
self
.
reset
()
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
'''
batch: a list produced by dataloaders.
image: np.ndarray of shape (N, C, H, W).
ratio_list: np.ndarray of shape(N,2)
polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
preds: a list of dict produced by post process
points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
'''
gt_polyons_batch
=
batch
[
2
]
gt_polyons_batch
=
batch
[
2
]
temp_gt_strs_batch
=
batch
[
3
]
temp_gt_strs_batch
=
batch
[
3
]
ignore_tags_batch
=
batch
[
4
]
ignore_tags_batch
=
batch
[
4
]
...
@@ -72,13 +62,6 @@ class E2EMetric(object):
...
@@ -72,13 +62,6 @@ class E2EMetric(object):
self
.
results
.
append
(
result
)
self
.
results
.
append
(
result
)
def
get_metric
(
self
):
def
get_metric
(
self
):
"""
return metrics {
'precision': 0,
'recall': 0,
'hmean': 0
}
"""
metircs
=
combine_results
(
self
.
results
)
metircs
=
combine_results
(
self
.
results
)
self
.
reset
()
self
.
reset
()
return
metircs
return
metircs
...
...
ppocr/modeling/necks/pg_fpn.py
浏览文件 @
bb49e1a5
...
@@ -106,172 +106,212 @@ class DeConvBNLayer(nn.Layer):
...
@@ -106,172 +106,212 @@ class DeConvBNLayer(nn.Layer):
return
x
return
x
class
FPN_Up_Fusion
(
nn
.
Layer
):
class
PGFPN
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
super
(
FPN_Up_Fusion
,
self
).
__init__
()
super
(
PGFPN
,
self
).
__init__
()
in_channels
=
in_channels
[::
-
1
]
num_inputs
=
[
2048
,
2048
,
1024
,
512
,
256
]
out_channels
=
[
256
,
256
,
192
,
192
,
128
]
num_outputs
=
[
256
,
256
,
192
,
192
,
128
]
self
.
out_channels
=
128
# print(in_channels)
self
.
conv_bn_layer_1
=
ConvBNLayer
(
in_channels
=
3
,
out_channels
=
32
,
kernel_size
=
3
,
stride
=
1
,
act
=
None
,
name
=
'FPN_d1'
)
self
.
conv_bn_layer_2
=
ConvBNLayer
(
in_channels
=
64
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
act
=
None
,
name
=
'FPN_d2'
)
self
.
conv_bn_layer_3
=
ConvBNLayer
(
in_channels
=
256
,
out_channels
=
128
,
kernel_size
=
3
,
stride
=
1
,
act
=
None
,
name
=
'FPN_d3'
)
self
.
conv_bn_layer_4
=
ConvBNLayer
(
in_channels
=
32
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
2
,
act
=
None
,
name
=
'FPN_d4'
)
self
.
conv_bn_layer_5
=
ConvBNLayer
(
in_channels
=
64
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
'FPN_d5'
)
self
.
conv_bn_layer_6
=
ConvBNLayer
(
in_channels
=
64
,
out_channels
=
128
,
kernel_size
=
3
,
stride
=
2
,
act
=
None
,
name
=
'FPN_d6'
)
self
.
conv_bn_layer_7
=
ConvBNLayer
(
in_channels
=
128
,
out_channels
=
128
,
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
'FPN_d7'
)
self
.
conv_bn_layer_8
=
ConvBNLayer
(
in_channels
=
128
,
out_channels
=
128
,
kernel_size
=
1
,
stride
=
1
,
act
=
None
,
name
=
'FPN_d8'
)
self
.
h0_conv
=
ConvBNLayer
(
self
.
conv_h0
=
ConvBNLayer
(
in_channels
[
0
],
out_channels
[
0
],
1
,
1
,
act
=
None
,
name
=
'conv_h0'
)
in_channels
=
num_inputs
[
0
],
self
.
h1_conv
=
ConvBNLayer
(
out_channels
=
num_outputs
[
0
],
in_channels
[
1
],
out_channels
[
1
],
1
,
1
,
act
=
None
,
name
=
'conv_h1'
)
kernel_size
=
1
,
self
.
h2_conv
=
ConvBNLayer
(
stride
=
1
,
in_channels
[
2
],
out_channels
[
2
],
1
,
1
,
act
=
None
,
name
=
'conv_h2'
)
act
=
None
,
self
.
h3_conv
=
ConvBNLayer
(
name
=
"conv_h{}"
.
format
(
0
))
in_channels
[
3
],
out_channels
[
3
],
1
,
1
,
act
=
None
,
name
=
'conv_h3'
)
self
.
conv_h1
=
ConvBNLayer
(
self
.
h4_conv
=
ConvBNLayer
(
in_channels
=
num_inputs
[
1
],
in_channels
[
4
],
out_channels
[
4
],
1
,
1
,
act
=
None
,
name
=
'conv_h4'
)
out_channels
=
num_outputs
[
1
],
kernel_size
=
1
,
stride
=
1
,
act
=
None
,
name
=
"conv_h{}"
.
format
(
1
))
self
.
conv_h2
=
ConvBNLayer
(
in_channels
=
num_inputs
[
2
],
out_channels
=
num_outputs
[
2
],
kernel_size
=
1
,
stride
=
1
,
act
=
None
,
name
=
"conv_h{}"
.
format
(
2
))
self
.
conv_h3
=
ConvBNLayer
(
in_channels
=
num_inputs
[
3
],
out_channels
=
num_outputs
[
3
],
kernel_size
=
1
,
stride
=
1
,
act
=
None
,
name
=
"conv_h{}"
.
format
(
3
))
self
.
conv_h4
=
ConvBNLayer
(
in_channels
=
num_inputs
[
4
],
out_channels
=
num_outputs
[
4
],
kernel_size
=
1
,
stride
=
1
,
act
=
None
,
name
=
"conv_h{}"
.
format
(
4
))
self
.
dconv0
=
DeConvBNLayer
(
self
.
dconv0
=
DeConvBNLayer
(
in_channels
=
out_channel
s
[
0
],
in_channels
=
num_output
s
[
0
],
out_channels
=
out_channels
[
1
],
out_channels
=
num_outputs
[
0
+
1
],
name
=
"dconv_{}"
.
format
(
0
))
name
=
"dconv_{}"
.
format
(
0
))
self
.
dconv1
=
DeConvBNLayer
(
self
.
dconv1
=
DeConvBNLayer
(
in_channels
=
out_channel
s
[
1
],
in_channels
=
num_output
s
[
1
],
out_channels
=
out_channels
[
2
],
out_channels
=
num_outputs
[
1
+
1
],
act
=
None
,
act
=
None
,
name
=
"dconv_{}"
.
format
(
1
))
name
=
"dconv_{}"
.
format
(
1
))
self
.
dconv2
=
DeConvBNLayer
(
self
.
dconv2
=
DeConvBNLayer
(
in_channels
=
out_channel
s
[
2
],
in_channels
=
num_output
s
[
2
],
out_channels
=
out_channels
[
3
],
out_channels
=
num_outputs
[
2
+
1
],
act
=
None
,
act
=
None
,
name
=
"dconv_{}"
.
format
(
2
))
name
=
"dconv_{}"
.
format
(
2
))
self
.
dconv3
=
DeConvBNLayer
(
self
.
dconv3
=
DeConvBNLayer
(
in_channels
=
out_channel
s
[
3
],
in_channels
=
num_output
s
[
3
],
out_channels
=
out_channels
[
4
],
out_channels
=
num_outputs
[
3
+
1
],
act
=
None
,
act
=
None
,
name
=
"dconv_{}"
.
format
(
3
))
name
=
"dconv_{}"
.
format
(
3
))
self
.
conv_g1
=
ConvBNLayer
(
self
.
conv_g1
=
ConvBNLayer
(
in_channels
=
out_channel
s
[
1
],
in_channels
=
num_output
s
[
1
],
out_channels
=
out_channel
s
[
1
],
out_channels
=
num_output
s
[
1
],
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
act
=
'relu'
,
act
=
'relu'
,
name
=
"conv_g{}"
.
format
(
1
))
name
=
"conv_g{}"
.
format
(
1
))
self
.
conv_g2
=
ConvBNLayer
(
self
.
conv_g2
=
ConvBNLayer
(
in_channels
=
out_channel
s
[
2
],
in_channels
=
num_output
s
[
2
],
out_channels
=
out_channel
s
[
2
],
out_channels
=
num_output
s
[
2
],
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
act
=
'relu'
,
act
=
'relu'
,
name
=
"conv_g{}"
.
format
(
2
))
name
=
"conv_g{}"
.
format
(
2
))
self
.
conv_g3
=
ConvBNLayer
(
self
.
conv_g3
=
ConvBNLayer
(
in_channels
=
out_channel
s
[
3
],
in_channels
=
num_output
s
[
3
],
out_channels
=
out_channel
s
[
3
],
out_channels
=
num_output
s
[
3
],
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
act
=
'relu'
,
act
=
'relu'
,
name
=
"conv_g{}"
.
format
(
3
))
name
=
"conv_g{}"
.
format
(
3
))
self
.
conv_g4
=
ConvBNLayer
(
self
.
conv_g4
=
ConvBNLayer
(
in_channels
=
out_channel
s
[
4
],
in_channels
=
num_output
s
[
4
],
out_channels
=
out_channel
s
[
4
],
out_channels
=
num_output
s
[
4
],
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
act
=
'relu'
,
act
=
'relu'
,
name
=
"conv_g{}"
.
format
(
4
))
name
=
"conv_g{}"
.
format
(
4
))
self
.
convf
=
ConvBNLayer
(
self
.
convf
=
ConvBNLayer
(
in_channels
=
out_channel
s
[
4
],
in_channels
=
num_output
s
[
4
],
out_channels
=
out_channel
s
[
4
],
out_channels
=
num_output
s
[
4
],
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
act
=
None
,
act
=
None
,
name
=
"conv_f{}"
.
format
(
4
))
name
=
"conv_f{}"
.
format
(
4
))
def
_add_relu
(
self
,
x1
,
x2
):
x
=
paddle
.
add
(
x
=
x1
,
y
=
x2
)
x
=
F
.
relu
(
x
)
return
x
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
f
=
x
[
2
:][::
-
1
]
c0
,
c1
,
c2
,
c3
,
c4
,
c5
,
c6
=
x
h0
=
self
.
h0_conv
(
f
[
0
])
# FPN_Down_Fusion
h1
=
self
.
h1_conv
(
f
[
1
])
f
=
[
c0
,
c1
,
c2
]
h2
=
self
.
h2_conv
(
f
[
2
])
g
=
[
None
,
None
,
None
]
h3
=
self
.
h3_conv
(
f
[
3
])
h
=
[
None
,
None
,
None
]
h4
=
self
.
h4_conv
(
f
[
4
])
h
[
0
]
=
self
.
conv_bn_layer_1
(
f
[
0
])
h
[
1
]
=
self
.
conv_bn_layer_2
(
f
[
1
])
h
[
2
]
=
self
.
conv_bn_layer_3
(
f
[
2
])
g0
=
self
.
dconv0
(
h0
)
g
[
0
]
=
self
.
conv_bn_layer_4
(
h
[
0
])
g
[
1
]
=
paddle
.
add
(
g
[
0
],
h
[
1
])
g
[
1
]
=
F
.
relu
(
g
[
1
])
g
[
1
]
=
self
.
conv_bn_layer_5
(
g
[
1
])
g
[
1
]
=
self
.
conv_bn_layer_6
(
g
[
1
])
g1
=
self
.
dconv2
(
self
.
conv_g2
(
self
.
_add_relu
(
g0
,
h1
)))
g
[
2
]
=
paddle
.
add
(
g
[
1
],
h
[
2
])
g2
=
self
.
dconv2
(
self
.
conv_g2
(
self
.
_add_relu
(
g1
,
h2
)))
g
[
2
]
=
F
.
relu
(
g
[
2
])
g3
=
self
.
dconv3
(
self
.
conv_g2
(
self
.
_add_relu
(
g2
,
h3
)))
g
[
2
]
=
self
.
conv_bn_layer_7
(
g
[
2
])
g4
=
self
.
dconv4
(
self
.
conv_g2
(
self
.
_add_relu
(
g3
,
h4
)))
f_down
=
self
.
conv_bn_layer_8
(
g
[
2
])
return
g4
# FPN UP Fusion
f1
=
[
c6
,
c5
,
c4
,
c3
,
c2
]
g
=
[
None
,
None
,
None
,
None
,
None
]
h
=
[
None
,
None
,
None
,
None
,
None
]
h
[
0
]
=
self
.
conv_h0
(
f1
[
0
])
h
[
1
]
=
self
.
conv_h1
(
f1
[
1
])
h
[
2
]
=
self
.
conv_h2
(
f1
[
2
])
h
[
3
]
=
self
.
conv_h3
(
f1
[
3
])
h
[
4
]
=
self
.
conv_h4
(
f1
[
4
])
class
FPN_Down_Fusion
(
nn
.
Layer
):
g
[
0
]
=
self
.
dconv0
(
h
[
0
])
def
__init__
(
self
,
in_channels
):
g
[
1
]
=
paddle
.
add
(
g
[
0
],
h
[
1
])
super
(
FPN_Down_Fusion
,
self
).
__init__
()
g
[
1
]
=
F
.
relu
(
g
[
1
])
out_channels
=
[
32
,
64
,
128
]
g
[
1
]
=
self
.
conv_g1
(
g
[
1
])
g
[
1
]
=
self
.
dconv1
(
g
[
1
])
self
.
h0_conv
=
ConvBNLayer
(
g
[
2
]
=
paddle
.
add
(
g
[
1
],
h
[
2
])
in_channels
[
0
],
out_channels
[
0
],
3
,
1
,
act
=
None
,
name
=
'FPN_d1'
)
g
[
2
]
=
F
.
relu
(
g
[
2
])
self
.
h1_conv
=
ConvBNLayer
(
g
[
2
]
=
self
.
conv_g2
(
g
[
2
])
in_channels
[
1
],
out_channels
[
1
],
3
,
1
,
act
=
None
,
name
=
'FPN_d2'
)
g
[
2
]
=
self
.
dconv2
(
g
[
2
])
self
.
h2_conv
=
ConvBNLayer
(
in_channels
[
2
],
out_channels
[
2
],
3
,
1
,
act
=
None
,
name
=
'FPN_d3'
)
self
.
g0_conv
=
ConvBNLayer
(
g
[
3
]
=
paddle
.
add
(
g
[
2
],
h
[
3
])
out_channels
[
0
],
out_channels
[
1
],
3
,
2
,
act
=
None
,
name
=
'FPN_d4'
)
g
[
3
]
=
F
.
relu
(
g
[
3
])
g
[
3
]
=
self
.
conv_g3
(
g
[
3
])
g
[
3
]
=
self
.
dconv3
(
g
[
3
])
self
.
g1_conv
=
nn
.
Sequential
(
g
[
4
]
=
paddle
.
add
(
x
=
g
[
3
],
y
=
h
[
4
])
ConvBNLayer
(
g
[
4
]
=
F
.
relu
(
g
[
4
])
out_channels
[
1
],
g
[
4
]
=
self
.
conv_g4
(
g
[
4
])
out_channels
[
1
],
f_up
=
self
.
convf
(
g
[
4
])
3
,
f_common
=
paddle
.
add
(
f_down
,
f_up
)
1
,
act
=
'relu'
,
name
=
'FPN_d5'
),
ConvBNLayer
(
out_channels
[
1
],
out_channels
[
2
],
3
,
2
,
act
=
None
,
name
=
'FPN_d6'
))
self
.
g2_conv
=
nn
.
Sequential
(
ConvBNLayer
(
out_channels
[
2
],
out_channels
[
2
],
3
,
1
,
act
=
'relu'
,
name
=
'FPN_d7'
),
ConvBNLayer
(
out_channels
[
2
],
out_channels
[
2
],
1
,
1
,
act
=
None
,
name
=
'FPN_d8'
))
def
forward
(
self
,
x
):
f
=
x
[:
3
]
h0
=
self
.
h0_conv
(
f
[
0
])
h1
=
self
.
h1_conv
(
f
[
1
])
h2
=
self
.
h2_conv
(
f
[
2
])
g0
=
self
.
g0_conv
(
h0
)
g1
=
paddle
.
add
(
x
=
g0
,
y
=
h1
)
g1
=
F
.
relu
(
g1
)
g1
=
self
.
g1_conv
(
g1
)
g2
=
paddle
.
add
(
x
=
g1
,
y
=
h2
)
g2
=
F
.
relu
(
g2
)
g2
=
self
.
g2_conv
(
g2
)
return
g2
class
PGFPN
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
with_cab
=
False
,
**
kwargs
):
super
(
PGFPN
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
with_cab
=
with_cab
self
.
FPN_Down_Fusion
=
FPN_Down_Fusion
(
self
.
in_channels
)
self
.
FPN_Up_Fusion
=
FPN_Up_Fusion
(
self
.
in_channels
)
self
.
out_channels
=
128
def
forward
(
self
,
x
):
# down fpn
f_down
=
self
.
FPN_Down_Fusion
(
x
)
# up fpn
f_up
=
self
.
FPN_Up_Fusion
(
x
)
# fusion
f_common
=
paddle
.
add
(
x
=
f_down
,
y
=
f_up
)
f_common
=
F
.
relu
(
f_common
)
f_common
=
F
.
relu
(
f_common
)
return
f_common
return
f_common
ppocr/utils/e2e_metric/Deteval.py
浏览文件 @
bb49e1a5
from
os
import
listdir
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
import
os
,
sys
#
from
scipy
import
io
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
numpy
as
np
from
ppocr.utils.e2e_metric.polygon_fast
import
iod
,
area_of_intersection
,
area
from
ppocr.utils.e2e_metric.polygon_fast
import
iod
,
area_of_intersection
,
area
from
tqdm
import
tqdm
try
:
# python2
try
:
# python2
range
=
xrange
range
=
xrange
...
@@ -862,16 +871,3 @@ def combine_results(all_data):
...
@@ -862,16 +871,3 @@ def combine_results(all_data):
'f_score_e2e'
:
f_score_e2e
'f_score_e2e'
:
f_score_e2e
}
}
return
final
return
final
# a = [1526, 642, 1565, 629, 1579, 627, 1593, 625, 1607, 623, 1620, 622, 1634, 620, 1659, 620, 1654, 681, 1631, 680, 1618,
# 681, 1606, 681, 1594, 681, 1584, 682, 1573, 685, 1542, 694]
# gt_dict = [{'points': np.array(a).reshape(-1, 2), 'text': 'MILK'}]
# pred_dict = [{'points': np.array(a), 'text': 'ccc'},
# {'points': np.array(a), 'text': 'ccf'}]
# result = []
# for i in range(2):
# result.append(get_socre(gt_dict, pred_dict))
# print(111)
# a = combine_results(result)
# print(a)
ppocr/utils/e2e_metric/polygon_fast.py
浏览文件 @
bb49e1a5
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
numpy
as
np
from
shapely.geometry
import
Polygon
from
shapely.geometry
import
Polygon
#import Polygon
"""
"""
:param det_x: [1, N] Xs of detection's vertices
:param det_x: [1, N] Xs of detection's vertices
:param det_y: [1, N] Ys of detection's vertices
:param det_y: [1, N] Ys of detection's vertices
...
...
ppocr/utils/e2e_metric/tttt.py
已删除
100644 → 0
浏览文件 @
1f76f449
from
os
import
listdir
import
os
,
sys
from
scipy
import
io
import
numpy
as
np
from
ppocr.utils.e2e_metric.polygon_fast
import
iod
,
area_of_intersection
,
area
from
tqdm
import
tqdm
try
:
# python2
range
=
xrange
except
Exception
:
# python3
range
=
range
"""
Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('
\n
')'
"""
# if len(sys.argv) != 4:
# print('\n usage: test.py pred_dir gt_dir savefile')
# sys.exit()
global_tp
=
0
global_fp
=
0
global_fn
=
0
tr
=
0.7
tp
=
0.6
fsc_k
=
0.8
k
=
2
def
get_socre
(
gt_dict
,
pred_dict
):
# allInputs = listdir(input_dir)
allInputs
=
1
global_pred_str
=
[]
global_gt_str
=
[]
global_sigma
=
[]
global_tau
=
[]
def
input_reading_mod
(
pred_dict
,
input
):
"""This helper reads input from txt files"""
det
=
[]
n
=
len
(
pred_dict
)
for
i
in
range
(
n
):
points
=
pred_dict
[
i
][
'points'
]
text
=
pred_dict
[
i
][
'text'
]
# for i in range(len(points)):
point
=
","
.
join
(
map
(
str
,
points
.
reshape
(
-
1
,
)))
det
.
append
([
point
,
text
])
return
det
def
gt_reading_mod
(
gt_dict
,
gt_id
):
"""This helper reads groundtruths from mat files"""
# gt_id = gt_id.split('.')[0]
gt
=
[]
n
=
len
(
gt_dict
)
for
i
in
range
(
n
):
points
=
gt_dict
[
i
][
'points'
].
tolist
()
h
=
len
(
points
)
text
=
gt_dict
[
i
][
'text'
]
xx
=
[
np
.
array
(
[
'x:'
],
dtype
=
'<U2'
),
0
,
np
.
array
(
[
'y:'
],
dtype
=
'<U2'
),
0
,
np
.
array
(
[
'#'
],
dtype
=
'<U1'
),
np
.
array
(
[
'#'
],
dtype
=
'<U1'
)
]
t_x
,
t_y
=
[],
[]
for
j
in
range
(
h
):
t_x
.
append
(
points
[
j
][
0
])
t_y
.
append
(
points
[
j
][
1
])
xx
[
1
]
=
np
.
array
([
t_x
],
dtype
=
'int16'
)
xx
[
3
]
=
np
.
array
([
t_y
],
dtype
=
'int16'
)
if
text
!=
""
:
xx
[
4
]
=
np
.
array
([
text
],
dtype
=
'U{}'
.
format
(
len
(
text
)))
xx
[
5
]
=
np
.
array
([
'c'
],
dtype
=
'<U1'
)
gt
.
append
(
xx
)
return
gt
def
detection_filtering
(
detections
,
groundtruths
,
threshold
=
0.5
):
for
gt_id
,
gt
in
enumerate
(
groundtruths
):
print
"liushanshan gt[1] = {}"
.
format
(
gt
[
1
])
print
"liushanshan gt[2] = {}"
.
format
(
gt
[
2
])
print
"liushanshan gt[3] = {}"
.
format
(
gt
[
3
])
print
"liushanshan gt[4] = {}"
.
format
(
gt
[
4
])
print
"liushanshan gt[5] = {}"
.
format
(
gt
[
5
])
if
(
gt
[
5
]
==
'#'
)
and
(
gt
[
1
].
shape
[
1
]
>
1
):
gt_x
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
1
])))
gt_y
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
3
])))
for
det_id
,
detection
in
enumerate
(
detections
):
detection_orig
=
detection
detection
=
[
float
(
x
)
for
x
in
detection
[
0
].
split
(
','
)]
# detection = detection.split(',')
detection
=
list
(
map
(
int
,
detection
))
det_x
=
detection
[
0
::
2
]
det_y
=
detection
[
1
::
2
]
det_gt_iou
=
iod
(
det_x
,
det_y
,
gt_x
,
gt_y
)
if
det_gt_iou
>
threshold
:
detections
[
det_id
]
=
[]
detections
[:]
=
[
item
for
item
in
detections
if
item
!=
[]]
return
detections
def
sigma_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
):
"""
sigma = inter_area / gt_area
"""
# print(area_of_intersection(det_x, det_y, gt_x, gt_y))
return
np
.
round
((
area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
)
/
area
(
gt_x
,
gt_y
)),
2
)
def
tau_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
):
"""
tau = inter_area / det_area
"""
# print "liushanshan det_x {}".format(det_x)
# print "liushanshan det_y {}".format(det_y)
# print "liushanshan area {}".format(area(det_x, det_y))
# print "liushanshan tau = {}".format(np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2))
if
area
(
det_x
,
det_y
)
==
0.0
:
return
0
return
np
.
round
((
area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
)
/
area
(
det_x
,
det_y
)),
2
)
##############################Initialization###################################
###############################################################################
single_data
=
{}
for
input_id
in
range
(
allInputs
):
if
(
input_id
!=
'.DS_Store'
)
and
(
input_id
!=
'Pascal_result.txt'
)
and
(
input_id
!=
'Pascal_result_curved.txt'
)
and
(
input_id
!=
'Pascal_result_non_curved.txt'
)
and
(
input_id
!=
'Deteval_result.txt'
)
and
(
input_id
!=
'Deteval_result_curved.txt'
)
\
and
(
input_id
!=
'Deteval_result_non_curved.txt'
):
print
(
input_id
)
detections
=
input_reading_mod
(
pred_dict
,
input_id
)
# print "liushanshan detections = {}".format(detections)
groundtruths
=
gt_reading_mod
(
gt_dict
,
input_id
)
detections
=
detection_filtering
(
detections
,
groundtruths
)
# filters detections overlapping with DC area
dc_id
=
[]
for
i
in
range
(
len
(
groundtruths
)):
if
groundtruths
[
i
][
5
]
==
'#'
:
dc_id
.
append
(
i
)
cnt
=
0
for
a
in
dc_id
:
num
=
a
-
cnt
del
groundtruths
[
num
]
cnt
+=
1
local_sigma_table
=
np
.
zeros
((
len
(
groundtruths
),
len
(
detections
)))
local_tau_table
=
np
.
zeros
((
len
(
groundtruths
),
len
(
detections
)))
local_pred_str
=
{}
local_gt_str
=
{}
for
gt_id
,
gt
in
enumerate
(
groundtruths
):
if
len
(
detections
)
>
0
:
for
det_id
,
detection
in
enumerate
(
detections
):
detection_orig
=
detection
detection
=
[
float
(
x
)
for
x
in
detection
[
0
].
split
(
','
)]
detection
=
list
(
map
(
int
,
detection
))
pred_seq_str
=
detection_orig
[
1
].
strip
()
det_x
=
detection
[
0
::
2
]
det_y
=
detection
[
1
::
2
]
gt_x
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
1
])))
gt_y
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
3
])))
gt_seq_str
=
str
(
gt
[
4
].
tolist
()[
0
])
local_sigma_table
[
gt_id
,
det_id
]
=
sigma_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
)
local_tau_table
[
gt_id
,
det_id
]
=
tau_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
)
local_pred_str
[
det_id
]
=
pred_seq_str
local_gt_str
[
gt_id
]
=
gt_seq_str
global_sigma
.
append
(
local_sigma_table
)
global_tau
.
append
(
local_tau_table
)
global_pred_str
.
append
(
local_pred_str
)
global_gt_str
.
append
(
local_gt_str
)
print
"liushanshan global_pred_str = {}"
.
format
(
global_pred_str
)
print
"liushanshan global_gt_str = {}"
.
format
(
global_gt_str
)
single_data
[
'sigma'
]
=
global_sigma
single_data
[
'global_tau'
]
=
global_tau
single_data
[
'global_pred_str'
]
=
global_pred_str
single_data
[
'global_gt_str'
]
=
global_gt_str
return
single_data
def
combine_results
(
all_data
):
global_sigma
,
global_tau
,
global_pred_str
,
global_gt_str
=
[],
[],
[],
[]
for
data
in
all_data
:
global_sigma
.
append
(
data
[
'sigma'
])
global_tau
.
append
(
data
[
'global_tau'
])
global_pred_str
.
append
(
data
[
'global_pred_str'
])
global_gt_str
.
append
(
data
[
'global_gt_str'
])
global_accumulative_recall
=
0
global_accumulative_precision
=
0
total_num_gt
=
0
total_num_det
=
0
hit_str_count
=
0
hit_count
=
0
def
one_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
gt_id
in
range
(
num_gt
):
gt_matching_qualified_sigma_candidates
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
tr
)
gt_matching_num_qualified_sigma_candidates
=
gt_matching_qualified_sigma_candidates
[
0
].
shape
[
0
]
gt_matching_qualified_tau_candidates
=
np
.
where
(
local_tau_table
[
gt_id
,
:]
>
tp
)
gt_matching_num_qualified_tau_candidates
=
gt_matching_qualified_tau_candidates
[
0
].
shape
[
0
]
det_matching_qualified_sigma_candidates
=
np
.
where
(
local_sigma_table
[:,
gt_matching_qualified_sigma_candidates
[
0
]]
>
tr
)
det_matching_num_qualified_sigma_candidates
=
det_matching_qualified_sigma_candidates
[
0
].
shape
[
0
]
det_matching_qualified_tau_candidates
=
np
.
where
(
local_tau_table
[:,
gt_matching_qualified_tau_candidates
[
0
]]
>
tp
)
det_matching_num_qualified_tau_candidates
=
det_matching_qualified_tau_candidates
[
0
].
shape
[
0
]
if
(
gt_matching_num_qualified_sigma_candidates
==
1
)
and
(
gt_matching_num_qualified_tau_candidates
==
1
)
and
\
(
det_matching_num_qualified_sigma_candidates
==
1
)
and
(
det_matching_num_qualified_tau_candidates
==
1
):
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
gt_id
]
=
1
matched_det_id
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
tr
)
# recg start
print
"liushanshan one to one det_id = {}"
.
format
(
matched_det_id
)
print
"liushanshan one to one gt_id = {}"
.
format
(
gt_id
)
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
matched_det_id
[
0
].
tolist
()[
0
]]
print
"liushanshan one to one gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan one to one pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
det_flag
[
0
,
matched_det_id
]
=
1
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
def
one_to_many
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
gt_id
in
range
(
num_gt
):
# skip the following if the groundtruth was matched
if
gt_flag
[
0
,
gt_id
]
>
0
:
continue
non_zero_in_sigma
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
0
)
num_non_zero_in_sigma
=
non_zero_in_sigma
[
0
].
shape
[
0
]
if
num_non_zero_in_sigma
>=
k
:
####search for all detections that overlaps with this groundtruth
qualified_tau_candidates
=
np
.
where
((
local_tau_table
[
gt_id
,
:]
>=
tp
)
&
(
det_flag
[
0
,
:]
==
0
))
num_qualified_tau_candidates
=
qualified_tau_candidates
[
0
].
shape
[
0
]
if
num_qualified_tau_candidates
==
1
:
if
((
local_tau_table
[
gt_id
,
qualified_tau_candidates
]
>=
tp
)
and
(
local_sigma_table
[
gt_id
,
qualified_tau_candidates
]
>=
tr
)):
# became an one-to-one case
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
gt_id
]
=
1
det_flag
[
0
,
qualified_tau_candidates
]
=
1
# recg start
print
"liushanshan one to many det_id = {}"
.
format
(
qualified_tau_candidates
)
print
"liushanshan one to many gt_id = {}"
.
format
(
gt_id
)
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
qualified_tau_candidates
[
0
].
tolist
()[
0
]]
print
"liushanshan one to many gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan one to many pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
elif
(
np
.
sum
(
local_sigma_table
[
gt_id
,
qualified_tau_candidates
])
>=
tr
):
gt_flag
[
0
,
gt_id
]
=
1
det_flag
[
0
,
qualified_tau_candidates
]
=
1
# recg start
print
"liushanshan one to many det_id = {}"
.
format
(
qualified_tau_candidates
)
print
"liushanshan one to many gt_id = {}"
.
format
(
gt_id
)
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
qualified_tau_candidates
[
0
].
tolist
()[
0
]]
print
"liushanshan one to many gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan one to many pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
global_accumulative_recall
=
global_accumulative_recall
+
fsc_k
global_accumulative_precision
=
global_accumulative_precision
+
num_qualified_tau_candidates
*
fsc_k
local_accumulative_recall
=
local_accumulative_recall
+
fsc_k
local_accumulative_precision
=
local_accumulative_precision
+
num_qualified_tau_candidates
*
fsc_k
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
def
many_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
det_id
in
range
(
num_det
):
# skip the following if the detection was matched
if
det_flag
[
0
,
det_id
]
>
0
:
continue
non_zero_in_tau
=
np
.
where
(
local_tau_table
[:,
det_id
]
>
0
)
num_non_zero_in_tau
=
non_zero_in_tau
[
0
].
shape
[
0
]
if
num_non_zero_in_tau
>=
k
:
####search for all detections that overlaps with this groundtruth
qualified_sigma_candidates
=
np
.
where
((
local_sigma_table
[:,
det_id
]
>=
tp
)
&
(
gt_flag
[
0
,
:]
==
0
))
num_qualified_sigma_candidates
=
qualified_sigma_candidates
[
0
].
shape
[
0
]
if
num_qualified_sigma_candidates
==
1
:
if
((
local_tau_table
[
qualified_sigma_candidates
,
det_id
]
>=
tp
)
and
(
local_sigma_table
[
qualified_sigma_candidates
,
det_id
]
>=
tr
)):
# became an one-to-one case
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
qualified_sigma_candidates
]
=
1
det_flag
[
0
,
det_id
]
=
1
# recg start
print
"liushanshan many to one det_id = {}"
.
format
(
det_id
)
print
"liushanshan many to one gt_id = {}"
.
format
(
qualified_sigma_candidates
)
pred_str_cur
=
global_pred_str
[
idy
][
det_id
]
gt_len
=
len
(
qualified_sigma_candidates
[
0
])
for
idx
in
range
(
gt_len
):
ele_gt_id
=
qualified_sigma_candidates
[
0
].
tolist
()[
idx
]
if
not
global_gt_str
[
idy
].
has_key
(
ele_gt_id
):
continue
gt_str_cur
=
global_gt_str
[
idy
][
ele_gt_id
]
print
"liushanshan many to one gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan many to one pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
break
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
break
# recg end
elif
(
np
.
sum
(
local_tau_table
[
qualified_sigma_candidates
,
det_id
])
>=
tp
):
det_flag
[
0
,
det_id
]
=
1
gt_flag
[
0
,
qualified_sigma_candidates
]
=
1
# recg start
print
"liushanshan many to one det_id = {}"
.
format
(
det_id
)
print
"liushanshan many to one gt_id = {}"
.
format
(
qualified_sigma_candidates
)
pred_str_cur
=
global_pred_str
[
idy
][
det_id
]
gt_len
=
len
(
qualified_sigma_candidates
[
0
])
for
idx
in
range
(
gt_len
):
ele_gt_id
=
qualified_sigma_candidates
[
0
].
tolist
()[
idx
]
if
not
global_gt_str
[
idy
].
has_key
(
ele_gt_id
):
continue
gt_str_cur
=
global_gt_str
[
idy
][
ele_gt_id
]
print
"liushanshan many to one gt_str_cur = {}"
.
format
(
gt_str_cur
)
print
"liushanshan many to one pred_str_cur = {}"
.
format
(
pred_str_cur
)
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
break
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
break
else
:
print
'no match'
# recg end
global_accumulative_recall
=
global_accumulative_recall
+
num_qualified_sigma_candidates
*
fsc_k
global_accumulative_precision
=
global_accumulative_precision
+
fsc_k
local_accumulative_recall
=
local_accumulative_recall
+
num_qualified_sigma_candidates
*
fsc_k
local_accumulative_precision
=
local_accumulative_precision
+
fsc_k
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
for
idx
in
range
(
len
(
global_sigma
)):
# print(allInputs[idx])
local_sigma_table
=
np
.
array
(
global_sigma
[
idx
])
local_tau_table
=
global_tau
[
idx
]
num_gt
=
local_sigma_table
.
shape
[
0
]
num_det
=
local_sigma_table
.
shape
[
1
]
total_num_gt
=
total_num_gt
+
num_gt
total_num_det
=
total_num_det
+
num_det
local_accumulative_recall
=
0
local_accumulative_precision
=
0
gt_flag
=
np
.
zeros
((
1
,
num_gt
))
det_flag
=
np
.
zeros
((
1
,
num_det
))
#######first check for one-to-one case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
one_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
#######then check for one-to-many case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
one_to_many
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
many_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
try
:
recall
=
global_accumulative_recall
/
total_num_gt
except
ZeroDivisionError
:
recall
=
0
try
:
precision
=
global_accumulative_precision
/
total_num_det
except
ZeroDivisionError
:
precision
=
0
try
:
f_score
=
2
*
precision
*
recall
/
(
precision
+
recall
)
except
ZeroDivisionError
:
f_score
=
0
try
:
seqerr
=
1
-
float
(
hit_str_count
)
/
global_accumulative_recall
except
ZeroDivisionError
:
seqerr
=
1
try
:
recall_e2e
=
float
(
hit_str_count
)
/
total_num_gt
except
ZeroDivisionError
:
recall_e2e
=
0
try
:
precision_e2e
=
float
(
hit_str_count
)
/
total_num_det
except
ZeroDivisionError
:
precision_e2e
=
0
try
:
f_score_e2e
=
2
*
precision_e2e
*
recall_e2e
/
(
precision_e2e
+
recall_e2e
)
except
ZeroDivisionError
:
f_score_e2e
=
0
final
=
{
'total_num_gt'
:
total_num_gt
,
'total_num_det'
:
total_num_det
,
'global_accumulative_recall'
:
global_accumulative_recall
,
'hit_str_count'
:
hit_str_count
,
'recall'
:
recall
,
'precision'
:
precision
,
'f_score'
:
f_score
,
'seqerr'
:
seqerr
,
'recall_e2e'
:
recall_e2e
,
'precision_e2e'
:
precision_e2e
,
'f_score_e2e'
:
f_score_e2e
}
return
final
# def combine_results(all_data):
# tr = 0.7
# tp = 0.6
# fsc_k = 0.8
# k = 2
# global_sigma = []
# global_tau = []
# global_pred_str = []
# global_gt_str = []
# for data in all_data:
# global_sigma.append(data['sigma'])
# global_tau.append(data['global_tau'])
# global_pred_str.append(data['global_pred_str'])
# global_gt_str.append(data['global_gt_str'])
#
# global_accumulative_recall = 0
# global_accumulative_precision = 0
# total_num_gt = 0
# total_num_det = 0
# hit_str_count = 0
# hit_count = 0
#
# def one_to_one(local_sigma_table, local_tau_table, local_accumulative_recall,
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idy):
# hit_str_num = 0
# for gt_id in range(num_gt):
# gt_matching_qualified_sigma_candidates = np.where(local_sigma_table[gt_id, :] > tr)
# gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[0].shape[0]
# gt_matching_qualified_tau_candidates = np.where(local_tau_table[gt_id, :] > tp)
# gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[0].shape[0]
#
# det_matching_qualified_sigma_candidates = np.where(
# local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] > tr)
# det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[0].shape[0]
# det_matching_qualified_tau_candidates = np.where(
# local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > tp)
# det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[0].shape[0]
#
# if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
# (det_matching_num_qualified_sigma_candidates == 1) and (
# det_matching_num_qualified_tau_candidates == 1):
# global_accumulative_recall = global_accumulative_recall + 1.0
# global_accumulative_precision = global_accumulative_precision + 1.0
# local_accumulative_recall = local_accumulative_recall + 1.0
# local_accumulative_precision = local_accumulative_precision + 1.0
#
# gt_flag[0, gt_id] = 1
# matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# # recg start
# print
# "liushanshan one to one det_id = {}".format(matched_det_id)
# print
# "liushanshan one to one gt_id = {}".format(gt_id)
# gt_str_cur = global_gt_str[idy][gt_id]
# pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[0]]
# print
# "liushanshan one to one gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan one to one pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# # recg end
# det_flag[0, matched_det_id] = 1
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
#
# def one_to_many(local_sigma_table, local_tau_table, local_accumulative_recall,
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idy):
# hit_str_num = 0
# for gt_id in range(num_gt):
# # skip the following if the groundtruth was matched
# if gt_flag[0, gt_id] > 0:
# continue
#
# non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
# num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
#
# if num_non_zero_in_sigma >= k:
# ####search for all detections that overlaps with this groundtruth
# qualified_tau_candidates = np.where((local_tau_table[gt_id, :] >= tp) & (det_flag[0, :] == 0))
# num_qualified_tau_candidates = qualified_tau_candidates[0].shape[0]
#
# if num_qualified_tau_candidates == 1:
# if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp) and (
# local_sigma_table[gt_id, qualified_tau_candidates] >= tr)):
# # became an one-to-one case
# global_accumulative_recall = global_accumulative_recall + 1.0
# global_accumulative_precision = global_accumulative_precision + 1.0
# local_accumulative_recall = local_accumulative_recall + 1.0
# local_accumulative_precision = local_accumulative_precision + 1.0
#
# gt_flag[0, gt_id] = 1
# det_flag[0, qualified_tau_candidates] = 1
# # recg start
# print
# "liushanshan one to many det_id = {}".format(qualified_tau_candidates)
# print
# "liushanshan one to many gt_id = {}".format(gt_id)
# gt_str_cur = global_gt_str[idy][gt_id]
# pred_str_cur = global_pred_str[idy][qualified_tau_candidates[0].tolist()[0]]
# print
# "liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan one to many pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# # recg end
# elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) >= tr):
# gt_flag[0, gt_id] = 1
# det_flag[0, qualified_tau_candidates] = 1
# # recg start
# print
# "liushanshan one to many det_id = {}".format(qualified_tau_candidates)
# print
# "liushanshan one to many gt_id = {}".format(gt_id)
# gt_str_cur = global_gt_str[idy][gt_id]
# pred_str_cur = global_pred_str[idy][qualified_tau_candidates[0].tolist()[0]]
# print
# "liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan one to many pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# # recg end
#
# global_accumulative_recall = global_accumulative_recall + fsc_k
# global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
#
# local_accumulative_recall = local_accumulative_recall + fsc_k
# local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
#
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
#
# def many_to_one(local_sigma_table, local_tau_table, local_accumulative_recall,
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idy):
# hit_str_num = 0
# for det_id in range(num_det):
# # skip the following if the detection was matched
# if det_flag[0, det_id] > 0:
# continue
#
# non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
# num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
#
# if num_non_zero_in_tau >= k:
# ####search for all detections that overlaps with this groundtruth
# qualified_sigma_candidates = np.where((local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
# num_qualified_sigma_candidates = qualified_sigma_candidates[0].shape[0]
#
# if num_qualified_sigma_candidates == 1:
# if ((local_tau_table[qualified_sigma_candidates, det_id] >= tp) and (
# local_sigma_table[qualified_sigma_candidates, det_id] >= tr)):
# # became an one-to-one case
# global_accumulative_recall = global_accumulative_recall + 1.0
# global_accumulative_precision = global_accumulative_precision + 1.0
# local_accumulative_recall = local_accumulative_recall + 1.0
# local_accumulative_precision = local_accumulative_precision + 1.0
#
# gt_flag[0, qualified_sigma_candidates] = 1
# det_flag[0, det_id] = 1
# # recg start
# print
# "liushanshan many to one det_id = {}".format(det_id)
# print
# "liushanshan many to one gt_id = {}".format(qualified_sigma_candidates)
# pred_str_cur = global_pred_str[idy][det_id]
# gt_len = len(qualified_sigma_candidates[0])
# for idx in range(gt_len):
# ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
# if ele_gt_id not in global_gt_str[idy]:
# continue
# gt_str_cur = global_gt_str[idy][ele_gt_id]
# print
# "liushanshan many to one gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan many to one pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# break
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# break
# # recg end
# elif (np.sum(local_tau_table[qualified_sigma_candidates, det_id]) >= tp):
# det_flag[0, det_id] = 1
# gt_flag[0, qualified_sigma_candidates] = 1
# # recg start
# print
# "liushanshan many to one det_id = {}".format(det_id)
# print
# "liushanshan many to one gt_id = {}".format(qualified_sigma_candidates)
# pred_str_cur = global_pred_str[idy][det_id]
# gt_len = len(qualified_sigma_candidates[0])
# for idx in range(gt_len):
# ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
# if not global_gt_str[idy].has_key(ele_gt_id):
# continue
# gt_str_cur = global_gt_str[idy][ele_gt_id]
# print
# "liushanshan many to one gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan many to one pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# break
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# break
# else:
# print
# 'no match'
# # recg end
#
# global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
# global_accumulative_precision = global_accumulative_precision + fsc_k
#
# local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
# local_accumulative_precision = local_accumulative_precision + fsc_k
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
#
# for idx in range(len(global_sigma)):
# local_sigma_table = np.array(global_sigma[idx])
# local_tau_table = np.array(global_tau[idx])
#
# num_gt = local_sigma_table.shape[0]
# num_det = local_sigma_table.shape[1]
#
# total_num_gt = total_num_gt + num_gt
# total_num_det = total_num_det + num_det
#
# local_accumulative_recall = 0
# local_accumulative_precision = 0
# gt_flag = np.zeros((1, num_gt))
# det_flag = np.zeros((1, num_det))
#
# #######first check for one-to-one case##########
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
# gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
# local_accumulative_recall, local_accumulative_precision,
# global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idx)
#
# hit_str_count += hit_str_num
# #######then check for one-to-many case##########
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
# gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
# local_accumulative_recall, local_accumulative_precision,
# global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idx)
# hit_str_count += hit_str_num
# #######then check for many-to-one case##########
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
# gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
# local_accumulative_recall, local_accumulative_precision,
# global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idx)
# try:
# recall = global_accumulative_recall / total_num_gt
# except ZeroDivisionError:
# recall = 0
#
# try:
# precision = global_accumulative_precision / total_num_det
# except ZeroDivisionError:
# precision = 0
#
# try:
# f_score = 2 * precision * recall / (precision + recall)
# except ZeroDivisionError:
# f_score = 0
#
# try:
# seqerr = 1 - float(hit_str_count) / global_accumulative_recall
# except ZeroDivisionError:
# seqerr = 1
#
# try:
# recall_e2e = float(hit_str_count) / total_num_gt
# except ZeroDivisionError:
# recall_e2e = 0
#
# try:
# precision_e2e = float(hit_str_count) / total_num_det
# except ZeroDivisionError:
# precision_e2e = 0
#
# try:
# f_score_e2e = 2 * precision_e2e * recall_e2e / (precision_e2e + recall_e2e)
# except ZeroDivisionError:
# f_score_e2e = 0
#
# final = {
# 'total_num_gt': total_num_gt,
# 'total_num_det': total_num_det,
# 'global_accumulative_recall': global_accumulative_recall,
# 'hit_str_count': hit_str_count,
# 'recall': recall,
# 'precision': precision,
# 'f_score': f_score,
# 'seqerr': seqerr,
# 'recall_e2e': recall_e2e,
# 'precision_e2e': precision_e2e,
# 'f_score_e2e': f_score_e2e
# }
# return final
a
=
[
1526
,
642
,
1565
,
629
,
1579
,
627
,
1593
,
625
,
1607
,
623
,
1620
,
622
,
1634
,
620
,
1659
,
620
,
1654
,
681
,
1631
,
680
,
1618
,
681
,
1606
,
681
,
1594
,
681
,
1584
,
682
,
1573
,
685
,
1542
,
694
]
gt_dict
=
[{
'points'
:
np
.
array
(
a
).
reshape
(
-
1
,
2
),
'text'
:
'MILK'
}]
pred_dict
=
[{
'points'
:
np
.
array
(
a
),
'text'
:
'ccc'
},
{
'points'
:
np
.
array
(
a
),
'text'
:
'ccf'
}]
result
=
[]
result
.
append
(
get_socre
(
gt_dict
,
gt_dict
))
a
=
combine_results
(
result
)
print
(
a
)
ppocr/utils/e2e_utils/extract_textpoint.py
浏览文件 @
bb49e1a5
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains various CTC decoders."""
"""Contains various CTC decoders."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
...
...
ppocr/utils/e2e_utils/ski_thin.py
浏览文件 @
bb49e1a5
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Algorithms for computing the skeleton of a binary image
#
"""
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
numpy
as
np
from
scipy
import
ndimage
as
ndi
from
scipy
import
ndimage
as
ndi
...
...
ppocr/utils/e2e_utils/visual.py
浏览文件 @
bb49e1a5
import
os
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
numpy
as
np
import
cv2
import
cv2
import
time
import
time
def
visualize_e2e_result
(
im_fn
,
poly_list
,
seq_strs
,
src_im
):
"""
"""
result_path
=
'./out'
im_basename
=
os
.
path
.
basename
(
im_fn
)
im_prefix
=
im_basename
[:
im_basename
.
rfind
(
'.'
)]
vis_det_img
=
src_im
.
copy
()
valid_set
=
'partvgg'
gt_dir
=
"/Users/hongyongjie/Downloads/part_vgg_synth/train"
text_path
=
os
.
path
.
join
(
gt_dir
,
im_prefix
+
'.txt'
)
fid
=
open
(
text_path
,
'r'
)
lines
=
[
line
.
strip
()
for
line
in
fid
.
readlines
()]
for
line
in
lines
:
if
valid_set
==
'partvgg'
:
tokens
=
line
.
strip
().
split
(
'
\t
'
)[
0
].
split
(
','
)
# tokens = line.strip().split(',')
coords
=
tokens
[:]
coords
=
list
(
map
(
float
,
coords
))
gt_poly
=
np
.
array
(
coords
).
reshape
(
1
,
4
,
2
)
elif
valid_set
==
'totaltext'
:
tokens
=
line
.
strip
().
split
(
'
\t
'
)[
0
].
split
(
','
)
coords
=
tokens
[:]
coords_len
=
len
(
coords
)
/
2
coords
=
list
(
map
(
float
,
coords
))
gt_poly
=
np
.
array
(
coords
).
reshape
(
1
,
coords_len
,
2
)
cv2
.
polylines
(
vis_det_img
,
np
.
array
(
gt_poly
).
astype
(
np
.
int32
),
isClosed
=
True
,
color
=
(
255
,
0
,
0
),
thickness
=
2
)
for
detected_poly
,
recognized_str
in
zip
(
poly_list
,
seq_strs
):
cv2
.
polylines
(
vis_det_img
,
np
.
array
(
detected_poly
[
np
.
newaxis
,
...]).
astype
(
np
.
int32
),
isClosed
=
True
,
color
=
(
0
,
0
,
255
),
thickness
=
2
)
cv2
.
putText
(
vis_det_img
,
recognized_str
,
org
=
(
int
(
detected_poly
[
0
,
0
]),
int
(
detected_poly
[
0
,
1
])),
fontFace
=
cv2
.
FONT_HERSHEY_COMPLEX
,
fontScale
=
0.7
,
color
=
(
0
,
255
,
0
),
thickness
=
1
)
if
not
os
.
path
.
exists
(
result_path
):
os
.
makedirs
(
result_path
)
cv2
.
imwrite
(
"{}/{}_detection.jpg"
.
format
(
result_path
,
im_prefix
),
vis_det_img
)
def
visualization_output
(
src_image
,
f_tcl
,
f_chars
,
output_dir
,
image_prefix
=
None
):
"""
"""
# restore BGR image, CHW -> HWC
im_mean
=
[
0.485
,
0.456
,
0.406
]
im_std
=
[
0.229
,
0.224
,
0.225
]
im_mean
=
np
.
array
(
im_mean
).
reshape
((
3
,
1
,
1
))
im_std
=
np
.
array
(
im_std
).
reshape
((
3
,
1
,
1
))
src_image
*=
im_std
src_image
+=
im_mean
src_image
=
src_image
.
transpose
([
1
,
2
,
0
])
src_image
=
src_image
[:,
:,
::
-
1
]
*
255
# BGR -> RGB
H
,
W
,
_
=
src_image
.
shape
file_prefix
=
image_prefix
if
image_prefix
is
not
None
else
str
(
int
(
time
.
time
()
*
1000
))
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
# visualization f_tcl
tcl_file_name
=
os
.
path
.
join
(
output_dir
,
file_prefix
+
'_0_tcl.jpg'
)
vis_tcl_img
=
src_image
.
copy
()
f_tcl_resized
=
cv2
.
resize
(
f_tcl
,
dsize
=
(
W
,
H
))
vis_tcl_img
[:,
:,
1
]
=
f_tcl_resized
*
255
cv2
.
imwrite
(
tcl_file_name
,
vis_tcl_img
)
# visualization char maps
vis_char_img
=
src_image
.
copy
()
# CHW -> HWC
char_file_name
=
os
.
path
.
join
(
output_dir
,
file_prefix
+
'_1_chars.jpg'
)
f_chars
=
np
.
argmax
(
f_chars
,
axis
=
2
)[:,
:,
np
.
newaxis
].
astype
(
'float32'
)
f_chars
[
f_chars
<
95
]
=
1.0
f_chars
[
f_chars
==
95
]
=
0.0
f_chars_resized
=
cv2
.
resize
(
f_chars
,
dsize
=
(
W
,
H
))
vis_char_img
[:,
:,
1
]
=
f_chars_resized
*
255
cv2
.
imwrite
(
char_file_name
,
vis_char_img
)
def
visualize_point_result
(
im_fn
,
point_list
,
point_pair_list
,
src_im
,
gt_dir
,
result_path
):
"""
"""
im_basename
=
os
.
path
.
basename
(
im_fn
)
im_prefix
=
im_basename
[:
im_basename
.
rfind
(
'.'
)]
vis_det_img
=
src_im
.
copy
()
# draw gt bbox on the image.
text_path
=
os
.
path
.
join
(
gt_dir
,
im_prefix
+
'.txt'
)
fid
=
open
(
text_path
,
'r'
)
lines
=
[
line
.
strip
()
for
line
in
fid
.
readlines
()]
for
line
in
lines
:
tokens
=
line
.
strip
().
split
(
'
\t
'
)
coords
=
tokens
[
0
].
split
(
','
)
coords_len
=
len
(
coords
)
coords
=
list
(
map
(
float
,
coords
))
gt_poly
=
np
.
array
(
coords
).
reshape
(
1
,
coords_len
/
2
,
2
)
cv2
.
polylines
(
vis_det_img
,
np
.
array
(
gt_poly
).
astype
(
np
.
int32
),
isClosed
=
True
,
color
=
(
255
,
255
,
255
),
thickness
=
1
)
for
point
,
point_pair
in
zip
(
point_list
,
point_pair_list
):
cv2
.
line
(
vis_det_img
,
tuple
(
point_pair
[
0
]),
tuple
(
point_pair
[
1
]),
(
0
,
255
,
255
),
thickness
=
1
)
cv2
.
circle
(
vis_det_img
,
tuple
(
point
),
2
,
(
0
,
0
,
255
))
cv2
.
circle
(
vis_det_img
,
tuple
(
point_pair
[
0
]),
2
,
(
255
,
0
,
0
))
cv2
.
circle
(
vis_det_img
,
tuple
(
point_pair
[
1
]),
2
,
(
0
,
255
,
0
))
if
not
os
.
path
.
exists
(
result_path
):
os
.
makedirs
(
result_path
)
cv2
.
imwrite
(
"{}/{}_border_points.jpg"
.
format
(
result_path
,
im_prefix
),
vis_det_img
)
def
resize_image
(
im
,
max_side_len
=
512
):
def
resize_image
(
im
,
max_side_len
=
512
):
"""
"""
resize image to a size multiple of max_stride which is required by the network
resize image to a size multiple of max_stride which is required by the network
...
@@ -295,49 +169,3 @@ def norm2(x, axis=None):
...
@@ -295,49 +169,3 @@ def norm2(x, axis=None):
def
cos
(
p1
,
p2
):
def
cos
(
p1
,
p2
):
return
(
p1
*
p2
).
sum
()
/
(
norm2
(
p1
)
*
norm2
(
p2
))
return
(
p1
*
p2
).
sum
()
/
(
norm2
(
p1
)
*
norm2
(
p2
))
def
generate_direction_info
(
image_fn
,
H
,
W
,
ratio_h
,
ratio_w
,
max_length
=
640
,
out_scale
=
4
,
gt_dir
=
None
):
"""
"""
im_basename
=
os
.
path
.
basename
(
image_fn
)
im_prefix
=
im_basename
[:
im_basename
.
rfind
(
'.'
)]
instance_direction_map
=
np
.
zeros
(
shape
=
[
H
//
out_scale
,
W
//
out_scale
,
3
])
if
gt_dir
is
None
:
gt_dir
=
'/home/vis/huangzuming/data/SYNTH_DATA/part_vgg_synth_icdar/processed/val/poly'
# get gt label map
text_path
=
os
.
path
.
join
(
gt_dir
,
im_prefix
+
'.txt'
)
fid
=
open
(
text_path
,
'r'
)
lines
=
[
line
.
strip
()
for
line
in
fid
.
readlines
()]
for
label_idx
,
line
in
enumerate
(
lines
,
start
=
1
):
coords
,
txt
=
line
.
strip
().
split
(
'
\t
'
)
if
txt
==
'###'
:
continue
tokens
=
coords
.
strip
().
split
(
','
)
coords
=
list
(
map
(
float
,
tokens
))
poly
=
np
.
array
(
coords
).
reshape
(
4
,
2
)
*
np
.
array
(
[
ratio_w
,
ratio_h
]).
reshape
(
1
,
2
)
/
out_scale
mid_idx
=
poly
.
shape
[
0
]
//
2
direct_vector
=
(
(
poly
[
mid_idx
]
+
poly
[
mid_idx
-
1
])
-
(
poly
[
0
]
+
poly
[
-
1
]))
/
2.0
direct_vector
/=
len
(
txt
)
# l2_distance = norm2(direct_vector)
# avg_char_distance = l2_distance / len(txt)
avg_char_distance
=
1.0
direct_label
=
(
direct_vector
[
0
],
direct_vector
[
1
],
avg_char_distance
)
cv2
.
fillPoly
(
instance_direction_map
,
poly
.
round
().
astype
(
np
.
int32
)[
np
.
newaxis
,
:,
:],
direct_label
)
instance_direction_map
=
instance_direction_map
.
transpose
([
2
,
0
,
1
])
return
instance_direction_map
[:
2
,
...]
tools/program.py
浏览文件 @
bb49e1a5
...
@@ -44,7 +44,6 @@ class ArgsParser(ArgumentParser):
...
@@ -44,7 +44,6 @@ class ArgsParser(ArgumentParser):
def
parse_args
(
self
,
argv
=
None
):
def
parse_args
(
self
,
argv
=
None
):
args
=
super
(
ArgsParser
,
self
).
parse_args
(
argv
)
args
=
super
(
ArgsParser
,
self
).
parse_args
(
argv
)
args
.
config
=
'/Users/hongyongjie/project/PaddleOCR/configs/e2e/e2e_r50_vd_pg.yml'
assert
args
.
config
is
not
None
,
\
assert
args
.
config
is
not
None
,
\
"Please specify --config=configure_file_path."
"Please specify --config=configure_file_path."
args
.
opt
=
self
.
_parse_opt
(
args
.
opt
)
args
.
opt
=
self
.
_parse_opt
(
args
.
opt
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录