提交 4e9c403e 编写于 作者: Y Yang Zhang

Fix weight loading

上级 50d2a350
......@@ -114,7 +114,6 @@ class StaticGraphAdapter(object):
# so we need to keep track of the parameters already created
self._startup_prog = fluid.default_startup_program()
self._orig_prog = fluid.default_main_program()
self._lazy_load_path = None
self._label_vars = {} # label variables
self._endpoints = {}
......@@ -183,12 +182,6 @@ class StaticGraphAdapter(object):
_save(optim, optim_path)
def load(self, path):
if self._executor is None:
self._lazy_load_path = path
else:
self._do_load(path)
def _do_load(self, path):
def _load(path):
if not os.path.exists(path):
return
......@@ -210,13 +203,22 @@ class StaticGraphAdapter(object):
t.set(ndarray, place)
param_path = path + ".pdparams"
params = _load(param_path)
assert params, "failed to load parameters, please check path"
param_state = _load(param_path)
assert param_state, "failed to load parameters, please check path"
if self._executor is None:
executor = fluid.Executor(fluid.CPUPlace())._default_executor
else:
executor = self._executor._default_executor
fluid.core._create_loaded_parameter(
list(self.model.state_dict().values()), global_scope(), executor)
for key, var in self.model.state_dict().items():
assert key in params, \
assert key in param_state, \
"parameter [{}] is not found in model file [{}]".format(
key, param_path)
set_var(var, params[key])
set_var(var, param_state[key])
# FIXME what if a different optimizer is used?
if not self.model._optimizer:
......@@ -234,7 +236,7 @@ class StaticGraphAdapter(object):
"optimizer saved in dygraph mode is not usable in static graph"
fluid.core._create_loaded_parameter(
optim, global_scope(), self._executor._default_executor)
optim, global_scope(), executor)
for var in optim:
assert var.name in optim_state, \
......@@ -334,10 +336,16 @@ class StaticGraphAdapter(object):
# therefore startup program only needs to run once
if self._executor is None:
self._executor = fluid.Executor(places[0])
self._executor.run(self._startup_prog)
if self._lazy_load_path is not None:
self._do_load(self._lazy_load_path)
self._lazy_load_path = None
# XXX incremental initialization
uninitialized = []
for var_py in self._startup_prog.list_vars():
var = fluid.global_scope().find_var(var_py.name)
if var and var.get_tensor()._is_initialized():
continue
uninitialized.append(var_py)
if uninitialized:
startup_prog = self._startup_prog._prune(uninitialized)
self._executor.run(startup_prog)
if len(device_ids) < 2:
return prog
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册