未验证 提交 3da16763 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #9 from guoshengCS/add-load-finetune

Support for fine-tuning.
......@@ -18,6 +18,8 @@ import inspect
import os
import pickle
import numpy as np
import six
import warnings
from collections import Iterable
from collections import OrderedDict
......@@ -157,7 +159,7 @@ class StaticGraphAdapter(object):
return self._run(inputs, None)
def parameters(self, *args, **kwargs):
return None
return super(Model, self.model).parameters(*args, **kwargs)
def save(self, path):
def _save(state, path):
......@@ -191,39 +193,23 @@ class StaticGraphAdapter(object):
_save(optim, optim_path)
def load(self, path):
def _load(path):
if not os.path.exists(path):
return
with open(path, 'rb') as f:
return pickle.load(f)
param_path = path + ".pdparams"
param_state = _load(param_path)
assert param_state, "failed to load parameters, please check path"
def load(self, param_state_pairs, optim_state):
if self._executor is None:
executor = fluid.Executor(fluid.CPUPlace())._default_executor
else:
executor = self._executor._default_executor
# restore parameter states
fluid.core._create_loaded_parameter(
list(self.model.state_dict().values()), global_scope(), executor)
for key, var in self.model.state_dict().items():
assert key in param_state, \
"parameter [{}] is not found in model file [{}]".format(
key, param_path)
self._set_var(var, param_state[key])
[param for param, state in param_state_pairs],
global_scope(), executor)
for param, state in param_state_pairs:
self._set_var(param, state)
# restore optimizer states
# FIXME what if a different optimizer is used?
if not self.model._optimizer:
return
optim_path = path + ".pdopt"
optim_state = _load(optim_path)
if optim_state is None:
if not self.model._optimizer or not optim_state:
return
self._load_optimizer(optim_state, executor)
def _load_optimizer(self, state, executor):
......@@ -473,7 +459,7 @@ class DynamicGraphAdapter(object):
if labels is not None:
labels = [to_variable(l) for l in to_list(labels)]
outputs = to_list(
self.model.forward(*[to_variable(x) for x in inputs]))
self.model.forward(* [to_variable(x) for x in inputs]))
losses = self.model._loss_function(outputs, labels)
final_loss = fluid.layers.sum(losses)
final_loss.backward()
......@@ -482,7 +468,7 @@ class DynamicGraphAdapter(object):
metrics = []
for metric in self.model._metrics:
metric_outs = metric.add_metric_op(outputs, to_list(labels))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
m = metric.update(* [to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
return ([to_numpy(l) for l in losses], metrics) \
if len(metrics) > 0 else [to_numpy(l) for l in losses]
......@@ -494,7 +480,7 @@ class DynamicGraphAdapter(object):
if labels is not None:
labels = [to_variable(l) for l in to_list(labels)]
outputs = to_list(
self.model.forward(*[to_variable(x) for x in inputs]))
self.model.forward(* [to_variable(x) for x in inputs]))
if self.model._loss_function:
losses = self.model._loss_function(outputs, labels)
......@@ -504,7 +490,7 @@ class DynamicGraphAdapter(object):
metrics = []
for metric in self.model._metrics:
metric_outs = metric.add_metric_op(outputs, labels)
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
m = metric.update(* [to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
# To be consistent with static graph
......@@ -531,10 +517,13 @@ class DynamicGraphAdapter(object):
optim = self.model._optimizer.state_dict()
fluid.save_dygraph(optim, path)
def load(self, path):
params, optim = fluid.load_dygraph(path)
self.model.set_dict(params)
if self.model._optimizer is None or optim is None:
def load(self, param_state_pairs, optim_state):
# restore parameter states
for param, state in param_state_pairs:
param.set_value(state)
# resotre optimizer states
if not self.model._optimizer or not optim_state:
return
# If optimizer performs set_dict when state vars haven't been created,
......@@ -543,13 +532,13 @@ class DynamicGraphAdapter(object):
# To contrive this when loading from static-graph saved states, extend
# state dict to include keys named accoring to dygraph naming rules.
# TODO: if len(self.model._optimizer._accumulators) > 0
converted_state = dict(optim)
converted_state = dict(optim_state)
opt_unq_name = self.model._optimizer._name
opt_cls_name = self.model._optimizer.__class__.__name__
opt_name = opt_unq_name[:opt_unq_name.rfind("_")] # remove suffix idx
param_names = [param.name for param in self.model.parameters()]
for var_name, state_var in sorted(
optim.items(), key=lambda x: len(x[0]), reverse=True):
optim_state.items(), key=lambda x: len(x[0]), reverse=True):
if var_name in ["@LR_DECAY_COUNTER@", "global_step"]:
# NOTE: dygraph saved global_step is 1 larger than that in
# static-graph, since the time of global_step to increase is
......@@ -597,7 +586,6 @@ class Model(fluid.dygraph.Layer):
self._optimizer = None
self._device = None
self._device_ids = None
self._optimizer = None
if in_dygraph_mode():
self._adapter = DynamicGraphAdapter(self)
else:
......@@ -615,8 +603,71 @@ class Model(fluid.dygraph.Layer):
def save(self, *args, **kwargs):
return self._adapter.save(*args, **kwargs)
def load(self, *args, **kwargs):
return self._adapter.load(*args, **kwargs)
def load(self, path, skip_mismatch=False, reset_optimizer=False):
"""
Load from files storing the model states and optimizer states. The file
for optimizer states is not necessary if no need to restore the optimizer.
NOTE: parameters are retrieved out from the file storing model states
accoring to their structured names.
For fine-tuning or transfer-learning models where some of the layers have
changed, keep parameters needed to restore have same structured names in
the pre-trained model and fine-tuning model.
Args:
path (str): The prefix of files storing the model states and
optimizer states. The files would be `path.pdparams` and
`path.pdopt` separately, and the latter is not necessary
when no need to restore.
skip_mismatch (bool): Whether to skip the loading of mismatch
parameter or raise an error when mismatch happens (not found
the parameter in file storing model states of or receives a
mismatch shape).
reset_optimizer (bool): If True, ignore the providing file storing
optimizer states and initialize optimizer states from scratch.
Otherwise, restore optimizer states from `path.pdopt` if
a optimizer has been set to the model. Default False.
"""
def _load_state_from_path(path):
if not os.path.exists(path):
return
with open(path, 'rb') as f:
return pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
def _check_match(key, param):
state = param_state.get(key, None)
if state is None:
raise ValueError(
"{} is not found in the providing file.".format(key))
if list(state.shape) != list(param.shape):
raise ValueError(
"{} receives a shape {}, but the expected shape is {}.".
format(key, list(state.shape), list(param.shape)))
return param, state
param_state = _load_state_from_path(path + ".pdparams")
assert param_state, "Failed to load parameters, please check path."
matched_param_state = []
for key, param in self.state_dict().items():
try:
match_res = _check_match(key, param)
except ValueError as err:
if skip_mismatch:
warnings.warn(
("Skip loading for {}. ".format(key) + err.message))
# reset optimizer when mismatch happens
reset_optimizer = True
else:
raise err
matched_param_state.append(match_res)
optim_state = None if reset_optimizer else _load_state_from_path(
path + ".pdopt")
return self._adapter.load(matched_param_state, optim_state)
def parameters(self, *args, **kwargs):
return self._adapter.parameters(*args, **kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册