Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
c455034f
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c455034f
编写于
4月 16, 2021
作者:
M
MissPenguin
提交者:
GitHub
4月 16, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2497 from JetHong/rel/fix_data_input_format
Rel/fix data input format
上级
46236706
68e6a362
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
70 addition
and
152 deletion
+70
-152
configs/e2e/e2e_r50_vd_pg.yml
configs/e2e/e2e_r50_vd_pg.yml
+6
-6
doc/doc_ch/pgnet.md
doc/doc_ch/pgnet.md
+1
-0
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+23
-21
ppocr/data/imaug/pg_process.py
ppocr/data/imaug/pg_process.py
+4
-4
ppocr/data/pgnet_dataset.py
ppocr/data/pgnet_dataset.py
+24
-92
ppocr/metrics/e2e_metric.py
ppocr/metrics/e2e_metric.py
+3
-3
ppocr/utils/e2e_metric/Deteval.py
ppocr/utils/e2e_metric/Deteval.py
+1
-1
ppocr/utils/e2e_utils/extract_textpoint_slow.py
ppocr/utils/e2e_utils/extract_textpoint_slow.py
+3
-1
ppocr/utils/e2e_utils/pgnet_pp_utils.py
ppocr/utils/e2e_utils/pgnet_pp_utils.py
+3
-22
tools/infer/predict_e2e.py
tools/infer/predict_e2e.py
+1
-1
tools/infer_e2e.py
tools/infer_e2e.py
+1
-1
未找到文件。
configs/e2e/e2e_r50_vd_pg.yml
浏览文件 @
c455034f
...
@@ -62,20 +62,21 @@ PostProcess:
...
@@ -62,20 +62,21 @@ PostProcess:
mode
:
fast
# fast or slow two ways
mode
:
fast
# fast or slow two ways
Metric
:
Metric
:
name
:
E2EMetric
name
:
E2EMetric
gt_mat_dir
:
# the dir of gt_mat
gt_mat_dir
:
./train_data/total_text/gt
# the dir of gt_mat
character_dict_path
:
ppocr/utils/ic15_dict.txt
character_dict_path
:
ppocr/utils/ic15_dict.txt
main_indicator
:
f_score_e2e
main_indicator
:
f_score_e2e
Train
:
Train
:
dataset
:
dataset
:
name
:
PGDataSet
name
:
PGDataSet
label_file_list
:
[
.././train_data/total_text/train/
]
data_dir
:
./train_data/total_text/train
label_file_list
:
[
./train_data/total_text/train/
]
ratio_list
:
[
1.0
]
ratio_list
:
[
1.0
]
data_format
:
icdar
#two data format: icdar/textnet
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
img_mode
:
BGR
img_mode
:
BGR
channel_first
:
False
channel_first
:
False
-
E2ELabelEncode
:
-
PGProcessTrain
:
-
PGProcessTrain
:
batch_size
:
14
# same as loader: batch_size_per_card
batch_size
:
14
# same as loader: batch_size_per_card
min_crop_size
:
24
min_crop_size
:
24
...
@@ -92,13 +93,12 @@ Train:
...
@@ -92,13 +93,12 @@ Train:
Eval
:
Eval
:
dataset
:
dataset
:
name
:
PGDataSet
name
:
PGDataSet
data_dir
:
./train_data/
data_dir
:
./train_data/
total_text/test
label_file_list
:
[
./train_data/total_text/test/
]
label_file_list
:
[
./train_data/total_text/test/
]
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
img_mode
:
RGB
img_mode
:
RGB
channel_first
:
False
channel_first
:
False
-
E2ELabelEncode
:
-
E2EResizeForTest
:
-
E2EResizeForTest
:
max_side_len
:
768
max_side_len
:
768
-
NormalizeImage
:
-
NormalizeImage
:
...
@@ -108,7 +108,7 @@ Eval:
...
@@ -108,7 +108,7 @@ Eval:
order
:
'
hwc'
order
:
'
hwc'
-
ToCHWImage
:
-
ToCHWImage
:
-
KeepKeys
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
shape'
,
'
polys'
,
'
strs'
,
'
tags'
,
'
img_id'
]
keep_keys
:
[
'
image'
,
'
shape'
,
'
img_id'
]
loader
:
loader
:
shuffle
:
False
shuffle
:
False
drop_last
:
False
drop_last
:
False
...
...
doc/doc_ch/pgnet.md
浏览文件 @
c455034f
...
@@ -30,6 +30,7 @@ PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.Wang
...
@@ -30,6 +30,7 @@ PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.Wang
测试集:Total-Text
测试集:Total-Text
测试环境: NVIDIA Tesla V100-SXM2-16GB
测试环境: NVIDIA Tesla V100-SXM2-16GB
|PGNetA|det_precision|det_recall|det_f_score|e2e_precision|e2e_recall|e2e_f_score|FPS|下载|
|PGNetA|det_precision|det_recall|det_f_score|e2e_precision|e2e_recall|e2e_f_score|FPS|下载|
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
|Paper|85.30|86.80|86.1|-|-|61.7|38.20 (size=640)|-|
|Paper|85.30|86.80|86.1|-|-|61.7|38.20 (size=640)|-|
...
...
ppocr/data/imaug/label_ops.py
浏览文件 @
c455034f
...
@@ -187,29 +187,31 @@ class CTCLabelEncode(BaseRecLabelEncode):
...
@@ -187,29 +187,31 @@ class CTCLabelEncode(BaseRecLabelEncode):
return
dict_character
return
dict_character
class
E2ELabelEncode
(
BaseRecLabelEncode
):
class
E2ELabelEncode
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
**
kwargs
):
max_text_length
,
pass
character_dict_path
=
None
,
character_type
=
'EN'
,
use_space_char
=
False
,
**
kwargs
):
super
(
E2ELabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
character_type
,
use_space_char
)
self
.
pad_num
=
len
(
self
.
dict
)
# the length to pad
def
__call__
(
self
,
data
):
def
__call__
(
self
,
data
):
texts
=
data
[
'strs'
]
import
json
temp_texts
=
[]
label
=
data
[
'label'
]
for
text
in
texts
:
label
=
json
.
loads
(
label
)
text
=
text
.
lower
()
nBox
=
len
(
label
)
text
=
self
.
encode
(
text
)
boxes
,
txts
,
txt_tags
=
[],
[],
[]
if
text
is
None
:
for
bno
in
range
(
0
,
nBox
):
return
None
box
=
label
[
bno
][
'points'
]
text
=
text
+
[
self
.
pad_num
]
*
(
self
.
max_text_len
-
len
(
text
))
txt
=
label
[
bno
][
'transcription'
]
temp_texts
.
append
(
text
)
boxes
.
append
(
box
)
data
[
'strs'
]
=
np
.
array
(
temp_texts
)
txts
.
append
(
txt
)
if
txt
in
[
'*'
,
'###'
]:
txt_tags
.
append
(
True
)
else
:
txt_tags
.
append
(
False
)
boxes
=
np
.
array
(
boxes
,
dtype
=
np
.
float32
)
txt_tags
=
np
.
array
(
txt_tags
,
dtype
=
np
.
bool
)
data
[
'polys'
]
=
boxes
data
[
'texts'
]
=
txts
data
[
'ignore_tags'
]
=
txt_tags
return
data
return
data
...
...
ppocr/data/imaug/pg_process.py
浏览文件 @
c455034f
...
@@ -88,7 +88,7 @@ class PGProcessTrain(object):
...
@@ -88,7 +88,7 @@ class PGProcessTrain(object):
return
min_area_quad
return
min_area_quad
def
check_and_validate_polys
(
self
,
polys
,
tags
,
xxx_todo_changem
e
):
def
check_and_validate_polys
(
self
,
polys
,
tags
,
im_siz
e
):
"""
"""
check so that the text poly is in the same direction,
check so that the text poly is in the same direction,
and also filter some invalid polygons
and also filter some invalid polygons
...
@@ -96,7 +96,7 @@ class PGProcessTrain(object):
...
@@ -96,7 +96,7 @@ class PGProcessTrain(object):
:param tags:
:param tags:
:return:
:return:
"""
"""
(
h
,
w
)
=
xxx_todo_changem
e
(
h
,
w
)
=
im_siz
e
if
polys
.
shape
[
0
]
==
0
:
if
polys
.
shape
[
0
]
==
0
:
return
polys
,
np
.
array
([]),
np
.
array
([])
return
polys
,
np
.
array
([]),
np
.
array
([])
polys
[:,
:,
0
]
=
np
.
clip
(
polys
[:,
:,
0
],
0
,
w
-
1
)
polys
[:,
:,
0
]
=
np
.
clip
(
polys
[:,
:,
0
],
0
,
w
-
1
)
...
@@ -750,8 +750,8 @@ class PGProcessTrain(object):
...
@@ -750,8 +750,8 @@ class PGProcessTrain(object):
input_size
=
512
input_size
=
512
im
=
data
[
'image'
]
im
=
data
[
'image'
]
text_polys
=
data
[
'polys'
]
text_polys
=
data
[
'polys'
]
text_tags
=
data
[
'tags'
]
text_tags
=
data
[
'
ignore_
tags'
]
text_strs
=
data
[
'
str
s'
]
text_strs
=
data
[
'
text
s'
]
h
,
w
,
_
=
im
.
shape
h
,
w
,
_
=
im
.
shape
text_polys
,
text_tags
,
hv_tags
=
self
.
check_and_validate_polys
(
text_polys
,
text_tags
,
hv_tags
=
self
.
check_and_validate_polys
(
text_polys
,
text_tags
,
(
h
,
w
))
text_polys
,
text_tags
,
(
h
,
w
))
...
...
ppocr/data/pgnet_dataset.py
浏览文件 @
c455034f
...
@@ -29,20 +29,20 @@ class PGDataSet(Dataset):
...
@@ -29,20 +29,20 @@ class PGDataSet(Dataset):
dataset_config
=
config
[
mode
][
'dataset'
]
dataset_config
=
config
[
mode
][
'dataset'
]
loader_config
=
config
[
mode
][
'loader'
]
loader_config
=
config
[
mode
][
'loader'
]
self
.
delimiter
=
dataset_config
.
get
(
'delimiter'
,
'
\t
'
)
label_file_list
=
dataset_config
.
pop
(
'label_file_list'
)
label_file_list
=
dataset_config
.
pop
(
'label_file_list'
)
data_source_num
=
len
(
label_file_list
)
data_source_num
=
len
(
label_file_list
)
ratio_list
=
dataset_config
.
get
(
"ratio_list"
,
[
1.0
])
ratio_list
=
dataset_config
.
get
(
"ratio_list"
,
[
1.0
])
if
isinstance
(
ratio_list
,
(
float
,
int
)):
if
isinstance
(
ratio_list
,
(
float
,
int
)):
ratio_list
=
[
float
(
ratio_list
)]
*
int
(
data_source_num
)
ratio_list
=
[
float
(
ratio_list
)]
*
int
(
data_source_num
)
self
.
data_format
=
dataset_config
.
get
(
'data_format'
,
'icdar'
)
assert
len
(
assert
len
(
ratio_list
ratio_list
)
==
data_source_num
,
"The length of ratio_list should be the same as the file_list."
)
==
data_source_num
,
"The length of ratio_list should be the same as the file_list."
self
.
data_dir
=
dataset_config
[
'data_dir'
]
self
.
do_shuffle
=
loader_config
[
'shuffle'
]
self
.
do_shuffle
=
loader_config
[
'shuffle'
]
logger
.
info
(
"Initialize indexs of datasets:%s"
%
label_file_list
)
logger
.
info
(
"Initialize indexs of datasets:%s"
%
label_file_list
)
self
.
data_lines
=
self
.
get_image_info_list
(
label_file_list
,
ratio_list
,
self
.
data_lines
=
self
.
get_image_info_list
(
label_file_list
,
ratio_list
)
self
.
data_format
)
self
.
data_idx_order_list
=
list
(
range
(
len
(
self
.
data_lines
)))
self
.
data_idx_order_list
=
list
(
range
(
len
(
self
.
data_lines
)))
if
mode
.
lower
()
==
"train"
:
if
mode
.
lower
()
==
"train"
:
self
.
shuffle_data_random
()
self
.
shuffle_data_random
()
...
@@ -55,108 +55,40 @@ class PGDataSet(Dataset):
...
@@ -55,108 +55,40 @@ class PGDataSet(Dataset):
random
.
shuffle
(
self
.
data_lines
)
random
.
shuffle
(
self
.
data_lines
)
return
return
def
extract_polys
(
self
,
poly_txt_path
):
def
get_image_info_list
(
self
,
file_list
,
ratio_list
):
"""
Read text_polys, txt_tags, txts from give txt file.
"""
text_polys
,
txt_tags
,
txts
=
[],
[],
[]
with
open
(
poly_txt_path
)
as
f
:
for
line
in
f
.
readlines
():
poly_str
,
txt
=
line
.
strip
().
split
(
'
\t
'
)
poly
=
list
(
map
(
float
,
poly_str
.
split
(
','
)))
text_polys
.
append
(
np
.
array
(
poly
,
dtype
=
np
.
float32
).
reshape
(
-
1
,
2
))
txts
.
append
(
txt
)
txt_tags
.
append
(
txt
==
'###'
)
return
np
.
array
(
list
(
map
(
np
.
array
,
text_polys
))),
\
np
.
array
(
txt_tags
,
dtype
=
np
.
bool
),
txts
def
extract_info_textnet
(
self
,
im_fn
,
img_dir
=
''
):
"""
Extract information from line in textnet format.
"""
info_list
=
im_fn
.
split
(
'
\t
'
)
img_path
=
''
for
ext
in
[
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
,
'JPG'
]:
if
os
.
path
.
exists
(
os
.
path
.
join
(
img_dir
,
info_list
[
0
]
+
"."
+
ext
)):
img_path
=
os
.
path
.
join
(
img_dir
,
info_list
[
0
]
+
"."
+
ext
)
break
if
img_path
==
''
:
print
(
'Image {0} NOT found in {1}, and it will be ignored.'
.
format
(
info_list
[
0
],
img_dir
))
nBox
=
(
len
(
info_list
)
-
1
)
//
9
wordBBs
,
txts
,
txt_tags
=
[],
[],
[]
for
n
in
range
(
0
,
nBox
):
wordBB
=
list
(
map
(
float
,
info_list
[
n
*
9
+
1
:(
n
+
1
)
*
9
]))
txt
=
info_list
[(
n
+
1
)
*
9
]
wordBBs
.
append
([[
wordBB
[
0
],
wordBB
[
1
]],
[
wordBB
[
2
],
wordBB
[
3
]],
[
wordBB
[
4
],
wordBB
[
5
]],
[
wordBB
[
6
],
wordBB
[
7
]]])
txts
.
append
(
txt
)
if
txt
==
'###'
:
txt_tags
.
append
(
True
)
else
:
txt_tags
.
append
(
False
)
return
img_path
,
np
.
array
(
wordBBs
,
dtype
=
np
.
float32
),
txt_tags
,
txts
def
get_image_info_list
(
self
,
file_list
,
ratio_list
,
data_format
=
'textnet'
):
if
isinstance
(
file_list
,
str
):
if
isinstance
(
file_list
,
str
):
file_list
=
[
file_list
]
file_list
=
[
file_list
]
data_lines
=
[]
data_lines
=
[]
for
idx
,
data_source
in
enumerate
(
file_list
):
for
idx
,
file
in
enumerate
(
file_list
):
image_files
=
[]
with
open
(
file
,
"rb"
)
as
f
:
if
data_format
==
'icdar'
:
lines
=
f
.
readlines
()
image_files
=
[(
data_source
,
x
)
for
x
in
if
self
.
mode
==
"train"
or
ratio_list
[
idx
]
<
1.0
:
os
.
listdir
(
os
.
path
.
join
(
data_source
,
'rgb'
))
if
x
.
split
(
'.'
)[
-
1
]
in
[
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
,
'JPG'
]]
elif
data_format
==
'textnet'
:
with
open
(
data_source
)
as
f
:
image_files
=
[(
data_source
,
x
.
strip
())
for
x
in
f
.
readlines
()]
else
:
print
(
"Unrecognized data format..."
)
exit
(
-
1
)
random
.
seed
(
self
.
seed
)
random
.
seed
(
self
.
seed
)
image_files
=
random
.
sample
(
lines
=
random
.
sample
(
lines
,
image_files
,
round
(
len
(
image_fil
es
)
*
ratio_list
[
idx
]))
round
(
len
(
lin
es
)
*
ratio_list
[
idx
]))
data_lines
.
extend
(
image_fil
es
)
data_lines
.
extend
(
lin
es
)
return
data_lines
return
data_lines
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
file_idx
=
self
.
data_idx_order_list
[
idx
]
file_idx
=
self
.
data_idx_order_list
[
idx
]
data_
path
,
data_
line
=
self
.
data_lines
[
file_idx
]
data_line
=
self
.
data_lines
[
file_idx
]
try
:
try
:
if
self
.
data_format
==
'icdar'
:
data_line
=
data_line
.
decode
(
'utf-8'
)
im_path
=
os
.
path
.
join
(
data_path
,
'rgb'
,
data_line
)
substr
=
data_line
.
strip
(
"
\n
"
).
split
(
self
.
delimiter
)
poly_path
=
os
.
path
.
join
(
data_path
,
'poly'
,
file_name
=
substr
[
0
]
data_line
.
split
(
'.'
)[
0
]
+
'.txt'
)
label
=
substr
[
1
]
text_polys
,
text_tags
,
text_strs
=
self
.
extract_polys
(
poly_path
)
img_path
=
os
.
path
.
join
(
self
.
data_dir
,
file_name
)
if
self
.
mode
.
lower
()
==
'eval'
:
img_id
=
int
(
data_line
.
split
(
"."
)[
0
][
7
:])
else
:
else
:
image_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
data_path
),
'image'
)
img_id
=
0
im_path
,
text_polys
,
text_tags
,
text_strs
=
self
.
extract_info_textnet
(
data
=
{
'img_path'
:
img_path
,
'label'
:
label
,
'img_id'
:
img_id
}
data_line
,
image_dir
)
if
not
os
.
path
.
exists
(
img_path
):
img_id
=
int
(
data_line
.
split
(
"."
)[
0
][
3
:])
raise
Exception
(
"{} does not exist!"
.
format
(
img_path
))
data
=
{
'img_path'
:
im_path
,
'polys'
:
text_polys
,
'tags'
:
text_tags
,
'strs'
:
text_strs
,
'img_id'
:
img_id
}
with
open
(
data
[
'img_path'
],
'rb'
)
as
f
:
with
open
(
data
[
'img_path'
],
'rb'
)
as
f
:
img
=
f
.
read
()
img
=
f
.
read
()
data
[
'image'
]
=
img
data
[
'image'
]
=
img
outs
=
transform
(
data
,
self
.
ops
)
outs
=
transform
(
data
,
self
.
ops
)
except
Exception
as
e
:
except
Exception
as
e
:
self
.
logger
.
error
(
self
.
logger
.
error
(
"When parsing line {}, error happened with msg: {}"
.
format
(
"When parsing line {}, error happened with msg: {}"
.
format
(
...
...
ppocr/metrics/e2e_metric.py
浏览文件 @
c455034f
...
@@ -35,11 +35,11 @@ class E2EMetric(object):
...
@@ -35,11 +35,11 @@ class E2EMetric(object):
self
.
reset
()
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
img_id
=
batch
[
5
][
0
]
img_id
=
batch
[
2
][
0
]
e2e_info_list
=
[{
e2e_info_list
=
[{
'points'
:
det_polyon
,
'points'
:
det_polyon
,
'text'
:
pred_str
'text
s
'
:
pred_str
}
for
det_polyon
,
pred_str
in
zip
(
preds
[
'points'
],
preds
[
'
str
s'
])]
}
for
det_polyon
,
pred_str
in
zip
(
preds
[
'points'
],
preds
[
'
text
s'
])]
result
=
get_socre
(
self
.
gt_mat_dir
,
img_id
,
e2e_info_list
)
result
=
get_socre
(
self
.
gt_mat_dir
,
img_id
,
e2e_info_list
)
self
.
results
.
append
(
result
)
self
.
results
.
append
(
result
)
...
...
ppocr/utils/e2e_metric/Deteval.py
浏览文件 @
c455034f
...
@@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict):
...
@@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict):
n
=
len
(
pred_dict
)
n
=
len
(
pred_dict
)
for
i
in
range
(
n
):
for
i
in
range
(
n
):
points
=
pred_dict
[
i
][
'points'
]
points
=
pred_dict
[
i
][
'points'
]
text
=
pred_dict
[
i
][
'text'
]
text
=
pred_dict
[
i
][
'text
s
'
]
point
=
","
.
join
(
map
(
str
,
points
.
reshape
(
-
1
,
)))
point
=
","
.
join
(
map
(
str
,
points
.
reshape
(
-
1
,
)))
det
.
append
([
point
,
text
])
det
.
append
([
point
,
text
])
return
det
return
det
...
...
ppocr/utils/e2e_utils/extract_textpoint_slow.py
浏览文件 @
c455034f
...
@@ -342,6 +342,7 @@ def generate_pivot_list_curved(p_score,
...
@@ -342,6 +342,7 @@ def generate_pivot_list_curved(p_score,
center_pos_yxs
=
[]
center_pos_yxs
=
[]
end_points_yxs
=
[]
end_points_yxs
=
[]
instance_center_pos_yxs
=
[]
instance_center_pos_yxs
=
[]
pred_strs
=
[]
if
instance_count
>
0
:
if
instance_count
>
0
:
for
instance_id
in
range
(
1
,
instance_count
):
for
instance_id
in
range
(
1
,
instance_count
):
pos_list
=
[]
pos_list
=
[]
...
@@ -367,12 +368,13 @@ def generate_pivot_list_curved(p_score,
...
@@ -367,12 +368,13 @@ def generate_pivot_list_curved(p_score,
if
is_backbone
:
if
is_backbone
:
keep_yxs_list_with_id
=
add_id
(
keep_yxs_list
,
image_id
=
image_id
)
keep_yxs_list_with_id
=
add_id
(
keep_yxs_list
,
image_id
=
image_id
)
instance_center_pos_yxs
.
append
(
keep_yxs_list_with_id
)
instance_center_pos_yxs
.
append
(
keep_yxs_list_with_id
)
pred_strs
.
append
(
decoded_str
)
else
:
else
:
end_points_yxs
.
extend
((
keep_yxs_list
[
0
],
keep_yxs_list
[
-
1
]))
end_points_yxs
.
extend
((
keep_yxs_list
[
0
],
keep_yxs_list
[
-
1
]))
center_pos_yxs
.
extend
(
keep_yxs_list
)
center_pos_yxs
.
extend
(
keep_yxs_list
)
if
is_backbone
:
if
is_backbone
:
return
instance_center_pos_yxs
return
pred_strs
,
instance_center_pos_yxs
else
:
else
:
return
center_pos_yxs
,
end_points_yxs
return
center_pos_yxs
,
end_points_yxs
...
...
ppocr/utils/e2e_utils/pgnet_pp_utils.py
浏览文件 @
c455034f
...
@@ -64,7 +64,7 @@ class PGNet_PostProcess(object):
...
@@ -64,7 +64,7 @@ class PGNet_PostProcess(object):
src_w
,
src_h
,
self
.
valid_set
)
src_w
,
src_h
,
self
.
valid_set
)
data
=
{
data
=
{
'points'
:
poly_list
,
'points'
:
poly_list
,
'
str
s'
:
keep_str_list
,
'
text
s'
:
keep_str_list
,
}
}
return
data
return
data
...
@@ -85,32 +85,13 @@ class PGNet_PostProcess(object):
...
@@ -85,32 +85,13 @@ class PGNet_PostProcess(object):
p_char
=
p_char
[
0
]
p_char
=
p_char
[
0
]
src_h
,
src_w
,
ratio_h
,
ratio_w
=
self
.
shape_list
[
0
]
src_h
,
src_w
,
ratio_h
,
ratio_w
=
self
.
shape_list
[
0
]
is_curved
=
self
.
valid_set
==
"totaltext"
is_curved
=
self
.
valid_set
==
"totaltext"
instance_yxs_list
=
generate_pivot_list_slow
(
char_seq_idx_set
,
instance_yxs_list
=
generate_pivot_list_slow
(
p_score
,
p_score
,
p_char
,
p_char
,
p_direction
,
p_direction
,
score_thresh
=
self
.
score_thresh
,
score_thresh
=
self
.
score_thresh
,
is_backbone
=
True
,
is_backbone
=
True
,
is_curved
=
is_curved
)
is_curved
=
is_curved
)
p_char
=
paddle
.
to_tensor
(
np
.
expand_dims
(
p_char
,
axis
=
0
))
char_seq_idx_set
=
[]
for
i
in
range
(
len
(
instance_yxs_list
)):
gather_info_lod
=
paddle
.
to_tensor
(
instance_yxs_list
[
i
])
f_char_map
=
paddle
.
transpose
(
p_char
,
[
0
,
2
,
3
,
1
])
feature_seq
=
paddle
.
gather_nd
(
f_char_map
,
gather_info_lod
)
feature_seq
=
np
.
expand_dims
(
feature_seq
.
numpy
(),
axis
=
0
)
feature_len
=
[
len
(
feature_seq
[
0
])]
featyre_seq
=
paddle
.
to_tensor
(
feature_seq
)
feature_len
=
np
.
array
([
feature_len
]).
astype
(
np
.
int64
)
length
=
paddle
.
to_tensor
(
feature_len
)
seq_pred
=
paddle
.
fluid
.
layers
.
ctc_greedy_decoder
(
input
=
featyre_seq
,
blank
=
36
,
input_length
=
length
)
seq_pred_str
=
seq_pred
[
0
].
numpy
().
tolist
()[
0
]
seq_len
=
seq_pred
[
1
].
numpy
()[
0
][
0
]
temp_t
=
[]
for
c
in
seq_pred_str
[:
seq_len
]:
temp_t
.
append
(
c
)
char_seq_idx_set
.
append
(
temp_t
)
seq_strs
=
[]
seq_strs
=
[]
for
char_idx_set
in
char_seq_idx_set
:
for
char_idx_set
in
char_seq_idx_set
:
pr_str
=
''
.
join
([
self
.
Lexicon_Table
[
pos
]
for
pos
in
char_idx_set
])
pr_str
=
''
.
join
([
self
.
Lexicon_Table
[
pos
]
for
pos
in
char_idx_set
])
...
@@ -176,6 +157,6 @@ class PGNet_PostProcess(object):
...
@@ -176,6 +157,6 @@ class PGNet_PostProcess(object):
exit
(
-
1
)
exit
(
-
1
)
data
=
{
data
=
{
'points'
:
poly_list
,
'points'
:
poly_list
,
'
str
s'
:
keep_str_list
,
'
text
s'
:
keep_str_list
,
}
}
return
data
return
data
tools/infer/predict_e2e.py
浏览文件 @
c455034f
...
@@ -122,7 +122,7 @@ class TextE2E(object):
...
@@ -122,7 +122,7 @@ class TextE2E(object):
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
post_result
=
self
.
postprocess_op
(
preds
,
shape_list
)
points
,
strs
=
post_result
[
'points'
],
post_result
[
'
str
s'
]
points
,
strs
=
post_result
[
'points'
],
post_result
[
'
text
s'
]
dt_boxes
=
self
.
filter_tag_det_res_only_clip
(
points
,
ori_im
.
shape
)
dt_boxes
=
self
.
filter_tag_det_res_only_clip
(
points
,
ori_im
.
shape
)
elapse
=
time
.
time
()
-
starttime
elapse
=
time
.
time
()
-
starttime
return
dt_boxes
,
strs
,
elapse
return
dt_boxes
,
strs
,
elapse
...
...
tools/infer_e2e.py
浏览文件 @
c455034f
...
@@ -103,7 +103,7 @@ def main():
...
@@ -103,7 +103,7 @@ def main():
images
=
paddle
.
to_tensor
(
images
)
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
)
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
,
shape_list
)
post_result
=
post_process_class
(
preds
,
shape_list
)
points
,
strs
=
post_result
[
'points'
],
post_result
[
'
str
s'
]
points
,
strs
=
post_result
[
'points'
],
post_result
[
'
text
s'
]
# write resule
# write resule
dt_boxes_json
=
[]
dt_boxes_json
=
[]
for
poly
,
str
in
zip
(
points
,
strs
):
for
poly
,
str
in
zip
(
points
,
strs
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录