提交 e103f756 编写于 作者: Y Yang Zhang

Lazy load optimizer state in static graph mode

上级 887474e9
......@@ -122,6 +122,8 @@ class StaticGraphAdapter(object):
self._progs = {}
self._compiled_progs = {}
self._lazy_load_optimizer = None
# parse shape hints
self._input_desc = OrderedDict([
(n, None) for n in extract_args(self.model.forward) if n != 'self'
......@@ -188,20 +190,6 @@ class StaticGraphAdapter(object):
with open(path, 'rb') as f:
return pickle.load(f)
def set_var(var, ndarray):
t = global_scope().find_var(var.name).get_tensor()
p = t._place()
if p.is_cpu_place():
place = fluid.CPUPlace()
elif p.is_cuda_pinned_place():
place = fluid.CUDAPinnedPlace()
else:
p = fluid.core.Place()
p.set_place(t._place())
place = fluid.CUDAPlace(p.gpu_device_id())
t.set(ndarray, place)
param_path = path + ".pdparams"
param_state = _load(param_path)
assert param_state, "failed to load parameters, please check path"
......@@ -218,16 +206,11 @@ class StaticGraphAdapter(object):
assert key in param_state, \
"parameter [{}] is not found in model file [{}]".format(
key, param_path)
set_var(var, param_state[key])
self._set_var(var, param_state[key])
# FIXME what if a different optimizer is used?
if not self.model._optimizer:
return
prog = self._progs.get('train', None)
optim = list(filter(is_belong_to_optimizer, prog.list_vars()))
if not optim:
return
optim_path = path + ".pdopt"
optim_state = _load(optim_path)
if optim_state is None:
......@@ -235,14 +218,38 @@ class StaticGraphAdapter(object):
assert '__static_graph_only__' in optim_state, \
"optimizer saved in dygraph mode is not usable in static graph"
if self._executor is not None:
self._load_optimizer(optim_state)
else:
self._lazy_load_optimizer = optim_state
def _load_optimizer(self, state):
prog = self._progs.get('train', None)
optim = list(filter(is_belong_to_optimizer, prog.list_vars()))
if not optim:
return
fluid.core._create_loaded_parameter(
optim, global_scope(), executor)
optim, global_scope(), self._executor._default_executor)
for var in optim:
assert var.name in optim_state, \
"variable [{}] is not found in model file [{}]".format(
var.name, optim_path)
set_var(var, optim_state[var.name])
assert var.name in state, \
"variable [{}] is not in optimizer state file".format(var.name)
self._set_var(var, state[var.name])
def _set_var(self, var, ndarray):
t = global_scope().find_var(var.name).get_tensor()
p = t._place()
if p.is_cpu_place():
place = fluid.CPUPlace()
elif p.is_cuda_pinned_place():
place = fluid.CUDAPinnedPlace()
else:
p = fluid.core.Place()
p.set_place(t._place())
place = fluid.CUDAPlace(p.gpu_device_id())
t.set(ndarray, place)
def _run(self, inputs, labels=None, device='CPU', device_ids=None):
inputs = to_list(inputs)
......@@ -349,6 +356,10 @@ class StaticGraphAdapter(object):
startup_prog = self._startup_prog._prune(uninitialized)
self._executor.run(startup_prog)
if self.mode == 'train' and self._lazy_load_optimizer:
self._load_optimizer(self._lazy_load_optimizer)
self._lazy_load_optimizer = None
compiled_prog = fluid.CompiledProgram(prog)
if len(device_ids) > 1:
loss_name = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册