Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
934de965
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
934de965
编写于
6月 05, 2021
作者:
W
weishengyu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rename cam -> unique
上级
16718f08
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
23 addition
and
23 deletion
+23
-23
ppcls/engine/trainer.py
ppcls/engine/trainer.py
+23
-23
未找到文件。
ppcls/engine/trainer.py
浏览文件 @
934de965
...
@@ -329,22 +329,22 @@ class Trainer(object):
...
@@ -329,22 +329,22 @@ class Trainer(object):
self
.
model
.
eval
()
self
.
model
.
eval
()
cum_similarity_matrix
=
None
cum_similarity_matrix
=
None
# step1. build gallery
# step1. build gallery
gallery_feas
,
gallery_img_id
,
gallery_
camera
_id
=
self
.
_cal_feature
(
gallery_feas
,
gallery_img_id
,
gallery_
unique
_id
=
self
.
_cal_feature
(
name
=
'gallery'
)
name
=
'gallery'
)
query_feas
,
query_img_id
,
query_
camera
_id
=
self
.
_cal_feature
(
query_feas
,
query_img_id
,
query_
query
_id
=
self
.
_cal_feature
(
name
=
'query'
)
name
=
'query'
)
gallery_img_id
=
gallery_img_id
gallery_img_id
=
gallery_img_id
# if gallery_
camera
_id is not None:
# if gallery_
unique
_id is not None:
# gallery_
camera_id = gallery_camera
_id
# gallery_
unique_id = gallery_unique
_id
# step2. do evaluation
# step2. do evaluation
sim_block_size
=
self
.
config
[
"Global"
].
get
(
"sim_block_size"
,
64
)
sim_block_size
=
self
.
config
[
"Global"
].
get
(
"sim_block_size"
,
64
)
sections
=
[
sim_block_size
]
*
(
len
(
query_feas
)
//
sim_block_size
)
sections
=
[
sim_block_size
]
*
(
len
(
query_feas
)
//
sim_block_size
)
if
len
(
query_feas
)
%
sim_block_size
:
if
len
(
query_feas
)
%
sim_block_size
:
sections
.
append
(
len
(
query_feas
)
%
sim_block_size
)
sections
.
append
(
len
(
query_feas
)
%
sim_block_size
)
fea_blocks
=
paddle
.
split
(
query_feas
,
num_or_sections
=
sections
)
fea_blocks
=
paddle
.
split
(
query_feas
,
num_or_sections
=
sections
)
if
query_
camera
_id
is
not
None
:
if
query_
query
_id
is
not
None
:
camera
_id_blocks
=
paddle
.
split
(
query
_id_blocks
=
paddle
.
split
(
query_
camera
_id
,
num_or_sections
=
sections
)
query_
query
_id
,
num_or_sections
=
sections
)
image_id_blocks
=
paddle
.
split
(
image_id_blocks
=
paddle
.
split
(
query_img_id
,
num_or_sections
=
sections
)
query_img_id
,
num_or_sections
=
sections
)
metric_key
=
None
metric_key
=
None
...
@@ -352,14 +352,14 @@ class Trainer(object):
...
@@ -352,14 +352,14 @@ class Trainer(object):
for
block_idx
,
block_fea
in
enumerate
(
fea_blocks
):
for
block_idx
,
block_fea
in
enumerate
(
fea_blocks
):
similarity_matrix
=
paddle
.
matmul
(
similarity_matrix
=
paddle
.
matmul
(
block_fea
,
gallery_feas
,
transpose_y
=
True
)
block_fea
,
gallery_feas
,
transpose_y
=
True
)
if
query_
camera
_id
is
not
None
:
if
query_
query
_id
is
not
None
:
camera_id_block
=
camera
_id_blocks
[
block_idx
]
query_id_block
=
query
_id_blocks
[
block_idx
]
camera_id_mask
=
(
camera_id_block
!=
gallery_camera
_id
.
t
())
query_id_mask
=
(
query_id_block
!=
gallery_unique
_id
.
t
())
image_id_block
=
image_id_blocks
[
block_idx
]
image_id_block
=
image_id_blocks
[
block_idx
]
image_id_mask
=
(
image_id_block
!=
gallery_img_id
.
t
())
image_id_mask
=
(
image_id_block
!=
gallery_img_id
.
t
())
keep_mask
=
paddle
.
logical_or
(
camera
_id_mask
,
image_id_mask
)
keep_mask
=
paddle
.
logical_or
(
query
_id_mask
,
image_id_mask
)
similarity_matrix
=
similarity_matrix
*
keep_mask
.
astype
(
similarity_matrix
=
similarity_matrix
*
keep_mask
.
astype
(
"float32"
)
"float32"
)
if
cum_similarity_matrix
is
None
:
if
cum_similarity_matrix
is
None
:
...
@@ -388,7 +388,7 @@ class Trainer(object):
...
@@ -388,7 +388,7 @@ class Trainer(object):
def
_cal_feature
(
self
,
name
=
'gallery'
):
def
_cal_feature
(
self
,
name
=
'gallery'
):
all_feas
=
None
all_feas
=
None
all_image_id
=
None
all_image_id
=
None
all_
camera
_id
=
None
all_
unique
_id
=
None
if
name
==
'gallery'
:
if
name
==
'gallery'
:
dataloader
=
self
.
gallery_dataloader
dataloader
=
self
.
gallery_dataloader
elif
name
==
'query'
:
elif
name
==
'query'
:
...
@@ -396,13 +396,13 @@ class Trainer(object):
...
@@ -396,13 +396,13 @@ class Trainer(object):
else
:
else
:
raise
RuntimeError
(
"Only support gallery or query dataset"
)
raise
RuntimeError
(
"Only support gallery or query dataset"
)
has_
cam
_id
=
False
has_
unique
_id
=
False
for
idx
,
batch
in
enumerate
(
dataloader
(
for
idx
,
batch
in
enumerate
(
dataloader
(
)):
# load is very time-consuming
)):
# load is very time-consuming
batch
=
[
paddle
.
to_tensor
(
x
)
for
x
in
batch
]
batch
=
[
paddle
.
to_tensor
(
x
)
for
x
in
batch
]
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
])
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
])
if
len
(
batch
)
==
3
:
if
len
(
batch
)
==
3
:
has_
cam
_id
=
True
has_
unique
_id
=
True
batch
[
2
]
=
batch
[
2
].
reshape
([
-
1
,
1
])
batch
[
2
]
=
batch
[
2
].
reshape
([
-
1
,
1
])
out
=
self
.
model
(
batch
[
0
],
batch
[
1
])
out
=
self
.
model
(
batch
[
0
],
batch
[
1
])
batch_feas
=
out
[
"features"
]
batch_feas
=
out
[
"features"
]
...
@@ -416,30 +416,30 @@ class Trainer(object):
...
@@ -416,30 +416,30 @@ class Trainer(object):
if
all_feas
is
None
:
if
all_feas
is
None
:
all_feas
=
batch_feas
all_feas
=
batch_feas
if
has_
cam
_id
:
if
has_
unique
_id
:
all_
camera
_id
=
batch
[
2
]
all_
unique
_id
=
batch
[
2
]
all_image_id
=
batch
[
1
]
all_image_id
=
batch
[
1
]
else
:
else
:
all_feas
=
paddle
.
concat
([
all_feas
,
batch_feas
])
all_feas
=
paddle
.
concat
([
all_feas
,
batch_feas
])
all_image_id
=
paddle
.
concat
([
all_image_id
,
batch
[
1
]])
all_image_id
=
paddle
.
concat
([
all_image_id
,
batch
[
1
]])
if
has_
cam
_id
:
if
has_
unique
_id
:
all_
camera_id
=
paddle
.
concat
([
all_camera
_id
,
batch
[
2
]])
all_
unique_id
=
paddle
.
concat
([
all_unique
_id
,
batch
[
2
]])
if
paddle
.
distributed
.
get_world_size
()
>
1
:
if
paddle
.
distributed
.
get_world_size
()
>
1
:
feat_list
=
[]
feat_list
=
[]
img_id_list
=
[]
img_id_list
=
[]
cam
_id_list
=
[]
unique
_id_list
=
[]
paddle
.
distributed
.
all_gather
(
feat_list
,
all_feas
)
paddle
.
distributed
.
all_gather
(
feat_list
,
all_feas
)
paddle
.
distributed
.
all_gather
(
img_id_list
,
all_image_id
)
paddle
.
distributed
.
all_gather
(
img_id_list
,
all_image_id
)
all_feas
=
paddle
.
concat
(
feat_list
,
axis
=
0
)
all_feas
=
paddle
.
concat
(
feat_list
,
axis
=
0
)
all_image_id
=
paddle
.
concat
(
img_id_list
,
axis
=
0
)
all_image_id
=
paddle
.
concat
(
img_id_list
,
axis
=
0
)
if
has_
cam
_id
:
if
has_
unique
_id
:
paddle
.
distributed
.
all_gather
(
cam_id_list
,
all_camera
_id
)
paddle
.
distributed
.
all_gather
(
unique_id_list
,
all_unique
_id
)
all_
camera_id
=
paddle
.
concat
(
cam
_id_list
,
axis
=
0
)
all_
unique_id
=
paddle
.
concat
(
unique
_id_list
,
axis
=
0
)
logger
.
info
(
"Build {} done, all feat shape: {}, begin to eval.."
.
logger
.
info
(
"Build {} done, all feat shape: {}, begin to eval.."
.
format
(
name
,
all_feas
.
shape
))
format
(
name
,
all_feas
.
shape
))
return
all_feas
,
all_image_id
,
all_
camera
_id
return
all_feas
,
all_image_id
,
all_
unique
_id
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
infer
(
self
,
):
def
infer
(
self
,
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录