提交 c8e49b65 编写于 作者: W wanghaoshuang

Fix load persistables in graph wrapper.

test=develop
上级 975aeee7
...@@ -489,8 +489,11 @@ class GraphWrapper(object): ...@@ -489,8 +489,11 @@ class GraphWrapper(object):
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(path, var.name)) return os.path.exists(os.path.join(path, var.name))
io.load_vars( persistables = []
exe.exe, path, vars=self.persistables.values(), predicate=if_exist) for var in self.persistables.values():
if if_exist(var):
persistables.append(var)
io.load_vars(exe.exe, path, vars=persistables, predicate=if_exist)
def update_param_shape(self, scope): def update_param_shape(self, scope):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册