Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
b8a65d43
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b8a65d43
编写于
7月 08, 2021
作者:
L
LDOUBLEV
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix eval bug
上级
0742f5c5
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
28 addition
and
31 deletion
+28
-31
ppocr/metrics/det_metric.py
ppocr/metrics/det_metric.py
+3
-3
ppocr/metrics/distillation_metric.py
ppocr/metrics/distillation_metric.py
+4
-7
ppocr/modeling/architectures/distillation_model.py
ppocr/modeling/architectures/distillation_model.py
+1
-1
ppocr/postprocess/db_postprocess.py
ppocr/postprocess/db_postprocess.py
+12
-14
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+1
-1
tools/eval.py
tools/eval.py
+4
-3
tools/program.py
tools/program.py
+1
-0
tools/train.py
tools/train.py
+2
-2
未找到文件。
ppocr/metrics/det_metric.py
浏览文件 @
b8a65d43
...
...
@@ -55,9 +55,9 @@ class DetMetric(object):
result
=
self
.
evaluator
.
evaluate_image
(
gt_info_list
,
det_info_list
)
self
.
results
.
append
(
result
)
metircs
=
self
.
evaluator
.
combine_results
(
self
.
results
)
self
.
reset
()
return
metircs
#
metircs = self.evaluator.combine_results(self.results)
#
self.reset()
#
return metircs
def
get_metric
(
self
):
"""
...
...
ppocr/metrics/distillation_metric.py
浏览文件 @
b8a65d43
...
...
@@ -24,8 +24,8 @@ from .cls_metric import ClsMetric
class
DistillationMetric
(
object
):
def
__init__
(
self
,
key
=
None
,
base_metric_name
=
"RecMetric"
,
main_indicator
=
'acc'
,
base_metric_name
=
None
,
main_indicator
=
None
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
key
=
key
...
...
@@ -42,16 +42,13 @@ class DistillationMetric(object):
main_indicator
=
self
.
main_indicator
,
**
self
.
kwargs
)
self
.
metrics
[
key
].
reset
()
def
__call__
(
self
,
preds
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
assert
isinstance
(
preds
,
dict
)
if
self
.
metrics
is
None
:
self
.
_init_metrcis
(
preds
)
output
=
dict
()
for
key
in
preds
:
metric
=
self
.
metrics
[
key
].
__call__
(
preds
[
key
],
*
args
,
**
kwargs
)
for
sub_key
in
metric
:
output
[
"{}_{}"
.
format
(
key
,
sub_key
)]
=
metric
[
sub_key
]
return
output
self
.
metrics
[
key
].
__call__
(
preds
[
key
],
batch
,
**
kwargs
)
def
get_metric
(
self
):
"""
...
...
ppocr/modeling/architectures/distillation_model.py
浏览文件 @
b8a65d43
...
...
@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
pretrained
=
model_config
.
pop
(
"pretrained"
)
model
=
BaseModel
(
model_config
)
if
pretrained
is
not
None
:
load_pretrained_params
(
model
,
pretrained
)
model
=
load_pretrained_params
(
model
,
pretrained
)
if
freeze_params
:
for
param
in
model
.
parameters
():
param
.
trainable
=
False
...
...
ppocr/postprocess/db_postprocess.py
浏览文件 @
b8a65d43
...
...
@@ -189,29 +189,27 @@ class DBPostProcess(object):
return
boxes_batch
class
DistillationDBPostProcess
(
DBPostProcess
):
def
__init__
(
self
,
model_name
=
[
"student"
],
class
DistillationDBPostProcess
(
object
):
def
__init__
(
self
,
model_name
=
[
"student"
],
key
=
None
,
thresh
=
0.3
,
box_thresh
=
0.
7
,
box_thresh
=
0.
6
,
max_candidates
=
1000
,
unclip_ratio
=
2.0
,
unclip_ratio
=
1.5
,
use_dilation
=
False
,
score_mode
=
"fast"
,
**
kwargs
):
super
().
__init__
()
if
not
isinstance
(
model_name
,
list
):
model_name
=
[
model_name
]
self
.
model_name
=
model_name
self
.
key
=
key
self
.
post_process
=
DBPostProcess
(
thresh
=
thresh
,
box_thresh
=
box_thresh
,
max_candidates
=
max_candidates
,
unclip_ratio
=
unclip_ratio
,
use_dilation
=
use_dilation
,
score_mode
=
score_mode
)
def
__call__
(
self
,
predicts
,
shape_list
):
results
=
{}
for
name
in
self
.
model_name
:
pred
=
predicts
[
name
]
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
results
[
name
]
=
super
().
__call__
(
pred
,
shape_list
=
shape_list
)
for
k
in
self
.
model_name
:
results
[
k
]
=
self
.
post_process
(
predicts
[
k
],
shape_list
=
shape_list
)
return
results
ppocr/utils/save_load.py
浏览文件 @
b8a65d43
...
...
@@ -136,7 +136,7 @@ def load_pretrained_params(model, path):
)
model
.
set_state_dict
(
new_state_dict
)
print
(
f
"load pretrain successful from
{
path
}
"
)
return
True
return
model
def
save_model
(
model
,
optimizer
,
...
...
tools/eval.py
浏览文件 @
b8a65d43
...
...
@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.metrics
import
build_metric
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.save_load
import
init_model
,
load_pretrained_params
from
ppocr.utils.utility
import
print_dict
import
tools.program
as
program
...
...
@@ -59,7 +59,8 @@ def main():
model_type
=
config
[
'Architecture'
][
'model_type'
]
else
:
model_type
=
None
best_model_dict
=
init_model
(
config
,
model
)
best_model_dict
=
init_model
(
config
,
model
,
model_type
)
if
len
(
best_model_dict
):
logger
.
info
(
'metric in ckpt ***************'
)
for
k
,
v
in
best_model_dict
.
items
():
...
...
tools/program.py
浏览文件 @
b8a65d43
...
...
@@ -374,6 +374,7 @@ def eval(model,
eval_class
(
preds
,
batch
)
else
:
post_result
=
post_process_class
(
preds
,
batch
[
1
])
# post_result = post_result_["Student"]
eval_class
(
post_result
,
batch
)
pbar
.
update
(
1
)
total_frame
+=
len
(
images
)
...
...
tools/train.py
浏览文件 @
b8a65d43
...
...
@@ -97,8 +97,8 @@ def main(config, device, logger, vdl_writer):
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
# load pretrain model
pre_best_model_dict
=
load_dygraph_params
(
config
,
model
,
logger
,
optimizer
)
#
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
pre_best_model_dict
=
{}
logger
.
info
(
'train dataloader has {} iters'
.
format
(
len
(
train_dataloader
)))
if
valid_dataloader
is
not
None
:
logger
.
info
(
'valid dataloader has {} iters'
.
format
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录