未验证 提交 d7e074c2 编写于 作者: W whs 提交者: GitHub

Fix function of graph wrapper. (#35)

上级 06eef983
...@@ -96,6 +96,7 @@ class VarWrapper(object): ...@@ -96,6 +96,7 @@ class VarWrapper(object):
def is_parameter(self): def is_parameter(self):
return isinstance(self._var, Parameter) return isinstance(self._var, Parameter)
class OpWrapper(object): class OpWrapper(object):
def __init__(self, op, graph): def __init__(self, op, graph):
assert isinstance(graph, GraphWrapper) assert isinstance(graph, GraphWrapper)
...@@ -270,7 +271,10 @@ class GraphWrapper(object): ...@@ -270,7 +271,10 @@ class GraphWrapper(object):
""" """
Get the variable by variable name. Get the variable by variable name.
""" """
return VarWrapper(self.program.global_block().var(name), self) for block in self.program.blocks:
if block.has_var(name):
return VarWrapper(block.var(name), self)
return None
def clone(self, for_test=False): def clone(self, for_test=False):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册