未验证 提交 1a2d3b5f 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #4 from guoshengCS/fix-save-load

To make save/load compatible between dygraph and static-graph.
...@@ -169,6 +169,9 @@ class StaticGraphAdapter(object): ...@@ -169,6 +169,9 @@ class StaticGraphAdapter(object):
base = os.path.basename(path) base = os.path.basename(path)
assert base != "", "path should be of 'dirname/filename' format" assert base != "", "path should be of 'dirname/filename' format"
dir_name = os.path.dirname(path)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name)
param_path = path + ".pdparams" param_path = path + ".pdparams"
_save(self.model.state_dict(), param_path) _save(self.model.state_dict(), param_path)
prog = self._progs.get('train', None) prog = self._progs.get('train', None)
...@@ -180,9 +183,7 @@ class StaticGraphAdapter(object): ...@@ -180,9 +183,7 @@ class StaticGraphAdapter(object):
is_belong_to_optimizer, prog.list_vars())} is_belong_to_optimizer, prog.list_vars())}
if not optim: if not optim:
return return
# HACK this is contrived, optimizer state is not the same for
# static/dynamic graph mode
optim['__static_graph_only__'] = True
_save(optim, optim_path) _save(optim, optim_path)
def load(self, path): def load(self, path):
...@@ -217,13 +218,11 @@ class StaticGraphAdapter(object): ...@@ -217,13 +218,11 @@ class StaticGraphAdapter(object):
optim_state = _load(optim_path) optim_state = _load(optim_path)
if optim_state is None: if optim_state is None:
return return
assert '__static_graph_only__' in optim_state, \
"optimizer saved in dygraph mode is not usable in static graph"
if self._executor is not None: if self._executor is not None:
self._load_optimizer(optim_state) self._load_optimizer(optim_state)
else: else:
self._lazy_load_optimizer = optim_state self._lazy_load_optimizer = optim_state
def _load_optimizer(self, state): def _load_optimizer(self, state):
prog = self._progs.get('train', None) prog = self._progs.get('train', None)
...@@ -234,10 +233,65 @@ class StaticGraphAdapter(object): ...@@ -234,10 +233,65 @@ class StaticGraphAdapter(object):
fluid.core._create_loaded_parameter( fluid.core._create_loaded_parameter(
optim, global_scope(), self._executor._default_executor) optim, global_scope(), self._executor._default_executor)
converted_state = dict(state)
for var in optim: for var in optim:
assert var.name in state, \ if var.name in ["@LR_DECAY_COUNTER@", "global_step"]:
# When using learning rate scheduler, dygraph would name the
# global step var as "global_step" to save, while static-graph
# would has a state var named as "@LR_DECAY_COUNTER@".
# NOTE: dygraph saved global_step is 1 larger than that in
# static-graph, since the time of global_step to increase is
# different.
state_val = (
np.array(converted_state.pop("global_step")) - 1
) if "global_step" in converted_state else converted_state.pop(
"@LR_DECAY_COUNTER@", None)
if state_val is not None:
converted_state[var.name] = state_val
elif var.name.startswith("learning_rate_"):
# When using static learning rate, static-graph would make it
# a persistable var named 'unique_name.generate("learning_rate")',
# However, dygraph wouldn't save it.
if var.name not in state: continue
else:
# moment and other accumulators
if var.name not in converted_state:
# try to convert from dygraph name
opt_name = self.model._optimizer._name
opt_cls_name = self.model._optimizer.__class__.__name__
opt_unq_name = None
for name in self.model._optimizer._accumulators.keys():
accum_name = name if opt_name is None else name[
len(opt_name) + 1:]
for param_name, state_var in self.model._optimizer._accumulators[
name].items():
if opt_unq_name is None:
# can not infer out the exact unique(opt_name),
# thus try to extract rather than generate
for state_key in sorted(state.keys(),
key=lambda x: len(x),
reverse=True):
prefix = param_name + "_" + (
opt_cls_name if opt_name is None else
opt_name) + "_"
if state_key.startswith(prefix):
prefix_offset = state_key[len(
prefix):].find("_") + len(prefix)
opt_unq_name = state_key[len(
param_name + "_"):prefix_offset]
# TODO: assert
# assert opt_unq_name is None
# gen(param.name + "_" + gen(opt_name) + "_" + accum_name)
# always end with "_0" since the unique optimizer._name
dy_state_name = (param_name + "_" + opt_unq_name +
"_" + accum_name + "_0")
converted_state[
state_var.name] = converted_state.pop(
dy_state_name)
assert var.name in converted_state, \
"variable [{}] is not in optimizer state file".format(var.name) "variable [{}] is not in optimizer state file".format(var.name)
self._set_var(var, state[var.name]) self._set_var(var, converted_state[var.name])
def _set_var(self, var, ndarray): def _set_var(self, var, ndarray):
t = global_scope().find_var(var.name).get_tensor() t = global_scope().find_var(var.name).get_tensor()
...@@ -289,6 +343,17 @@ class StaticGraphAdapter(object): ...@@ -289,6 +343,17 @@ class StaticGraphAdapter(object):
def _make_program(self, inputs): def _make_program(self, inputs):
prog = self._orig_prog.clone() prog = self._orig_prog.clone()
# change inputs to the same var in cloned program
inputs = fluid.layers.utils.map_structure(
lambda var: prog.global_block().var(var.name), inputs)
# NOTE: When defining learning rate scheduling in static-graph, ops to
# increase the global step var and calculate learning rate would be
# prepended into _orig_prog. test program maked by `_orig_prog.clone`
# also would include these ops. Thus must prune these ops in test
# program, otherwise the global step would be changed in test.
if self.mode != 'train':
for op in list(prog.global_block().ops):
prog.global_block()._remove_op(0)
if self.mode == 'train' and self.model._optimizer._learning_rate_map: if self.mode == 'train' and self.model._optimizer._learning_rate_map:
# HACK workaround learning rate map issue # HACK workaround learning rate map issue
lr_var = self.model._optimizer._learning_rate_map[self._orig_prog] lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
...@@ -451,7 +516,51 @@ class DynamicGraphAdapter(object): ...@@ -451,7 +516,51 @@ class DynamicGraphAdapter(object):
self.model.set_dict(params) self.model.set_dict(params)
if self.model._optimizer is None or optim is None: if self.model._optimizer is None or optim is None:
return return
self.model._optimizer.set_dict(optim)
# If optimizer performs set_dict when state vars haven't been created,
# which would happen when set_dict before minimize, the state would be
# stored in optimizer._accumulators_holder and loaded lazily.
# To contrive this when loading from static-graph saved states, extend
# state dict to include keys named accoring to dygraph naming rules.
# TODO: if len(self.model._optimizer._accumulators) > 0
converted_state = dict(optim)
opt_unq_name = self.model._optimizer._name
opt_cls_name = self.model._optimizer.__class__.__name__
opt_name = opt_unq_name[:opt_unq_name.rfind("_")] # remove suffix idx
param_names = [param.name for param in self.model.parameters()]
for var_name, state_var in sorted(optim.items(),
key=lambda x: len(x[0]),
reverse=True):
if var_name in ["@LR_DECAY_COUNTER@", "global_step"]:
# NOTE: dygraph saved global_step is 1 larger than that in
# static-graph, since the time of global_step to increase is
# different.
if var_name == "@LR_DECAY_COUNTER@":
converted_state["global_step"] = np.array(
converted_state.pop("@LR_DECAY_COUNTER@")) + 1
else:
# moment and other accumulators
# extend state dict to include promising dygraph names
for param_name in param_names:
if var_name.startswith(param_name + "_" + opt_name):
# when init optimizer with name
accum_name = var_name[len(param_name + "_" + opt_name +
"_"):]
elif var_name.startswith(param_name +
"_") and opt_name == opt_cls_name:
# when init optimizer without name
accum_name = var_name[len(param_name + "_"):]
else:
continue
# remove suffix idx
accum_name = accum_name[:accum_name.rfind("_")]
# state names always end with "_0" in dygraph because of the
# unique optimizer._name
dy_state_name = (param_name + "_" + opt_unq_name + "_" +
accum_name + "_0")
converted_state[dy_state_name] = state_var
self.model._optimizer.set_dict(converted_state)
class Model(fluid.dygraph.Layer): class Model(fluid.dygraph.Layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册