Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
1989b660
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1989b660
编写于
5月 17, 2022
作者:
C
cuicheng01
提交者:
GitHub
5月 17, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1 from littletomatodonkey/me/add_pdemo
fix convert weight
上级
7eef98da
afafb8f4
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
29 addition
and
43 deletion
+29
-43
ppcls/engine/engine.py
ppcls/engine/engine.py
+6
-5
ppcls/utils/convert_weights.py
ppcls/utils/convert_weights.py
+0
-31
ppcls/utils/save_load.py
ppcls/utils/save_load.py
+23
-7
未找到文件。
ppcls/engine/engine.py
浏览文件 @
1989b660
...
...
@@ -344,15 +344,15 @@ class Engine(object):
if
self
.
use_dali
:
self
.
train_dataloader
.
reset
()
metric_msg
=
", "
.
join
([
self
.
output_info
[
key
].
avg_info
for
key
in
self
.
output_info
])
metric_msg
=
", "
.
join
(
[
self
.
output_info
[
key
].
avg_info
for
key
in
self
.
output_info
])
logger
.
info
(
"[Train][Epoch {}/{}][Avg]{}"
.
format
(
epoch_id
,
self
.
config
[
"Global"
][
"epochs"
],
metric_msg
))
self
.
output_info
.
clear
()
# eval model and save model if possible
start_eval_epoch
=
self
.
config
[
"Global"
].
get
(
"start_eval_epoch"
,
0
)
-
1
start_eval_epoch
=
self
.
config
[
"Global"
].
get
(
"start_eval_epoch"
,
0
)
-
1
if
self
.
config
[
"Global"
][
"eval_during_train"
]
and
epoch_id
%
self
.
config
[
"Global"
][
"eval_interval"
]
==
0
and
epoch_id
>
start_eval_epoch
:
...
...
@@ -367,7 +367,8 @@ class Engine(object):
self
.
output_dir
,
model_name
=
self
.
config
[
"Arch"
][
"name"
],
prefix
=
"best_model"
,
loss
=
self
.
train_loss_func
)
loss
=
self
.
train_loss_func
,
save_student_model
=
True
)
logger
.
info
(
"[Eval][Epoch {}][best metric: {}]"
.
format
(
epoch_id
,
best_metric
[
"metric"
]))
logger
.
scaler
(
...
...
ppcls/utils/convert_weights.py
已删除
100644 → 0
浏览文件 @
7eef98da
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
paddle
def
convert_distill_weights
(
distill_weights_path
,
student_weights_path
):
assert
os
.
path
.
exists
(
distill_weights_path
),
\
"Given distill_weights_path {} not exist."
.
format
(
distill_weights_path
)
# Load teacher and student weights
all_params
=
paddle
.
load
(
distill_weights_path
)
# Extract student weights
s_params
=
{
key
[
len
(
"Student."
):]:
all_params
[
key
]
for
key
in
all_params
if
"Student."
in
key
}
# Save student weights
paddle
.
save
(
s_params
,
student_weights_path
)
ppcls/utils/save_load.py
浏览文件 @
1989b660
...
...
@@ -42,6 +42,14 @@ def _mkdir_if_not_exist(path):
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
def
_extract_student_weights
(
all_params
,
student_prefix
=
"Student."
):
s_params
=
{
key
[
len
(
student_prefix
):]:
all_params
[
key
]
for
key
in
all_params
if
student_prefix
in
key
}
return
s_params
def
load_dygraph_pretrain
(
model
,
path
=
None
):
if
not
(
os
.
path
.
isdir
(
path
)
or
os
.
path
.
exists
(
path
+
'.pdparams'
)):
raise
ValueError
(
"Model pretrain path {}.pdparams does not "
...
...
@@ -117,7 +125,7 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None):
else
:
# common load
load_dygraph_pretrain
(
net
,
path
=
pretrained_model
)
logger
.
info
(
"Finish load pretrained model from {}"
.
format
(
pretrained_model
))
pretrained_model
))
def
save_model
(
net
,
...
...
@@ -126,7 +134,8 @@ def save_model(net,
model_path
,
model_name
=
""
,
prefix
=
'ppcls'
,
loss
:
paddle
.
nn
.
Layer
=
None
):
loss
:
paddle
.
nn
.
Layer
=
None
,
save_student_model
=
False
):
"""
save model to the target path
"""
...
...
@@ -137,11 +146,18 @@ def save_model(net,
model_path
=
os
.
path
.
join
(
model_path
,
prefix
)
params_state_dict
=
net
.
state_dict
()
loss_state_dict
=
loss
.
state_dict
()
keys_inter
=
set
(
params_state_dict
.
keys
())
&
set
(
loss_state_dict
.
keys
())
assert
len
(
keys_inter
)
==
0
,
\
f
"keys in model and loss state_dict must be unique, but got intersection
{
keys_inter
}
"
params_state_dict
.
update
(
loss_state_dict
)
if
loss
is
not
None
:
loss_state_dict
=
loss
.
state_dict
()
keys_inter
=
set
(
params_state_dict
.
keys
())
&
set
(
loss_state_dict
.
keys
(
))
assert
len
(
keys_inter
)
==
0
,
\
f
"keys in model and loss state_dict must be unique, but got intersection
{
keys_inter
}
"
params_state_dict
.
update
(
loss_state_dict
)
if
save_student_model
:
s_params
=
_extract_student_weights
(
params_state_dict
)
if
len
(
s_params
)
>
0
:
paddle
.
save
(
s_params
,
model_path
+
"_student.pdparams"
)
paddle
.
save
(
params_state_dict
,
model_path
+
".pdparams"
)
paddle
.
save
([
opt
.
state_dict
()
for
opt
in
optimizer
],
model_path
+
".pdopt"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录