提交 b6de4543 编写于 作者: G guosheng

Support for fine-tuning.

上级 358f7852
......@@ -17,6 +17,7 @@ from __future__ import absolute_import
import inspect
import os
import pickle
import six
from collections import OrderedDict
import numpy as np
......@@ -133,7 +134,7 @@ class StaticGraphAdapter(object):
return self._run(inputs, None)
def parameters(self, *args, **kwargs):
return None
return super(Model, self.model).parameters(*args, **kwargs)
def save(self, path):
def _save(state, path):
......@@ -167,12 +168,13 @@ class StaticGraphAdapter(object):
_save(optim, optim_path)
def load(self, path):
def load(self, path, reset_optimizer=False, parameters=[]):
def _load(path):
if not os.path.exists(path):
return
with open(path, 'rb') as f:
return pickle.load(f)
return pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
param_path = path + ".pdparams"
param_state = _load(param_path)
......@@ -183,14 +185,20 @@ class StaticGraphAdapter(object):
else:
executor = self._executor._default_executor
param_names = [param.name for param in parameters]
fluid.core._create_loaded_parameter(
list(self.model.state_dict().values()), global_scope(), executor)
list(parameters), global_scope(), executor)
for key, var in self.model.state_dict().items():
assert key in param_state, \
"parameter [{}] is not found in model file [{}]".format(
key, param_path)
self._set_var(var, param_state[key])
if not param_names or var.name in param_names:
assert key in param_state, \
"parameter [{}] is not found in model file [{}]".format(
key, param_path)
self._set_var(var, param_state[key])
if reset_optimizer or parameters:
return
# FIXME what if a different optimizer is used?
if not self.model._optimizer:
......@@ -429,7 +437,7 @@ class DynamicGraphAdapter(object):
inputs = to_list(inputs)
if labels is not None:
labels = to_list(labels)
outputs = self.model.forward(*[to_variable(x) for x in inputs])
outputs = self.model.forward(* [to_variable(x) for x in inputs])
losses = self.model._loss_function(outputs, labels)
final_loss = fluid.layers.sum(losses)
final_loss.backward()
......@@ -444,7 +452,7 @@ class DynamicGraphAdapter(object):
inputs = to_list(inputs)
if labels is not None:
labels = to_list(labels)
outputs = self.model.forward(*[to_variable(x) for x in inputs])
outputs = self.model.forward(* [to_variable(x) for x in inputs])
if self.model._loss_function:
losses = self.model._loss_function(outputs, labels)
......@@ -475,10 +483,22 @@ class DynamicGraphAdapter(object):
optim = self.model._optimizer.state_dict()
fluid.save_dygraph(optim, path)
def load(self, path):
params, optim = fluid.load_dygraph(path)
self.model.set_dict(params)
if self.model._optimizer is None or optim is None:
def load(self, path, reset_optimizer=False, parameters=[]):
param_state, optim_state = fluid.load_dygraph(path)
param_names = [param.name for param in parameters]
for key, var in self.model.state_dict().items():
if not param_names or var.name in param_names:
assert key in param_state, \
"parameter [{}] is not found in model file [{}]".format(
key, path + ".pdparams")
var.set_value(param_state[key])
if reset_optimizer or parameters:
return
if self.model._optimizer is None or optim_state is None:
return
# If optimizer performs set_dict when state vars haven't been created,
......@@ -487,13 +507,13 @@ class DynamicGraphAdapter(object):
# To contrive this when loading from static-graph saved states, extend
# state dict to include keys named accoring to dygraph naming rules.
# TODO: if len(self.model._optimizer._accumulators) > 0
converted_state = dict(optim)
converted_state = dict(optim_state)
opt_unq_name = self.model._optimizer._name
opt_cls_name = self.model._optimizer.__class__.__name__
opt_name = opt_unq_name[:opt_unq_name.rfind("_")] # remove suffix idx
param_names = [param.name for param in self.model.parameters()]
for var_name, state_var in sorted(
optim.items(), key=lambda x: len(x[0]), reverse=True):
optim_state.items(), key=lambda x: len(x[0]), reverse=True):
if var_name in ["@LR_DECAY_COUNTER@", "global_step"]:
# NOTE: dygraph saved global_step is 1 larger than that in
# static-graph, since the time of global_step to increase is
......@@ -560,8 +580,79 @@ class Model(fluid.dygraph.Layer):
def save(self, *args, **kwargs):
return self._adapter.save(*args, **kwargs)
def load(self, *args, **kwargs):
return self._adapter.load(*args, **kwargs)
def load(self, path, reset_optimizer=False, layers=None, weights=None):
"""
Load from files storing the model states and optimizer states. The file
for optimizer states is not necessary if no need to restore the optimizer.
`layers` and `weights` are useful for fine-tuning or transfer-learning
models where some of the layers have changed. If provided, only
parameters included in layers and weights would be loaded, and optimizer
would be reset. If both are None, make no effect and load all parameters.
NOTE: parameters are restored based on names, which are decided by the
network's topology if not given by `param_attr` explicitly. This means
the architecture should be the same as when the weights were saved.
Layers that don't have parameters are not taken into account in the
topological ordering, thus could be added or removed casually.
Args:
path (str): The prefix of files storing the model states and
optimizer states. The files would be `path.pdparams` and
`path.pdopt` separately, and the latter is not necessary
when no need to restore.
reset_optimizer (bool): If True, ignore the providing file storing
optimizer states and initialize optimizer states from scratch.
Otherwise, restore optimizer states from `path.pdopt` if
a optimizer has been set to the model. Default False.
layers (list|Layer|str|None): The layers to be restored. All
parameters in these layers would be loaded. `layers` is
composed of instances of Layer or string. A string corresponded
layer is the one whose `full_name()` equals to the string.
If None, make no effect to load. Default None.
weights (list|Parameter|str|None): The parameters to be loaded.
`weights` is composed of instances of Parameter or string.
A string corresponded parameter is the one whose name equals to
the string. If None, make no effect to load. Default None.
"""
load_param_vars = set()
if layers is not None:
model_layers = self.sublayers()
model_layers_dict = dict((layer.full_name(), layer)
for layer in model_layers)
for i, layer in enumerate(to_list(layers)):
if isinstance(layer, fluid.dygraph.Layer):
assert layer in model_layers, (
"The #%d layer in layers is not in model." % i)
load_param_vars.update(layer.state_dict().values())
elif isinstance(layer, six.string_types):
assert layer in model_layers_dict, (
"The #%d layer in layers is not in model." % i)
load_param_vars.update(model_layers_dict[layer].state_dict(
).values())
else:
raise TypeError(
"The value in layers should be string or Layer.")
if weights is not None:
model_weights = self.parameters()
model_weights_dict = dict((weight.name, weight)
for weight in model_weights)
param_type = fluid.framework.ParamBase if in_dygraph_mode(
) else fluid.framework.Parameter
for i, weight in enumerate(to_list(weights)):
if isinstance(weight, param_type):
# var== has been overwrited, thus do not use `weight in`
assert weight.name in model_weights_dict, (
"The #%d weight in weights is not in model." % i)
load_param_vars.add(weight)
elif isinstance(weight, six.string_types):
assert weight in model_weights_dict, (
"The #%d weight in weights is not in model." % i)
load_param_vars.add(model_weights_dict[weight])
else:
raise TypeError(
"The value in weights should be string or %s." %
param_type.__name__)
return self._adapter.load(path, reset_optimizer, list(load_param_vars))
def prepare(self,
optimizer=None,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册