Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
b56237dc
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b56237dc
编写于
6月 04, 2021
作者:
W
weishengyu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update trainer
上级
ed098b3c
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
48 addition
and
48 deletion
+48
-48
ppcls/engine/trainer.py
ppcls/engine/trainer.py
+48
-48
未找到文件。
ppcls/engine/trainer.py
浏览文件 @
b56237dc
...
...
@@ -158,7 +158,6 @@ class Trainer(object):
for
epoch_id
in
range
(
best_metric
[
"epoch"
]
+
1
,
self
.
config
[
"Global"
][
"epochs"
]
+
1
):
acc
=
0.0
self
.
model
.
train
()
for
iter_id
,
batch
in
enumerate
(
self
.
train_dataloader
()):
batch_size
=
batch
[
0
].
shape
[
0
]
batch
[
1
]
=
paddle
.
to_tensor
(
batch
[
1
].
numpy
().
astype
(
"int64"
)
...
...
@@ -241,34 +240,34 @@ class Trainer(object):
@
paddle
.
no_grad
()
def
eval
(
self
,
epoch_id
=
0
):
if
self
.
eval_dataloader
is
None
:
self
.
eval_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Eval"
,
self
.
device
)
if
self
.
gallery_dataloader
is
None
:
self
.
gallery_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Gallery"
,
self
.
device
)
if
self
.
query_dataloader
is
None
:
self
.
query_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Query"
,
self
.
device
)
# build train loss and metric info
if
self
.
eval_loss_func
is
None
:
self
.
eval_loss_func
=
self
.
_build_loss_info
(
self
.
config
[
"Loss"
],
"eval"
)
if
self
.
eval_metric_func
is
None
:
self
.
eval_metric_func
=
self
.
_build_metric_info
(
self
.
config
[
"Metric"
],
"eval"
)
self
.
model
.
eval
()
if
self
.
eval_mode
==
"classification"
:
self
.
eval_cls
(
epoch_id
)
if
self
.
eval_dataloader
is
None
:
self
.
eval_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Eval"
,
self
.
device
)
eval_result
=
self
.
eval_cls
(
epoch_id
)
elif
self
.
eval_mode
==
"retrieval"
:
self
.
eval_retrieval
(
epoch_id
)
if
self
.
gallery_dataloader
is
None
:
self
.
gallery_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Gallery"
,
self
.
device
)
if
self
.
query_dataloader
is
None
:
self
.
query_dataloader
=
build_dataloader
(
self
.
config
[
"DataLoader"
],
"Query"
,
self
.
device
)
# build train loss and metric info
if
self
.
eval_loss_func
is
None
:
self
.
eval_loss_func
=
self
.
_build_loss_info
(
self
.
config
[
"Loss"
],
"eval"
)
if
self
.
eval_metric_func
is
None
:
self
.
eval_metric_func
=
self
.
_build_metric_info
(
self
.
config
[
"Metric"
],
"eval"
)
eval_result
=
self
.
eval_retrieval
(
epoch_id
)
else
:
logger
.
warning
(
"Invalid eval mode: {}"
.
format
(
self
.
eval_mode
))
eval_result
=
None
self
.
model
.
train
()
return
eval_result
def
eval_cls
(
self
,
epoch_id
=
0
):
output_info
=
dict
()
...
...
@@ -332,9 +331,8 @@ class Trainer(object):
return
output_info
[
metric_key
].
avg
def
eval_retrieval
(
self
,
epoch_id
=
0
):
output_info
=
dict
()
self
.
model
.
eval
()
cum_similarity_matrix
=
None
# step1. build gallery
gallery_feas
,
gallery_img_id
,
gallery_camera_id
=
self
.
_cal_feature
(
name
=
'gallery'
)
...
...
@@ -342,7 +340,7 @@ class Trainer(object):
name
=
'query'
)
gallery_img_id
=
paddle
.
to_tensor
([
gallery_img_id
]).
t
()
if
gallery_camera_id
is
not
None
:
gallery_camera_id
=
paddle
.
to_tensor
(
gallery_camera_id
).
t
()
gallery_camera_id
=
paddle
.
to_tensor
(
[
gallery_camera_id
]
).
t
()
query_img_id
=
paddle
.
to_tensor
(
query_img_id
)
if
query_camera_id
is
not
None
:
query_camera_id
=
paddle
.
to_tensor
(
query_camera_id
)
...
...
@@ -352,35 +350,37 @@ class Trainer(object):
if
not
len
(
query_feas
)
%
sim_block_size
:
sections
.
append
(
len
(
query_feas
)
%
sim_block_size
)
fea_blocks
=
paddle
.
split
(
query_feas
,
num_or_sections
=
sections
)
camera_id_blocks
=
paddle
.
split
(
query_camera_id
,
num_or_sections
=
sections
)
if
query_camera_id
is
not
None
:
camera_id_blocks
=
paddle
.
split
(
query_camera_id
,
num_or_sections
=
sections
)
image_id_blocks
=
paddle
.
split
(
query_img_id
,
num_or_sections
=
sections
)
metric_key
=
None
for
block_idx
,
block_fea
in
enumerate
(
fea_blocks
):
similarit
ies
_matrix
=
paddle
.
matmul
(
similarit
y
_matrix
=
paddle
.
matmul
(
block_fea
,
gallery_feas
,
transpose_y
=
True
)
image_id_block
=
image_id_blocks
[
block_idx
]
image_id_mask
=
(
image_id_block
==
gallery_img_id
)
similarities_matrix
=
similarities_matrix
.
masked_select
(
image_id_mask
)
camera_id_block
=
camera_id_blocks
[
block_idx
]
camera_id_mask
=
(
camera_id_block
==
gallery_camera_id
)
similarities_matrix
=
similarities_matrix
.
masked_select
(
camera_id_mask
)
image_id_mask
=
(
image_id_block
!=
gallery_img_id
)
similarity_matrix
=
similarity_matrix
.
masked_select
(
image_id_mask
)
if
query_camera_id
is
not
None
:
camera_id_block
=
camera_id_blocks
[
block_idx
]
camera_id_mask
=
(
camera_id_block
!=
gallery_camera_id
)
similarity_matrix
=
similarity_matrix
.
masked_select
(
camera_id_mask
)
if
similarity_matrix
is
None
:
cum_similarity_matrix
=
similarity_matrix
else
:
cum_similarity_matrix
=
paddle
.
concat
(
cum_similarity_matrix
,
similarity_matrix
)
# calc metric
if
self
.
eval_metric_func
is
not
None
:
metric_dict
=
self
.
eval_metric_func
(
similarities_matrix
,
image_id_block
)
for
key
in
metric_dict
:
if
metric_key
is
None
:
metric_key
=
key
if
not
key
in
output_info
:
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
# calc metric
if
self
.
eval_metric_func
is
not
None
:
metric_dict
=
self
.
eval_metric_func
(
cum_similarity_matrix
,
query_img_id
,
gallery_img_id
)
else
:
metric_dict
=
{
metric_key
:
0.
}
output_info
[
key
].
update
(
metric_dict
[
key
].
numpy
()[
0
],
len
(
image_id_block
))
return
metric_dict
[
metric_key
]
def
_cal_feature
(
self
,
name
=
'gallery'
):
all_feas
=
None
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录