未验证 提交 00268194 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Fix to_tensor Bug Reported from QA (#32701)

Dy2stat failed when user writes return paddle.to_tensor(xxx), the reason is that visit_Expr doesn't work when the Expr is in return. Some other statements may trigger same bug. To fix it, we re-wrote a transformer to transform paddle.to_tensor to paddle.assign for all Call nodes.
上级 0a0f3244
...@@ -33,10 +33,11 @@ class BasicApiTransformer(gast.NodeTransformer): ...@@ -33,10 +33,11 @@ class BasicApiTransformer(gast.NodeTransformer):
self.root = wrapper_root.node self.root = wrapper_root.node
self.class_node_dict = {} self.class_node_dict = {}
self.name_to_tensor_shape = {}
def transform(self): def transform(self):
to_tensor_transformer = ToTensorTransformer(self.root)
to_tensor_transformer.transform()
self.visit(self.root) self.visit(self.root)
return self.wrapper_root return self.wrapper_root
def visit_Assign(self, node): def visit_Assign(self, node):
...@@ -62,11 +63,6 @@ class BasicApiTransformer(gast.NodeTransformer): ...@@ -62,11 +63,6 @@ class BasicApiTransformer(gast.NodeTransformer):
def _visit_Call(self, node): def _visit_Call(self, node):
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
# Replace API `to_variable` with `fluid.layers.assign`
if is_to_variable(node):
node = to_assign_node(node)
return node
func_name = astor.to_source(gast.gast_to_ast(node.func)) func_name = astor.to_source(gast.gast_to_ast(node.func))
if self._is_dygraph_forward(func_name): if self._is_dygraph_forward(func_name):
...@@ -102,6 +98,29 @@ class BasicApiTransformer(gast.NodeTransformer): ...@@ -102,6 +98,29 @@ class BasicApiTransformer(gast.NodeTransformer):
return False return False
class ToTensorTransformer(gast.NodeTransformer):
"""
Class to transform paddle.to_tensor and paddle.to_variable to paddle.assign
"""
def __init__(self, node):
assert isinstance(
node, gast.AST
), "Input non-gast.AST node for the initialization of ToTensorTransformer."
self.root = node
def transform(self):
self.visit(self.root)
return self.root
def visit_Call(self, node):
assert isinstance(node, gast.Call)
if is_to_variable(node):
node = to_assign_node(node)
self.generic_visit(node)
return node
def is_to_variable(node): def is_to_variable(node):
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
api_name = utils.ast_to_source_code(node.func).strip() api_name = utils.ast_to_source_code(node.func).strip()
......
...@@ -64,13 +64,11 @@ def dyfunc_int_to_tensor(x): ...@@ -64,13 +64,11 @@ def dyfunc_int_to_tensor(x):
def dyfunc_float_to_tensor(x): def dyfunc_float_to_tensor(x):
res = paddle.to_tensor(2.0) return paddle.to_tensor(2.0)
return res
def dyfunc_bool_to_tensor(x): def dyfunc_bool_to_tensor(x):
res = paddle.to_tensor(True) return paddle.to_tensor(True)
return res
class TestDygraphBasicApi_ToVariable(unittest.TestCase): class TestDygraphBasicApi_ToVariable(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册