Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
284e2a67
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
284e2a67
编写于
2月 22, 2023
作者:
G
gaotingquan
提交者:
Wei Shengyu
3月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor: mv all dataloaders to engine.dataloader_dict
上级
efe0d45c
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
60 addition
and
47 deletion
+60
-47
ppcls/data/__init__.py
ppcls/data/__init__.py
+31
-2
ppcls/engine/engine.py
ppcls/engine/engine.py
+2
-10
ppcls/engine/evaluation/classification.py
ppcls/engine/evaluation/classification.py
+10
-13
ppcls/engine/train/regular_train_epoch.py
ppcls/engine/train/regular_train_epoch.py
+2
-13
ppcls/engine/train/utils.py
ppcls/engine/train/utils.py
+6
-5
ppcls/metric/__init__.py
ppcls/metric/__init__.py
+3
-2
ppcls/optimizer/__init__.py
ppcls/optimizer/__init__.py
+6
-2
未找到文件。
ppcls/data/__init__.py
浏览文件 @
284e2a67
...
@@ -187,10 +187,37 @@ def build(config, mode, use_dali=False, seed=None):
...
@@ -187,10 +187,37 @@ def build(config, mode, use_dali=False, seed=None):
collate_fn
=
batch_collate_fn
,
collate_fn
=
batch_collate_fn
,
worker_init_fn
=
init_fn
)
worker_init_fn
=
init_fn
)
total_samples
=
len
(
data_loader
.
dataset
)
if
not
use_dali
else
data_loader
.
size
max_iter
=
len
(
data_loader
)
-
1
if
platform
.
system
()
==
"Windows"
else
len
(
data_loader
)
data_loader
.
max_iter
=
max_iter
data_loader
.
total_samples
=
total_samples
logger
.
debug
(
"build data_loader({}) success..."
.
format
(
data_loader
))
logger
.
debug
(
"build data_loader({}) success..."
.
format
(
data_loader
))
return
data_loader
return
data_loader
# TODO(gaotingquan): perf
class
DataIterator
(
object
):
def
__init__
(
self
,
dataloader
,
use_dali
=
False
):
self
.
dataloader
=
dataloader
self
.
use_dali
=
use_dali
self
.
iterator
=
iter
(
dataloader
)
def
get_batch
(
self
):
# fetch data batch from dataloader
try
:
batch
=
next
(
self
.
iterator
)
except
Exception
:
# NOTE: reset DALI dataloader manually
if
self
.
use_dali
:
self
.
dataloader
.
reset
()
self
.
iterator
=
iter
(
self
.
dataloader
)
batch
=
next
(
self
.
iterator
)
return
batch
def
build_dataloader
(
engine
):
def
build_dataloader
(
engine
):
if
"class_num"
in
engine
.
config
[
"Global"
]:
if
"class_num"
in
engine
.
config
[
"Global"
]:
global_class_num
=
engine
.
config
[
"Global"
][
"class_num"
]
global_class_num
=
engine
.
config
[
"Global"
][
"class_num"
]
...
@@ -222,12 +249,15 @@ def build_dataloader(engine):
...
@@ -222,12 +249,15 @@ def build_dataloader(engine):
iter_per_epoch
=
len
(
train_dataloader
)
-
1
if
platform
.
system
(
iter_per_epoch
=
len
(
train_dataloader
)
-
1
if
platform
.
system
(
)
==
"Windows"
else
len
(
train_dataloader
)
)
==
"Windows"
else
len
(
train_dataloader
)
if
engine
.
config
[
"Global"
].
get
(
"iter_per_epoch"
,
None
):
if
engine
.
config
[
"Global"
].
get
(
"iter_per_epoch"
,
None
):
# TODO(gaotingquan): iter_per_epoch should be set in Dataloader.Train, not Global
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
iter_per_epoch
=
engine
.
config
[
"Global"
].
get
(
"iter_per_epoch"
)
iter_per_epoch
=
engine
.
config
[
"Global"
].
get
(
"iter_per_epoch"
)
iter_per_epoch
=
iter_per_epoch
//
engine
.
update_freq
*
engine
.
update_freq
iter_per_epoch
=
iter_per_epoch
//
engine
.
update_freq
*
engine
.
update_freq
engine
.
iter_per_epoch
=
iter_per_epoch
#
engine.iter_per_epoch = iter_per_epoch
train_dataloader
.
iter_per_epoch
=
iter_per_epoch
train_dataloader
.
iter_per_epoch
=
iter_per_epoch
dataloader_dict
[
"Train"
]
=
train_dataloader
dataloader_dict
[
"Train"
]
=
train_dataloader
# TODO(gaotingquan): set the iterator field in config, such as Dataloader.Train.convert_iterator=True
dataloader_dict
[
"TrainIter"
]
=
DataIterator
(
train_dataloader
,
use_dali
)
if
engine
.
config
[
"DataLoader"
].
get
(
'UnLabelTrain'
,
None
)
is
not
None
:
if
engine
.
config
[
"DataLoader"
].
get
(
'UnLabelTrain'
,
None
)
is
not
None
:
dataloader_dict
[
"UnLabelTrain"
]
=
build
(
dataloader_dict
[
"UnLabelTrain"
]
=
build
(
...
@@ -249,5 +279,4 @@ def build_dataloader(engine):
...
@@ -249,5 +279,4 @@ def build_dataloader(engine):
engine
.
config
[
"DataLoader"
][
"Eval"
],
"Gallery"
,
use_dali
)
engine
.
config
[
"DataLoader"
][
"Eval"
],
"Gallery"
,
use_dali
)
dataloader_dict
[
"Query"
]
=
build
(
dataloader_dict
[
"Query"
]
=
build
(
engine
.
config
[
"DataLoader"
][
"Eval"
],
"Query"
,
use_dali
)
engine
.
config
[
"DataLoader"
][
"Eval"
],
"Query"
,
use_dali
)
return
dataloader_dict
return
dataloader_dict
ppcls/engine/engine.py
浏览文件 @
284e2a67
...
@@ -63,7 +63,7 @@ class Engine(object):
...
@@ -63,7 +63,7 @@ class Engine(object):
# init train_func and eval_func
# init train_func and eval_func
self
.
train_epoch_func
=
build_train_epoch_func
(
self
.
config
)
self
.
train_epoch_func
=
build_train_epoch_func
(
self
.
config
)
self
.
eval_
epoch_
func
=
build_eval_func
(
self
.
config
)
self
.
eval_func
=
build_eval_func
(
self
.
config
)
# set device
# set device
self
.
_init_device
()
self
.
_init_device
()
...
@@ -73,12 +73,6 @@ class Engine(object):
...
@@ -73,12 +73,6 @@ class Engine(object):
# build dataloader
# build dataloader
self
.
dataloader_dict
=
build_dataloader
(
self
)
self
.
dataloader_dict
=
build_dataloader
(
self
)
self
.
train_dataloader
,
self
.
unlabel_train_dataloader
,
self
.
eval_dataloader
=
self
.
dataloader_dict
[
"Train"
],
self
.
dataloader_dict
[
"UnLabelTrain"
],
self
.
dataloader_dict
[
"Eval"
]
self
.
gallery_query_dataloader
,
self
.
gallery_dataloader
,
self
.
query_dataloader
=
self
.
dataloader_dict
[
"GalleryQuery"
],
self
.
dataloader_dict
[
"Gallery"
],
self
.
dataloader_dict
[
"Query"
]
# build loss
# build loss
self
.
train_loss_func
,
self
.
unlabel_train_loss_func
,
self
.
eval_loss_func
=
build_loss
(
self
.
train_loss_func
,
self
.
unlabel_train_loss_func
,
self
.
eval_loss_func
=
build_loss
(
...
@@ -94,9 +88,7 @@ class Engine(object):
...
@@ -94,9 +88,7 @@ class Engine(object):
self
.
_init_pretrained
()
self
.
_init_pretrained
()
# build optimizer
# build optimizer
self
.
optimizer
,
self
.
lr_sch
=
build_optimizer
(
self
.
optimizer
,
self
.
lr_sch
=
build_optimizer
(
self
)
self
.
config
,
self
.
train_dataloader
,
[
self
.
model
,
self
.
train_loss_func
])
# AMP training and evaluating
# AMP training and evaluating
self
.
_init_amp
()
self
.
_init_amp
()
...
...
ppcls/engine/evaluation/classification.py
浏览文件 @
284e2a67
...
@@ -35,13 +35,10 @@ def classification_eval(engine, epoch_id=0):
...
@@ -35,13 +35,10 @@ def classification_eval(engine, epoch_id=0):
print_batch_step
=
engine
.
config
[
"Global"
][
"print_batch_step"
]
print_batch_step
=
engine
.
config
[
"Global"
][
"print_batch_step"
]
tic
=
time
.
time
()
tic
=
time
.
time
()
total_samples
=
engine
.
dataloader_dict
[
"Eval"
].
total_samples
accum_samples
=
0
accum_samples
=
0
total_samples
=
len
(
max_iter
=
engine
.
dataloader_dict
[
"Eval"
].
max_iter
engine
.
eval_dataloader
.
for
iter_id
,
batch
in
enumerate
(
engine
.
dataloader_dict
[
"Eval"
]):
dataset
)
if
not
engine
.
use_dali
else
engine
.
eval_dataloader
.
size
max_iter
=
len
(
engine
.
eval_dataloader
)
-
1
if
platform
.
system
(
)
==
"Windows"
else
len
(
engine
.
eval_dataloader
)
for
iter_id
,
batch
in
enumerate
(
engine
.
eval_dataloader
):
if
iter_id
>=
max_iter
:
if
iter_id
>=
max_iter
:
break
break
if
iter_id
==
5
:
if
iter_id
==
5
:
...
@@ -61,9 +58,9 @@ def classification_eval(engine, epoch_id=0):
...
@@ -61,9 +58,9 @@ def classification_eval(engine, epoch_id=0):
"flatten_contiguous_range"
,
"greater_than"
"flatten_contiguous_range"
,
"greater_than"
},
},
level
=
engine
.
amp_level
):
level
=
engine
.
amp_level
):
out
=
engine
.
model
(
batch
[
0
]
)
out
=
engine
.
model
(
batch
)
else
:
else
:
out
=
engine
.
model
(
batch
[
0
]
)
out
=
engine
.
model
(
batch
)
# just for DistributedBatchSampler issue: repeat sampling
# just for DistributedBatchSampler issue: repeat sampling
current_samples
=
batch_size
*
paddle
.
distributed
.
get_world_size
()
current_samples
=
batch_size
*
paddle
.
distributed
.
get_world_size
()
...
@@ -95,7 +92,8 @@ def classification_eval(engine, epoch_id=0):
...
@@ -95,7 +92,8 @@ def classification_eval(engine, epoch_id=0):
paddle
.
distributed
.
all_gather
(
pred_list
,
out
)
paddle
.
distributed
.
all_gather
(
pred_list
,
out
)
preds
=
paddle
.
concat
(
pred_list
,
0
)
preds
=
paddle
.
concat
(
pred_list
,
0
)
if
accum_samples
>
total_samples
and
not
engine
.
use_dali
:
if
accum_samples
>
total_samples
and
not
engine
.
config
[
"Global"
].
get
(
"use_dali"
,
False
):
if
isinstance
(
preds
,
list
):
if
isinstance
(
preds
,
list
):
preds
=
[
preds
=
[
pred
[:
total_samples
+
current_samples
-
accum_samples
]
pred
[:
total_samples
+
current_samples
-
accum_samples
]
...
@@ -151,12 +149,11 @@ def classification_eval(engine, epoch_id=0):
...
@@ -151,12 +149,11 @@ def classification_eval(engine, epoch_id=0):
])
])
metric_msg
+=
", {}"
.
format
(
engine
.
eval_metric_func
.
avg_info
)
metric_msg
+=
", {}"
.
format
(
engine
.
eval_metric_func
.
avg_info
)
logger
.
info
(
"[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}"
.
format
(
logger
.
info
(
"[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}"
.
format
(
epoch_id
,
iter_id
,
epoch_id
,
iter_id
,
max_iter
,
metric_msg
,
time_msg
,
ips_msg
))
len
(
engine
.
eval_dataloader
),
metric_msg
,
time_msg
,
ips_msg
))
tic
=
time
.
time
()
tic
=
time
.
time
()
if
engine
.
use_dali
:
if
engine
.
config
[
"Global"
].
get
(
"use_dali"
,
False
)
:
engine
.
eval_dataloader
.
reset
()
engine
.
dataloader_dict
[
"Eval"
]
.
reset
()
if
"ATTRMetric"
in
engine
.
config
[
"Metric"
][
"Eval"
][
0
]:
if
"ATTRMetric"
in
engine
.
config
[
"Metric"
][
"Eval"
][
0
]:
metric_msg
=
", "
.
join
([
metric_msg
=
", "
.
join
([
...
...
ppcls/engine/train/regular_train_epoch.py
浏览文件 @
284e2a67
...
@@ -22,19 +22,8 @@ from ppcls.utils import profiler
...
@@ -22,19 +22,8 @@ from ppcls.utils import profiler
def
regular_train_epoch
(
engine
,
epoch_id
,
print_batch_step
):
def
regular_train_epoch
(
engine
,
epoch_id
,
print_batch_step
):
tic
=
time
.
time
()
tic
=
time
.
time
()
if
not
hasattr
(
engine
,
"train_dataloader_iter"
):
for
iter_id
in
range
(
engine
.
dataloader_dict
[
"Train"
].
iter_per_epoch
):
engine
.
train_dataloader_iter
=
iter
(
engine
.
train_dataloader
)
batch
=
engine
.
dataloader_dict
[
"TrainIter"
].
get_batch
()
for
iter_id
in
range
(
engine
.
iter_per_epoch
):
# fetch data batch from dataloader
try
:
batch
=
next
(
engine
.
train_dataloader_iter
)
except
Exception
:
# NOTE: reset DALI dataloader manually
if
engine
.
use_dali
:
engine
.
train_dataloader
.
reset
()
engine
.
train_dataloader_iter
=
iter
(
engine
.
train_dataloader
)
batch
=
next
(
engine
.
train_dataloader_iter
)
profiler
.
add_profiler_step
(
engine
.
config
[
"profiler_options"
])
profiler
.
add_profiler_step
(
engine
.
config
[
"profiler_options"
])
if
iter_id
==
5
:
if
iter_id
==
5
:
...
...
ppcls/engine/train/utils.py
浏览文件 @
284e2a67
...
@@ -54,13 +54,14 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
...
@@ -54,13 +54,14 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
ips_msg
=
"ips: {:.5f} samples/s"
.
format
(
ips_msg
=
"ips: {:.5f} samples/s"
.
format
(
batch_size
/
trainer
.
time_info
[
"batch_cost"
].
avg
)
batch_size
/
trainer
.
time_info
[
"batch_cost"
].
avg
)
eta_sec
=
(
eta_sec
=
(
(
trainer
.
config
[
"Global"
][
"epochs"
]
-
epoch_id
+
1
(
trainer
.
config
[
"Global"
][
"epochs"
]
-
epoch_id
+
1
)
*
)
*
trainer
.
dataloader_dict
[
"Train"
].
iter_per_epoch
-
iter_id
trainer
.
iter_per_epoch
-
iter_id
)
*
trainer
.
time_info
[
"batch_cost"
].
avg
)
*
trainer
.
time_info
[
"batch_cost"
].
avg
eta_msg
=
"eta: {:s}"
.
format
(
str
(
datetime
.
timedelta
(
seconds
=
int
(
eta_sec
))))
eta_msg
=
"eta: {:s}"
.
format
(
str
(
datetime
.
timedelta
(
seconds
=
int
(
eta_sec
))))
logger
.
info
(
"[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}"
.
format
(
logger
.
info
(
"[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}"
.
format
(
epoch_id
,
trainer
.
config
[
"Global"
][
"epochs"
],
iter_id
,
trainer
.
epoch_id
,
trainer
.
config
[
"Global"
][
iter_per_epoch
,
lr_msg
,
metric_msg
,
time_msg
,
ips_msg
,
eta_msg
))
"epochs"
],
iter_id
,
trainer
.
dataloader_dict
[
"Train"
]
.
iter_per_epoch
,
lr_msg
,
metric_msg
,
time_msg
,
ips_msg
,
eta_msg
))
for
i
,
lr
in
enumerate
(
trainer
.
lr_sch
):
for
i
,
lr
in
enumerate
(
trainer
.
lr_sch
):
logger
.
scaler
(
logger
.
scaler
(
...
...
ppcls/metric/__init__.py
浏览文件 @
284e2a67
...
@@ -70,8 +70,9 @@ def build_metrics(engine):
...
@@ -70,8 +70,9 @@ def build_metrics(engine):
if
mode
==
'train'
and
"Metric"
in
config
and
"Train"
in
config
[
if
mode
==
'train'
and
"Metric"
in
config
and
"Train"
in
config
[
"Metric"
]
and
config
[
"Metric"
][
"Train"
]:
"Metric"
]
and
config
[
"Metric"
][
"Train"
]:
metric_config
=
config
[
"Metric"
][
"Train"
]
metric_config
=
config
[
"Metric"
][
"Train"
]
if
hasattr
(
engine
.
train_dataloader
,
"collate_fn"
if
hasattr
(
engine
.
dataloader_dict
[
"Train"
],
)
and
engine
.
train_dataloader
.
collate_fn
is
not
None
:
"collate_fn"
)
and
engine
.
dataloader_dict
[
"Train"
].
collate_fn
is
not
None
:
for
m_idx
,
m
in
enumerate
(
metric_config
):
for
m_idx
,
m
in
enumerate
(
metric_config
):
if
"TopkAcc"
in
m
:
if
"TopkAcc"
in
m
:
msg
=
f
"Unable to calculate accuracy when using
\"
batch_transform_ops
\"
. The metric
\"
{
m
}
\"
has been removed."
msg
=
f
"Unable to calculate accuracy when using
\"
batch_transform_ops
\"
. The metric
\"
{
m
}
\"
has been removed."
...
...
ppcls/optimizer/__init__.py
浏览文件 @
284e2a67
...
@@ -45,11 +45,15 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
...
@@ -45,11 +45,15 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph
# model_list is None in static graph
def
build_optimizer
(
config
,
dataloader
,
model_list
=
None
):
def
build_optimizer
(
engine
):
if
engine
.
mode
!=
"train"
:
return
None
,
None
config
,
iter_per_epoch
,
model_list
=
engine
.
config
,
engine
.
dataloader_dict
[
"Train"
].
iter_per_epoch
,
[
engine
.
mode
,
engine
.
train_loss_func
]
optim_config
=
copy
.
deepcopy
(
config
[
"Optimizer"
])
optim_config
=
copy
.
deepcopy
(
config
[
"Optimizer"
])
epochs
=
config
[
"Global"
][
"epochs"
]
epochs
=
config
[
"Global"
][
"epochs"
]
update_freq
=
config
[
"Global"
].
get
(
"update_freq"
,
1
)
update_freq
=
config
[
"Global"
].
get
(
"update_freq"
,
1
)
step_each_epoch
=
dataloader
.
iter_per_epoch
//
update_freq
step_each_epoch
=
iter_per_epoch
//
update_freq
if
isinstance
(
optim_config
,
dict
):
if
isinstance
(
optim_config
,
dict
):
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
optim_name
=
optim_config
.
pop
(
"name"
)
optim_name
=
optim_config
.
pop
(
"name"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录