未验证 提交 a8afb98e 编写于 作者: W whs 提交者: GitHub

Merge pull request #16928 from wanghaoshuang/cp_slim_ck

[cherry-pick] Fix load persistables in graph wrapper.
...@@ -489,8 +489,11 @@ class GraphWrapper(object): ...@@ -489,8 +489,11 @@ class GraphWrapper(object):
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(path, var.name)) return os.path.exists(os.path.join(path, var.name))
io.load_vars( persistables = []
exe.exe, path, vars=self.persistables.values(), predicate=if_exist) for var in self.persistables.values():
if if_exist(var):
persistables.append(var)
io.load_vars(exe.exe, path, vars=persistables, predicate=if_exist)
def update_param_shape(self, scope): def update_param_shape(self, scope):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册