Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
c4f38454
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c4f38454
编写于
9月 26, 2021
作者:
W
Wei Shengyu
提交者:
GitHub
9月 26, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1205 from weisy11/develop
add MixDataset, MixSampler and PKSampler
上级
9f13876e
da259314
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
322 addition
and
82 deletion
+322
-82
ppcls/configs/Logo/ResNet50_ReID.yaml
ppcls/configs/Logo/ResNet50_ReID.yaml
+6
-6
ppcls/configs/Products/ResNet50_vd_Inshop.yaml
ppcls/configs/Products/ResNet50_vd_Inshop.yaml
+4
-4
ppcls/configs/Vehicle/ResNet50_ReID.yaml
ppcls/configs/Vehicle/ResNet50_ReID.yaml
+4
-4
ppcls/data/__init__.py
ppcls/data/__init__.py
+3
-0
ppcls/data/dataloader/__init__.py
ppcls/data/dataloader/__init__.py
+9
-0
ppcls/data/dataloader/mix_dataset.py
ppcls/data/dataloader/mix_dataset.py
+49
-0
ppcls/data/dataloader/mix_sampler.py
ppcls/data/dataloader/mix_sampler.py
+79
-0
ppcls/data/dataloader/pk_sampler.py
ppcls/data/dataloader/pk_sampler.py
+106
-0
ppcls/engine/evaluation/classification.py
ppcls/engine/evaluation/classification.py
+15
-17
ppcls/engine/evaluation/retrieval.py
ppcls/engine/evaluation/retrieval.py
+22
-23
ppcls/engine/train/train.py
ppcls/engine/train/train.py
+25
-28
未找到文件。
ppcls/configs/Logo/ResNet50_ReID.yaml
浏览文件 @
c4f38454
...
...
@@ -54,7 +54,7 @@ Optimizer:
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.0
1
learning_rate
:
0.0
4
regularizer
:
name
:
'
L2'
coeff
:
0.0001
...
...
@@ -84,10 +84,10 @@ DataLoader:
-
RandomErasing
:
EPSILON
:
0.5
sampler
:
name
:
DistributedRandomIdentity
Sampler
name
:
PK
Sampler
batch_size
:
128
num_instances
:
2
drop_last
:
Fals
e
sample_per_id
:
2
drop_last
:
Tru
e
loader
:
num_workers
:
6
...
...
@@ -97,7 +97,7 @@ DataLoader:
dataset
:
name
:
LogoDataset
image_root
:
"
dataset/LogoDet-3K-crop/val/"
cls_label_path
:
"
dataset/LogoDet-3K-crop/LogoDet-3K+
query
.txt"
cls_label_path
:
"
dataset/LogoDet-3K-crop/LogoDet-3K+
val
.txt"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
...
...
@@ -122,7 +122,7 @@ DataLoader:
dataset
:
name
:
LogoDataset
image_root
:
"
dataset/LogoDet-3K-crop/train/"
cls_label_path
:
"
dataset/LogoDet-3K-crop/LogoDet-3K+
gallery
.txt"
cls_label_path
:
"
dataset/LogoDet-3K-crop/LogoDet-3K+
train
.txt"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
...
...
ppcls/configs/Products/ResNet50_vd_Inshop.yaml
浏览文件 @
c4f38454
...
...
@@ -54,7 +54,7 @@ Optimizer:
momentum
:
0.9
lr
:
name
:
MultiStepDecay
learning_rate
:
0.0
1
learning_rate
:
0.0
4
milestones
:
[
30
,
60
,
70
,
80
,
90
,
100
]
gamma
:
0.5
verbose
:
False
...
...
@@ -90,10 +90,10 @@ DataLoader:
r1
:
0.3
mean
:
[
0.
,
0.
,
0.
]
sampler
:
name
:
DistributedRandomIdentity
Sampler
name
:
PK
Sampler
batch_size
:
64
num_instances
:
2
drop_last
:
Fals
e
sample_per_id
:
2
drop_last
:
Tru
e
shuffle
:
True
loader
:
num_workers
:
4
...
...
ppcls/configs/Vehicle/ResNet50_ReID.yaml
浏览文件 @
c4f38454
...
...
@@ -53,7 +53,7 @@ Optimizer:
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.0
1
learning_rate
:
0.0
4
regularizer
:
name
:
'
L2'
coeff
:
0.0005
...
...
@@ -88,10 +88,10 @@ DataLoader:
mean
:
[
0.
,
0.
,
0.
]
sampler
:
name
:
DistributedRandomIdentity
Sampler
name
:
PK
Sampler
batch_size
:
128
num_instances
:
2
drop_last
:
Fals
e
sample_per_id
:
2
drop_last
:
Tru
e
shuffle
:
True
loader
:
num_workers
:
6
...
...
ppcls/data/__init__.py
浏览文件 @
c4f38454
...
...
@@ -26,9 +26,12 @@ from ppcls.data.dataloader.common_dataset import create_operators
from
ppcls.data.dataloader.vehicle_dataset
import
CompCars
,
VeriWild
from
ppcls.data.dataloader.logo_dataset
import
LogoDataset
from
ppcls.data.dataloader.icartoon_dataset
import
ICartoonDataset
from
ppcls.data.dataloader.mix_dataset
import
MixDataset
# sampler
from
ppcls.data.dataloader.DistributedRandomIdentitySampler
import
DistributedRandomIdentitySampler
from
ppcls.data.dataloader.pk_sampler
import
PKSampler
from
ppcls.data.dataloader.mix_sampler
import
MixSampler
from
ppcls.data
import
preprocess
from
ppcls.data.preprocess
import
transform
...
...
ppcls/data/dataloader/__init__.py
浏览文件 @
c4f38454
from
ppcls.data.dataloader.imagenet_dataset
import
ImageNetDataset
from
ppcls.data.dataloader.multilabel_dataset
import
MultiLabelDataset
from
ppcls.data.dataloader.common_dataset
import
create_operators
from
ppcls.data.dataloader.vehicle_dataset
import
CompCars
,
VeriWild
from
ppcls.data.dataloader.logo_dataset
import
LogoDataset
from
ppcls.data.dataloader.icartoon_dataset
import
ICartoonDataset
from
ppcls.data.dataloader.mix_dataset
import
MixDataset
from
ppcls.data.dataloader.mix_sampler
import
MixSampler
from
ppcls.data.dataloader.pk_sampler
import
PKSampler
ppcls/data/dataloader/mix_dataset.py
0 → 100644
浏览文件 @
c4f38454
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
numpy
as
np
import
os
from
paddle.io
import
Dataset
from
..
import
dataloader
class
MixDataset
(
Dataset
):
def
__init__
(
self
,
datasets_config
):
super
().
__init__
()
self
.
dataset_list
=
[]
start_idx
=
0
end_idx
=
0
for
config_i
in
datasets_config
:
dataset_name
=
config_i
.
pop
(
'name'
)
dataset
=
getattr
(
dataloader
,
dataset_name
)(
**
config_i
)
end_idx
+=
len
(
dataset
)
self
.
dataset_list
.
append
([
end_idx
,
start_idx
,
dataset
])
start_idx
=
end_idx
self
.
length
=
end_idx
def
__getitem__
(
self
,
idx
):
for
dataset_i
in
self
.
dataset_list
:
if
dataset_i
[
0
]
>
idx
:
dataset_i_idx
=
idx
-
dataset_i
[
1
]
return
dataset_i
[
2
][
dataset_i_idx
]
def
__len__
(
self
):
return
self
.
length
def
get_dataset_list
(
self
):
return
self
.
dataset_list
ppcls/data/dataloader/mix_sampler.py
0 → 100644
浏览文件 @
c4f38454
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
paddle.io
import
DistributedBatchSampler
,
Sampler
from
ppcls.utils
import
logger
from
ppcls.data.dataloader.mix_dataset
import
MixDataset
from
ppcls.data
import
dataloader
class
MixSampler
(
DistributedBatchSampler
):
def
__init__
(
self
,
dataset
,
batch_size
,
sample_configs
,
iter_per_epoch
):
super
().
__init__
(
dataset
,
batch_size
)
assert
isinstance
(
dataset
,
MixDataset
),
"MixSampler only support MixDataset"
self
.
sampler_list
=
[]
self
.
batch_size
=
batch_size
self
.
start_list
=
[]
self
.
length
=
iter_per_epoch
dataset_list
=
dataset
.
get_dataset_list
()
batch_size_left
=
self
.
batch_size
self
.
iter_list
=
[]
for
i
,
config_i
in
enumerate
(
sample_configs
):
self
.
start_list
.
append
(
dataset_list
[
i
][
1
])
sample_method
=
config_i
.
pop
(
"name"
)
ratio_i
=
config_i
.
pop
(
"ratio"
)
if
i
<
len
(
sample_configs
)
-
1
:
batch_size_i
=
int
(
self
.
batch_size
*
ratio_i
)
batch_size_left
-=
batch_size_i
else
:
batch_size_i
=
batch_size_left
assert
batch_size_i
<=
len
(
dataset_list
[
i
][
2
])
config_i
[
"batch_size"
]
=
batch_size_i
if
sample_method
==
"DistributedBatchSampler"
:
sampler_i
=
DistributedBatchSampler
(
dataset_list
[
i
][
2
],
**
config_i
)
else
:
sampler_i
=
getattr
(
dataloader
,
sample_method
)(
dataset_list
[
i
][
2
],
**
config_i
)
self
.
sampler_list
.
append
(
sampler_i
)
self
.
iter_list
.
append
(
iter
(
sampler_i
))
self
.
length
+=
len
(
dataset_list
[
i
][
2
])
*
ratio_i
self
.
iter_counter
=
0
def
__iter__
(
self
):
while
self
.
iter_counter
<
self
.
length
:
batch
=
[]
for
i
,
iter_i
in
enumerate
(
self
.
iter_list
):
batch_i
=
next
(
iter_i
,
None
)
if
batch_i
is
None
:
iter_i
=
iter
(
self
.
sampler_list
[
i
])
self
.
iter_list
[
i
]
=
iter_i
batch_i
=
next
(
iter_i
,
None
)
assert
batch_i
is
not
None
,
"dataset {} return None"
.
format
(
i
)
batch
+=
[
idx
+
self
.
start_list
[
i
]
for
idx
in
batch_i
]
if
len
(
batch
)
==
self
.
batch_size
:
self
.
iter_counter
+=
1
yield
batch
else
:
logger
.
info
(
"Some dataset reaches end"
)
self
.
iter_counter
=
0
def
__len__
(
self
):
return
self
.
length
ppcls/data/dataloader/pk_sampler.py
0 → 100644
浏览文件 @
c4f38454
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
collections
import
defaultdict
import
numpy
as
np
import
random
from
paddle.io
import
DistributedBatchSampler
from
ppcls.utils
import
logger
class
PKSampler
(
DistributedBatchSampler
):
"""
First, randomly sample P identities.
Then for each identity randomly sample K instances.
Therefore batch size is P*K, and the sampler called PKSampler.
Args:
dataset (paddle.io.Dataset): list of (img_path, pid, cam_id).
sample_per_id(int): number of instances per identity in a batch.
batch_size (int): number of examples in a batch.
shuffle(bool): whether to shuffle indices order before generating
batch indices. Default False.
"""
def
__init__
(
self
,
dataset
,
batch_size
,
sample_per_id
,
shuffle
=
True
,
drop_last
=
True
,
sample_method
=
"sample_avg_prob"
):
super
().
__init__
(
dataset
,
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
assert
batch_size
%
sample_per_id
==
0
,
\
"PKSampler configs error, Sample_per_id must be a divisor of batch_size."
assert
hasattr
(
self
.
dataset
,
"labels"
),
"Dataset must have labels attribute."
self
.
sample_per_label
=
sample_per_id
self
.
label_dict
=
defaultdict
(
list
)
self
.
sample_method
=
sample_method
for
idx
,
label
in
enumerate
(
self
.
dataset
.
labels
):
self
.
label_dict
[
label
].
append
(
idx
)
self
.
label_list
=
list
(
self
.
label_dict
)
assert
len
(
self
.
label_list
)
*
self
.
sample_per_label
>
self
.
batch_size
,
\
"batch size should be smaller than "
if
self
.
sample_method
==
"id_avg_prob"
:
self
.
prob_list
=
np
.
array
([
1
/
len
(
self
.
label_list
)]
*
len
(
self
.
label_list
))
elif
self
.
sample_method
==
"sample_avg_prob"
:
counter
=
[]
for
label_i
in
self
.
label_list
:
counter
.
append
(
len
(
self
.
label_dict
[
label_i
]))
self
.
prob_list
=
np
.
array
(
counter
)
/
sum
(
counter
)
else
:
logger
.
error
(
"PKSampler only support id_avg_prob and sample_avg_prob sample method, "
"but receive {}."
.
format
(
self
.
sample_method
))
if
sum
(
np
.
abs
(
self
.
prob_list
-
1
)
>
0.00000001
):
self
.
prob_list
[
-
1
]
=
1
-
sum
(
self
.
prob_list
[:
-
1
])
if
self
.
prob_list
[
-
1
]
>
1
or
self
.
prob_list
[
-
1
]
<
0
:
logger
.
error
(
"PKSampler prob list error"
)
else
:
logger
.
info
(
"PKSampler: sum of prob list not equal to 1, change the last prob"
)
def
__iter__
(
self
):
label_per_batch
=
self
.
batch_size
//
self
.
sample_per_label
if
self
.
shuffle
:
np
.
random
.
RandomState
(
self
.
epoch
).
shuffle
(
self
.
label_list
)
for
i
in
range
(
len
(
self
)):
batch_index
=
[]
batch_label_list
=
np
.
random
.
choice
(
self
.
label_list
,
size
=
label_per_batch
,
replace
=
False
,
p
=
self
.
prob_list
)
for
label_i
in
batch_label_list
:
label_i_indexes
=
self
.
label_dict
[
label_i
]
if
self
.
sample_per_label
<=
len
(
label_i_indexes
):
batch_index
.
extend
(
np
.
random
.
choice
(
label_i_indexes
,
size
=
self
.
sample_per_label
,
replace
=
False
))
else
:
batch_index
.
extend
(
np
.
random
.
choice
(
label_i_indexes
,
size
=
self
.
sample_per_label
,
replace
=
True
))
if
not
self
.
drop_last
or
len
(
batch_index
)
==
self
.
batch_size
:
yield
batch_index
ppcls/engine/evaluation/classification.py
浏览文件 @
c4f38454
...
...
@@ -22,7 +22,7 @@ from ppcls.utils.misc import AverageMeter
from
ppcls.utils
import
logger
def
classification_eval
(
e
valer
,
epoch_id
=
0
):
def
classification_eval
(
e
ngine
,
epoch_id
=
0
):
output_info
=
dict
()
time_info
=
{
"batch_cost"
:
AverageMeter
(
...
...
@@ -30,21 +30,19 @@ def classification_eval(evaler, epoch_id=0):
"reader_cost"
:
AverageMeter
(
"reader_cost"
,
".5f"
,
postfix
=
" s,"
),
}
print_batch_step
=
e
valer
.
config
[
"Global"
][
"print_batch_step"
]
print_batch_step
=
e
ngine
.
config
[
"Global"
][
"print_batch_step"
]
metric_key
=
None
tic
=
time
.
time
()
eval_dataloader
=
evaler
.
eval_dataloader
if
evaler
.
use_dali
else
evaler
.
eval_dataloader
(
)
max_iter
=
len
(
evaler
.
eval_dataloader
)
-
1
if
platform
.
system
(
)
==
"Windows"
else
len
(
evaler
.
eval_dataloader
)
for
iter_id
,
batch
in
enumerate
(
eval_dataloader
):
max_iter
=
len
(
engine
.
eval_dataloader
)
-
1
if
platform
.
system
(
)
==
"Windows"
else
len
(
engine
.
eval_dataloader
)
for
iter_id
,
batch
in
enumerate
(
engine
.
eval_dataloader
):
if
iter_id
>=
max_iter
:
break
if
iter_id
==
5
:
for
key
in
time_info
:
time_info
[
key
].
reset
()
if
e
valer
.
use_dali
:
if
e
ngine
.
use_dali
:
batch
=
[
paddle
.
to_tensor
(
batch
[
0
][
'data'
]),
paddle
.
to_tensor
(
batch
[
0
][
'label'
])
...
...
@@ -54,17 +52,17 @@ def classification_eval(evaler, epoch_id=0):
batch
[
0
]
=
paddle
.
to_tensor
(
batch
[
0
]).
astype
(
"float32"
)
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
# image input
out
=
e
valer
.
model
(
batch
[
0
])
out
=
e
ngine
.
model
(
batch
[
0
])
# calc loss
if
e
valer
.
eval_loss_func
is
not
None
:
loss_dict
=
e
valer
.
eval_loss_func
(
out
,
batch
[
1
])
if
e
ngine
.
eval_loss_func
is
not
None
:
loss_dict
=
e
ngine
.
eval_loss_func
(
out
,
batch
[
1
])
for
key
in
loss_dict
:
if
key
not
in
output_info
:
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
output_info
[
key
].
update
(
loss_dict
[
key
].
numpy
()[
0
],
batch_size
)
# calc metric
if
e
valer
.
eval_metric_func
is
not
None
:
metric_dict
=
e
valer
.
eval_metric_func
(
out
,
batch
[
1
])
if
e
ngine
.
eval_metric_func
is
not
None
:
metric_dict
=
e
ngine
.
eval_metric_func
(
out
,
batch
[
1
])
if
paddle
.
distributed
.
get_world_size
()
>
1
:
for
key
in
metric_dict
:
paddle
.
distributed
.
all_reduce
(
...
...
@@ -97,18 +95,18 @@ def classification_eval(evaler, epoch_id=0):
])
logger
.
info
(
"[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}"
.
format
(
epoch_id
,
iter_id
,
len
(
e
valer
.
eval_dataloader
),
metric_msg
,
time_msg
,
ips_msg
))
len
(
e
ngine
.
eval_dataloader
),
metric_msg
,
time_msg
,
ips_msg
))
tic
=
time
.
time
()
if
e
valer
.
use_dali
:
e
valer
.
eval_dataloader
.
reset
()
if
e
ngine
.
use_dali
:
e
ngine
.
eval_dataloader
.
reset
()
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
avg
)
for
key
in
output_info
])
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
))
# do not try to save best eval.model
if
e
valer
.
eval_metric_func
is
None
:
if
e
ngine
.
eval_metric_func
is
None
:
return
-
1
# return 1st metric in the dict
return
output_info
[
metric_key
].
avg
ppcls/engine/evaluation/retrieval.py
浏览文件 @
c4f38454
...
...
@@ -20,21 +20,21 @@ import paddle
from
ppcls.utils
import
logger
def
retrieval_eval
(
e
valer
,
epoch_id
=
0
):
e
valer
.
model
.
eval
()
def
retrieval_eval
(
e
ngine
,
epoch_id
=
0
):
e
ngine
.
model
.
eval
()
# step1. build gallery
if
e
valer
.
gallery_query_dataloader
is
not
None
:
if
e
ngine
.
gallery_query_dataloader
is
not
None
:
gallery_feas
,
gallery_img_id
,
gallery_unique_id
=
cal_feature
(
e
valer
,
name
=
'gallery_query'
)
e
ngine
,
name
=
'gallery_query'
)
query_feas
,
query_img_id
,
query_query_id
=
gallery_feas
,
gallery_img_id
,
gallery_unique_id
else
:
gallery_feas
,
gallery_img_id
,
gallery_unique_id
=
cal_feature
(
e
valer
,
name
=
'gallery'
)
e
ngine
,
name
=
'gallery'
)
query_feas
,
query_img_id
,
query_query_id
=
cal_feature
(
e
valer
,
name
=
'query'
)
e
ngine
,
name
=
'query'
)
# step2. do evaluation
sim_block_size
=
e
valer
.
config
[
"Global"
].
get
(
"sim_block_size"
,
64
)
sim_block_size
=
e
ngine
.
config
[
"Global"
].
get
(
"sim_block_size"
,
64
)
sections
=
[
sim_block_size
]
*
(
len
(
query_feas
)
//
sim_block_size
)
if
len
(
query_feas
)
%
sim_block_size
:
sections
.
append
(
len
(
query_feas
)
%
sim_block_size
)
...
...
@@ -45,7 +45,7 @@ def retrieval_eval(evaler, epoch_id=0):
image_id_blocks
=
paddle
.
split
(
query_img_id
,
num_or_sections
=
sections
)
metric_key
=
None
if
e
valer
.
eval_loss_func
is
None
:
if
e
ngine
.
eval_loss_func
is
None
:
metric_dict
=
{
metric_key
:
0.
}
else
:
metric_dict
=
dict
()
...
...
@@ -65,7 +65,7 @@ def retrieval_eval(evaler, epoch_id=0):
else
:
keep_mask
=
None
metric_tmp
=
e
valer
.
eval_metric_func
(
similarity_matrix
,
metric_tmp
=
e
ngine
.
eval_metric_func
(
similarity_matrix
,
image_id_blocks
[
block_idx
],
gallery_img_id
,
keep_mask
)
...
...
@@ -88,32 +88,31 @@ def retrieval_eval(evaler, epoch_id=0):
return
metric_dict
[
metric_key
]
def
cal_feature
(
e
valer
,
name
=
'gallery'
):
def
cal_feature
(
e
ngine
,
name
=
'gallery'
):
all_feas
=
None
all_image_id
=
None
all_unique_id
=
None
has_unique_id
=
False
if
name
==
'gallery'
:
dataloader
=
e
valer
.
gallery_dataloader
dataloader
=
e
ngine
.
gallery_dataloader
elif
name
==
'query'
:
dataloader
=
e
valer
.
query_dataloader
dataloader
=
e
ngine
.
query_dataloader
elif
name
==
'gallery_query'
:
dataloader
=
e
valer
.
gallery_query_dataloader
dataloader
=
e
ngine
.
gallery_query_dataloader
else
:
raise
RuntimeError
(
"Only support gallery or query dataset"
)
max_iter
=
len
(
dataloader
)
-
1
if
platform
.
system
()
==
"Windows"
else
len
(
dataloader
)
dataloader_tmp
=
dataloader
if
evaler
.
use_dali
else
dataloader
()
for
idx
,
batch
in
enumerate
(
dataloader_tmp
):
# load is very time-consuming
for
idx
,
batch
in
enumerate
(
dataloader
):
# load is very time-consuming
if
idx
>=
max_iter
:
break
if
idx
%
e
valer
.
config
[
"Global"
][
"print_batch_step"
]
==
0
:
if
idx
%
e
ngine
.
config
[
"Global"
][
"print_batch_step"
]
==
0
:
logger
.
info
(
f
"
{
name
}
feature calculation process: [
{
idx
}
/
{
len
(
dataloader
)
}
]"
)
if
e
valer
.
use_dali
:
if
e
ngine
.
use_dali
:
batch
=
[
paddle
.
to_tensor
(
batch
[
0
][
'data'
]),
paddle
.
to_tensor
(
batch
[
0
][
'label'
])
...
...
@@ -123,20 +122,20 @@ def cal_feature(evaler, name='gallery'):
if
len
(
batch
)
==
3
:
has_unique_id
=
True
batch
[
2
]
=
batch
[
2
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
out
=
e
valer
.
model
(
batch
[
0
],
batch
[
1
])
out
=
e
ngine
.
model
(
batch
[
0
],
batch
[
1
])
batch_feas
=
out
[
"features"
]
# do norm
if
e
valer
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
):
if
e
ngine
.
config
[
"Global"
].
get
(
"feature_normalize"
,
True
):
feas_norm
=
paddle
.
sqrt
(
paddle
.
sum
(
paddle
.
square
(
batch_feas
),
axis
=
1
,
keepdim
=
True
))
batch_feas
=
paddle
.
divide
(
batch_feas
,
feas_norm
)
# do binarize
if
e
valer
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"round"
:
if
e
ngine
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"round"
:
batch_feas
=
paddle
.
round
(
batch_feas
).
astype
(
"float32"
)
*
2.0
-
1.0
if
e
valer
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"sign"
:
if
e
ngine
.
config
[
"Global"
].
get
(
"feature_binarize"
)
==
"sign"
:
batch_feas
=
paddle
.
sign
(
batch_feas
).
astype
(
"float32"
)
if
all_feas
is
None
:
...
...
@@ -150,8 +149,8 @@ def cal_feature(evaler, name='gallery'):
if
has_unique_id
:
all_unique_id
=
paddle
.
concat
([
all_unique_id
,
batch
[
2
]])
if
e
valer
.
use_dali
:
dataloader
_tmp
.
reset
()
if
e
ngine
.
use_dali
:
dataloader
.
reset
()
if
paddle
.
distributed
.
get_world_size
()
>
1
:
feat_list
=
[]
...
...
ppcls/engine/train/train.py
浏览文件 @
c4f38454
...
...
@@ -18,19 +18,16 @@ import paddle
from
ppcls.engine.train.utils
import
update_loss
,
update_metric
,
log_info
def
train_epoch
(
trainer
,
epoch_id
,
print_batch_step
):
def
train_epoch
(
engine
,
epoch_id
,
print_batch_step
):
tic
=
time
.
time
()
train_dataloader
=
trainer
.
train_dataloader
if
trainer
.
use_dali
else
trainer
.
train_dataloader
(
)
for
iter_id
,
batch
in
enumerate
(
train_dataloader
):
if
iter_id
>=
trainer
.
max_iter
:
for
iter_id
,
batch
in
enumerate
(
engine
.
train_dataloader
):
if
iter_id
>=
engine
.
max_iter
:
break
if
iter_id
==
5
:
for
key
in
trainer
.
time_info
:
trainer
.
time_info
[
key
].
reset
()
trainer
.
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
if
trainer
.
use_dali
:
for
key
in
engine
.
time_info
:
engine
.
time_info
[
key
].
reset
()
engine
.
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
if
engine
.
use_dali
:
batch
=
[
paddle
.
to_tensor
(
batch
[
0
][
'data'
]),
paddle
.
to_tensor
(
batch
[
0
][
'label'
])
...
...
@@ -38,43 +35,43 @@ def train_epoch(trainer, epoch_id, print_batch_step):
batch_size
=
batch
[
0
].
shape
[
0
]
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
trainer
.
global_step
+=
1
engine
.
global_step
+=
1
# image input
if
trainer
.
amp
:
if
engine
.
amp
:
with
paddle
.
amp
.
auto_cast
(
custom_black_list
=
{
"flatten_contiguous_range"
,
"greater_than"
}):
out
=
forward
(
trainer
,
batch
)
loss_dict
=
trainer
.
train_loss_func
(
out
,
batch
[
1
])
out
=
forward
(
engine
,
batch
)
loss_dict
=
engine
.
train_loss_func
(
out
,
batch
[
1
])
else
:
out
=
forward
(
trainer
,
batch
)
out
=
forward
(
engine
,
batch
)
# calc loss
if
trainer
.
config
[
"DataLoader"
][
"Train"
][
"dataset"
].
get
(
if
engine
.
config
[
"DataLoader"
][
"Train"
][
"dataset"
].
get
(
"batch_transform_ops"
,
None
):
loss_dict
=
trainer
.
train_loss_func
(
out
,
batch
[
1
:])
loss_dict
=
engine
.
train_loss_func
(
out
,
batch
[
1
:])
else
:
loss_dict
=
trainer
.
train_loss_func
(
out
,
batch
[
1
])
loss_dict
=
engine
.
train_loss_func
(
out
,
batch
[
1
])
# step opt and lr
if
trainer
.
amp
:
scaled
=
trainer
.
scaler
.
scale
(
loss_dict
[
"loss"
])
if
engine
.
amp
:
scaled
=
engine
.
scaler
.
scale
(
loss_dict
[
"loss"
])
scaled
.
backward
()
trainer
.
scaler
.
minimize
(
trainer
.
optimizer
,
scaled
)
engine
.
scaler
.
minimize
(
engine
.
optimizer
,
scaled
)
else
:
loss_dict
[
"loss"
].
backward
()
trainer
.
optimizer
.
step
()
trainer
.
optimizer
.
clear_grad
()
trainer
.
lr_sch
.
step
()
engine
.
optimizer
.
step
()
engine
.
optimizer
.
clear_grad
()
engine
.
lr_sch
.
step
()
# below code just for logging
# update metric_for_logger
update_metric
(
trainer
,
out
,
batch
,
batch_size
)
update_metric
(
engine
,
out
,
batch
,
batch_size
)
# update_loss_for_logger
update_loss
(
trainer
,
loss_dict
,
batch_size
)
trainer
.
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
update_loss
(
engine
,
loss_dict
,
batch_size
)
engine
.
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
if
iter_id
%
print_batch_step
==
0
:
log_info
(
trainer
,
batch_size
,
epoch_id
,
iter_id
)
log_info
(
engine
,
batch_size
,
epoch_id
,
iter_id
)
tic
=
time
.
time
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录