提交 72081907 编写于 作者: X xuwei06

More informative comment and error message for fetch_var()

上级 6695a204
......@@ -152,8 +152,10 @@ def fetch_var(name, scope=None, return_numpy=True):
"""
Fetch the value of the variable with the given name from the given scope
Args:
name(str): name of the variable
scope(core.Scope|None): scope object.
name(str): name of the variable. Typically, only persistable variables
can be found in the scope used for running the program.
scope(core.Scope|None): scope object. It should be the scope where
you pass to Executor.run() when running your program.
If None, global_scope() will be used.
return_numpy(bool): whether convert the tensor to numpy.ndarray
Returns:
......@@ -165,7 +167,10 @@ def fetch_var(name, scope=None, return_numpy=True):
assert isinstance(scope, core.Scope)
var = global_scope().find_var(name)
assert var is not None, "Cannot find '%s' in scope." % name
assert var is not None, (
"Cannot find " + name + " in scope. Perhaps you need to make the"
" variable persistable by using var.persistable = True in your"
" program.")
tensor = var.get_tensor()
if return_numpy:
tensor = as_numpy(tensor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册