Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
0055ca2f
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0055ca2f
编写于
3月 14, 2023
作者:
T
Tingquan Gao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert "debug"
This reverts commit
9e683d0d
.
上级
753270ab
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
11 addition
and
11 deletion
+11
-11
ppcls/engine/engine.py
ppcls/engine/engine.py
+10
-10
ppcls/engine/evaluation/__init__.py
ppcls/engine/evaluation/__init__.py
+1
-1
未找到文件。
ppcls/engine/engine.py
浏览文件 @
0055ca2f
...
@@ -60,17 +60,17 @@ class Engine(object):
...
@@ -60,17 +60,17 @@ class Engine(object):
# build model
# build model
self
.
model
=
build_model
(
self
.
config
,
self
.
mode
)
self
.
model
=
build_model
(
self
.
config
,
self
.
mode
)
# load_pretrain
self
.
_init_pretrained
()
self
.
_init_amp
()
# init train_func and eval_func
# init train_func and eval_func
self
.
eval
=
build_eval_func
(
self
.
eval
=
build_eval_func
(
self
.
config
,
mode
=
self
.
mode
,
model
=
self
.
model
)
self
.
config
,
mode
=
self
.
mode
,
model
=
self
.
model
)
self
.
train
=
build_train_func
(
self
.
train
=
build_train_func
(
self
.
config
,
mode
=
self
.
mode
,
model
=
self
.
model
,
eval_func
=
self
.
eval
)
self
.
config
,
mode
=
self
.
mode
,
model
=
self
.
model
,
eval_func
=
self
.
eval
)
# load_pretrain
self
.
_init_pretrained
()
self
.
_init_amp
()
# for distributed
# for distributed
self
.
_init_dist
()
self
.
_init_dist
()
...
@@ -197,11 +197,11 @@ class Engine(object):
...
@@ -197,11 +197,11 @@ class Engine(object):
if
self
.
config
[
"Global"
][
"pretrained_model"
]
is
not
None
:
if
self
.
config
[
"Global"
][
"pretrained_model"
]
is
not
None
:
if
self
.
config
[
"Global"
][
"pretrained_model"
].
startswith
(
"http"
):
if
self
.
config
[
"Global"
][
"pretrained_model"
].
startswith
(
"http"
):
load_dygraph_pretrain_from_url
(
load_dygraph_pretrain_from_url
(
[
self
.
model
,
getattr
(
self
.
train
,
"loss_func"
,
None
)],
[
self
.
model
,
getattr
(
self
,
'train_loss_func'
,
None
)],
self
.
config
[
"Global"
][
"pretrained_model"
])
self
.
config
[
"Global"
][
"pretrained_model"
])
else
:
else
:
load_dygraph_pretrain
(
load_dygraph_pretrain
(
[
self
.
model
,
getattr
(
self
.
train
,
"loss_func"
,
None
)],
[
self
.
model
,
getattr
(
self
,
'train_loss_func'
,
None
)],
self
.
config
[
"Global"
][
"pretrained_model"
])
self
.
config
[
"Global"
][
"pretrained_model"
])
def
_init_amp
(
self
):
def
_init_amp
(
self
):
...
@@ -257,10 +257,10 @@ class Engine(object):
...
@@ -257,10 +257,10 @@ class Engine(object):
if
self
.
config
[
"Global"
][
"distributed"
]:
if
self
.
config
[
"Global"
][
"distributed"
]:
dist
.
init_parallel_env
()
dist
.
init_parallel_env
()
self
.
model
=
paddle
.
DataParallel
(
self
.
model
)
self
.
model
=
paddle
.
DataParallel
(
self
.
model
)
if
self
.
mode
==
'train'
and
len
(
self
.
train
.
loss_func
.
parameters
(
if
self
.
mode
==
'train'
and
len
(
self
.
train
_
loss_func
.
parameters
(
))
>
0
:
))
>
0
:
self
.
train
.
loss_func
=
paddle
.
DataParallel
(
self
.
train
_
loss_func
=
paddle
.
DataParallel
(
self
.
train
.
loss_func
)
self
.
train
_
loss_func
)
class
ExportModel
(
TheseusLayer
):
class
ExportModel
(
TheseusLayer
):
...
...
ppcls/engine/evaluation/__init__.py
浏览文件 @
0055ca2f
...
@@ -20,7 +20,7 @@ from .adaface import adaface_eval
...
@@ -20,7 +20,7 @@ from .adaface import adaface_eval
def
build_eval_func
(
config
,
mode
,
model
):
def
build_eval_func
(
config
,
mode
,
model
):
if
mode
not
in
[
"eval"
,
"train"
]:
if
mode
not
in
[
"eval"
,
"train"
]:
return
None
return
None
task
=
config
[
"Global"
].
get
(
"
eval_mode
"
,
"classification"
)
task
=
config
[
"Global"
].
get
(
"
task
"
,
"classification"
)
if
task
==
"classification"
:
if
task
==
"classification"
:
return
ClassEval
(
config
,
mode
,
model
)
return
ClassEval
(
config
,
mode
,
model
)
elif
task
==
"retrieval"
:
elif
task
==
"retrieval"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录