Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
5d06a88a
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5d06a88a
编写于
3月 14, 2023
作者:
T
Tingquan Gao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert "refactor: simplify engine"
This reverts commit
376d83d4
.
上级
6aabb94d
变更
6
展开全部
隐藏空白更改
内联
并排
Showing
6 changed file
with
288 addition
and
339 deletion
+288
-339
ppcls/data/__init__.py
ppcls/data/__init__.py
+1
-79
ppcls/engine/engine.py
ppcls/engine/engine.py
+269
-192
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+9
-23
ppcls/metric/__init__.py
ppcls/metric/__init__.py
+3
-35
ppcls/optimizer/__init__.py
ppcls/optimizer/__init__.py
+2
-5
ppcls/utils/logger.py
ppcls/utils/logger.py
+4
-5
未找到文件。
ppcls/data/__init__.py
浏览文件 @
5d06a88a
...
...
@@ -15,8 +15,6 @@
import
inspect
import
copy
import
random
import
platform
import
paddle
import
numpy
as
np
import
paddle.distributed
as
dist
...
...
@@ -88,7 +86,7 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
random
.
seed
(
worker_seed
)
def
build
(
config
,
mode
,
device
,
use_dali
=
False
,
seed
=
None
):
def
build
_dataloader
(
config
,
mode
,
device
,
use_dali
=
False
,
seed
=
None
):
assert
mode
in
[
'Train'
,
'Eval'
,
'Test'
,
'Gallery'
,
'Query'
,
'UnLabelTrain'
],
"Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
...
...
@@ -189,79 +187,3 @@ def build(config, mode, device, use_dali=False, seed=None):
logger
.
debug
(
"build data_loader({}) success..."
.
format
(
data_loader
))
return
data_loader
def
build_dataloader
(
engine
):
if
"class_num"
in
engine
.
config
[
"Global"
]:
global_class_num
=
engine
.
config
[
"Global"
][
"class_num"
]
if
"class_num"
not
in
config
[
"Arch"
]:
engine
.
config
[
"Arch"
][
"class_num"
]
=
global_class_num
msg
=
f
"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to
{
global_class_num
}
."
else
:
msg
=
"The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
logger
.
warning
(
msg
)
class_num
=
engine
.
config
[
"Arch"
].
get
(
"class_num"
,
None
)
engine
.
config
[
"DataLoader"
].
update
({
"class_num"
:
class_num
})
engine
.
config
[
"DataLoader"
].
update
({
"epochs"
:
engine
.
config
[
"Global"
][
"epochs"
]
})
use_dali
=
engine
.
config
[
'Global'
].
get
(
"use_dali"
,
False
)
dataloader_dict
=
{
"Train"
:
None
,
"UnLabelTrain"
:
None
,
"Eval"
:
None
,
"Query"
:
None
,
"Gallery"
:
None
,
"GalleryQuery"
:
None
}
if
engine
.
mode
==
'train'
:
train_dataloader
=
build
(
engine
.
config
[
"DataLoader"
],
"Train"
,
engine
.
device
,
use_dali
,
seed
=
None
)
iter_per_epoch
=
len
(
train_dataloader
)
-
1
if
platform
.
system
(
)
==
"Windows"
else
len
(
train_dataloader
)
if
engine
.
config
[
"Global"
].
get
(
"iter_per_epoch"
,
None
):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
iter_per_epoch
=
engine
.
config
[
"Global"
].
get
(
"iter_per_epoch"
)
iter_per_epoch
=
iter_per_epoch
//
engine
.
update_freq
*
engine
.
update_freq
engine
.
iter_per_epoch
=
iter_per_epoch
train_dataloader
.
iter_per_epoch
=
iter_per_epoch
dataloader_dict
[
"Train"
]
=
train_dataloader
if
engine
.
config
[
"DataLoader"
].
get
(
'UnLabelTrain'
,
None
)
is
not
None
:
dataloader_dict
[
"UnLabelTrain"
]
=
build
(
engine
.
config
[
"DataLoader"
],
"UnLabelTrain"
,
engine
.
device
,
use_dali
,
seed
=
None
)
if
engine
.
mode
==
"eval"
or
(
engine
.
mode
==
"train"
and
engine
.
config
[
"Global"
][
"eval_during_train"
]):
if
engine
.
eval_mode
in
[
"classification"
,
"adaface"
]:
dataloader_dict
[
"Eval"
]
=
build
(
engine
.
config
[
"DataLoader"
],
"Eval"
,
engine
.
device
,
use_dali
,
seed
=
None
)
elif
engine
.
eval_mode
==
"retrieval"
:
if
len
(
engine
.
config
[
"DataLoader"
][
"Eval"
].
keys
())
==
1
:
key
=
list
(
engine
.
config
[
"DataLoader"
][
"Eval"
].
keys
())[
0
]
dataloader_dict
[
"GalleryQuery"
]
=
build_dataloader
(
engine
.
config
[
"DataLoader"
][
"Eval"
],
key
,
engine
.
device
,
use_dali
)
else
:
dataloader_dict
[
"Gallery"
]
=
build_dataloader
(
engine
.
config
[
"DataLoader"
][
"Eval"
],
"Gallery"
,
engine
.
device
,
use_dali
)
dataloader_dict
[
"Query"
]
=
build_dataloader
(
engine
.
config
[
"DataLoader"
][
"Eval"
],
"Query"
,
engine
.
device
,
use_dali
)
return
dataloader_dict
ppcls/engine/engine.py
浏览文件 @
5d06a88a
此差异已折叠。
点击以展开。
ppcls/loss/__init__.py
浏览文件 @
5d06a88a
...
...
@@ -51,7 +51,7 @@ from .metabinloss import IntraDomainScatterLoss
class
CombinedLoss
(
nn
.
Layer
):
def
__init__
(
self
,
config_list
):
super
().
__init__
()
loss_func
=
[]
self
.
loss_func
=
[]
self
.
loss_weight
=
[]
assert
isinstance
(
config_list
,
list
),
(
'operator config should be a list'
)
...
...
@@ -63,9 +63,8 @@ class CombinedLoss(nn.Layer):
assert
"weight"
in
param
,
"weight must be in param, but param just contains {}"
.
format
(
param
.
keys
())
self
.
loss_weight
.
append
(
param
.
pop
(
"weight"
))
loss_func
.
append
(
eval
(
name
)(
**
param
))
self
.
loss_func
=
nn
.
LayerList
(
loss_func
)
logger
.
debug
(
"build loss {} success."
.
format
(
loss_func
))
self
.
loss_func
.
append
(
eval
(
name
)(
**
param
))
self
.
loss_func
=
nn
.
LayerList
(
self
.
loss_func
)
def
__call__
(
self
,
input
,
batch
):
loss_dict
=
{}
...
...
@@ -84,22 +83,9 @@ class CombinedLoss(nn.Layer):
return
loss_dict
def
build_loss
(
config
,
mode
=
"train"
):
train_loss_func
,
unlabel_train_loss_func
,
eval_loss_func
=
None
,
None
,
None
if
mode
==
"train"
:
label_loss_info
=
config
[
"Loss"
][
"Train"
]
if
label_loss_info
:
train_loss_func
=
CombinedLoss
(
copy
.
deepcopy
(
label_loss_info
))
unlabel_loss_info
=
config
.
get
(
"UnLabelLoss"
,
{}).
get
(
"Train"
,
None
)
if
unlabel_loss_info
:
unlabel_train_loss_func
=
CombinedLoss
(
copy
.
deepcopy
(
unlabel_loss_info
))
if
mode
==
"eval"
or
(
mode
==
"train"
and
config
[
"Global"
][
"eval_during_train"
]):
loss_config
=
config
.
get
(
"Loss"
,
None
)
if
loss_config
is
not
None
:
loss_config
=
loss_config
.
get
(
"Eval"
)
if
loss_config
is
not
None
:
eval_loss_func
=
CombinedLoss
(
copy
.
deepcopy
(
loss_config
))
return
train_loss_func
,
unlabel_train_loss_func
,
eval_loss_func
def
build_loss
(
config
):
if
config
is
None
:
return
None
module_class
=
CombinedLoss
(
copy
.
deepcopy
(
config
))
logger
.
debug
(
"build loss {} success."
.
format
(
module_class
))
return
module_class
ppcls/metric/__init__.py
浏览文件 @
5d06a88a
...
...
@@ -65,38 +65,6 @@ class CombinedMetrics(AvgMetrics):
metric
.
reset
()
def
build_metrics
(
engine
):
config
,
mode
=
engine
.
config
,
engine
.
mode
if
mode
==
'train'
and
"Metric"
in
config
and
"Train"
in
config
[
"Metric"
]
and
config
[
"Metric"
][
"Train"
]:
metric_config
=
config
[
"Metric"
][
"Train"
]
if
hasattr
(
engine
.
train_dataloader
,
"collate_fn"
)
and
engine
.
train_dataloader
.
collate_fn
is
not
None
:
for
m_idx
,
m
in
enumerate
(
metric_config
):
if
"TopkAcc"
in
m
:
msg
=
f
"Unable to calculate accuracy when using
\"
batch_transform_ops
\"
. The metric
\"
{
m
}
\"
has been removed."
logger
.
warning
(
msg
)
metric_config
.
pop
(
m_idx
)
train_metric_func
=
CombinedMetrics
(
copy
.
deepcopy
(
metric_config
))
else
:
train_metric_func
=
None
if
mode
==
"eval"
or
(
mode
==
"train"
and
config
[
"Global"
][
"eval_during_train"
]):
eval_mode
=
config
[
"Global"
].
get
(
"eval_mode"
,
"classification"
)
if
eval_mode
==
"classification"
:
if
"Metric"
in
config
and
"Eval"
in
config
[
"Metric"
]:
eval_metric_func
=
CombinedMetrics
(
copy
.
deepcopy
(
config
[
"Metric"
][
"Eval"
]))
else
:
eval_metric_func
=
None
elif
eval_mode
==
"retrieval"
:
if
"Metric"
in
config
and
"Eval"
in
config
[
"Metric"
]:
metric_config
=
config
[
"Metric"
][
"Eval"
]
else
:
metric_config
=
[{
"name"
:
"Recallk"
,
"topk"
:
(
1
,
5
)}]
eval_metric_func
=
CombinedMetrics
(
copy
.
deepcopy
(
metric_config
))
else
:
eval_metric_func
=
None
return
train_metric_func
,
eval_metric_func
def
build_metrics
(
config
):
metrics_list
=
CombinedMetrics
(
copy
.
deepcopy
(
config
))
return
metrics_list
ppcls/optimizer/__init__.py
浏览文件 @
5d06a88a
...
...
@@ -45,11 +45,8 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph
def
build_optimizer
(
config
,
dataloader
,
model_list
=
None
):
optim_config
=
copy
.
deepcopy
(
config
[
"Optimizer"
])
epochs
=
config
[
"Global"
][
"epochs"
]
update_freq
=
config
[
"Global"
].
get
(
"update_freq"
,
1
)
step_each_epoch
=
dataloader
.
iter_per_epoch
//
update_freq
def
build_optimizer
(
config
,
epochs
,
step_each_epoch
,
model_list
=
None
):
optim_config
=
copy
.
deepcopy
(
config
)
if
isinstance
(
optim_config
,
dict
):
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
optim_name
=
optim_config
.
pop
(
"name"
)
...
...
ppcls/utils/logger.py
浏览文件 @
5d06a88a
...
...
@@ -22,15 +22,16 @@ import paddle.distributed as dist
_logger
=
None
def
init_logger
(
config
,
mode
=
"train"
,
name
=
'ppcls'
,
log_level
=
logging
.
INFO
):
def
init_logger
(
name
=
'ppcls'
,
log_file
=
None
,
log_level
=
logging
.
INFO
):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added.
added.
If `log_file` is specified a FileHandler will also be added.
Args:
config(dict): Training config.
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
...
...
@@ -62,8 +63,6 @@ def init_logger(config, mode="train", name='ppcls', log_level=logging.INFO):
if
init_flag
:
_logger
.
addHandler
(
stream_handler
)
log_file
=
os
.
path
.
join
(
config
[
'Global'
][
'output_dir'
],
config
[
"Arch"
][
"name"
],
f
"
{
mode
}
.log"
)
if
log_file
is
not
None
and
dist
.
get_rank
()
==
0
:
log_file_folder
=
os
.
path
.
split
(
log_file
)[
0
]
os
.
makedirs
(
log_file_folder
,
exist_ok
=
True
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录