Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
b6de4543
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b6de4543
编写于
3月 17, 2020
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support for fine-tuning.
上级
358f7852
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
109 addition
and
18 deletion
+109
-18
model.py
model.py
+109
-18
未找到文件。
model.py
浏览文件 @
b6de4543
...
...
@@ -17,6 +17,7 @@ from __future__ import absolute_import
import
inspect
import
os
import
pickle
import
six
from
collections
import
OrderedDict
import
numpy
as
np
...
...
@@ -133,7 +134,7 @@ class StaticGraphAdapter(object):
return
self
.
_run
(
inputs
,
None
)
def
parameters
(
self
,
*
args
,
**
kwargs
):
return
None
return
super
(
Model
,
self
.
model
).
parameters
(
*
args
,
**
kwargs
)
def
save
(
self
,
path
):
def
_save
(
state
,
path
):
...
...
@@ -167,12 +168,13 @@ class StaticGraphAdapter(object):
_save
(
optim
,
optim_path
)
def
load
(
self
,
path
):
def
load
(
self
,
path
,
reset_optimizer
=
False
,
parameters
=
[]
):
def
_load
(
path
):
if
not
os
.
path
.
exists
(
path
):
return
with
open
(
path
,
'rb'
)
as
f
:
return
pickle
.
load
(
f
)
return
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
f
,
encoding
=
'latin1'
)
param_path
=
path
+
".pdparams"
param_state
=
_load
(
param_path
)
...
...
@@ -183,14 +185,20 @@ class StaticGraphAdapter(object):
else
:
executor
=
self
.
_executor
.
_default_executor
param_names
=
[
param
.
name
for
param
in
parameters
]
fluid
.
core
.
_create_loaded_parameter
(
list
(
self
.
model
.
state_dict
().
values
()
),
global_scope
(),
executor
)
list
(
parameters
),
global_scope
(),
executor
)
for
key
,
var
in
self
.
model
.
state_dict
().
items
():
assert
key
in
param_state
,
\
"parameter [{}] is not found in model file [{}]"
.
format
(
key
,
param_path
)
self
.
_set_var
(
var
,
param_state
[
key
])
if
not
param_names
or
var
.
name
in
param_names
:
assert
key
in
param_state
,
\
"parameter [{}] is not found in model file [{}]"
.
format
(
key
,
param_path
)
self
.
_set_var
(
var
,
param_state
[
key
])
if
reset_optimizer
or
parameters
:
return
# FIXME what if a different optimizer is used?
if
not
self
.
model
.
_optimizer
:
...
...
@@ -429,7 +437,7 @@ class DynamicGraphAdapter(object):
inputs
=
to_list
(
inputs
)
if
labels
is
not
None
:
labels
=
to_list
(
labels
)
outputs
=
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
])
outputs
=
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
])
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
final_loss
=
fluid
.
layers
.
sum
(
losses
)
final_loss
.
backward
()
...
...
@@ -444,7 +452,7 @@ class DynamicGraphAdapter(object):
inputs
=
to_list
(
inputs
)
if
labels
is
not
None
:
labels
=
to_list
(
labels
)
outputs
=
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
])
outputs
=
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
])
if
self
.
model
.
_loss_function
:
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
...
...
@@ -475,10 +483,22 @@ class DynamicGraphAdapter(object):
optim
=
self
.
model
.
_optimizer
.
state_dict
()
fluid
.
save_dygraph
(
optim
,
path
)
def
load
(
self
,
path
):
params
,
optim
=
fluid
.
load_dygraph
(
path
)
self
.
model
.
set_dict
(
params
)
if
self
.
model
.
_optimizer
is
None
or
optim
is
None
:
def
load
(
self
,
path
,
reset_optimizer
=
False
,
parameters
=
[]):
param_state
,
optim_state
=
fluid
.
load_dygraph
(
path
)
param_names
=
[
param
.
name
for
param
in
parameters
]
for
key
,
var
in
self
.
model
.
state_dict
().
items
():
if
not
param_names
or
var
.
name
in
param_names
:
assert
key
in
param_state
,
\
"parameter [{}] is not found in model file [{}]"
.
format
(
key
,
path
+
".pdparams"
)
var
.
set_value
(
param_state
[
key
])
if
reset_optimizer
or
parameters
:
return
if
self
.
model
.
_optimizer
is
None
or
optim_state
is
None
:
return
# If optimizer performs set_dict when state vars haven't been created,
...
...
@@ -487,13 +507,13 @@ class DynamicGraphAdapter(object):
# To contrive this when loading from static-graph saved states, extend
# state dict to include keys named accoring to dygraph naming rules.
# TODO: if len(self.model._optimizer._accumulators) > 0
converted_state
=
dict
(
optim
)
converted_state
=
dict
(
optim
_state
)
opt_unq_name
=
self
.
model
.
_optimizer
.
_name
opt_cls_name
=
self
.
model
.
_optimizer
.
__class__
.
__name__
opt_name
=
opt_unq_name
[:
opt_unq_name
.
rfind
(
"_"
)]
# remove suffix idx
param_names
=
[
param
.
name
for
param
in
self
.
model
.
parameters
()]
for
var_name
,
state_var
in
sorted
(
optim
.
items
(),
key
=
lambda
x
:
len
(
x
[
0
]),
reverse
=
True
):
optim
_state
.
items
(),
key
=
lambda
x
:
len
(
x
[
0
]),
reverse
=
True
):
if
var_name
in
[
"@LR_DECAY_COUNTER@"
,
"global_step"
]:
# NOTE: dygraph saved global_step is 1 larger than that in
# static-graph, since the time of global_step to increase is
...
...
@@ -560,8 +580,79 @@ class Model(fluid.dygraph.Layer):
def
save
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
save
(
*
args
,
**
kwargs
)
def
load
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
load
(
*
args
,
**
kwargs
)
def
load
(
self
,
path
,
reset_optimizer
=
False
,
layers
=
None
,
weights
=
None
):
"""
Load from files storing the model states and optimizer states. The file
for optimizer states is not necessary if no need to restore the optimizer.
`layers` and `weights` are useful for fine-tuning or transfer-learning
models where some of the layers have changed. If provided, only
parameters included in layers and weights would be loaded, and optimizer
would be reset. If both are None, make no effect and load all parameters.
NOTE: parameters are restored based on names, which are decided by the
network's topology if not given by `param_attr` explicitly. This means
the architecture should be the same as when the weights were saved.
Layers that don't have parameters are not taken into account in the
topological ordering, thus could be added or removed casually.
Args:
path (str): The prefix of files storing the model states and
optimizer states. The files would be `path.pdparams` and
`path.pdopt` separately, and the latter is not necessary
when no need to restore.
reset_optimizer (bool): If True, ignore the providing file storing
optimizer states and initialize optimizer states from scratch.
Otherwise, restore optimizer states from `path.pdopt` if
a optimizer has been set to the model. Default False.
layers (list|Layer|str|None): The layers to be restored. All
parameters in these layers would be loaded. `layers` is
composed of instances of Layer or string. A string corresponded
layer is the one whose `full_name()` equals to the string.
If None, make no effect to load. Default None.
weights (list|Parameter|str|None): The parameters to be loaded.
`weights` is composed of instances of Parameter or string.
A string corresponded parameter is the one whose name equals to
the string. If None, make no effect to load. Default None.
"""
load_param_vars
=
set
()
if
layers
is
not
None
:
model_layers
=
self
.
sublayers
()
model_layers_dict
=
dict
((
layer
.
full_name
(),
layer
)
for
layer
in
model_layers
)
for
i
,
layer
in
enumerate
(
to_list
(
layers
)):
if
isinstance
(
layer
,
fluid
.
dygraph
.
Layer
):
assert
layer
in
model_layers
,
(
"The #%d layer in layers is not in model."
%
i
)
load_param_vars
.
update
(
layer
.
state_dict
().
values
())
elif
isinstance
(
layer
,
six
.
string_types
):
assert
layer
in
model_layers_dict
,
(
"The #%d layer in layers is not in model."
%
i
)
load_param_vars
.
update
(
model_layers_dict
[
layer
].
state_dict
(
).
values
())
else
:
raise
TypeError
(
"The value in layers should be string or Layer."
)
if
weights
is
not
None
:
model_weights
=
self
.
parameters
()
model_weights_dict
=
dict
((
weight
.
name
,
weight
)
for
weight
in
model_weights
)
param_type
=
fluid
.
framework
.
ParamBase
if
in_dygraph_mode
(
)
else
fluid
.
framework
.
Parameter
for
i
,
weight
in
enumerate
(
to_list
(
weights
)):
if
isinstance
(
weight
,
param_type
):
# var== has been overwrited, thus do not use `weight in`
assert
weight
.
name
in
model_weights_dict
,
(
"The #%d weight in weights is not in model."
%
i
)
load_param_vars
.
add
(
weight
)
elif
isinstance
(
weight
,
six
.
string_types
):
assert
weight
in
model_weights_dict
,
(
"The #%d weight in weights is not in model."
%
i
)
load_param_vars
.
add
(
model_weights_dict
[
weight
])
else
:
raise
TypeError
(
"The value in weights should be string or %s."
%
param_type
.
__name__
)
return
self
.
_adapter
.
load
(
path
,
reset_optimizer
,
list
(
load_param_vars
))
def
prepare
(
self
,
optimizer
=
None
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录