Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
3da16763
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3da16763
编写于
3月 22, 2020
作者:
G
Guo Sheng
提交者:
GitHub
3月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #9 from guoshengCS/add-load-finetune
Support for fine-tuning.
上级
b7674284
f5cbe6d0
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
89 addition
and
38 deletion
+89
-38
model.py
model.py
+89
-38
未找到文件。
model.py
浏览文件 @
3da16763
...
...
@@ -18,6 +18,8 @@ import inspect
import
os
import
pickle
import
numpy
as
np
import
six
import
warnings
from
collections
import
Iterable
from
collections
import
OrderedDict
...
...
@@ -157,7 +159,7 @@ class StaticGraphAdapter(object):
return
self
.
_run
(
inputs
,
None
)
def
parameters
(
self
,
*
args
,
**
kwargs
):
return
None
return
super
(
Model
,
self
.
model
).
parameters
(
*
args
,
**
kwargs
)
def
save
(
self
,
path
):
def
_save
(
state
,
path
):
...
...
@@ -191,39 +193,23 @@ class StaticGraphAdapter(object):
_save
(
optim
,
optim_path
)
def
load
(
self
,
path
):
def
_load
(
path
):
if
not
os
.
path
.
exists
(
path
):
return
with
open
(
path
,
'rb'
)
as
f
:
return
pickle
.
load
(
f
)
param_path
=
path
+
".pdparams"
param_state
=
_load
(
param_path
)
assert
param_state
,
"failed to load parameters, please check path"
def
load
(
self
,
param_state_pairs
,
optim_state
):
if
self
.
_executor
is
None
:
executor
=
fluid
.
Executor
(
fluid
.
CPUPlace
()).
_default_executor
else
:
executor
=
self
.
_executor
.
_default_executor
# restore parameter states
fluid
.
core
.
_create_loaded_parameter
(
list
(
self
.
model
.
state_dict
().
values
()),
global_scope
(),
executor
)
for
key
,
var
in
self
.
model
.
state_dict
().
items
():
assert
key
in
param_state
,
\
"parameter [{}] is not found in model file [{}]"
.
format
(
key
,
param_path
)
self
.
_set_var
(
var
,
param_state
[
key
])
[
param
for
param
,
state
in
param_state_pairs
],
global_scope
(),
executor
)
for
param
,
state
in
param_state_pairs
:
self
.
_set_var
(
param
,
state
)
# restore optimizer states
# FIXME what if a different optimizer is used?
if
not
self
.
model
.
_optimizer
:
return
optim_path
=
path
+
".pdopt"
optim_state
=
_load
(
optim_path
)
if
optim_state
is
None
:
if
not
self
.
model
.
_optimizer
or
not
optim_state
:
return
self
.
_load_optimizer
(
optim_state
,
executor
)
def
_load_optimizer
(
self
,
state
,
executor
):
...
...
@@ -473,7 +459,7 @@ class DynamicGraphAdapter(object):
if
labels
is
not
None
:
labels
=
[
to_variable
(
l
)
for
l
in
to_list
(
labels
)]
outputs
=
to_list
(
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
]))
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
]))
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
final_loss
=
fluid
.
layers
.
sum
(
losses
)
final_loss
.
backward
()
...
...
@@ -482,7 +468,7 @@ class DynamicGraphAdapter(object):
metrics
=
[]
for
metric
in
self
.
model
.
_metrics
:
metric_outs
=
metric
.
add_metric_op
(
outputs
,
to_list
(
labels
))
m
=
metric
.
update
(
*
[
to_numpy
(
m
)
for
m
in
to_list
(
metric_outs
)])
m
=
metric
.
update
(
*
[
to_numpy
(
m
)
for
m
in
to_list
(
metric_outs
)])
metrics
.
append
(
m
)
return
([
to_numpy
(
l
)
for
l
in
losses
],
metrics
)
\
if
len
(
metrics
)
>
0
else
[
to_numpy
(
l
)
for
l
in
losses
]
...
...
@@ -494,7 +480,7 @@ class DynamicGraphAdapter(object):
if
labels
is
not
None
:
labels
=
[
to_variable
(
l
)
for
l
in
to_list
(
labels
)]
outputs
=
to_list
(
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
]))
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
]))
if
self
.
model
.
_loss_function
:
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
...
...
@@ -504,7 +490,7 @@ class DynamicGraphAdapter(object):
metrics
=
[]
for
metric
in
self
.
model
.
_metrics
:
metric_outs
=
metric
.
add_metric_op
(
outputs
,
labels
)
m
=
metric
.
update
(
*
[
to_numpy
(
m
)
for
m
in
to_list
(
metric_outs
)])
m
=
metric
.
update
(
*
[
to_numpy
(
m
)
for
m
in
to_list
(
metric_outs
)])
metrics
.
append
(
m
)
# To be consistent with static graph
...
...
@@ -531,10 +517,13 @@ class DynamicGraphAdapter(object):
optim
=
self
.
model
.
_optimizer
.
state_dict
()
fluid
.
save_dygraph
(
optim
,
path
)
def
load
(
self
,
path
):
params
,
optim
=
fluid
.
load_dygraph
(
path
)
self
.
model
.
set_dict
(
params
)
if
self
.
model
.
_optimizer
is
None
or
optim
is
None
:
def
load
(
self
,
param_state_pairs
,
optim_state
):
# restore parameter states
for
param
,
state
in
param_state_pairs
:
param
.
set_value
(
state
)
# resotre optimizer states
if
not
self
.
model
.
_optimizer
or
not
optim_state
:
return
# If optimizer performs set_dict when state vars haven't been created,
...
...
@@ -543,13 +532,13 @@ class DynamicGraphAdapter(object):
# To contrive this when loading from static-graph saved states, extend
# state dict to include keys named accoring to dygraph naming rules.
# TODO: if len(self.model._optimizer._accumulators) > 0
converted_state
=
dict
(
optim
)
converted_state
=
dict
(
optim
_state
)
opt_unq_name
=
self
.
model
.
_optimizer
.
_name
opt_cls_name
=
self
.
model
.
_optimizer
.
__class__
.
__name__
opt_name
=
opt_unq_name
[:
opt_unq_name
.
rfind
(
"_"
)]
# remove suffix idx
param_names
=
[
param
.
name
for
param
in
self
.
model
.
parameters
()]
for
var_name
,
state_var
in
sorted
(
optim
.
items
(),
key
=
lambda
x
:
len
(
x
[
0
]),
reverse
=
True
):
optim
_state
.
items
(),
key
=
lambda
x
:
len
(
x
[
0
]),
reverse
=
True
):
if
var_name
in
[
"@LR_DECAY_COUNTER@"
,
"global_step"
]:
# NOTE: dygraph saved global_step is 1 larger than that in
# static-graph, since the time of global_step to increase is
...
...
@@ -597,7 +586,6 @@ class Model(fluid.dygraph.Layer):
self
.
_optimizer
=
None
self
.
_device
=
None
self
.
_device_ids
=
None
self
.
_optimizer
=
None
if
in_dygraph_mode
():
self
.
_adapter
=
DynamicGraphAdapter
(
self
)
else
:
...
...
@@ -615,8 +603,71 @@ class Model(fluid.dygraph.Layer):
def
save
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
save
(
*
args
,
**
kwargs
)
def
load
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
load
(
*
args
,
**
kwargs
)
def
load
(
self
,
path
,
skip_mismatch
=
False
,
reset_optimizer
=
False
):
"""
Load from files storing the model states and optimizer states. The file
for optimizer states is not necessary if no need to restore the optimizer.
NOTE: parameters are retrieved out from the file storing model states
accoring to their structured names.
For fine-tuning or transfer-learning models where some of the layers have
changed, keep parameters needed to restore have same structured names in
the pre-trained model and fine-tuning model.
Args:
path (str): The prefix of files storing the model states and
optimizer states. The files would be `path.pdparams` and
`path.pdopt` separately, and the latter is not necessary
when no need to restore.
skip_mismatch (bool): Whether to skip the loading of mismatch
parameter or raise an error when mismatch happens (not found
the parameter in file storing model states of or receives a
mismatch shape).
reset_optimizer (bool): If True, ignore the providing file storing
optimizer states and initialize optimizer states from scratch.
Otherwise, restore optimizer states from `path.pdopt` if
a optimizer has been set to the model. Default False.
"""
def
_load_state_from_path
(
path
):
if
not
os
.
path
.
exists
(
path
):
return
with
open
(
path
,
'rb'
)
as
f
:
return
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
f
,
encoding
=
'latin1'
)
def
_check_match
(
key
,
param
):
state
=
param_state
.
get
(
key
,
None
)
if
state
is
None
:
raise
ValueError
(
"{} is not found in the providing file."
.
format
(
key
))
if
list
(
state
.
shape
)
!=
list
(
param
.
shape
):
raise
ValueError
(
"{} receives a shape {}, but the expected shape is {}."
.
format
(
key
,
list
(
state
.
shape
),
list
(
param
.
shape
)))
return
param
,
state
param_state
=
_load_state_from_path
(
path
+
".pdparams"
)
assert
param_state
,
"Failed to load parameters, please check path."
matched_param_state
=
[]
for
key
,
param
in
self
.
state_dict
().
items
():
try
:
match_res
=
_check_match
(
key
,
param
)
except
ValueError
as
err
:
if
skip_mismatch
:
warnings
.
warn
(
(
"Skip loading for {}. "
.
format
(
key
)
+
err
.
message
))
# reset optimizer when mismatch happens
reset_optimizer
=
True
else
:
raise
err
matched_param_state
.
append
(
match_res
)
optim_state
=
None
if
reset_optimizer
else
_load_state_from_path
(
path
+
".pdopt"
)
return
self
.
_adapter
.
load
(
matched_param_state
,
optim_state
)
def
parameters
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
parameters
(
*
args
,
**
kwargs
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录