Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
42fe769c
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
42fe769c
编写于
8月 03, 2021
作者:
M
MissPenguin
提交者:
GitHub
8月 03, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3506 from LDOUBLEV/fix_dyg_bugs
cherry-pick 3505
上级
a50b747b
c620b939
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
40 addition
and
28 deletion
+40
-28
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+37
-26
tools/infer/utility.py
tools/infer/utility.py
+3
-2
未找到文件。
ppocr/data/imaug/label_ops.py
浏览文件 @
42fe769c
...
@@ -19,6 +19,7 @@ from __future__ import unicode_literals
...
@@ -19,6 +19,7 @@ from __future__ import unicode_literals
import
numpy
as
np
import
numpy
as
np
import
string
import
string
import
json
class
ClsLabelEncode
(
object
):
class
ClsLabelEncode
(
object
):
...
@@ -39,7 +40,6 @@ class DetLabelEncode(object):
...
@@ -39,7 +40,6 @@ class DetLabelEncode(object):
pass
pass
def
__call__
(
self
,
data
):
def
__call__
(
self
,
data
):
import
json
label
=
data
[
'label'
]
label
=
data
[
'label'
]
label
=
json
.
loads
(
label
)
label
=
json
.
loads
(
label
)
nBox
=
len
(
label
)
nBox
=
len
(
label
)
...
@@ -53,6 +53,8 @@ class DetLabelEncode(object):
...
@@ -53,6 +53,8 @@ class DetLabelEncode(object):
txt_tags
.
append
(
True
)
txt_tags
.
append
(
True
)
else
:
else
:
txt_tags
.
append
(
False
)
txt_tags
.
append
(
False
)
if
len
(
boxes
)
==
0
:
return
None
boxes
=
self
.
expand_points_num
(
boxes
)
boxes
=
self
.
expand_points_num
(
boxes
)
boxes
=
np
.
array
(
boxes
,
dtype
=
np
.
float32
)
boxes
=
np
.
array
(
boxes
,
dtype
=
np
.
float32
)
txt_tags
=
np
.
array
(
txt_tags
,
dtype
=
np
.
bool
)
txt_tags
=
np
.
array
(
txt_tags
,
dtype
=
np
.
bool
)
...
@@ -352,19 +354,22 @@ class SRNLabelEncode(BaseRecLabelEncode):
...
@@ -352,19 +354,22 @@ class SRNLabelEncode(BaseRecLabelEncode):
%
beg_or_end
%
beg_or_end
return
idx
return
idx
class
TableLabelEncode
(
object
):
class
TableLabelEncode
(
object
):
""" Convert between text-label and text-index """
""" Convert between text-label and text-index """
def
__init__
(
self
,
max_text_length
,
def
__init__
(
self
,
max_elem_length
,
max_text_length
,
max_cell_num
,
max_elem_length
,
character_dict_path
,
max_cell_num
,
span_weight
=
1.0
,
character_dict_path
,
**
kwargs
):
span_weight
=
1.0
,
**
kwargs
):
self
.
max_text_length
=
max_text_length
self
.
max_text_length
=
max_text_length
self
.
max_elem_length
=
max_elem_length
self
.
max_elem_length
=
max_elem_length
self
.
max_cell_num
=
max_cell_num
self
.
max_cell_num
=
max_cell_num
list_character
,
list_elem
=
self
.
load_char_elem_dict
(
character_dict_path
)
list_character
,
list_elem
=
self
.
load_char_elem_dict
(
character_dict_path
)
list_character
=
self
.
add_special_char
(
list_character
)
list_character
=
self
.
add_special_char
(
list_character
)
list_elem
=
self
.
add_special_char
(
list_elem
)
list_elem
=
self
.
add_special_char
(
list_elem
)
self
.
dict_character
=
{}
self
.
dict_character
=
{}
...
@@ -374,7 +379,7 @@ class TableLabelEncode(object):
...
@@ -374,7 +379,7 @@ class TableLabelEncode(object):
for
i
,
elem
in
enumerate
(
list_elem
):
for
i
,
elem
in
enumerate
(
list_elem
):
self
.
dict_elem
[
elem
]
=
i
self
.
dict_elem
[
elem
]
=
i
self
.
span_weight
=
span_weight
self
.
span_weight
=
span_weight
def
load_char_elem_dict
(
self
,
character_dict_path
):
def
load_char_elem_dict
(
self
,
character_dict_path
):
list_character
=
[]
list_character
=
[]
list_elem
=
[]
list_elem
=
[]
...
@@ -383,27 +388,28 @@ class TableLabelEncode(object):
...
@@ -383,27 +388,28 @@ class TableLabelEncode(object):
substr
=
lines
[
0
].
decode
(
'utf-8'
).
strip
(
"
\r\n
"
).
split
(
"
\t
"
)
substr
=
lines
[
0
].
decode
(
'utf-8'
).
strip
(
"
\r\n
"
).
split
(
"
\t
"
)
character_num
=
int
(
substr
[
0
])
character_num
=
int
(
substr
[
0
])
elem_num
=
int
(
substr
[
1
])
elem_num
=
int
(
substr
[
1
])
for
cno
in
range
(
1
,
1
+
character_num
):
for
cno
in
range
(
1
,
1
+
character_num
):
character
=
lines
[
cno
].
decode
(
'utf-8'
).
strip
(
"
\r\n
"
)
character
=
lines
[
cno
].
decode
(
'utf-8'
).
strip
(
"
\r\n
"
)
list_character
.
append
(
character
)
list_character
.
append
(
character
)
for
eno
in
range
(
1
+
character_num
,
1
+
character_num
+
elem_num
):
for
eno
in
range
(
1
+
character_num
,
1
+
character_num
+
elem_num
):
elem
=
lines
[
eno
].
decode
(
'utf-8'
).
strip
(
"
\r\n
"
)
elem
=
lines
[
eno
].
decode
(
'utf-8'
).
strip
(
"
\r\n
"
)
list_elem
.
append
(
elem
)
list_elem
.
append
(
elem
)
return
list_character
,
list_elem
return
list_character
,
list_elem
def
add_special_char
(
self
,
list_character
):
def
add_special_char
(
self
,
list_character
):
self
.
beg_str
=
"sos"
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
self
.
end_str
=
"eos"
list_character
=
[
self
.
beg_str
]
+
list_character
+
[
self
.
end_str
]
list_character
=
[
self
.
beg_str
]
+
list_character
+
[
self
.
end_str
]
return
list_character
return
list_character
def
get_span_idx_list
(
self
):
def
get_span_idx_list
(
self
):
span_idx_list
=
[]
span_idx_list
=
[]
for
elem
in
self
.
dict_elem
:
for
elem
in
self
.
dict_elem
:
if
'span'
in
elem
:
if
'span'
in
elem
:
span_idx_list
.
append
(
self
.
dict_elem
[
elem
])
span_idx_list
.
append
(
self
.
dict_elem
[
elem
])
return
span_idx_list
return
span_idx_list
def
__call__
(
self
,
data
):
def
__call__
(
self
,
data
):
cells
=
data
[
'cells'
]
cells
=
data
[
'cells'
]
structure
=
data
[
'structure'
][
'tokens'
]
structure
=
data
[
'structure'
][
'tokens'
]
...
@@ -412,18 +418,22 @@ class TableLabelEncode(object):
...
@@ -412,18 +418,22 @@ class TableLabelEncode(object):
return
None
return
None
elem_num
=
len
(
structure
)
elem_num
=
len
(
structure
)
structure
=
[
0
]
+
structure
+
[
len
(
self
.
dict_elem
)
-
1
]
structure
=
[
0
]
+
structure
+
[
len
(
self
.
dict_elem
)
-
1
]
structure
=
structure
+
[
0
]
*
(
self
.
max_elem_length
+
2
-
len
(
structure
))
structure
=
structure
+
[
0
]
*
(
self
.
max_elem_length
+
2
-
len
(
structure
)
)
structure
=
np
.
array
(
structure
)
structure
=
np
.
array
(
structure
)
data
[
'structure'
]
=
structure
data
[
'structure'
]
=
structure
elem_char_idx1
=
self
.
dict_elem
[
'<td>'
]
elem_char_idx1
=
self
.
dict_elem
[
'<td>'
]
elem_char_idx2
=
self
.
dict_elem
[
'<td'
]
elem_char_idx2
=
self
.
dict_elem
[
'<td'
]
span_idx_list
=
self
.
get_span_idx_list
()
span_idx_list
=
self
.
get_span_idx_list
()
td_idx_list
=
np
.
logical_or
(
structure
==
elem_char_idx1
,
structure
==
elem_char_idx2
)
td_idx_list
=
np
.
logical_or
(
structure
==
elem_char_idx1
,
structure
==
elem_char_idx2
)
td_idx_list
=
np
.
where
(
td_idx_list
)[
0
]
td_idx_list
=
np
.
where
(
td_idx_list
)[
0
]
structure_mask
=
np
.
ones
((
self
.
max_elem_length
+
2
,
1
),
dtype
=
np
.
float32
)
structure_mask
=
np
.
ones
(
(
self
.
max_elem_length
+
2
,
1
),
dtype
=
np
.
float32
)
bbox_list
=
np
.
zeros
((
self
.
max_elem_length
+
2
,
4
),
dtype
=
np
.
float32
)
bbox_list
=
np
.
zeros
((
self
.
max_elem_length
+
2
,
4
),
dtype
=
np
.
float32
)
bbox_list_mask
=
np
.
zeros
((
self
.
max_elem_length
+
2
,
1
),
dtype
=
np
.
float32
)
bbox_list_mask
=
np
.
zeros
(
(
self
.
max_elem_length
+
2
,
1
),
dtype
=
np
.
float32
)
img_height
,
img_width
,
img_ch
=
data
[
'image'
].
shape
img_height
,
img_width
,
img_ch
=
data
[
'image'
].
shape
if
len
(
span_idx_list
)
>
0
:
if
len
(
span_idx_list
)
>
0
:
span_weight
=
len
(
td_idx_list
)
*
1.0
/
len
(
span_idx_list
)
span_weight
=
len
(
td_idx_list
)
*
1.0
/
len
(
span_idx_list
)
...
@@ -450,9 +460,11 @@ class TableLabelEncode(object):
...
@@ -450,9 +460,11 @@ class TableLabelEncode(object):
char_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'char'
)
char_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'char'
)
elem_beg_idx
=
self
.
get_beg_end_flag_idx
(
'beg'
,
'elem'
)
elem_beg_idx
=
self
.
get_beg_end_flag_idx
(
'beg'
,
'elem'
)
elem_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'elem'
)
elem_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'elem'
)
data
[
'sp_tokens'
]
=
np
.
array
([
char_beg_idx
,
char_end_idx
,
elem_beg_idx
,
data
[
'sp_tokens'
]
=
np
.
array
([
elem_end_idx
,
elem_char_idx1
,
elem_char_idx2
,
self
.
max_text_length
,
char_beg_idx
,
char_end_idx
,
elem_beg_idx
,
elem_end_idx
,
self
.
max_elem_length
,
self
.
max_cell_num
,
elem_num
])
elem_char_idx1
,
elem_char_idx2
,
self
.
max_text_length
,
self
.
max_elem_length
,
self
.
max_cell_num
,
elem_num
])
return
data
return
data
def
encode
(
self
,
text
,
char_or_elem
):
def
encode
(
self
,
text
,
char_or_elem
):
...
@@ -504,9 +516,8 @@ class TableLabelEncode(object):
...
@@ -504,9 +516,8 @@ class TableLabelEncode(object):
idx
=
np
.
array
(
self
.
dict_elem
[
self
.
end_str
])
idx
=
np
.
array
(
self
.
dict_elem
[
self
.
end_str
])
else
:
else
:
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx of elem"
\
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx of elem"
\
%
beg_or_end
%
beg_or_end
else
:
else
:
assert
False
,
"Unsupport type %s in char_or_elem"
\
assert
False
,
"Unsupport type %s in char_or_elem"
\
%
char_or_elem
%
char_or_elem
return
idx
return
idx
\ No newline at end of file
tools/infer/utility.py
浏览文件 @
42fe769c
...
@@ -24,6 +24,7 @@ from paddle import inference
...
@@ -24,6 +24,7 @@ from paddle import inference
import
time
import
time
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.logging
import
get_logger
def
str2bool
(
v
):
def
str2bool
(
v
):
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
...
@@ -47,8 +48,8 @@ def init_args():
...
@@ -47,8 +48,8 @@ def init_args():
# DB parmas
# DB parmas
parser
.
add_argument
(
"--det_db_thresh"
,
type
=
float
,
default
=
0.3
)
parser
.
add_argument
(
"--det_db_thresh"
,
type
=
float
,
default
=
0.3
)
parser
.
add_argument
(
"--det_db_box_thresh"
,
type
=
float
,
default
=
0.
5
)
parser
.
add_argument
(
"--det_db_box_thresh"
,
type
=
float
,
default
=
0.
6
)
parser
.
add_argument
(
"--det_db_unclip_ratio"
,
type
=
float
,
default
=
1.
6
)
parser
.
add_argument
(
"--det_db_unclip_ratio"
,
type
=
float
,
default
=
1.
5
)
parser
.
add_argument
(
"--max_batch_size"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--max_batch_size"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--use_dilation"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--use_dilation"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--det_db_score_mode"
,
type
=
str
,
default
=
"fast"
)
parser
.
add_argument
(
"--det_db_score_mode"
,
type
=
str
,
default
=
"fast"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录