Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
310d399b
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
310d399b
编写于
3月 08, 2021
作者:
J
Jethong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ADD PGnet_v3
上级
bb49e1a5
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
4 addition
and
21 deletion
+4
-21
ppocr/data/pgnet_dataset.py
ppocr/data/pgnet_dataset.py
+4
-2
ppocr/postprocess/pg_postprocess.py
ppocr/postprocess/pg_postprocess.py
+0
-19
未找到文件。
ppocr/data/pgnet_dataset.py
浏览文件 @
310d399b
...
@@ -19,10 +19,11 @@ import random
...
@@ -19,10 +19,11 @@ import random
class
PGDateSet
(
Dataset
):
class
PGDateSet
(
Dataset
):
def
__init__
(
self
,
config
,
mode
,
logger
):
def
__init__
(
self
,
config
,
mode
,
logger
,
seed
=
None
):
super
(
PGDateSet
,
self
).
__init__
()
super
(
PGDateSet
,
self
).
__init__
()
self
.
logger
=
logger
self
.
logger
=
logger
self
.
seed
=
seed
global_config
=
config
[
'Global'
]
global_config
=
config
[
'Global'
]
dataset_config
=
config
[
mode
][
'dataset'
]
dataset_config
=
config
[
mode
][
'dataset'
]
loader_config
=
config
[
mode
][
'loader'
]
loader_config
=
config
[
mode
][
'loader'
]
...
@@ -36,7 +37,6 @@ class PGDateSet(Dataset):
...
@@ -36,7 +37,6 @@ class PGDateSet(Dataset):
assert
len
(
assert
len
(
ratio_list
ratio_list
)
==
data_source_num
,
"The length of ratio_list should be the same as the file_list."
)
==
data_source_num
,
"The length of ratio_list should be the same as the file_list."
# self.data_dir = dataset_config['data_dir']
self
.
do_shuffle
=
loader_config
[
'shuffle'
]
self
.
do_shuffle
=
loader_config
[
'shuffle'
]
logger
.
info
(
"Initialize indexs of datasets:%s"
%
label_file_list
)
logger
.
info
(
"Initialize indexs of datasets:%s"
%
label_file_list
)
...
@@ -50,6 +50,7 @@ class PGDateSet(Dataset):
...
@@ -50,6 +50,7 @@ class PGDateSet(Dataset):
def
shuffle_data_random
(
self
):
def
shuffle_data_random
(
self
):
if
self
.
do_shuffle
:
if
self
.
do_shuffle
:
random
.
seed
(
self
.
seed
)
random
.
shuffle
(
self
.
data_lines
)
random
.
shuffle
(
self
.
data_lines
)
return
return
...
@@ -122,6 +123,7 @@ class PGDateSet(Dataset):
...
@@ -122,6 +123,7 @@ class PGDateSet(Dataset):
else
:
else
:
print
(
"Unrecognized data format..."
)
print
(
"Unrecognized data format..."
)
exit
(
-
1
)
exit
(
-
1
)
random
.
seed
(
self
.
seed
)
image_files
=
random
.
sample
(
image_files
=
random
.
sample
(
image_files
,
round
(
len
(
image_files
)
*
ratio_list
[
idx
]))
image_files
,
round
(
len
(
image_files
)
*
ratio_list
[
idx
]))
data_lines
.
extend
(
image_files
)
data_lines
.
extend
(
image_files
)
...
...
ppocr/postprocess/pg_postprocess.py
浏览文件 @
310d399b
...
@@ -113,7 +113,6 @@ class PGPostProcess(object):
...
@@ -113,7 +113,6 @@ class PGPostProcess(object):
all_point_pair_list
=
[]
all_point_pair_list
=
[]
for
yx_center_line
,
keep_str
in
zip
(
instance_yxs_list
,
seq_strs
):
for
yx_center_line
,
keep_str
in
zip
(
instance_yxs_list
,
seq_strs
):
if
len
(
yx_center_line
)
==
1
:
if
len
(
yx_center_line
)
==
1
:
print
(
'the length of tcl point is less than 2, repeat'
)
yx_center_line
.
append
(
yx_center_line
[
-
1
])
yx_center_line
.
append
(
yx_center_line
[
-
1
])
# expand corresponding offset for total-text.
# expand corresponding offset for total-text.
...
@@ -148,7 +147,6 @@ class PGPostProcess(object):
...
@@ -148,7 +147,6 @@ class PGPostProcess(object):
# ndarry: (x, 2)
# ndarry: (x, 2)
detected_poly
,
pair_length_info
=
point_pair2poly
(
point_pair_list
)
detected_poly
,
pair_length_info
=
point_pair2poly
(
point_pair_list
)
print
(
'expand along width. {}'
.
format
(
detected_poly
.
shape
))
detected_poly
=
expand_poly_along_width
(
detected_poly
=
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
=
0.2
)
detected_poly
,
shrink_ratio_of_width
=
0.2
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
]
=
np
.
clip
(
...
@@ -157,7 +155,6 @@ class PGPostProcess(object):
...
@@ -157,7 +155,6 @@ class PGPostProcess(object):
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
if
len
(
keep_str
)
<
2
:
if
len
(
keep_str
)
<
2
:
print
(
'--> too short, {}'
.
format
(
keep_str
))
continue
continue
keep_str_list
.
append
(
keep_str
)
keep_str_list
.
append
(
keep_str
)
...
@@ -175,20 +172,4 @@ class PGPostProcess(object):
...
@@ -175,20 +172,4 @@ class PGPostProcess(object):
'points'
:
poly_list
,
'points'
:
poly_list
,
'strs'
:
keep_str_list
,
'strs'
:
keep_str_list
,
}
}
# visualization
# if self.save_visualization:
# visualize_e2e_result(im_fn, poly_list, keep_str_list, src_im)
# visualize_point_result(im_fn, all_point_list, all_point_pair_list, src_im)
# save detected boxes
# txt_dir = (result_path[:-1] if result_path.endswith('/') else result_path) + '_txt_anno'
# if not os.path.exists(txt_dir):
# os.makedirs(txt_dir)
# res_file = os.path.join(txt_dir, '{}.txt'.format(im_prefix))
# with open(res_file, 'w') as f:
# for i_box, box in enumerate(poly_list):
# seq_str = keep_str_list[i_box]
# box = np.round(box).astype('int32')
# box_str = ','.join(str(s) for s in (box.flatten().tolist()))
# f.write('{}\t{}\r\n'.format(box_str, seq_str))
return
data
return
data
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录