Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
04e71041
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
04e71041
编写于
6月 28, 2022
作者:
W
wangjingyeye
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add db++
上级
1315cdfc
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
41 addition
and
13 deletion
+41
-13
configs/det/det_r50_db++_ic15.yml
configs/det/det_r50_db++_ic15.yml
+1
-1
configs/det/det_r50_db++_td_tr.yml
configs/det/det_r50_db++_td_tr.yml
+1
-1
tools/infer/predict_det.py
tools/infer/predict_det.py
+18
-1
tools/program.py
tools/program.py
+21
-10
未找到文件。
configs/det/det_r50_db++_ic15.yml
浏览文件 @
04e71041
...
...
@@ -18,7 +18,7 @@ Global:
save_res_path
:
./checkpoints/det_db/predicts_db.txt
Architecture
:
model_type
:
det
algorithm
:
DB
algorithm
:
DB
++
Transform
:
null
Backbone
:
name
:
ResNet
...
...
configs/det/det_r50_db++_td_tr.yml
浏览文件 @
04e71041
...
...
@@ -18,7 +18,7 @@ Global:
save_res_path
:
./checkpoints/det_db/predicts_db.txt
Architecture
:
model_type
:
det
algorithm
:
DB
algorithm
:
DB
++
Transform
:
null
Backbone
:
name
:
ResNet
...
...
tools/infer/predict_det.py
浏览文件 @
04e71041
...
...
@@ -67,6 +67,23 @@ class TextDetector(object):
postprocess_params
[
"unclip_ratio"
]
=
args
.
det_db_unclip_ratio
postprocess_params
[
"use_dilation"
]
=
args
.
use_dilation
postprocess_params
[
"score_mode"
]
=
args
.
det_db_score_mode
elif
self
.
det_algorithm
==
"DB++"
:
postprocess_params
[
'name'
]
=
'DBPostProcess'
postprocess_params
[
"thresh"
]
=
args
.
det_db_thresh
postprocess_params
[
"box_thresh"
]
=
args
.
det_db_box_thresh
postprocess_params
[
"max_candidates"
]
=
1000
postprocess_params
[
"unclip_ratio"
]
=
args
.
det_db_unclip_ratio
postprocess_params
[
"use_dilation"
]
=
args
.
use_dilation
postprocess_params
[
"score_mode"
]
=
args
.
det_db_score_mode
pre_process_list
[
1
]
=
{
'NormalizeImage'
:
{
'std'
:
[
1.0
,
1.0
,
1.0
],
'mean'
:
[
0.48109378172549
,
0.45752457890196
,
0.40787054090196
],
'scale'
:
'1./255.'
,
'order'
:
'hwc'
}
}
elif
self
.
det_algorithm
==
"EAST"
:
postprocess_params
[
'name'
]
=
'EASTPostProcess'
postprocess_params
[
"score_thresh"
]
=
args
.
det_east_score_thresh
...
...
@@ -231,7 +248,7 @@ class TextDetector(object):
preds
[
'f_score'
]
=
outputs
[
1
]
preds
[
'f_tco'
]
=
outputs
[
2
]
preds
[
'f_tvo'
]
=
outputs
[
3
]
elif
self
.
det_algorithm
in
[
'DB'
,
'PSE'
]:
elif
self
.
det_algorithm
in
[
'DB'
,
'PSE'
,
'DB++'
]:
preds
[
'maps'
]
=
outputs
[
0
]
elif
self
.
det_algorithm
==
'FCE'
:
for
i
,
output
in
enumerate
(
outputs
):
...
...
tools/program.py
浏览文件 @
04e71041
...
...
@@ -307,7 +307,8 @@ def train(config,
train_stats
.
update
(
stats
)
if
log_writer
is
not
None
and
dist
.
get_rank
()
==
0
:
log_writer
.
log_metrics
(
metrics
=
train_stats
.
get
(),
prefix
=
"TRAIN"
,
step
=
global_step
)
log_writer
.
log_metrics
(
metrics
=
train_stats
.
get
(),
prefix
=
"TRAIN"
,
step
=
global_step
)
if
dist
.
get_rank
()
==
0
and
(
(
global_step
>
0
and
global_step
%
print_batch_step
==
0
)
or
...
...
@@ -354,7 +355,8 @@ def train(config,
# logger metric
if
log_writer
is
not
None
:
log_writer
.
log_metrics
(
metrics
=
cur_metric
,
prefix
=
"EVAL"
,
step
=
global_step
)
log_writer
.
log_metrics
(
metrics
=
cur_metric
,
prefix
=
"EVAL"
,
step
=
global_step
)
if
cur_metric
[
main_indicator
]
>=
best_model_dict
[
main_indicator
]:
...
...
@@ -377,11 +379,18 @@ def train(config,
logger
.
info
(
best_str
)
# logger best metric
if
log_writer
is
not
None
:
log_writer
.
log_metrics
(
metrics
=
{
"best_{}"
.
format
(
main_indicator
):
best_model_dict
[
main_indicator
]
},
prefix
=
"EVAL"
,
step
=
global_step
)
log_writer
.
log_model
(
is_best
=
True
,
prefix
=
"best_accuracy"
,
metadata
=
best_model_dict
)
log_writer
.
log_metrics
(
metrics
=
{
"best_{}"
.
format
(
main_indicator
):
best_model_dict
[
main_indicator
]
},
prefix
=
"EVAL"
,
step
=
global_step
)
log_writer
.
log_model
(
is_best
=
True
,
prefix
=
"best_accuracy"
,
metadata
=
best_model_dict
)
reader_start
=
time
.
time
()
if
dist
.
get_rank
()
==
0
:
...
...
@@ -413,7 +422,8 @@ def train(config,
epoch
=
epoch
,
global_step
=
global_step
)
if
log_writer
is
not
None
:
log_writer
.
log_model
(
is_best
=
False
,
prefix
=
'iter_epoch_{}'
.
format
(
epoch
))
log_writer
.
log_model
(
is_best
=
False
,
prefix
=
'iter_epoch_{}'
.
format
(
epoch
))
best_str
=
'best metric, {}'
.
format
(
', '
.
join
(
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
best_model_dict
.
items
()]))
...
...
@@ -564,7 +574,7 @@ def preprocess(is_train=False):
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'SEED'
,
'SDMGR'
,
'LayoutXLM'
,
'LayoutLM'
,
'PREN'
,
'FCE'
,
'SVTR'
'SEED'
,
'SDMGR'
,
'LayoutXLM'
,
'LayoutLM'
,
'PREN'
,
'FCE'
,
'SVTR'
,
'DB++'
]
if
use_xpu
:
...
...
@@ -585,7 +595,8 @@ def preprocess(is_train=False):
vdl_writer_path
=
'{}/vdl/'
.
format
(
save_model_dir
)
log_writer
=
VDLLogger
(
save_model_dir
)
loggers
.
append
(
log_writer
)
if
(
'use_wandb'
in
config
[
'Global'
]
and
config
[
'Global'
][
'use_wandb'
])
or
'wandb'
in
config
:
if
(
'use_wandb'
in
config
[
'Global'
]
and
config
[
'Global'
][
'use_wandb'
])
or
'wandb'
in
config
:
save_dir
=
config
[
'Global'
][
'save_model_dir'
]
wandb_writer_path
=
"{}/wandb"
.
format
(
save_dir
)
if
"wandb"
in
config
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录