提交 79a08599 编写于 作者: Y Yang Zhang

Unify save/load somewhat

parameters are bound to model not program, critical for transfer learning to work
parameters should be reusable between static/dynamic graph
optimizer state is not reusable
上级 f55efcfc
......@@ -15,12 +15,16 @@
from __future__ import absolute_import
import inspect
import os
import pickle
from collections import OrderedDict
import numpy as np
from paddle import fluid
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer
from paddle.fluid.dygraph.base import to_variable
__all__ = ['Model', 'shape_hints']
......@@ -110,15 +114,88 @@ class StaticGraphAdapter(object):
return None
def save(self, path):
def _save(state, path):
def to_numpy(var):
if not isinstance(var, Variable):
return var
t = global_scope().find_var(var.name).get_tensor()
return np.array(t)
if not state:
return
state = {k: to_numpy(v) for k, v in state.items()}
with open(path, 'wb') as f:
pickle.dump(state, f)
base = os.path.basename(path)
assert base != "", "path should be of 'dirname/filename' format"
param_path = path + ".pdparams"
_save(self.model.state_dict(), param_path)
prog = self._progs.get('train', None)
if prog is None or self.model._optimizer is None:
print("optimizer not initialized, save parameters only")
prog = self._main_prog
fluid.save(prog, path)
return
# XXX `optimizer.state_dict()` only work in dygraph mode
optim_path = path + ".pdopt"
optim = {p.name: p for p in filter(
is_belong_to_optimizer, prog.list_vars())}
# HACK this is contrived, optimizer state is not the same for
# static/dynamic graph mode
optim['__static_graph_only__'] = True
_save(optim, optim_path)
def load(self, path):
prog = self._main_prog
fluid.load(prog, path, self._executor)
def _load(path):
if not os.path.exists(path):
return
with open(path, 'rb') as f:
return pickle.load(f)
def set_var(var, ndarray):
t = global_scope().find_var(var.name).get_tensor()
p = t._place()
if p.is_cpu_place():
place = fluid.CPUPlace()
elif p.is_cuda_pinned_place():
place = fluid.CUDAPinnedPlace()
else:
p = fluid.core.Place()
p.set_place(t._place())
place = fluid.CUDAPlace(p.gpu_device_id())
t.set(ndarray, place)
param_path = path + ".pdparams"
params = _load(param_path)
assert params, "failed to load parameters, please check path"
for key, var in self.model.state_dict().items():
assert key in params, \
"parameter [{}] is not found in model file [{}]".format(
key, param_path)
set_var(var, params[key])
# FIXME what if a different optimizer is used?
if not self.model._optimizer:
return
prog = self._progs.get('train', None)
optim = list(filter(is_belong_to_optimizer, prog.list_vars()))
if not optim:
return
optim_path = path + ".pdopt"
optim_state = _load(optim_path)
if optim_state is None:
return
assert '__static_graph_only__' in optim_state, \
"optimizer saved in dygraph mode is not usable in static graph"
fluid.core._create_loaded_parameter(
optim, global_scope(), self._executor._default_executor)
for var in optim:
assert var.name in optim_state, \
"variable [{}] is not found in model file [{}]".format(
var.name, optim_path)
set_var(var, optim_state[var.name])
def _run(self, inputs, labels=None, device='CPU', device_ids=None):
inputs = to_list(inputs)
......@@ -293,9 +370,7 @@ class DynamicGraphAdapter(object):
def save(self, path):
params = self.model.state_dict()
fluid.save_dygraph(params, path)
if self.model._optimizer is None:
print("model does not have an optimizer, save parameters only")
return
if self.model._optimizer.state_dict():
optim = self.model._optimizer.state_dict()
......@@ -304,8 +379,7 @@ class DynamicGraphAdapter(object):
def load(self, path):
params, optim = fluid.load_dygraph(path)
self.model.set_dict(params)
if optim is None:
print("optimizer state file not found, load parameters only")
if self.model._optimizer is None or optim is None:
return
self.model._optimizer.set_dict(optim)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册