Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
e7e4f68b
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e7e4f68b
编写于
3月 14, 2023
作者:
T
Tingquan Gao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert "refactor: build_train_func & build_eval_func"
This reverts commit
6bed0f57
.
上级
6245b64c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
57 addition
and
51 deletion
+57
-51
ppcls/data/__init__.py
ppcls/data/__init__.py
+29
-15
ppcls/engine/engine.py
ppcls/engine/engine.py
+16
-6
ppcls/engine/evaluation/__init__.py
ppcls/engine/evaluation/__init__.py
+3
-12
ppcls/engine/train/__init__.py
ppcls/engine/train/__init__.py
+5
-15
ppcls/engine/train/train.py
ppcls/engine/train/train.py
+1
-1
ppcls/engine/train/train_progressive.py
ppcls/engine/train/train_progressive.py
+3
-2
未找到文件。
ppcls/data/__init__.py
浏览文件 @
e7e4f68b
...
...
@@ -88,7 +88,7 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
random
.
seed
(
worker_seed
)
def
build
(
config
,
mode
,
use_dali
=
False
,
seed
=
None
):
def
build
(
config
,
mode
,
device
,
use_dali
=
False
,
seed
=
None
):
assert
mode
in
[
'Train'
,
'Eval'
,
'Test'
,
'Gallery'
,
'Query'
,
'UnLabelTrain'
],
"Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
...
...
@@ -167,7 +167,7 @@ def build(config, mode, use_dali=False, seed=None):
if
batch_sampler
is
None
:
data_loader
=
DataLoader
(
dataset
=
dataset
,
places
=
paddle
.
device
.
get_device
()
,
places
=
device
,
num_workers
=
num_workers
,
return_list
=
True
,
use_shared_memory
=
use_shared_memory
,
...
...
@@ -179,7 +179,7 @@ def build(config, mode, use_dali=False, seed=None):
else
:
data_loader
=
DataLoader
(
dataset
=
dataset
,
places
=
paddle
.
device
.
get_device
()
,
places
=
device
,
num_workers
=
num_workers
,
return_list
=
True
,
use_shared_memory
=
use_shared_memory
,
...
...
@@ -218,7 +218,11 @@ def build_dataloader(engine):
}
if
engine
.
mode
==
'train'
:
train_dataloader
=
build
(
engine
.
config
[
"DataLoader"
],
"Train"
,
use_dali
,
seed
=
None
)
engine
.
config
[
"DataLoader"
],
"Train"
,
engine
.
device
,
use_dali
,
seed
=
None
)
iter_per_epoch
=
len
(
train_dataloader
)
-
1
if
platform
.
system
(
)
==
"Windows"
else
len
(
train_dataloader
)
if
engine
.
config
[
"Global"
].
get
(
"iter_per_epoch"
,
None
):
...
...
@@ -231,23 +235,33 @@ def build_dataloader(engine):
if
engine
.
config
[
"DataLoader"
].
get
(
'UnLabelTrain'
,
None
)
is
not
None
:
dataloader_dict
[
"UnLabelTrain"
]
=
build
(
engine
.
config
[
"DataLoader"
],
"UnLabelTrain"
,
use_dali
,
seed
=
None
)
engine
.
config
[
"DataLoader"
],
"UnLabelTrain"
,
engine
.
device
,
use_dali
,
seed
=
None
)
if
engine
.
mode
==
"eval"
or
(
engine
.
mode
==
"train"
and
engine
.
config
[
"Global"
][
"eval_during_train"
]):
if
engine
.
config
[
"Global"
][
"eval_mode"
]
in
[
"classification"
,
"adaface"
]:
if
engine
.
eval_mode
in
[
"classification"
,
"adaface"
]:
dataloader_dict
[
"Eval"
]
=
build
(
engine
.
config
[
"DataLoader"
],
"Eval"
,
use_dali
,
seed
=
None
)
elif
engine
.
config
[
"Global"
][
"eval_mode"
]
==
"retrieval"
:
engine
.
config
[
"DataLoader"
],
"Eval"
,
engine
.
device
,
use_dali
,
seed
=
None
)
elif
engine
.
eval_mode
==
"retrieval"
:
if
len
(
engine
.
config
[
"DataLoader"
][
"Eval"
].
keys
())
==
1
:
key
=
list
(
engine
.
config
[
"DataLoader"
][
"Eval"
].
keys
())[
0
]
dataloader_dict
[
"GalleryQuery"
]
=
build
(
engine
.
config
[
"DataLoader"
][
"Eval"
],
key
,
use_dali
)
dataloader_dict
[
"GalleryQuery"
]
=
build_dataloader
(
engine
.
config
[
"DataLoader"
][
"Eval"
],
key
,
engine
.
device
,
use_dali
)
else
:
dataloader_dict
[
"Gallery"
]
=
build
(
engine
.
config
[
"DataLoader"
][
"Eval"
],
"Gallery"
,
use_dali
)
dataloader_dict
[
"Query"
]
=
build
(
engine
.
config
[
"DataLoader"
][
"Eval"
],
"Query"
,
use_dali
)
dataloader_dict
[
"Gallery"
]
=
build_dataloader
(
engine
.
config
[
"DataLoader"
][
"Eval"
],
"Gallery"
,
engine
.
device
,
use_dali
)
dataloader_dict
[
"Query"
]
=
build_dataloader
(
engine
.
config
[
"DataLoader"
][
"Eval"
],
"Query"
,
engine
.
device
,
use_dali
)
return
dataloader_dict
ppcls/engine/engine.py
浏览文件 @
e7e4f68b
...
...
@@ -39,8 +39,7 @@ from ppcls.utils import save_load
from
ppcls.data.utils.get_image_list
import
get_image_list
from
ppcls.data.postprocess
import
build_postprocess
from
ppcls.data
import
create_operators
from
.train
import
build_train_epoch_func
from
.evaluation
import
build_eval_func
from
ppcls.engine
import
train
as
train_method
from
ppcls.engine.train.utils
import
type_name
from
ppcls.engine
import
evaluation
from
ppcls.arch.gears.identity_head
import
IdentityHead
...
...
@@ -62,11 +61,22 @@ class Engine(object):
self
.
vdl_writer
=
self
.
_init_vdl
()
# init train_func and eval_func
self
.
train_epoch_func
=
build_train_epoch_func
(
self
.
config
)
self
.
eval_epoch_func
=
build_eval_func
(
self
.
config
)
self
.
train_mode
=
self
.
config
[
"Global"
].
get
(
"train_mode"
,
None
)
if
self
.
train_mode
is
None
:
self
.
train_epoch_func
=
train_method
.
train_epoch
else
:
self
.
train_epoch_func
=
getattr
(
train_method
,
"train_epoch_"
+
self
.
train_mode
)
self
.
eval_mode
=
self
.
config
[
"Global"
].
get
(
"eval_mode"
,
"classification"
)
assert
self
.
eval_mode
in
[
"classification"
,
"retrieval"
,
"adaface"
],
logger
.
error
(
"Invalid eval mode: {}"
.
format
(
self
.
eval_mode
))
self
.
eval_func
=
getattr
(
evaluation
,
self
.
eval_mode
+
"_eval"
)
# set device
self
.
_init_device
()
self
.
device
=
self
.
_init_device
()
# gradient accumulation
self
.
update_freq
=
self
.
config
[
"Global"
].
get
(
"update_freq"
,
1
)
...
...
@@ -385,7 +395,7 @@ class Engine(object):
assert
device
in
[
"cpu"
,
"gpu"
,
"xpu"
,
"npu"
,
"mlu"
,
"ascend"
]
logger
.
info
(
'train with paddle {} and device {}'
.
format
(
paddle
.
__version__
,
device
))
paddle
.
set_device
(
device
)
return
paddle
.
set_device
(
device
)
def
_init_pretrained
(
self
):
if
self
.
config
[
"Global"
][
"pretrained_model"
]
is
not
None
:
...
...
ppcls/engine/evaluation/__init__.py
浏览文件 @
e7e4f68b
...
...
@@ -12,15 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.classification
import
classification_eval
from
.retrieval
import
retrieval_eval
from
.adaface
import
adaface_eval
def
build_eval_func
(
config
):
eval_mode
=
config
[
"Global"
].
get
(
"eval_mode"
,
None
)
if
eval_mode
is
None
:
config
[
"Global"
][
"eval_mode"
]
=
"classification"
return
classification_eval
else
:
return
getattr
(
sys
.
modules
[
__name__
],
eval_mode
+
"_eval"
)
from
ppcls.engine.evaluation.classification
import
classification_eval
from
ppcls.engine.evaluation.retrieval
import
retrieval_eval
from
ppcls.engine.evaluation.adaface
import
adaface_eval
\ No newline at end of file
ppcls/engine/train/__init__.py
浏览文件 @
e7e4f68b
...
...
@@ -11,18 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.train_metabin
import
train_epoch_metabin
from
.regular_train_epoch
import
regular_train_epoch
from
.train_fixmatch
import
train_epoch_fixmatch
from
.train_fixmatch_ccssl
import
train_epoch_fixmatch_ccssl
from
.train_progressive
import
train_epoch_progressive
def
build_train_epoch_func
(
config
):
train_mode
=
config
[
"Global"
].
get
(
"train_mode"
,
None
)
if
train_mode
is
None
:
config
[
"Global"
][
"train_mode"
]
=
"regular_train"
return
regular_train_epoch
else
:
return
getattr
(
sys
.
modules
[
__name__
],
"train_epoch_"
+
train_mode
)
from
ppcls.engine.train.train
import
train_epoch
from
ppcls.engine.train.train_fixmatch
import
train_epoch_fixmatch
from
ppcls.engine.train.train_fixmatch_ccssl
import
train_epoch_fixmatch_ccssl
from
ppcls.engine.train.train_progressive
import
train_epoch_progressive
from
ppcls.engine.train.train_metabin
import
train_epoch_metabin
ppcls/engine/train/
regular_train_epoch
.py
→
ppcls/engine/train/
train
.py
浏览文件 @
e7e4f68b
...
...
@@ -19,7 +19,7 @@ from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_
from
ppcls.utils
import
profiler
def
regular_
train_epoch
(
engine
,
epoch_id
,
print_batch_step
):
def
train_epoch
(
engine
,
epoch_id
,
print_batch_step
):
tic
=
time
.
time
()
if
not
hasattr
(
engine
,
"train_dataloader_iter"
):
...
...
ppcls/engine/train/train_progressive.py
浏览文件 @
e7e4f68b
...
...
@@ -16,7 +16,8 @@ from __future__ import absolute_import, division, print_function
from
ppcls.data
import
build_dataloader
from
ppcls.engine.train.utils
import
type_name
from
ppcls.utils
import
logger
from
.regular_train_epoch
import
regular_train_epoch
from
.train
import
train_epoch
def
train_epoch_progressive
(
engine
,
epoch_id
,
print_batch_step
):
...
...
@@ -68,4 +69,4 @@ def train_epoch_progressive(engine, epoch_id, print_batch_step):
f
")"
)
# 3. Train one epoch as usual at current stage
regular_
train_epoch
(
engine
,
epoch_id
,
print_batch_step
)
train_epoch
(
engine
,
epoch_id
,
print_batch_step
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录