Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
af25e256
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
接近 2 年 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
af25e256
编写于
9月 23, 2021
作者:
W
weishengyu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify format
上级
9e975699
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
62 addition
and
68 deletion
+62
-68
ppcls/engine/evaluation/classification.py
ppcls/engine/evaluation/classification.py
+15
-17
ppcls/engine/evaluation/retrieval.py
ppcls/engine/evaluation/retrieval.py
+22
-23
ppcls/engine/train/train.py
ppcls/engine/train/train.py
+25
-28
未找到文件。
ppcls/engine/evaluation/classification.py
浏览文件 @
af25e256
...
@@ -22,7 +22,7 @@ from ppcls.utils.misc import AverageMeter
...
@@ -22,7 +22,7 @@ from ppcls.utils.misc import AverageMeter
from
ppcls.utils
import
logger
from
ppcls.utils
import
logger
def
classification_eval
(
e
valer
,
epoch_id
=
0
):
def
classification_eval
(
e
ngine
,
epoch_id
=
0
):
output_info
=
dict
()
output_info
=
dict
()
time_info
=
{
time_info
=
{
"batch_cost"
:
AverageMeter
(
"batch_cost"
:
AverageMeter
(
...
@@ -30,21 +30,19 @@ def classification_eval(evaler, epoch_id=0):
...
@@ -30,21 +30,19 @@ def classification_eval(evaler, epoch_id=0):
"reader_cost"
:
AverageMeter
(
"reader_cost"
:
AverageMeter
(
"reader_cost"
,
".5f"
,
postfix
=
" s,"
),
"reader_cost"
,
".5f"
,
postfix
=
" s,"
),
}
}
print_batch_step
=
e
valer
.
config
[
"Global"
][
"print_batch_step"
]
print_batch_step
=
e
ngine
.
config
[
"Global"
][
"print_batch_step"
]
metric_key
=
None
metric_key
=
None
tic
=
time
.
time
()
tic
=
time
.
time
()
eval_dataloader
=
evaler
.
eval_dataloader
if
evaler
.
use_dali
else
evaler
.
eval_dataloader
(
max_iter
=
len
(
engine
.
eval_dataloader
)
-
1
if
platform
.
system
(
)
)
==
"Windows"
else
len
(
engine
.
eval_dataloader
)
max_iter
=
len
(
evaler
.
eval_dataloader
)
-
1
if
platform
.
system
(
for
iter_id
,
batch
in
enumerate
(
engine
.
eval_dataloader
):
)
==
"Windows"
else
len
(
evaler
.
eval_dataloader
)
for
iter_id
,
batch
in
enumerate
(
eval_dataloader
):
if
iter_id
>=
max_iter
:
if
iter_id
>=
max_iter
:
break
break
if
iter_id
==
5
:
if
iter_id
==
5
:
for
key
in
time_info
:
for
key
in
time_info
:
time_info
[
key
].
reset
()
time_info
[
key
].
reset
()
if
e
valer
.
use_dali
:
if
e
ngine
.
use_dali
:
batch
=
[
batch
=
[
paddle
.
to_tensor
(
batch
[
0
][
'data'
]),
paddle
.
to_tensor
(
batch
[
0
][
'data'
]),
paddle
.
to_tensor
(
batch
[
0
][
'label'
])
paddle
.
to_tensor
(
batch
[
0
][
'label'
])
...
@@ -54,17 +52,17 @@ def classification_eval(evaler, epoch_id=0):
...
@@ -54,17 +52,17 @@ def classification_eval(evaler, epoch_id=0):
batch
[
0
]
=
paddle
.
to_tensor
(
batch
[
0
]).
astype
(
"float32"
)
batch
[
0
]
=
paddle
.
to_tensor
(
batch
[
0
]).
astype
(
"float32"
)
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
# image input
# image input
out
=
e
valer
.
model
(
batch
[
0
])
out
=
e
ngine
.
model
(
batch
[
0
])
# calc loss
# calc loss
if
e
valer
.
eval_loss_func
is
not
None
:
if
e
ngine
.
eval_loss_func
is
not
None
:
loss_dict
=
e
valer
.
eval_loss_func
(
out
,
batch
[
1
])
loss_dict
=
e
ngine
.
eval_loss_func
(
out
,
batch
[
1
])
for
key
in
loss_dict
:
for
key
in
loss_dict
:
if
key
not
in
output_info
:
if
key
not
in
output_info
:
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
output_info
[
key
].
update
(
loss_dict
[
key
].
numpy
()[
0
],
batch_size
)
output_info
[
key
].
update
(
loss_dict
[
key
].
numpy
()[
0
],
batch_size
)
# calc metric
# calc metric
if
e
valer
.
eval_metric_func
is
not
None
:
if
e
ngine
.
eval_metric_func
is
not
None
:
metric_dict
=
e
valer
.
eval_metric_func
(
out
,
batch
[
1
])
metric_dict
=
e
ngine
.
eval_metric_func
(
out
,
batch
[
1
])
if
paddle
.
distributed
.
get_world_size
()
>
1
:
if
paddle
.
distributed
.
get_world_size
()
>
1
:
for
key
in
metric_dict
:
for
key
in
metric_dict
:
paddle
.
distributed
.
all_reduce
(
paddle
.
distributed
.
all_reduce
(
...
@@ -97,18 +95,18 @@ def classification_eval(evaler, epoch_id=0):
...
@@ -97,18 +95,18 @@ def classification_eval(evaler, epoch_id=0):
])
])
logger
.
info
(
"[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}"
.
format
(
logger
.
info
(
"[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}"
.
format
(
epoch_id
,
iter_id
,
epoch_id
,
iter_id
,
len
(
e
valer
.
eval_dataloader
),
metric_msg
,
time_msg
,
ips_msg
))
len
(
e
ngine
.
eval_dataloader
),
metric_msg
,
time_msg
,
ips_msg
))
tic
=
time
.
time
()
tic
=
time
.
time
()
if
e
valer
.
use_dali
:
if
e
ngine
.
use_dali
:
e
valer
.
eval_dataloader
.
reset
()
e
ngine
.
eval_dataloader
.
reset
()
metric_msg
=
", "
.
join
([
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
avg
)
for
key
in
output_info
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
avg
)
for
key
in
output_info
])
])
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
))
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
))
# do not try to save best eval.model
# do not try to save best eval.model
if
e
valer
.
eval_metric_func
is
None
:
if
e
ngine
.
eval_metric_func
is
None
:
return
-
1
return
-
1
# return 1st metric in the dict
# return 1st metric in the dict
return
output_info
[
metric_key
].
avg
return
output_info
[
metric_key
].
avg
ppcls/engine/evaluation/retrieval.py
浏览文件 @
af25e256
...
@@ -20,21 +20,21 @@ import paddle
...
@@ -20,21 +20,21 @@ import paddle
from
ppcls.utils
import
logger
from
ppcls.utils
import
logger
def
retrieval_eval
(
e
valer
,
epoch_id
=
0
):
def
retrieval_eval
(
e
ngine
,
epoch_id
=
0
):
e
valer
.
model
.
eval
()
e
ngine
.
model
.
eval
()
# step1. build gallery
# step1. build gallery
if
e
valer
.
gallery_query_dataloader
is
not
None
:
if
e
ngine
.
gallery_query_dataloader
is
not
None
:
gallery_feas
,
gallery_img_id
,
gallery_unique_id
=
cal_feature
(
gallery_feas
,
gallery_img_id
,
gallery_unique_id
=
cal_feature
(
e
valer
,
name
=
'gallery_query'
)
e
ngine
,
name
=
'gallery_query'
)
query_feas
,
query_img_id
,
query_query_id
=
gallery_feas
,
gallery_img_id
,
gallery_unique_id
query_feas
,
query_img_id
,
query_query_id
=
gallery_feas
,
gallery_img_id
,
gallery_unique_id
else
:
else
:
gallery_feas
,
gallery_img_id
,
gallery_unique_id
=
cal_feature
(
gallery_feas
,
gallery_img_id
,
gallery_unique_id
=
cal_feature
(
e
valer
,
name
=
'gallery'
)
e
ngine
,
name
=
'gallery'
)
query_feas
,
query_img_id
,
query_query_id
=
cal_feature
(
query_feas
,
query_img_id
,
query_query_id
=
cal_feature
(
e
valer
,
name
=
'query'
)
e
ngine
,
name
=
'query'
)
# step2. do evaluation
# step2. do evaluation
sim_block_size
=
e
valer
.
config
[
"Global"
].
get
(
"sim_block_size"
,
64
)
sim_block_size
=
e
ngine
.
config
[
"Global"
].
get
(
"sim_block_size"
,
64
)
sections
=
[
sim_block_size
]
*
(
len
(
query_feas
)
//
sim_block_size
)
sections
=
[
sim_block_size
]
*
(
len
(
query_feas
)
//
sim_block_size
)
if
len
(
query_feas
)
%
sim_block_size
:
if
len
(
query_feas
)
%
sim_block_size
:
sections
.
append
(
len
(
query_feas
)
%
sim_block_size
)
sections
.
append
(
len
(
query_feas
)
%
sim_block_size
)
...
@@ -45,7 +45,7 @@ def retrieval_eval(evaler, epoch_id=0):
...
@@ -45,7 +45,7 @@ def retrieval_eval(evaler, epoch_id=0):
image_id_blocks
=
paddle
.
split
(
query_img_id
,
num_or_sections
=
sections
)
image_id_blocks
=
paddle
.
split
(
query_img_id
,
num_or_sections
=
sections
)
metric_key
=
None
metric_key
=
None
if
e
valer
.
eval_loss_func
is
None
:
if
e
ngine
.
eval_loss_func
is
None
:
metric_dict
=
{
metric_key
:
0.
}
metric_dict
=
{
metric_key
:
0.
}
else
:
else
:
metric_dict
=
dict
()
metric_dict
=
dict
()
...
@@ -65,7 +65,7 @@ def retrieval_eval(evaler, epoch_id=0):
...
@@ -65,7 +65,7 @@ def retrieval_eval(evaler, epoch_id=0):
else
:
else
:
keep_mask
=
None
keep_mask
=
None
metric_tmp
=
e
valer
.
eval_metric_func
(
similarity_matrix
,
metric_tmp
=
e
ngine
.
eval_metric_func
(
similarity_matrix
,
image_id_blocks
[
block_idx
],
image_id_blocks
[
block_idx
],
gallery_img_id
,
keep_mask
)
gallery_img_id
,
keep_mask
)
...
@@ -88,32 +88,31 @@ def retrieval_eval(evaler, epoch_id=0):
...
@@ -88,32 +88,31 @@ def retrieval_eval(evaler, epoch_id=0):
return
metric_dict
[
metric_key
]
return
metric_dict
[
metric_key
]
def
cal_feature
(
e
valer
,
name
=
'gallery'
):
def
cal_feature
(
e
ngine
,
name
=
'gallery'
):
all_feas
=
None
all_feas
=
None
all_image_id
=
None
all_image_id
=
None
all_unique_id
=
None
all_unique_id
=
None
has_unique_id
=
False
has_unique_id
=
False
if
name
==
'gallery'
:
if
name
==
'gallery'
:
dataloader
=
e
valer
.
gallery_dataloader
dataloader
=
e
ngine
.
gallery_dataloader
elif
name
==
'query'
:
elif
name
==
'query'
:
dataloader
=
e
valer
.
query_dataloader
dataloader
=
e
ngine
.
query_dataloader
elif
name
==
'gallery_query'
:
elif
name
==
'gallery_query'
:
dataloader
=
e
valer
.
gallery_query_dataloader
dataloader
=
e
ngine
.
gallery_query_dataloader
else
:
else
:
raise
RuntimeError
(
"Only support gallery or query dataset"
)
raise
RuntimeError
(
"Only support gallery or query dataset"
)
max_iter
=
len
(
dataloader
)
-
1
if
platform
.
system
()
==
"Windows"
else
len
(
max_iter
=
len
(
dataloader
)
-
1
if
platform
.
system
()
==
"Windows"
else
len
(
dataloader
)
dataloader
)
dataloader_tmp
=
dataloader
if
evaler
.
use_dali
else
dataloader
()
for
idx
,
batch
in
enumerate
(
dataloader
):
# load is very time-consuming
for
idx
,
batch
in
enumerate
(
dataloader_tmp
):
# load is very time-consuming
if
idx
>=
max_iter
:
if
idx
>=
max_iter
:
break
break
if
idx
%
e
valer
.
config
[
"Global"
][
"print_batch_step"
]
==
0
:
if
idx
%
e
ngine
.
config
[
"Global"
][
"print_batch_step"
]
==
0
:
logger
.
info
(
logger
.
info
(
f
"
{
name
}
feature calculation process: [
{
idx
}
/
{
len
(
dataloader
)
}
]"
f
"
{
name
}
feature calculation process: [
{
idx
}
/
{
len
(
dataloader
)
}
]"
)
)
if
e
valer
.
use_dali
:
if
e
ngine
.
use_dali
:
batch
=
[
batch
=
[
paddle
.
to_tensor
(
batch
[
0
][
'data'
]),
paddle
.
to_tensor
(
batch
[
0
][
'data'
]),
paddle
.
to_tensor
(
batch
[
0
][
'label'
])
paddle
.
to_tensor
(
batch
[
0
][
'label'
])
...
@@ -123,20 +122,20 @@ def cal_feature(evaler, name='gallery'):
...
@@ -123,20 +122,20 @@ def cal_feature(evaler, name='gallery'):
if
len
(
batch
)
==
3
:
if
len
(
batch
)
==
3
:
has_unique_id
=
True
has_unique_id
=
True
batch
[
2
]
=
batch
[
2
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
batch
[
2
]
=
batch
[
2
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
out
=
e
valer
.
model
(
batch
[
0
],
batch
[
1
])
out
=
e
ngine
.
model
(
batch
[
0
],
batch
[
1
])
batch_feas
=
out
[
"features"
]
batch_feas
=
out
[
"features"
]
# do norm
# do norm
if
e
valer
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
):
if
e
ngine
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
):
feas_norm
=
paddle
.
sqrt
(
feas_norm
=
paddle
.
sqrt
(
paddle
.
sum
(
paddle
.
square
(
batch_feas
),
axis
=
1
,
keepdim
=
True
))
paddle
.
sum
(
paddle
.
square
(
batch_feas
),
axis
=
1
,
keepdim
=
True
))
batch_feas
=
paddle
.
divide
(
batch_feas
,
feas_norm
)
batch_feas
=
paddle
.
divide
(
batch_feas
,
feas_norm
)
# do binarize
# do binarize
if
e
valer
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"round"
:
if
e
ngine
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"round"
:
batch_feas
=
paddle
.
round
(
batch_feas
).
astype
(
"float32"
)
*
2.0
-
1.0
batch_feas
=
paddle
.
round
(
batch_feas
).
astype
(
"float32"
)
*
2.0
-
1.0
if
e
valer
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"sign"
:
if
e
ngine
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"sign"
:
batch_feas
=
paddle
.
sign
(
batch_feas
).
astype
(
"float32"
)
batch_feas
=
paddle
.
sign
(
batch_feas
).
astype
(
"float32"
)
if
all_feas
is
None
:
if
all_feas
is
None
:
...
@@ -150,8 +149,8 @@ def cal_feature(evaler, name='gallery'):
...
@@ -150,8 +149,8 @@ def cal_feature(evaler, name='gallery'):
if
has_unique_id
:
if
has_unique_id
:
all_unique_id
=
paddle
.
concat
([
all_unique_id
,
batch
[
2
]])
all_unique_id
=
paddle
.
concat
([
all_unique_id
,
batch
[
2
]])
if
e
valer
.
use_dali
:
if
e
ngine
.
use_dali
:
dataloader
_tmp
.
reset
()
dataloader
.
reset
()
if
paddle
.
distributed
.
get_world_size
()
>
1
:
if
paddle
.
distributed
.
get_world_size
()
>
1
:
feat_list
=
[]
feat_list
=
[]
...
...
ppcls/engine/train/train.py
浏览文件 @
af25e256
...
@@ -18,19 +18,16 @@ import paddle
...
@@ -18,19 +18,16 @@ import paddle
from
ppcls.engine.train.utils
import
update_loss
,
update_metric
,
log_info
from
ppcls.engine.train.utils
import
update_loss
,
update_metric
,
log_info
def
train_epoch
(
trainer
,
epoch_id
,
print_batch_step
):
def
train_epoch
(
engine
,
epoch_id
,
print_batch_step
):
tic
=
time
.
time
()
tic
=
time
.
time
()
for
iter_id
,
batch
in
enumerate
(
engine
.
train_dataloader
):
train_dataloader
=
trainer
.
train_dataloader
if
trainer
.
use_dali
else
trainer
.
train_dataloader
(
if
iter_id
>=
engine
.
max_iter
:
)
for
iter_id
,
batch
in
enumerate
(
train_dataloader
):
if
iter_id
>=
trainer
.
max_iter
:
break
break
if
iter_id
==
5
:
if
iter_id
==
5
:
for
key
in
trainer
.
time_info
:
for
key
in
engine
.
time_info
:
trainer
.
time_info
[
key
].
reset
()
engine
.
time_info
[
key
].
reset
()
trainer
.
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
engine
.
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
if
trainer
.
use_dali
:
if
engine
.
use_dali
:
batch
=
[
batch
=
[
paddle
.
to_tensor
(
batch
[
0
][
'data'
]),
paddle
.
to_tensor
(
batch
[
0
][
'data'
]),
paddle
.
to_tensor
(
batch
[
0
][
'label'
])
paddle
.
to_tensor
(
batch
[
0
][
'label'
])
...
@@ -38,43 +35,43 @@ def train_epoch(trainer, epoch_id, print_batch_step):
...
@@ -38,43 +35,43 @@ def train_epoch(trainer, epoch_id, print_batch_step):
batch_size
=
batch
[
0
].
shape
[
0
]
batch_size
=
batch
[
0
].
shape
[
0
]
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
trainer
.
global_step
+=
1
engine
.
global_step
+=
1
# image input
# image input
if
trainer
.
amp
:
if
engine
.
amp
:
with
paddle
.
amp
.
auto_cast
(
custom_black_list
=
{
with
paddle
.
amp
.
auto_cast
(
custom_black_list
=
{
"flatten_contiguous_range"
,
"greater_than"
"flatten_contiguous_range"
,
"greater_than"
}):
}):
out
=
forward
(
trainer
,
batch
)
out
=
forward
(
engine
,
batch
)
loss_dict
=
trainer
.
train_loss_func
(
out
,
batch
[
1
])
loss_dict
=
engine
.
train_loss_func
(
out
,
batch
[
1
])
else
:
else
:
out
=
forward
(
trainer
,
batch
)
out
=
forward
(
engine
,
batch
)
# calc loss
# calc loss
if
trainer
.
config
[
"DataLoader"
][
"Train"
][
"dataset"
].
get
(
if
engine
.
config
[
"DataLoader"
][
"Train"
][
"dataset"
].
get
(
"batch_transform_ops"
,
None
):
"batch_transform_ops"
,
None
):
loss_dict
=
trainer
.
train_loss_func
(
out
,
batch
[
1
:])
loss_dict
=
engine
.
train_loss_func
(
out
,
batch
[
1
:])
else
:
else
:
loss_dict
=
trainer
.
train_loss_func
(
out
,
batch
[
1
])
loss_dict
=
engine
.
train_loss_func
(
out
,
batch
[
1
])
# step opt and lr
# step opt and lr
if
trainer
.
amp
:
if
engine
.
amp
:
scaled
=
trainer
.
scaler
.
scale
(
loss_dict
[
"loss"
])
scaled
=
engine
.
scaler
.
scale
(
loss_dict
[
"loss"
])
scaled
.
backward
()
scaled
.
backward
()
trainer
.
scaler
.
minimize
(
trainer
.
optimizer
,
scaled
)
engine
.
scaler
.
minimize
(
engine
.
optimizer
,
scaled
)
else
:
else
:
loss_dict
[
"loss"
].
backward
()
loss_dict
[
"loss"
].
backward
()
trainer
.
optimizer
.
step
()
engine
.
optimizer
.
step
()
trainer
.
optimizer
.
clear_grad
()
engine
.
optimizer
.
clear_grad
()
trainer
.
lr_sch
.
step
()
engine
.
lr_sch
.
step
()
# below code just for logging
# below code just for logging
# update metric_for_logger
# update metric_for_logger
update_metric
(
trainer
,
out
,
batch
,
batch_size
)
update_metric
(
engine
,
out
,
batch
,
batch_size
)
# update_loss_for_logger
# update_loss_for_logger
update_loss
(
trainer
,
loss_dict
,
batch_size
)
update_loss
(
engine
,
loss_dict
,
batch_size
)
trainer
.
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
engine
.
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
if
iter_id
%
print_batch_step
==
0
:
if
iter_id
%
print_batch_step
==
0
:
log_info
(
trainer
,
batch_size
,
epoch_id
,
iter_id
)
log_info
(
engine
,
batch_size
,
epoch_id
,
iter_id
)
tic
=
time
.
time
()
tic
=
time
.
time
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录