Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
8002ccf4
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 2 年 前同步成功
通知
118
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8002ccf4
编写于
3月 14, 2023
作者:
T
Tingquan Gao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert "support ShiTu"
This reverts commit
9beb154b
.
上级
78652070
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
184 addition
and
205 deletion
+184
-205
ppcls/data/__init__.py
ppcls/data/__init__.py
+8
-9
ppcls/engine/engine.py
ppcls/engine/engine.py
+0
-1
ppcls/engine/evaluation/__init__.py
ppcls/engine/evaluation/__init__.py
+6
-6
ppcls/engine/evaluation/retrieval.py
ppcls/engine/evaluation/retrieval.py
+166
-186
ppcls/engine/train/__init__.py
ppcls/engine/train/__init__.py
+3
-2
ppcls/engine/train/train_progressive.py
ppcls/engine/train/train_progressive.py
+1
-1
未找到文件。
ppcls/data/__init__.py
浏览文件 @
8002ccf4
...
@@ -88,15 +88,14 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
...
@@ -88,15 +88,14 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
random
.
seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
def
build_dataloader
(
config
,
*
mode
,
seed
=
None
):
def
build_dataloader
(
config
,
mode
,
seed
=
None
):
dataloader_config
=
config
[
"DataLoader"
]
assert
mode
in
[
for
m
in
mode
:
assert
m
in
[
'Train'
,
'Eval'
,
'Test'
,
'Gallery'
,
'Query'
,
'UnLabelTrain'
'Train'
,
'Eval'
,
'Test'
,
'Gallery'
,
'Query'
,
'UnLabelTrain'
],
"Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
],
"Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
assert
m
in
dataloader_config
.
keys
(),
"{} config not in yaml"
.
format
(
m
)
assert
mode
in
config
[
"DataLoader"
].
keys
(),
"{} config not in yaml"
.
format
(
dataloader_config
=
dataloader_config
[
m
]
mode
)
dataloader_config
=
config
[
"DataLoader"
][
mode
]
class_num
=
config
[
"Arch"
].
get
(
"class_num"
,
None
)
class_num
=
config
[
"Arch"
].
get
(
"class_num"
,
None
)
epochs
=
config
[
"Global"
][
"epochs"
]
epochs
=
config
[
"Global"
][
"epochs"
]
use_dali
=
config
[
"Global"
].
get
(
"use_dali"
,
False
)
use_dali
=
config
[
"Global"
].
get
(
"use_dali"
,
False
)
...
...
ppcls/engine/engine.py
浏览文件 @
8002ccf4
...
@@ -22,7 +22,6 @@ from paddle import nn
...
@@ -22,7 +22,6 @@ from paddle import nn
import
numpy
as
np
import
numpy
as
np
import
random
import
random
from
..utils.amp
import
AMPForwardDecorator
from
ppcls.utils
import
logger
from
ppcls.utils
import
logger
from
ppcls.utils.logger
import
init_logger
from
ppcls.utils.logger
import
init_logger
from
ppcls.utils.config
import
print_config
from
ppcls.utils.config
import
print_config
...
...
ppcls/engine/evaluation/__init__.py
浏览文件 @
8002ccf4
...
@@ -13,17 +13,17 @@
...
@@ -13,17 +13,17 @@
# limitations under the License.
# limitations under the License.
from
.classification
import
ClassEval
from
.classification
import
ClassEval
from
.retrieval
import
RetrievalE
val
from
.retrieval
import
retrieval_e
val
from
.adaface
import
adaface_eval
from
.adaface
import
adaface_eval
def
build_eval_func
(
config
,
mode
,
model
):
def
build_eval_func
(
config
,
mode
,
model
):
if
mode
not
in
[
"eval"
,
"train"
]:
if
mode
not
in
[
"eval"
,
"train"
]:
return
None
return
None
task
=
config
[
"Global"
].
get
(
"task"
,
"classification"
)
eval_mode
=
config
[
"Global"
].
get
(
"eval_mode"
,
None
)
if
task
==
"classification"
:
if
eval_mode
is
None
:
config
[
"Global"
][
"eval_mode"
]
=
"classification"
return
ClassEval
(
config
,
mode
,
model
)
return
ClassEval
(
config
,
mode
,
model
)
elif
task
==
"retrieval"
:
return
RetrievalEval
(
config
,
mode
,
model
)
else
:
else
:
raise
Exception
()
return
getattr
(
sys
.
modules
[
__name__
],
eval_mode
+
"_eval"
)(
config
,
mode
,
model
)
ppcls/engine/evaluation/retrieval.py
浏览文件 @
8002ccf4
...
@@ -21,50 +21,25 @@ import numpy as np
...
@@ -21,50 +21,25 @@ import numpy as np
import
paddle
import
paddle
import
scipy
import
scipy
from
...utils.misc
import
AverageMeter
from
ppcls.utils
import
all_gather
,
logger
from
...utils
import
all_gather
,
logger
from
...data
import
build_dataloader
from
...loss
import
build_loss
from
...metric
import
build_metrics
class
RetrievalEval
(
object
):
def
__init__
(
self
,
config
,
mode
,
model
):
self
.
config
=
config
self
.
model
=
model
self
.
print_batch_step
=
self
.
config
[
"Global"
][
"print_batch_step"
]
self
.
use_dali
=
self
.
config
[
"Global"
].
get
(
"use_dali"
,
False
)
self
.
eval_metric_func
=
build_metrics
(
self
.
config
,
"Eval"
)
self
.
eval_loss_func
=
build_loss
(
self
.
config
,
"Eval"
)
self
.
output_info
=
dict
()
self
.
gallery_query_dataloader
=
None
if
len
(
self
.
config
[
"DataLoader"
][
"Eval"
].
keys
())
==
1
:
self
.
gallery_query_dataloader
=
build_dataloader
(
self
.
config
,
"Eval"
)
else
:
self
.
gallery_dataloader
=
build_dataloader
(
self
.
config
,
"Eval"
,
"Gallery"
)
self
.
query_dataloader
=
build_dataloader
(
self
.
config
,
"Eval"
,
"Query"
)
def
__call__
(
self
,
epoch_id
=
0
):
self
.
model
.
eval
()
def
retrieval_eval
(
engine
,
epoch_id
=
0
):
engine
.
model
.
eval
()
# step1. prepare query and gallery features
# step1. prepare query and gallery features
if
self
.
gallery_query_dataloader
is
not
None
:
if
engine
.
gallery_query_dataloader
is
not
None
:
gallery_feat
,
gallery_label
,
gallery_camera
=
self
.
compute_feature
(
gallery_feat
,
gallery_label
,
gallery_camera
=
compute_feature
(
"gallery_query"
)
engine
,
"gallery_query"
)
query_feat
,
query_label
,
query_camera
=
gallery_feat
,
gallery_label
,
gallery_camera
query_feat
,
query_label
,
query_camera
=
gallery_feat
,
gallery_label
,
gallery_camera
else
:
else
:
gallery_feat
,
gallery_label
,
gallery_camera
=
self
.
compute_feature
(
gallery_feat
,
gallery_label
,
gallery_camera
=
compute_feature
(
"gallery"
)
engine
,
"gallery"
)
query_feat
,
query_label
,
query_camera
=
self
.
compute_feature
(
query_feat
,
query_label
,
query_camera
=
compute_feature
(
engine
,
"query"
)
"query"
)
# step2. split features into feature blocks for saving memory
# step2. split features into feature blocks for saving memory
num_query
=
len
(
query_feat
)
num_query
=
len
(
query_feat
)
block_size
=
self
.
config
[
"Global"
].
get
(
"sim_block_size"
,
64
)
block_size
=
engine
.
config
[
"Global"
].
get
(
"sim_block_size"
,
64
)
sections
=
[
block_size
]
*
(
num_query
//
block_size
)
sections
=
[
block_size
]
*
(
num_query
//
block_size
)
if
num_query
%
block_size
>
0
:
if
num_query
%
block_size
>
0
:
sections
.
append
(
num_query
%
block_size
)
sections
.
append
(
num_query
%
block_size
)
...
@@ -76,15 +51,15 @@ class RetrievalEval(object):
...
@@ -76,15 +51,15 @@ class RetrievalEval(object):
metric_key
=
None
metric_key
=
None
# step3. compute metric
# step3. compute metric
if
self
.
eval_loss_func
is
None
:
if
engine
.
eval_loss_func
is
None
:
metric_dict
=
{
metric_key
:
0.0
}
metric_dict
=
{
metric_key
:
0.0
}
else
:
else
:
use_reranking
=
self
.
config
[
"Global"
].
get
(
"re_ranking"
,
False
)
use_reranking
=
engine
.
config
[
"Global"
].
get
(
"re_ranking"
,
False
)
logger
.
info
(
f
"re_ranking=
{
use_reranking
}
"
)
logger
.
info
(
f
"re_ranking=
{
use_reranking
}
"
)
if
use_reranking
:
if
use_reranking
:
# compute distance matrix
# compute distance matrix
distmat
=
compute_re_ranking_dist
(
distmat
=
compute_re_ranking_dist
(
query_feat
,
gallery_feat
,
self
.
config
[
"Global"
].
get
(
query_feat
,
gallery_feat
,
engine
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
),
20
,
6
,
0.3
)
"feature_normalize"
,
True
),
20
,
6
,
0.3
)
# exclude illegal distance
# exclude illegal distance
if
query_camera
is
not
None
:
if
query_camera
is
not
None
:
...
@@ -92,12 +67,11 @@ class RetrievalEval(object):
...
@@ -92,12 +67,11 @@ class RetrievalEval(object):
label_mask
=
query_label
!=
gallery_label
.
t
()
label_mask
=
query_label
!=
gallery_label
.
t
()
keep_mask
=
label_mask
|
camera_mask
keep_mask
=
label_mask
|
camera_mask
distmat
=
keep_mask
.
astype
(
query_feat
.
dtype
)
*
distmat
+
(
distmat
=
keep_mask
.
astype
(
query_feat
.
dtype
)
*
distmat
+
(
~
keep_mask
).
astype
(
query_feat
.
dtype
)
*
(
distmat
.
max
()
+
~
keep_mask
).
astype
(
query_feat
.
dtype
)
*
(
distmat
.
max
()
+
1
)
1
)
else
:
else
:
keep_mask
=
None
keep_mask
=
None
# compute metric with all samples
# compute metric with all samples
metric_dict
=
self
.
eval_metric_func
(
-
distmat
,
query_label
,
metric_dict
=
engine
.
eval_metric_func
(
-
distmat
,
query_label
,
gallery_label
,
keep_mask
)
gallery_label
,
keep_mask
)
else
:
else
:
metric_dict
=
defaultdict
(
float
)
metric_dict
=
defaultdict
(
float
)
...
@@ -116,13 +90,13 @@ class RetrievalEval(object):
...
@@ -116,13 +90,13 @@ class RetrievalEval(object):
else
:
else
:
keep_mask
=
None
keep_mask
=
None
# compute metric by block
# compute metric by block
metric_block
=
self
.
eval_metric_func
(
metric_block
=
engine
.
eval_metric_func
(
distmat
,
query_label_blocks
[
block_idx
],
gallery_label
,
distmat
,
query_label_blocks
[
block_idx
],
gallery_label
,
keep_mask
)
keep_mask
)
# accumulate metric
# accumulate metric
for
key
in
metric_block
:
for
key
in
metric_block
:
metric_dict
[
key
]
+=
metric_block
[
metric_dict
[
key
]
+=
metric_block
[
key
]
*
block_feat
.
shape
[
key
]
*
block_feat
.
shape
[
0
]
/
num_query
0
]
/
num_query
metric_info_list
=
[]
metric_info_list
=
[]
for
key
,
value
in
metric_dict
.
items
():
for
key
,
value
in
metric_dict
.
items
():
...
@@ -134,13 +108,14 @@ class RetrievalEval(object):
...
@@ -134,13 +108,14 @@ class RetrievalEval(object):
return
metric_dict
[
metric_key
]
return
metric_dict
[
metric_key
]
def
compute_feature
(
self
,
name
=
"gallery"
):
def
compute_feature
(
engine
,
name
=
"gallery"
):
if
name
==
"gallery"
:
if
name
==
"gallery"
:
dataloader
=
self
.
gallery_dataloader
dataloader
=
engine
.
gallery_dataloader
elif
name
==
"query"
:
elif
name
==
"query"
:
dataloader
=
self
.
query_dataloader
dataloader
=
engine
.
query_dataloader
elif
name
==
"gallery_query"
:
elif
name
==
"gallery_query"
:
dataloader
=
self
.
gallery_query_dataloader
dataloader
=
engine
.
gallery_query_dataloader
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Only support gallery or query or gallery_query dataset, but got
{
name
}
"
f
"Only support gallery or query or gallery_query dataset, but got
{
name
}
"
...
@@ -151,7 +126,7 @@ class RetrievalEval(object):
...
@@ -151,7 +126,7 @@ class RetrievalEval(object):
all_camera
=
[]
all_camera
=
[]
has_camera
=
False
has_camera
=
False
for
idx
,
batch
in
enumerate
(
dataloader
):
# load is very time-consuming
for
idx
,
batch
in
enumerate
(
dataloader
):
# load is very time-consuming
if
idx
%
self
.
print_batch_step
==
0
:
if
idx
%
engine
.
config
[
"Global"
][
"print_batch_step"
]
==
0
:
logger
.
info
(
logger
.
info
(
f
"
{
name
}
feature calculation process: [
{
idx
}
/
{
len
(
dataloader
)
}
]"
f
"
{
name
}
feature calculation process: [
{
idx
}
/
{
len
(
dataloader
)
}
]"
)
)
...
@@ -161,14 +136,20 @@ class RetrievalEval(object):
...
@@ -161,14 +136,20 @@ class RetrievalEval(object):
if
len
(
batch
)
>=
3
:
if
len
(
batch
)
>=
3
:
has_camera
=
True
has_camera
=
True
batch
[
2
]
=
batch
[
2
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
batch
[
2
]
=
batch
[
2
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
if
engine
.
amp
and
engine
.
amp_eval
:
out
=
self
.
model
(
batch
)
with
paddle
.
amp
.
auto_cast
(
custom_black_list
=
{
"flatten_contiguous_range"
,
"greater_than"
},
level
=
engine
.
amp_level
):
out
=
engine
.
model
(
batch
[
0
])
else
:
out
=
engine
.
model
(
batch
[
0
])
if
"Student"
in
out
:
if
"Student"
in
out
:
out
=
out
[
"Student"
]
out
=
out
[
"Student"
]
# get features
# get features
if
self
.
config
[
"Global"
].
get
(
"retrieval_feature_from"
,
if
engine
.
config
[
"Global"
].
get
(
"retrieval_feature_from"
,
"features"
)
==
"features"
:
"features"
)
==
"features"
:
# use output from neck as feature
# use output from neck as feature
batch_feat
=
out
[
"features"
]
batch_feat
=
out
[
"features"
]
...
@@ -177,14 +158,13 @@ class RetrievalEval(object):
...
@@ -177,14 +158,13 @@ class RetrievalEval(object):
batch_feat
=
out
[
"backbone"
]
batch_feat
=
out
[
"backbone"
]
# do norm(optional)
# do norm(optional)
if
self
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
):
if
engine
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
):
batch_feat
=
paddle
.
nn
.
functional
.
normalize
(
batch_feat
,
p
=
2
)
batch_feat
=
paddle
.
nn
.
functional
.
normalize
(
batch_feat
,
p
=
2
)
# do binarize(optional)
# do binarize(optional)
if
self
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"round"
:
if
engine
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"round"
:
batch_feat
=
paddle
.
round
(
batch_feat
).
astype
(
batch_feat
=
paddle
.
round
(
batch_feat
).
astype
(
"float32"
)
*
2.0
-
1.0
"float32"
)
*
2.0
-
1.0
elif
engine
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"sign"
:
elif
self
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"sign"
:
batch_feat
=
paddle
.
sign
(
batch_feat
).
astype
(
"float32"
)
batch_feat
=
paddle
.
sign
(
batch_feat
).
astype
(
"float32"
)
if
paddle
.
distributed
.
get_world_size
()
>
1
:
if
paddle
.
distributed
.
get_world_size
()
>
1
:
...
@@ -198,7 +178,7 @@ class RetrievalEval(object):
...
@@ -198,7 +178,7 @@ class RetrievalEval(object):
if
has_camera
:
if
has_camera
:
all_camera
.
append
(
batch
[
2
])
all_camera
.
append
(
batch
[
2
])
if
self
.
use_dali
:
if
engine
.
use_dali
:
dataloader
.
reset
()
dataloader
.
reset
()
all_feat
=
paddle
.
concat
(
all_feat
)
all_feat
=
paddle
.
concat
(
all_feat
)
...
@@ -208,7 +188,7 @@ class RetrievalEval(object):
...
@@ -208,7 +188,7 @@ class RetrievalEval(object):
else
:
else
:
all_camera
=
None
all_camera
=
None
# discard redundant padding sample(s) at the end
# discard redundant padding sample(s) at the end
total_samples
=
dataloader
.
size
if
self
.
use_dali
else
len
(
total_samples
=
dataloader
.
size
if
engine
.
use_dali
else
len
(
dataloader
.
dataset
)
dataloader
.
dataset
)
all_feat
=
all_feat
[:
total_samples
]
all_feat
=
all_feat
[:
total_samples
]
all_label
=
all_label
[:
total_samples
]
all_label
=
all_label
[:
total_samples
]
...
...
ppcls/engine/train/__init__.py
浏览文件 @
8002ccf4
...
@@ -22,8 +22,9 @@ from .train_progressive import train_epoch_progressive
...
@@ -22,8 +22,9 @@ from .train_progressive import train_epoch_progressive
def
build_train_func
(
config
,
mode
,
model
,
eval_func
):
def
build_train_func
(
config
,
mode
,
model
,
eval_func
):
if
mode
!=
"train"
:
if
mode
!=
"train"
:
return
None
return
None
task
=
config
[
"Global"
].
get
(
"task"
,
"classification"
)
train_mode
=
config
[
"Global"
].
get
(
"task"
,
None
)
if
task
==
"classification"
or
task
==
"retrieval"
:
if
train_mode
is
None
:
config
[
"Global"
][
"task"
]
=
"classification"
return
ClassTrainer
(
config
,
model
,
eval_func
)
return
ClassTrainer
(
config
,
model
,
eval_func
)
else
:
else
:
return
getattr
(
sys
.
modules
[
__name__
],
"train_epoch_"
+
train_mode
)(
return
getattr
(
sys
.
modules
[
__name__
],
"train_epoch_"
+
train_mode
)(
...
...
ppcls/engine/train/train_progressive.py
浏览文件 @
8002ccf4
...
@@ -15,7 +15,7 @@ from __future__ import absolute_import, division, print_function
...
@@ -15,7 +15,7 @@ from __future__ import absolute_import, division, print_function
from
ppcls.data
import
build_dataloader
from
ppcls.data
import
build_dataloader
from
ppcls.utils
import
logger
,
type_name
from
ppcls.utils
import
logger
,
type_name
from
.
classification
import
ClassTrainer
from
.
regular_train_epoch
import
regular_train_epoch
def
train_epoch_progressive
(
engine
,
epoch_id
,
print_batch_step
):
def
train_epoch_progressive
(
engine
,
epoch_id
,
print_batch_step
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录