Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
8d9cfade
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8d9cfade
编写于
7月 04, 2022
作者:
W
wangjingyeye
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add polygons
上级
c26e7aee
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
68 addition
and
19 deletion
+68
-19
ppocr/metrics/eval_det_iou.py
ppocr/metrics/eval_det_iou.py
+2
-10
ppocr/postprocess/db_postprocess.py
ppocr/postprocess/db_postprocess.py
+63
-6
tools/infer_det.py
tools/infer_det.py
+3
-3
未找到文件。
ppocr/metrics/eval_det_iou.py
浏览文件 @
8d9cfade
...
...
@@ -83,14 +83,10 @@ class DetectionIoUEvaluator(object):
evaluationLog
=
""
# print(len(gt))
for
n
in
range
(
len
(
gt
)):
points
=
gt
[
n
][
'points'
]
# transcription = gt[n]['text']
dontCare
=
gt
[
n
][
'ignore'
]
# points = Polygon(points)
# points = points.buffer(0)
if
not
Polygon
(
points
).
is_valid
or
not
Polygon
(
points
).
is_simple
:
if
not
Polygon
(
points
).
is_valid
:
continue
gtPol
=
points
...
...
@@ -105,9 +101,7 @@ class DetectionIoUEvaluator(object):
for
n
in
range
(
len
(
pred
)):
points
=
pred
[
n
][
'points'
]
# points = Polygon(points)
# points = points.buffer(0)
if
not
Polygon
(
points
).
is_valid
or
not
Polygon
(
points
).
is_simple
:
if
not
Polygon
(
points
).
is_valid
:
continue
detPol
=
points
...
...
@@ -191,8 +185,6 @@ class DetectionIoUEvaluator(object):
methodHmean
=
0
if
methodRecall
+
methodPrecision
==
0
else
2
*
\
methodRecall
*
methodPrecision
/
(
methodRecall
+
methodPrecision
)
# print(methodRecall, methodPrecision, methodHmean)
# sys.exit(-1)
methodMetrics
=
{
'precision'
:
methodPrecision
,
'recall'
:
methodRecall
,
...
...
ppocr/postprocess/db_postprocess.py
浏览文件 @
8d9cfade
...
...
@@ -38,6 +38,7 @@ class DBPostProcess(object):
unclip_ratio
=
2.0
,
use_dilation
=
False
,
score_mode
=
"fast"
,
use_polygon
=
False
,
**
kwargs
):
self
.
thresh
=
thresh
self
.
box_thresh
=
box_thresh
...
...
@@ -45,6 +46,7 @@ class DBPostProcess(object):
self
.
unclip_ratio
=
unclip_ratio
self
.
min_size
=
3
self
.
score_mode
=
score_mode
self
.
use_polygon
=
use_polygon
assert
score_mode
in
[
"slow"
,
"fast"
],
"Score mode must be in [slow, fast] but got: {}"
.
format
(
score_mode
)
...
...
@@ -52,6 +54,56 @@ class DBPostProcess(object):
self
.
dilation_kernel
=
None
if
not
use_dilation
else
np
.
array
(
[[
1
,
1
],
[
1
,
1
]])
def
polygons_from_bitmap
(
self
,
pred
,
_bitmap
,
dest_width
,
dest_height
):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap
=
_bitmap
height
,
width
=
bitmap
.
shape
boxes
=
[]
scores
=
[]
contours
,
_
=
cv2
.
findContours
((
bitmap
*
255
).
astype
(
np
.
uint8
),
cv2
.
RETR_LIST
,
cv2
.
CHAIN_APPROX_SIMPLE
)
for
contour
in
contours
[:
self
.
max_candidates
]:
epsilon
=
0.002
*
cv2
.
arcLength
(
contour
,
True
)
approx
=
cv2
.
approxPolyDP
(
contour
,
epsilon
,
True
)
points
=
approx
.
reshape
((
-
1
,
2
))
# print(points)
if
points
.
shape
[
0
]
<
4
:
continue
score
=
self
.
box_score_fast
(
pred
,
points
.
reshape
(
-
1
,
2
))
if
self
.
box_thresh
>
score
:
continue
if
points
.
shape
[
0
]
>
2
:
box
=
self
.
unclip
(
points
,
self
.
unclip_ratio
)
if
len
(
box
)
>
1
:
continue
else
:
continue
box
=
box
.
reshape
(
-
1
,
2
)
# print(box)
_
,
sside
=
self
.
get_mini_boxes
(
box
.
reshape
((
-
1
,
1
,
2
)))
if
sside
<
self
.
min_size
+
2
:
continue
box
=
np
.
array
(
box
)
box
[:,
0
]
=
np
.
clip
(
np
.
round
(
box
[:,
0
]
/
width
*
dest_width
),
0
,
dest_width
)
box
[:,
1
]
=
np
.
clip
(
np
.
round
(
box
[:,
1
]
/
height
*
dest_height
),
0
,
dest_height
)
boxes
.
append
(
box
.
tolist
())
scores
.
append
(
score
)
# print(boxes)
return
boxes
,
scores
def
boxes_from_bitmap
(
self
,
pred
,
_bitmap
,
dest_width
,
dest_height
):
'''
_bitmap: single map with shape (1, H, W),
...
...
@@ -85,7 +137,7 @@ class DBPostProcess(object):
if
self
.
box_thresh
>
score
:
continue
box
=
self
.
unclip
(
points
).
reshape
(
-
1
,
1
,
2
)
box
=
self
.
unclip
(
points
,
self
.
unclip_ratio
).
reshape
(
-
1
,
1
,
2
)
box
,
sside
=
self
.
get_mini_boxes
(
box
)
if
sside
<
self
.
min_size
+
2
:
continue
...
...
@@ -99,8 +151,7 @@ class DBPostProcess(object):
scores
.
append
(
score
)
return
np
.
array
(
boxes
,
dtype
=
np
.
int16
),
scores
def
unclip
(
self
,
box
):
unclip_ratio
=
self
.
unclip_ratio
def
unclip
(
self
,
box
,
unclip_ratio
):
poly
=
Polygon
(
box
)
distance
=
poly
.
area
*
unclip_ratio
/
poly
.
length
offset
=
pyclipper
.
PyclipperOffset
()
...
...
@@ -185,6 +236,10 @@ class DBPostProcess(object):
self
.
dilation_kernel
)
else
:
mask
=
segmentation
[
batch_index
]
if
self
.
use_polygon
:
boxes
,
scores
=
self
.
polygons_from_bitmap
(
pred
[
batch_index
],
mask
,
src_w
,
src_h
)
else
:
boxes
,
scores
=
self
.
boxes_from_bitmap
(
pred
[
batch_index
],
mask
,
src_w
,
src_h
)
...
...
@@ -202,6 +257,7 @@ class DistillationDBPostProcess(object):
unclip_ratio
=
1.5
,
use_dilation
=
False
,
score_mode
=
"fast"
,
use_polygon
=
False
,
**
kwargs
):
self
.
model_name
=
model_name
self
.
key
=
key
...
...
@@ -211,7 +267,8 @@ class DistillationDBPostProcess(object):
max_candidates
=
max_candidates
,
unclip_ratio
=
unclip_ratio
,
use_dilation
=
use_dilation
,
score_mode
=
score_mode
)
score_mode
=
score_mode
,
use_polygon
=
use_polygon
)
def
__call__
(
self
,
predicts
,
shape_list
):
results
=
{}
...
...
tools/infer_det.py
浏览文件 @
8d9cfade
...
...
@@ -44,7 +44,7 @@ def draw_det_res(dt_boxes, config, img, img_name, save_path):
import
cv2
src_im
=
img
for
box
in
dt_boxes
:
box
=
box
.
astype
(
np
.
int32
).
reshape
((
-
1
,
1
,
2
))
box
=
np
.
array
(
box
)
.
astype
(
np
.
int32
).
reshape
((
-
1
,
1
,
2
))
cv2
.
polylines
(
src_im
,
[
box
],
True
,
color
=
(
255
,
255
,
0
),
thickness
=
2
)
if
not
os
.
path
.
exists
(
save_path
):
os
.
makedirs
(
save_path
)
...
...
@@ -106,7 +106,7 @@ def main():
dt_boxes_list
=
[]
for
box
in
boxes
:
tmp_json
=
{
"transcription"
:
""
}
tmp_json
[
'points'
]
=
box
.
tolist
(
)
tmp_json
[
'points'
]
=
list
(
box
)
dt_boxes_list
.
append
(
tmp_json
)
det_box_json
[
k
]
=
dt_boxes_list
save_det_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
...
...
@@ -118,7 +118,7 @@ def main():
# write result
for
box
in
boxes
:
tmp_json
=
{
"transcription"
:
""
}
tmp_json
[
'points'
]
=
box
.
tolist
(
)
tmp_json
[
'points'
]
=
list
(
box
)
dt_boxes_json
.
append
(
tmp_json
)
save_det_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/det_results/"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录