提交 63ba03c9 编写于 作者: W WuHaobo

polish save_load

上级 83decd42
...@@ -46,7 +46,7 @@ def _mkdir_if_not_exist(path): ...@@ -46,7 +46,7 @@ def _mkdir_if_not_exist(path):
def _load_state(path): def _load_state(path):
print("path: ", path) logger.info("path: {}".format(path))
if os.path.exists(path + '.pdopt'): if os.path.exists(path + '.pdopt'):
# XXX another hack to ignore the optimizer state # XXX another hack to ignore the optimizer state
tmp = tempfile.mkdtemp() tmp = tempfile.mkdtemp()
...@@ -55,12 +55,12 @@ def _load_state(path): ...@@ -55,12 +55,12 @@ def _load_state(path):
state = fluid.io.load_program_state(dst) state = fluid.io.load_program_state(dst)
shutil.rmtree(tmp) shutil.rmtree(tmp)
else: else:
print("path: ", path) logger.info("path: {}".format(path))
state = fluid.io.load_program_state(path) state = fluid.io.load_program_state(path)
return state return state
def load_params(exe, prog, path, ignore_params=[]): def load_params(exe, prog, path, ignore_params=None):
""" """
Load model from the given path. Load model from the given path.
Args: Args:
...@@ -103,6 +103,7 @@ def load_params(exe, prog, path, ignore_params=[]): ...@@ -103,6 +103,7 @@ def load_params(exe, prog, path, ignore_params=[]):
if k in state: if k in state:
logger.warning('variable {} not used'.format(k)) logger.warning('variable {} not used'.format(k))
del state[k] del state[k]
fluid.io.set_program_state(prog, state) fluid.io.set_program_state(prog, state)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册