Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
330f08ff
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
330f08ff
编写于
6月 10, 2021
作者:
W
WenmuZhou
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix table infer bug
上级
e836ab7f
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
252 addition
and
5 deletion
+252
-5
ppocr/data/imaug/__init__.py
ppocr/data/imaug/__init__.py
+1
-0
ppocr/data/imaug/gen_table_mask.py
ppocr/data/imaug/gen_table_mask.py
+244
-0
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+3
-2
ppstructure/table/predict_structure.py
ppstructure/table/predict_structure.py
+3
-2
ppstructure/table/predict_table.py
ppstructure/table/predict_table.py
+1
-1
未找到文件。
ppocr/data/imaug/__init__.py
浏览文件 @
330f08ff
...
@@ -29,6 +29,7 @@ from .label_ops import *
...
@@ -29,6 +29,7 @@ from .label_ops import *
from
.east_process
import
*
from
.east_process
import
*
from
.sast_process
import
*
from
.sast_process
import
*
from
.pg_process
import
*
from
.pg_process
import
*
from
.gen_table_mask
import
*
def
transform
(
data
,
ops
=
None
):
def
transform
(
data
,
ops
=
None
):
...
...
ppocr/data/imaug/gen_table_mask.py
0 → 100644
浏览文件 @
330f08ff
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
import
sys
import
six
import
cv2
import
numpy
as
np
class
GenTableMask
(
object
):
""" gen table mask """
def
__init__
(
self
,
shrink_h_max
,
shrink_w_max
,
mask_type
=
0
,
**
kwargs
):
self
.
shrink_h_max
=
5
self
.
shrink_w_max
=
5
self
.
mask_type
=
mask_type
def
projection
(
self
,
erosion
,
h
,
w
,
spilt_threshold
=
0
):
# 水平投影
projection_map
=
np
.
ones_like
(
erosion
)
project_val_array
=
[
0
for
_
in
range
(
0
,
h
)]
for
j
in
range
(
0
,
h
):
for
i
in
range
(
0
,
w
):
if
erosion
[
j
,
i
]
==
255
:
project_val_array
[
j
]
+=
1
# 根据数组,获取切割点
start_idx
=
0
# 记录进入字符区的索引
end_idx
=
0
# 记录进入空白区域的索引
in_text
=
False
# 是否遍历到了字符区内
box_list
=
[]
for
i
in
range
(
len
(
project_val_array
)):
if
in_text
==
False
and
project_val_array
[
i
]
>
spilt_threshold
:
# 进入字符区了
in_text
=
True
start_idx
=
i
elif
project_val_array
[
i
]
<=
spilt_threshold
and
in_text
==
True
:
# 进入空白区了
end_idx
=
i
in_text
=
False
if
end_idx
-
start_idx
<=
2
:
continue
box_list
.
append
((
start_idx
,
end_idx
+
1
))
if
in_text
:
box_list
.
append
((
start_idx
,
h
-
1
))
# 绘制投影直方图
for
j
in
range
(
0
,
h
):
for
i
in
range
(
0
,
project_val_array
[
j
]):
projection_map
[
j
,
i
]
=
0
return
box_list
,
projection_map
def
projection_cx
(
self
,
box_img
):
box_gray_img
=
cv2
.
cvtColor
(
box_img
,
cv2
.
COLOR_BGR2GRAY
)
h
,
w
=
box_gray_img
.
shape
# 灰度图片进行二值化处理
ret
,
thresh1
=
cv2
.
threshold
(
box_gray_img
,
200
,
255
,
cv2
.
THRESH_BINARY_INV
)
# 纵向腐蚀
if
h
<
w
:
kernel
=
np
.
ones
((
2
,
1
),
np
.
uint8
)
erode
=
cv2
.
erode
(
thresh1
,
kernel
,
iterations
=
1
)
else
:
erode
=
thresh1
# 水平膨胀
kernel
=
np
.
ones
((
1
,
5
),
np
.
uint8
)
erosion
=
cv2
.
dilate
(
erode
,
kernel
,
iterations
=
1
)
# 水平投影
projection_map
=
np
.
ones_like
(
erosion
)
project_val_array
=
[
0
for
_
in
range
(
0
,
h
)]
for
j
in
range
(
0
,
h
):
for
i
in
range
(
0
,
w
):
if
erosion
[
j
,
i
]
==
255
:
project_val_array
[
j
]
+=
1
# 根据数组,获取切割点
start_idx
=
0
# 记录进入字符区的索引
end_idx
=
0
# 记录进入空白区域的索引
in_text
=
False
# 是否遍历到了字符区内
box_list
=
[]
spilt_threshold
=
0
for
i
in
range
(
len
(
project_val_array
)):
if
in_text
==
False
and
project_val_array
[
i
]
>
spilt_threshold
:
# 进入字符区了
in_text
=
True
start_idx
=
i
elif
project_val_array
[
i
]
<=
spilt_threshold
and
in_text
==
True
:
# 进入空白区了
end_idx
=
i
in_text
=
False
if
end_idx
-
start_idx
<=
2
:
continue
box_list
.
append
((
start_idx
,
end_idx
+
1
))
if
in_text
:
box_list
.
append
((
start_idx
,
h
-
1
))
# 绘制投影直方图
for
j
in
range
(
0
,
h
):
for
i
in
range
(
0
,
project_val_array
[
j
]):
projection_map
[
j
,
i
]
=
0
split_bbox_list
=
[]
if
len
(
box_list
)
>
1
:
for
i
,
(
h_start
,
h_end
)
in
enumerate
(
box_list
):
if
i
==
0
:
h_start
=
0
if
i
==
len
(
box_list
):
h_end
=
h
word_img
=
erosion
[
h_start
:
h_end
+
1
,
:]
word_h
,
word_w
=
word_img
.
shape
w_split_list
,
w_projection_map
=
self
.
projection
(
word_img
.
T
,
word_w
,
word_h
)
w_start
,
w_end
=
w_split_list
[
0
][
0
],
w_split_list
[
-
1
][
1
]
if
h_start
>
0
:
h_start
-=
1
h_end
+=
1
word_img
=
box_img
[
h_start
:
h_end
+
1
:,
w_start
:
w_end
+
1
,
:]
split_bbox_list
.
append
([
w_start
,
h_start
,
w_end
,
h_end
])
else
:
split_bbox_list
.
append
([
0
,
0
,
w
,
h
])
return
split_bbox_list
def
shrink_bbox
(
self
,
bbox
):
left
,
top
,
right
,
bottom
=
bbox
sh_h
=
min
(
max
(
int
((
bottom
-
top
)
*
0.1
),
1
),
self
.
shrink_h_max
)
sh_w
=
min
(
max
(
int
((
right
-
left
)
*
0.1
),
1
),
self
.
shrink_w_max
)
left_new
=
left
+
sh_w
right_new
=
right
-
sh_w
top_new
=
top
+
sh_h
bottom_new
=
bottom
-
sh_h
if
left_new
>=
right_new
:
left_new
=
left
right_new
=
right
if
top_new
>=
bottom_new
:
top_new
=
top
bottom_new
=
bottom
return
[
left_new
,
top_new
,
right_new
,
bottom_new
]
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
cells
=
data
[
'cells'
]
height
,
width
=
img
.
shape
[
0
:
2
]
if
self
.
mask_type
==
1
:
mask_img
=
np
.
zeros
((
height
,
width
),
dtype
=
np
.
float32
)
else
:
mask_img
=
np
.
zeros
((
height
,
width
,
3
),
dtype
=
np
.
float32
)
cell_num
=
len
(
cells
)
for
cno
in
range
(
cell_num
):
if
"bbox"
in
cells
[
cno
]:
bbox
=
cells
[
cno
][
'bbox'
]
left
,
top
,
right
,
bottom
=
bbox
box_img
=
img
[
top
:
bottom
,
left
:
right
,
:].
copy
()
split_bbox_list
=
self
.
projection_cx
(
box_img
)
for
sno
in
range
(
len
(
split_bbox_list
)):
split_bbox_list
[
sno
][
0
]
+=
left
split_bbox_list
[
sno
][
1
]
+=
top
split_bbox_list
[
sno
][
2
]
+=
left
split_bbox_list
[
sno
][
3
]
+=
top
for
sno
in
range
(
len
(
split_bbox_list
)):
left
,
top
,
right
,
bottom
=
split_bbox_list
[
sno
]
left
,
top
,
right
,
bottom
=
self
.
shrink_bbox
([
left
,
top
,
right
,
bottom
])
if
self
.
mask_type
==
1
:
mask_img
[
top
:
bottom
,
left
:
right
]
=
1.0
data
[
'mask_img'
]
=
mask_img
else
:
mask_img
[
top
:
bottom
,
left
:
right
,
:]
=
(
255
,
255
,
255
)
data
[
'image'
]
=
mask_img
return
data
class
ResizeTableImage
(
object
):
def
__init__
(
self
,
max_len
,
**
kwargs
):
super
(
ResizeTableImage
,
self
).
__init__
()
self
.
max_len
=
max_len
def
get_img_bbox
(
self
,
cells
):
bbox_list
=
[]
if
len
(
cells
)
==
0
:
return
bbox_list
cell_num
=
len
(
cells
)
for
cno
in
range
(
cell_num
):
if
"bbox"
in
cells
[
cno
]:
bbox
=
cells
[
cno
][
'bbox'
]
bbox_list
.
append
(
bbox
)
return
bbox_list
def
resize_img_table
(
self
,
img
,
bbox_list
,
max_len
):
height
,
width
=
img
.
shape
[
0
:
2
]
ratio
=
max_len
/
(
max
(
height
,
width
)
*
1.0
)
resize_h
=
int
(
height
*
ratio
)
resize_w
=
int
(
width
*
ratio
)
img_new
=
cv2
.
resize
(
img
,
(
resize_w
,
resize_h
))
bbox_list_new
=
[]
for
bno
in
range
(
len
(
bbox_list
)):
left
,
top
,
right
,
bottom
=
bbox_list
[
bno
].
copy
()
left
=
int
(
left
*
ratio
)
top
=
int
(
top
*
ratio
)
right
=
int
(
right
*
ratio
)
bottom
=
int
(
bottom
*
ratio
)
bbox_list_new
.
append
([
left
,
top
,
right
,
bottom
])
return
img_new
,
bbox_list_new
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
if
'cells'
not
in
data
:
cells
=
[]
else
:
cells
=
data
[
'cells'
]
bbox_list
=
self
.
get_img_bbox
(
cells
)
img_new
,
bbox_list_new
=
self
.
resize_img_table
(
img
,
bbox_list
,
self
.
max_len
)
data
[
'image'
]
=
img_new
cell_num
=
len
(
cells
)
bno
=
0
for
cno
in
range
(
cell_num
):
if
"bbox"
in
data
[
'cells'
][
cno
]:
data
[
'cells'
][
cno
][
'bbox'
]
=
bbox_list_new
[
bno
]
bno
+=
1
data
[
'max_len'
]
=
self
.
max_len
return
data
class
PaddingTableImage
(
object
):
def
__init__
(
self
,
**
kwargs
):
super
(
PaddingTableImage
,
self
).
__init__
()
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
max_len
=
data
[
'max_len'
]
padding_img
=
np
.
zeros
((
max_len
,
max_len
,
3
),
dtype
=
np
.
float32
)
height
,
width
=
img
.
shape
[
0
:
2
]
padding_img
[
0
:
height
,
0
:
width
,
:]
=
img
.
copy
()
data
[
'image'
]
=
padding_img
return
data
\ No newline at end of file
ppocr/postprocess/__init__.py
浏览文件 @
330f08ff
...
@@ -24,7 +24,8 @@ __all__ = ['build_post_process']
...
@@ -24,7 +24,8 @@ __all__ = ['build_post_process']
from
.db_postprocess
import
DBPostProcess
from
.db_postprocess
import
DBPostProcess
from
.east_postprocess
import
EASTPostProcess
from
.east_postprocess
import
EASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
DistillationCTCLabelDecode
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
DistillationCTCLabelDecode
,
\
TableLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
from
.pg_postprocess
import
PGPostProcess
...
@@ -33,7 +34,7 @@ def build_post_process(config, global_config=None):
...
@@ -33,7 +34,7 @@ def build_post_process(config, global_config=None):
support_dict
=
[
support_dict
=
[
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
,
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
,
'DistillationCTCLabelDecode'
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
]
]
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
...
...
ppstructure/table/predict_structure.py
浏览文件 @
330f08ff
...
@@ -32,6 +32,7 @@ from ppocr.data import create_operators, transform
...
@@ -32,6 +32,7 @@ from ppocr.data import create_operators, transform
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppstructure.utility
import
parse_args
logger
=
get_logger
()
logger
=
get_logger
()
...
@@ -69,7 +70,7 @@ class TableStructurer(object):
...
@@ -69,7 +70,7 @@ class TableStructurer(object):
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
\
utility
.
create_predictor
(
args
,
'structure'
,
logger
)
utility
.
create_predictor
(
args
,
'structure'
,
logger
)
def
__call__
(
self
,
img
):
def
__call__
(
self
,
img
):
...
@@ -138,4 +139,4 @@ def main(args):
...
@@ -138,4 +139,4 @@ def main(args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
(
utility
.
parse_args
())
main
(
parse_args
())
ppstructure/table/predict_table.py
浏览文件 @
330f08ff
...
@@ -187,7 +187,7 @@ def main(args):
...
@@ -187,7 +187,7 @@ def main(args):
for
i
,
image_file
in
enumerate
(
image_file_list
):
for
i
,
image_file
in
enumerate
(
image_file_list
):
logger
.
info
(
"[{}/{}] {}"
.
format
(
i
,
img_num
,
image_file
))
logger
.
info
(
"[{}/{}] {}"
.
format
(
i
,
img_num
,
image_file
))
img
,
flag
=
check_and_read_gif
(
image_file
)
img
,
flag
=
check_and_read_gif
(
image_file
)
excel_path
=
os
.
path
.
join
(
args
.
table_
output
,
os
.
path
.
basename
(
image_file
).
split
(
'.'
)[
0
]
+
'.xlsx'
)
excel_path
=
os
.
path
.
join
(
args
.
output
,
os
.
path
.
basename
(
image_file
).
split
(
'.'
)[
0
]
+
'.xlsx'
)
if
not
flag
:
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
if
img
is
None
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录