Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
3da16763
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3da16763
编写于
3月 22, 2020
作者:
G
Guo Sheng
提交者:
GitHub
3月 22, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #9 from guoshengCS/add-load-finetune
Support for fine-tuning.
上级
b7674284
f5cbe6d0
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
89 addition
and
38 deletion
+89
-38
model.py
model.py
+89
-38
未找到文件。
model.py
浏览文件 @
3da16763
...
@@ -18,6 +18,8 @@ import inspect
...
@@ -18,6 +18,8 @@ import inspect
import
os
import
os
import
pickle
import
pickle
import
numpy
as
np
import
numpy
as
np
import
six
import
warnings
from
collections
import
Iterable
from
collections
import
Iterable
from
collections
import
OrderedDict
from
collections
import
OrderedDict
...
@@ -157,7 +159,7 @@ class StaticGraphAdapter(object):
...
@@ -157,7 +159,7 @@ class StaticGraphAdapter(object):
return
self
.
_run
(
inputs
,
None
)
return
self
.
_run
(
inputs
,
None
)
def
parameters
(
self
,
*
args
,
**
kwargs
):
def
parameters
(
self
,
*
args
,
**
kwargs
):
return
None
return
super
(
Model
,
self
.
model
).
parameters
(
*
args
,
**
kwargs
)
def
save
(
self
,
path
):
def
save
(
self
,
path
):
def
_save
(
state
,
path
):
def
_save
(
state
,
path
):
...
@@ -191,39 +193,23 @@ class StaticGraphAdapter(object):
...
@@ -191,39 +193,23 @@ class StaticGraphAdapter(object):
_save
(
optim
,
optim_path
)
_save
(
optim
,
optim_path
)
def
load
(
self
,
path
):
def
load
(
self
,
param_state_pairs
,
optim_state
):
def
_load
(
path
):
if
not
os
.
path
.
exists
(
path
):
return
with
open
(
path
,
'rb'
)
as
f
:
return
pickle
.
load
(
f
)
param_path
=
path
+
".pdparams"
param_state
=
_load
(
param_path
)
assert
param_state
,
"failed to load parameters, please check path"
if
self
.
_executor
is
None
:
if
self
.
_executor
is
None
:
executor
=
fluid
.
Executor
(
fluid
.
CPUPlace
()).
_default_executor
executor
=
fluid
.
Executor
(
fluid
.
CPUPlace
()).
_default_executor
else
:
else
:
executor
=
self
.
_executor
.
_default_executor
executor
=
self
.
_executor
.
_default_executor
# restore parameter states
fluid
.
core
.
_create_loaded_parameter
(
fluid
.
core
.
_create_loaded_parameter
(
list
(
self
.
model
.
state_dict
().
values
()),
global_scope
(),
executor
)
[
param
for
param
,
state
in
param_state_pairs
],
global_scope
(),
executor
)
for
key
,
var
in
self
.
model
.
state_dict
().
items
():
for
param
,
state
in
param_state_pairs
:
assert
key
in
param_state
,
\
self
.
_set_var
(
param
,
state
)
"parameter [{}] is not found in model file [{}]"
.
format
(
key
,
param_path
)
self
.
_set_var
(
var
,
param_state
[
key
])
# restore optimizer states
# FIXME what if a different optimizer is used?
# FIXME what if a different optimizer is used?
if
not
self
.
model
.
_optimizer
:
if
not
self
.
model
.
_optimizer
or
not
optim_state
:
return
optim_path
=
path
+
".pdopt"
optim_state
=
_load
(
optim_path
)
if
optim_state
is
None
:
return
return
self
.
_load_optimizer
(
optim_state
,
executor
)
self
.
_load_optimizer
(
optim_state
,
executor
)
def
_load_optimizer
(
self
,
state
,
executor
):
def
_load_optimizer
(
self
,
state
,
executor
):
...
@@ -473,7 +459,7 @@ class DynamicGraphAdapter(object):
...
@@ -473,7 +459,7 @@ class DynamicGraphAdapter(object):
if
labels
is
not
None
:
if
labels
is
not
None
:
labels
=
[
to_variable
(
l
)
for
l
in
to_list
(
labels
)]
labels
=
[
to_variable
(
l
)
for
l
in
to_list
(
labels
)]
outputs
=
to_list
(
outputs
=
to_list
(
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
]))
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
]))
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
final_loss
=
fluid
.
layers
.
sum
(
losses
)
final_loss
=
fluid
.
layers
.
sum
(
losses
)
final_loss
.
backward
()
final_loss
.
backward
()
...
@@ -482,7 +468,7 @@ class DynamicGraphAdapter(object):
...
@@ -482,7 +468,7 @@ class DynamicGraphAdapter(object):
metrics
=
[]
metrics
=
[]
for
metric
in
self
.
model
.
_metrics
:
for
metric
in
self
.
model
.
_metrics
:
metric_outs
=
metric
.
add_metric_op
(
outputs
,
to_list
(
labels
))
metric_outs
=
metric
.
add_metric_op
(
outputs
,
to_list
(
labels
))
m
=
metric
.
update
(
*
[
to_numpy
(
m
)
for
m
in
to_list
(
metric_outs
)])
m
=
metric
.
update
(
*
[
to_numpy
(
m
)
for
m
in
to_list
(
metric_outs
)])
metrics
.
append
(
m
)
metrics
.
append
(
m
)
return
([
to_numpy
(
l
)
for
l
in
losses
],
metrics
)
\
return
([
to_numpy
(
l
)
for
l
in
losses
],
metrics
)
\
if
len
(
metrics
)
>
0
else
[
to_numpy
(
l
)
for
l
in
losses
]
if
len
(
metrics
)
>
0
else
[
to_numpy
(
l
)
for
l
in
losses
]
...
@@ -494,7 +480,7 @@ class DynamicGraphAdapter(object):
...
@@ -494,7 +480,7 @@ class DynamicGraphAdapter(object):
if
labels
is
not
None
:
if
labels
is
not
None
:
labels
=
[
to_variable
(
l
)
for
l
in
to_list
(
labels
)]
labels
=
[
to_variable
(
l
)
for
l
in
to_list
(
labels
)]
outputs
=
to_list
(
outputs
=
to_list
(
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
]))
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
]))
if
self
.
model
.
_loss_function
:
if
self
.
model
.
_loss_function
:
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
...
@@ -504,7 +490,7 @@ class DynamicGraphAdapter(object):
...
@@ -504,7 +490,7 @@ class DynamicGraphAdapter(object):
metrics
=
[]
metrics
=
[]
for
metric
in
self
.
model
.
_metrics
:
for
metric
in
self
.
model
.
_metrics
:
metric_outs
=
metric
.
add_metric_op
(
outputs
,
labels
)
metric_outs
=
metric
.
add_metric_op
(
outputs
,
labels
)
m
=
metric
.
update
(
*
[
to_numpy
(
m
)
for
m
in
to_list
(
metric_outs
)])
m
=
metric
.
update
(
*
[
to_numpy
(
m
)
for
m
in
to_list
(
metric_outs
)])
metrics
.
append
(
m
)
metrics
.
append
(
m
)
# To be consistent with static graph
# To be consistent with static graph
...
@@ -531,10 +517,13 @@ class DynamicGraphAdapter(object):
...
@@ -531,10 +517,13 @@ class DynamicGraphAdapter(object):
optim
=
self
.
model
.
_optimizer
.
state_dict
()
optim
=
self
.
model
.
_optimizer
.
state_dict
()
fluid
.
save_dygraph
(
optim
,
path
)
fluid
.
save_dygraph
(
optim
,
path
)
def
load
(
self
,
path
):
def
load
(
self
,
param_state_pairs
,
optim_state
):
params
,
optim
=
fluid
.
load_dygraph
(
path
)
# restore parameter states
self
.
model
.
set_dict
(
params
)
for
param
,
state
in
param_state_pairs
:
if
self
.
model
.
_optimizer
is
None
or
optim
is
None
:
param
.
set_value
(
state
)
# resotre optimizer states
if
not
self
.
model
.
_optimizer
or
not
optim_state
:
return
return
# If optimizer performs set_dict when state vars haven't been created,
# If optimizer performs set_dict when state vars haven't been created,
...
@@ -543,13 +532,13 @@ class DynamicGraphAdapter(object):
...
@@ -543,13 +532,13 @@ class DynamicGraphAdapter(object):
# To contrive this when loading from static-graph saved states, extend
# To contrive this when loading from static-graph saved states, extend
# state dict to include keys named accoring to dygraph naming rules.
# state dict to include keys named accoring to dygraph naming rules.
# TODO: if len(self.model._optimizer._accumulators) > 0
# TODO: if len(self.model._optimizer._accumulators) > 0
converted_state
=
dict
(
optim
)
converted_state
=
dict
(
optim
_state
)
opt_unq_name
=
self
.
model
.
_optimizer
.
_name
opt_unq_name
=
self
.
model
.
_optimizer
.
_name
opt_cls_name
=
self
.
model
.
_optimizer
.
__class__
.
__name__
opt_cls_name
=
self
.
model
.
_optimizer
.
__class__
.
__name__
opt_name
=
opt_unq_name
[:
opt_unq_name
.
rfind
(
"_"
)]
# remove suffix idx
opt_name
=
opt_unq_name
[:
opt_unq_name
.
rfind
(
"_"
)]
# remove suffix idx
param_names
=
[
param
.
name
for
param
in
self
.
model
.
parameters
()]
param_names
=
[
param
.
name
for
param
in
self
.
model
.
parameters
()]
for
var_name
,
state_var
in
sorted
(
for
var_name
,
state_var
in
sorted
(
optim
.
items
(),
key
=
lambda
x
:
len
(
x
[
0
]),
reverse
=
True
):
optim
_state
.
items
(),
key
=
lambda
x
:
len
(
x
[
0
]),
reverse
=
True
):
if
var_name
in
[
"@LR_DECAY_COUNTER@"
,
"global_step"
]:
if
var_name
in
[
"@LR_DECAY_COUNTER@"
,
"global_step"
]:
# NOTE: dygraph saved global_step is 1 larger than that in
# NOTE: dygraph saved global_step is 1 larger than that in
# static-graph, since the time of global_step to increase is
# static-graph, since the time of global_step to increase is
...
@@ -597,7 +586,6 @@ class Model(fluid.dygraph.Layer):
...
@@ -597,7 +586,6 @@ class Model(fluid.dygraph.Layer):
self
.
_optimizer
=
None
self
.
_optimizer
=
None
self
.
_device
=
None
self
.
_device
=
None
self
.
_device_ids
=
None
self
.
_device_ids
=
None
self
.
_optimizer
=
None
if
in_dygraph_mode
():
if
in_dygraph_mode
():
self
.
_adapter
=
DynamicGraphAdapter
(
self
)
self
.
_adapter
=
DynamicGraphAdapter
(
self
)
else
:
else
:
...
@@ -615,8 +603,71 @@ class Model(fluid.dygraph.Layer):
...
@@ -615,8 +603,71 @@ class Model(fluid.dygraph.Layer):
def
save
(
self
,
*
args
,
**
kwargs
):
def
save
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
save
(
*
args
,
**
kwargs
)
return
self
.
_adapter
.
save
(
*
args
,
**
kwargs
)
def
load
(
self
,
*
args
,
**
kwargs
):
def
load
(
self
,
path
,
skip_mismatch
=
False
,
reset_optimizer
=
False
):
return
self
.
_adapter
.
load
(
*
args
,
**
kwargs
)
"""
Load from files storing the model states and optimizer states. The file
for optimizer states is not necessary if no need to restore the optimizer.
NOTE: parameters are retrieved out from the file storing model states
accoring to their structured names.
For fine-tuning or transfer-learning models where some of the layers have
changed, keep parameters needed to restore have same structured names in
the pre-trained model and fine-tuning model.
Args:
path (str): The prefix of files storing the model states and
optimizer states. The files would be `path.pdparams` and
`path.pdopt` separately, and the latter is not necessary
when no need to restore.
skip_mismatch (bool): Whether to skip the loading of mismatch
parameter or raise an error when mismatch happens (not found
the parameter in file storing model states of or receives a
mismatch shape).
reset_optimizer (bool): If True, ignore the providing file storing
optimizer states and initialize optimizer states from scratch.
Otherwise, restore optimizer states from `path.pdopt` if
a optimizer has been set to the model. Default False.
"""
def
_load_state_from_path
(
path
):
if
not
os
.
path
.
exists
(
path
):
return
with
open
(
path
,
'rb'
)
as
f
:
return
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
f
,
encoding
=
'latin1'
)
def
_check_match
(
key
,
param
):
state
=
param_state
.
get
(
key
,
None
)
if
state
is
None
:
raise
ValueError
(
"{} is not found in the providing file."
.
format
(
key
))
if
list
(
state
.
shape
)
!=
list
(
param
.
shape
):
raise
ValueError
(
"{} receives a shape {}, but the expected shape is {}."
.
format
(
key
,
list
(
state
.
shape
),
list
(
param
.
shape
)))
return
param
,
state
param_state
=
_load_state_from_path
(
path
+
".pdparams"
)
assert
param_state
,
"Failed to load parameters, please check path."
matched_param_state
=
[]
for
key
,
param
in
self
.
state_dict
().
items
():
try
:
match_res
=
_check_match
(
key
,
param
)
except
ValueError
as
err
:
if
skip_mismatch
:
warnings
.
warn
(
(
"Skip loading for {}. "
.
format
(
key
)
+
err
.
message
))
# reset optimizer when mismatch happens
reset_optimizer
=
True
else
:
raise
err
matched_param_state
.
append
(
match_res
)
optim_state
=
None
if
reset_optimizer
else
_load_state_from_path
(
path
+
".pdopt"
)
return
self
.
_adapter
.
load
(
matched_param_state
,
optim_state
)
def
parameters
(
self
,
*
args
,
**
kwargs
):
def
parameters
(
self
,
*
args
,
**
kwargs
):
return
self
.
_adapter
.
parameters
(
*
args
,
**
kwargs
)
return
self
.
_adapter
.
parameters
(
*
args
,
**
kwargs
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录