diff --git a/model.py b/model.py index f4439fc93a1de29bd561095f8da0b0db49e2a33c..16f308251cd712f376f88c55440b4446f25f4e5c 100644 --- a/model.py +++ b/model.py @@ -114,7 +114,6 @@ class StaticGraphAdapter(object): # so we need to keep track of the parameters already created self._startup_prog = fluid.default_startup_program() self._orig_prog = fluid.default_main_program() - self._lazy_load_path = None self._label_vars = {} # label variables self._endpoints = {} @@ -183,12 +182,6 @@ class StaticGraphAdapter(object): _save(optim, optim_path) def load(self, path): - if self._executor is None: - self._lazy_load_path = path - else: - self._do_load(path) - - def _do_load(self, path): def _load(path): if not os.path.exists(path): return @@ -210,13 +203,22 @@ class StaticGraphAdapter(object): t.set(ndarray, place) param_path = path + ".pdparams" - params = _load(param_path) - assert params, "failed to load parameters, please check path" + param_state = _load(param_path) + assert param_state, "failed to load parameters, please check path" + + if self._executor is None: + executor = fluid.Executor(fluid.CPUPlace())._default_executor + else: + executor = self._executor._default_executor + + fluid.core._create_loaded_parameter( + list(self.model.state_dict().values()), global_scope(), executor) + for key, var in self.model.state_dict().items(): - assert key in params, \ + assert key in param_state, \ "parameter [{}] is not found in model file [{}]".format( key, param_path) - set_var(var, params[key]) + set_var(var, param_state[key]) # FIXME what if a different optimizer is used? if not self.model._optimizer: @@ -234,7 +236,7 @@ class StaticGraphAdapter(object): "optimizer saved in dygraph mode is not usable in static graph" fluid.core._create_loaded_parameter( - optim, global_scope(), self._executor._default_executor) + optim, global_scope(), executor) for var in optim: assert var.name in optim_state, \ @@ -334,10 +336,16 @@ class StaticGraphAdapter(object): # therefore startup program only needs to run once if self._executor is None: self._executor = fluid.Executor(places[0]) - self._executor.run(self._startup_prog) - if self._lazy_load_path is not None: - self._do_load(self._lazy_load_path) - self._lazy_load_path = None + # XXX incremental initialization + uninitialized = [] + for var_py in self._startup_prog.list_vars(): + var = fluid.global_scope().find_var(var_py.name) + if var and var.get_tensor()._is_initialized(): + continue + uninitialized.append(var_py) + if uninitialized: + startup_prog = self._startup_prog._prune(uninitialized) + self._executor.run(startup_prog) if len(device_ids) < 2: return prog