提交 f5cbe6d0 编写于 作者: G guosheng

Update model.load to use skip_mismatch.

上级 a7d677e5
...@@ -17,8 +17,9 @@ from __future__ import absolute_import ...@@ -17,8 +17,9 @@ from __future__ import absolute_import
import inspect import inspect
import os import os
import pickle import pickle
import six
import numpy as np import numpy as np
import six
import warnings
from collections import Iterable from collections import Iterable
from collections import OrderedDict from collections import OrderedDict
...@@ -192,46 +193,23 @@ class StaticGraphAdapter(object): ...@@ -192,46 +193,23 @@ class StaticGraphAdapter(object):
_save(optim, optim_path) _save(optim, optim_path)
def load(self, path, reset_optimizer=False, parameters=[]): def load(self, param_state_pairs, optim_state):
def _load(path):
if not os.path.exists(path):
return
with open(path, 'rb') as f:
return pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
param_path = path + ".pdparams"
param_state = _load(param_path)
assert param_state, "failed to load parameters, please check path"
if self._executor is None: if self._executor is None:
executor = fluid.Executor(fluid.CPUPlace())._default_executor executor = fluid.Executor(fluid.CPUPlace())._default_executor
else: else:
executor = self._executor._default_executor executor = self._executor._default_executor
param_names = [param.name for param in parameters] # restore parameter states
fluid.core._create_loaded_parameter( fluid.core._create_loaded_parameter(
list(parameters), global_scope(), executor) [param for param, state in param_state_pairs],
global_scope(), executor)
for key, var in self.model.state_dict().items(): for param, state in param_state_pairs:
if not param_names or var.name in param_names: self._set_var(param, state)
assert key in param_state, \
"parameter [{}] is not found in model file [{}]".format(
key, param_path)
self._set_var(var, param_state[key])
if reset_optimizer or parameters:
return
# restore optimizer states
# FIXME what if a different optimizer is used? # FIXME what if a different optimizer is used?
if not self.model._optimizer: if not self.model._optimizer or not optim_state:
return
optim_path = path + ".pdopt"
optim_state = _load(optim_path)
if optim_state is None:
return return
self._load_optimizer(optim_state, executor) self._load_optimizer(optim_state, executor)
def _load_optimizer(self, state, executor): def _load_optimizer(self, state, executor):
...@@ -539,22 +517,13 @@ class DynamicGraphAdapter(object): ...@@ -539,22 +517,13 @@ class DynamicGraphAdapter(object):
optim = self.model._optimizer.state_dict() optim = self.model._optimizer.state_dict()
fluid.save_dygraph(optim, path) fluid.save_dygraph(optim, path)
def load(self, path, reset_optimizer=False, parameters=[]): def load(self, param_state_pairs, optim_state):
param_state, optim_state = fluid.load_dygraph(path) # restore parameter states
for param, state in param_state_pairs:
param_names = [param.name for param in parameters] param.set_value(state)
for key, var in self.model.state_dict().items():
if not param_names or var.name in param_names:
assert key in param_state, \
"parameter [{}] is not found in model file [{}]".format(
key, path + ".pdparams")
var.set_value(param_state[key])
if reset_optimizer or parameters:
return
if self.model._optimizer is None or optim_state is None: # resotre optimizer states
if not self.model._optimizer or not optim_state:
return return
# If optimizer performs set_dict when state vars haven't been created, # If optimizer performs set_dict when state vars haven't been created,
...@@ -617,7 +586,6 @@ class Model(fluid.dygraph.Layer): ...@@ -617,7 +586,6 @@ class Model(fluid.dygraph.Layer):
self._optimizer = None self._optimizer = None
self._device = None self._device = None
self._device_ids = None self._device_ids = None
self._optimizer = None
if in_dygraph_mode(): if in_dygraph_mode():
self._adapter = DynamicGraphAdapter(self) self._adapter = DynamicGraphAdapter(self)
else: else:
...@@ -635,79 +603,71 @@ class Model(fluid.dygraph.Layer): ...@@ -635,79 +603,71 @@ class Model(fluid.dygraph.Layer):
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
return self._adapter.save(*args, **kwargs) return self._adapter.save(*args, **kwargs)
def load(self, path, reset_optimizer=False, layers=None, weights=None): def load(self, path, skip_mismatch=False, reset_optimizer=False):
""" """
Load from files storing the model states and optimizer states. The file Load from files storing the model states and optimizer states. The file
for optimizer states is not necessary if no need to restore the optimizer. for optimizer states is not necessary if no need to restore the optimizer.
`layers` and `weights` are useful for fine-tuning or transfer-learning NOTE: parameters are retrieved out from the file storing model states
models where some of the layers have changed. If provided, only accoring to their structured names.
parameters included in layers and weights would be loaded, and optimizer
would be reset. If both are None, make no effect and load all parameters. For fine-tuning or transfer-learning models where some of the layers have
NOTE: parameters are restored based on names, which are decided by the changed, keep parameters needed to restore have same structured names in
network's topology if not given by `param_attr` explicitly. This means the pre-trained model and fine-tuning model.
the architecture should be the same as when the weights were saved.
Layers that don't have parameters are not taken into account in the
topological ordering, thus could be added or removed casually.
Args: Args:
path (str): The prefix of files storing the model states and path (str): The prefix of files storing the model states and
optimizer states. The files would be `path.pdparams` and optimizer states. The files would be `path.pdparams` and
`path.pdopt` separately, and the latter is not necessary `path.pdopt` separately, and the latter is not necessary
when no need to restore. when no need to restore.
skip_mismatch (bool): Whether to skip the loading of mismatch
parameter or raise an error when mismatch happens (not found
the parameter in file storing model states of or receives a
mismatch shape).
reset_optimizer (bool): If True, ignore the providing file storing reset_optimizer (bool): If True, ignore the providing file storing
optimizer states and initialize optimizer states from scratch. optimizer states and initialize optimizer states from scratch.
Otherwise, restore optimizer states from `path.pdopt` if Otherwise, restore optimizer states from `path.pdopt` if
a optimizer has been set to the model. Default False. a optimizer has been set to the model. Default False.
layers (list|Layer|str|None): The layers to be restored. All
parameters in these layers would be loaded. `layers` is
composed of instances of Layer or string. A string corresponded
layer is the one whose `full_name()` equals to the string.
If None, make no effect to load. Default None.
weights (list|Parameter|str|None): The parameters to be loaded.
`weights` is composed of instances of Parameter or string.
A string corresponded parameter is the one whose name equals to
the string. If None, make no effect to load. Default None.
""" """
load_param_vars = set()
if layers is not None: def _load_state_from_path(path):
model_layers = self.sublayers() if not os.path.exists(path):
model_layers_dict = dict((layer.full_name(), layer) return
for layer in model_layers) with open(path, 'rb') as f:
for i, layer in enumerate(to_list(layers)): return pickle.load(f) if six.PY2 else pickle.load(
if isinstance(layer, fluid.dygraph.Layer): f, encoding='latin1')
assert layer in model_layers, (
"The #%d layer in layers is not in model." % i) def _check_match(key, param):
load_param_vars.update(layer.state_dict().values()) state = param_state.get(key, None)
elif isinstance(layer, six.string_types): if state is None:
assert layer in model_layers_dict, ( raise ValueError(
"The #%d layer in layers is not in model." % i) "{} is not found in the providing file.".format(key))
load_param_vars.update(model_layers_dict[layer].state_dict( if list(state.shape) != list(param.shape):
).values()) raise ValueError(
else: "{} receives a shape {}, but the expected shape is {}.".
raise TypeError( format(key, list(state.shape), list(param.shape)))
"The value in layers should be string or Layer.") return param, state
if weights is not None:
model_weights = self.parameters() param_state = _load_state_from_path(path + ".pdparams")
model_weights_dict = dict((weight.name, weight) assert param_state, "Failed to load parameters, please check path."
for weight in model_weights)
param_type = fluid.framework.ParamBase if in_dygraph_mode( matched_param_state = []
) else fluid.framework.Parameter for key, param in self.state_dict().items():
for i, weight in enumerate(to_list(weights)): try:
if isinstance(weight, param_type): match_res = _check_match(key, param)
# var== has been overwrited, thus do not use `weight in` except ValueError as err:
assert weight.name in model_weights_dict, ( if skip_mismatch:
"The #%d weight in weights is not in model." % i) warnings.warn(
load_param_vars.add(weight) ("Skip loading for {}. ".format(key) + err.message))
elif isinstance(weight, six.string_types): # reset optimizer when mismatch happens
assert weight in model_weights_dict, ( reset_optimizer = True
"The #%d weight in weights is not in model." % i)
load_param_vars.add(model_weights_dict[weight])
else: else:
raise TypeError( raise err
"The value in weights should be string or %s." % matched_param_state.append(match_res)
param_type.__name__)
return self._adapter.load(path, reset_optimizer, list(load_param_vars)) optim_state = None if reset_optimizer else _load_state_from_path(
path + ".pdopt")
return self._adapter.load(matched_param_state, optim_state)
def parameters(self, *args, **kwargs): def parameters(self, *args, **kwargs):
return self._adapter.parameters(*args, **kwargs) return self._adapter.parameters(*args, **kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册