Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
f525cea0
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f525cea0
编写于
2月 27, 2023
作者:
G
gaotingquan
提交者:
Wei Shengyu
3月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
replace the arg engine with config
上级
e4a3e1bb
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
29 addition
and
31 deletion
+29
-31
ppcls/data/__init__.py
ppcls/data/__init__.py
+28
-30
ppcls/engine/engine.py
ppcls/engine/engine.py
+1
-1
未找到文件。
ppcls/data/__init__.py
浏览文件 @
f525cea0
...
...
@@ -220,23 +220,21 @@ class DataIterator(object):
return
batch
def
build_dataloader
(
engin
e
):
if
"class_num"
in
engine
.
config
[
"Global"
]:
global_class_num
=
engine
.
config
[
"Global"
][
"class_num"
]
def
build_dataloader
(
config
,
mod
e
):
if
"class_num"
in
config
[
"Global"
]:
global_class_num
=
config
[
"Global"
][
"class_num"
]
if
"class_num"
not
in
config
[
"Arch"
]:
engine
.
config
[
"Arch"
][
"class_num"
]
=
global_class_num
config
[
"Arch"
][
"class_num"
]
=
global_class_num
msg
=
f
"The Global.class_num will be deprecated. Please use Arch.class_num instead. Arch.class_num has been set to
{
global_class_num
}
."
else
:
msg
=
"The Global.class_num will be deprecated. Please use Arch.class_num instead. The Global.class_num has been ignored."
logger
.
warning
(
msg
)
class_num
=
engine
.
config
[
"Arch"
].
get
(
"class_num"
,
None
)
engine
.
config
[
"DataLoader"
].
update
({
"class_num"
:
class_num
})
engine
.
config
[
"DataLoader"
].
update
({
"epochs"
:
engine
.
config
[
"Global"
][
"epochs"
]
})
class_num
=
config
[
"Arch"
].
get
(
"class_num"
,
None
)
config
[
"DataLoader"
].
update
({
"class_num"
:
class_num
})
config
[
"DataLoader"
].
update
({
"epochs"
:
config
[
"Global"
][
"epochs"
]})
use_dali
=
engine
.
use_dali
use_dali
=
config
[
"Global"
].
get
(
"use_dali"
,
False
)
dataloader_dict
=
{
"Train"
:
None
,
"UnLabelTrain"
:
None
,
...
...
@@ -245,37 +243,37 @@ def build_dataloader(engine):
"Gallery"
:
None
,
"GalleryQuery"
:
None
}
if
engine
.
mode
==
'train'
:
if
mode
==
'train'
:
train_dataloader
=
build
(
engine
.
config
[
"DataLoader"
],
"Train"
,
use_dali
,
seed
=
None
)
config
[
"DataLoader"
],
"Train"
,
use_dali
,
seed
=
None
)
if
engine
.
config
[
"DataLoader"
][
"Train"
].
get
(
"max_iter"
,
None
):
if
config
[
"DataLoader"
][
"Train"
].
get
(
"max_iter"
,
None
):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
max_iter
=
engine
.
config
[
"Train"
].
get
(
"max_iter"
)
max_iter
=
train_dataloader
.
max_iter
//
engine
.
update_freq
*
engine
.
update_freq
max_iter
=
config
[
"Train"
].
get
(
"max_iter"
)
update_freq
=
config
[
"Global"
].
get
(
"update_freq"
,
1
)
max_iter
=
train_dataloader
.
max_iter
//
update_freq
*
update_freq
train_dataloader
.
max_iter
=
max_iter
if
engine
.
config
[
"DataLoader"
][
"Train"
].
get
(
"convert_iterator"
,
True
):
if
config
[
"DataLoader"
][
"Train"
].
get
(
"convert_iterator"
,
True
):
train_dataloader
=
DataIterator
(
train_dataloader
,
use_dali
)
dataloader_dict
[
"Train"
]
=
train_dataloader
if
engine
.
config
[
"DataLoader"
].
get
(
'UnLabelTrain'
,
None
)
is
not
None
:
if
config
[
"DataLoader"
].
get
(
'UnLabelTrain'
,
None
)
is
not
None
:
dataloader_dict
[
"UnLabelTrain"
]
=
build
(
engine
.
config
[
"DataLoader"
],
"UnLabelTrain"
,
use_dali
,
seed
=
None
)
config
[
"DataLoader"
],
"UnLabelTrain"
,
use_dali
,
seed
=
None
)
if
engine
.
mode
==
"eval"
or
(
engine
.
mode
==
"train"
and
engine
.
config
[
"Global"
][
"eval_during_train"
]):
if
engine
.
config
[
"Global"
][
"eval_mode"
]
in
[
"classification"
,
"adaface"
]:
if
mode
==
"eval"
or
(
mode
==
"train"
and
config
[
"Global"
][
"eval_during_train"
]):
if
config
[
"Global"
][
"eval_mode"
]
in
[
"classification"
,
"adaface"
]:
dataloader_dict
[
"Eval"
]
=
build
(
engine
.
config
[
"DataLoader"
],
"Eval"
,
use_dali
,
seed
=
None
)
elif
engine
.
config
[
"Global"
][
"eval_mode"
]
==
"retrieval"
:
if
len
(
engine
.
config
[
"DataLoader"
][
"Eval"
].
keys
())
==
1
:
key
=
list
(
engine
.
config
[
"DataLoader"
][
"Eval"
].
keys
())[
0
]
config
[
"DataLoader"
],
"Eval"
,
use_dali
,
seed
=
None
)
elif
config
[
"Global"
][
"eval_mode"
]
==
"retrieval"
:
if
len
(
config
[
"DataLoader"
][
"Eval"
].
keys
())
==
1
:
key
=
list
(
config
[
"DataLoader"
][
"Eval"
].
keys
())[
0
]
dataloader_dict
[
"GalleryQuery"
]
=
build
(
engine
.
config
[
"DataLoader"
][
"Eval"
],
key
,
use_dali
)
config
[
"DataLoader"
][
"Eval"
],
key
,
use_dali
)
else
:
dataloader_dict
[
"Gallery"
]
=
build
(
engine
.
config
[
"DataLoader"
][
"Eval"
],
"Gallery"
,
use_dali
)
dataloader_dict
[
"Query"
]
=
build
(
engine
.
config
[
"DataLoader"
][
"Eval"
],
"Query"
,
use_dali
)
config
[
"DataLoader"
][
"Eval"
],
"Gallery"
,
use_dali
)
dataloader_dict
[
"Query"
]
=
build
(
config
[
"DataLoader"
][
"Eval"
],
"Query"
,
use_dali
)
return
dataloader_dict
ppcls/engine/engine.py
浏览文件 @
f525cea0
...
...
@@ -76,7 +76,7 @@ class Engine(object):
# build dataloader
self
.
use_dali
=
self
.
config
[
"Global"
].
get
(
"use_dali"
,
False
)
self
.
dataloader_dict
=
build_dataloader
(
self
)
self
.
dataloader_dict
=
build_dataloader
(
self
.
config
,
mode
)
# build loss
self
.
train_loss_func
,
self
.
unlabel_train_loss_func
,
self
.
eval_loss_func
=
build_loss
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录