Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
3a1276d3
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3a1276d3
编写于
4月 19, 2022
作者:
H
HydrogenSulfate
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
train_loss_func only used in train mode
上级
24abea15
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
15 addition
and
12 deletion
+15
-12
ppcls/engine/engine.py
ppcls/engine/engine.py
+5
-4
ppcls/optimizer/__init__.py
ppcls/optimizer/__init__.py
+8
-7
ppcls/utils/save_load.py
ppcls/utils/save_load.py
+2
-1
未找到文件。
ppcls/engine/engine.py
浏览文件 @
3a1276d3
...
...
@@ -214,17 +214,17 @@ class Engine(object):
if
self
.
config
[
"Global"
][
"pretrained_model"
]
is
not
None
:
if
self
.
config
[
"Global"
][
"pretrained_model"
].
startswith
(
"http"
):
load_dygraph_pretrain_from_url
(
[
self
.
model
,
self
.
train_loss_func
],
[
self
.
model
,
getattr
(
self
,
'train_loss_func'
,
None
)
],
self
.
config
[
"Global"
][
"pretrained_model"
])
else
:
load_dygraph_pretrain
(
[
self
.
model
,
self
.
train_loss_func
],
[
self
.
model
,
getattr
(
self
,
'train_loss_func'
,
None
)
],
self
.
config
[
"Global"
][
"pretrained_model"
])
# build optimizer
if
self
.
mode
==
'train'
:
self
.
optimizer
,
self
.
lr_sch
=
build_optimizer
(
self
.
config
[
"Optimizer"
]
,
self
.
config
[
"Global"
][
"epochs"
],
self
.
config
,
self
.
config
[
"Global"
][
"epochs"
],
len
(
self
.
train_dataloader
),
[
self
.
model
,
self
.
train_loss_func
])
...
...
@@ -259,7 +259,8 @@ class Engine(object):
if
self
.
config
[
"Global"
][
"distributed"
]:
dist
.
init_parallel_env
()
self
.
model
=
paddle
.
DataParallel
(
self
.
model
)
if
len
(
self
.
train_loss_func
.
parameters
())
>
0
:
if
self
.
mode
==
'train'
and
len
(
self
.
train_loss_func
.
parameters
(
))
>
0
:
self
.
train_loss_func
=
paddle
.
DataParallel
(
self
.
train_loss_func
)
# build postprocess for infer
...
...
ppcls/optimizer/__init__.py
浏览文件 @
3a1276d3
...
...
@@ -45,19 +45,20 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph
def
build_optimizer
(
config
,
epochs
,
step_each_epoch
,
model_list
=
None
):
config
=
copy
.
deepcopy
(
config
)
if
isinstance
(
config
,
dict
):
# convert to [{optim_name1: {scope: xxx, **optim_cfg}}, {optim_name2: {scope: xxx, **optim_cfg}}, ...]
optim_name
=
config
.
Optimizer
.
pop
(
'name'
)
config
:
List
[
Dict
[
str
,
Dict
]]
=
[{
optim_config
=
config
[
"Optimizer"
]
if
isinstance
(
optim_config
,
dict
):
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
optim_name
=
optim_config
.
pop
(
"name"
)
optim_config
:
List
[
Dict
[
str
,
Dict
]]
=
[{
optim_name
:
{
'scope'
:
config
.
Arch
.
name
,
'scope'
:
config
[
"Arch"
].
get
(
"name"
)
,
**
config
.
Optimizer
optim_config
}
}]
optim_list
=
[]
lr_list
=
[]
for
optim_item
in
config
:
for
optim_item
in
optim_
config
:
# optim_cfg = {optim_name1: {scope: xxx, **optim_cfg}}
# step1 build lr
optim_name
=
optim_item
.
keys
()[
0
]
# get optim_name1
...
...
ppcls/utils/save_load.py
浏览文件 @
3a1276d3
...
...
@@ -49,7 +49,8 @@ def load_dygraph_pretrain(model, path=None):
param_state_dict
=
paddle
.
load
(
path
+
".pdparams"
)
if
isinstance
(
model
,
list
):
for
m
in
model
:
m
.
set_dict
(
param_state_dict
)
if
hasattr
(
m
,
'set_dict'
):
m
.
set_dict
(
param_state_dict
)
else
:
model
.
set_dict
(
param_state_dict
)
return
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录