From 002681942fec43b24e49bde71dd82954666f4e02 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Fri, 30 Apr 2021 18:04:31 +0800 Subject: [PATCH] [Dy2stat] Fix to_tensor Bug Reported from QA (#32701) Dy2stat failed when user writes return paddle.to_tensor(xxx), the reason is that visit_Expr doesn't work when the Expr is in return. Some other statements may trigger same bug. To fix it, we re-wrote a transformer to transform paddle.to_tensor to paddle.assign for all Call nodes. --- .../basic_api_transformer.py | 33 +++++++++++++++---- .../test_basic_api_transformation.py | 6 ++-- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py index 198c2920eec..5ea1fdfac09 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py @@ -33,10 +33,11 @@ class BasicApiTransformer(gast.NodeTransformer): self.root = wrapper_root.node self.class_node_dict = {} - self.name_to_tensor_shape = {} - def transform(self): + to_tensor_transformer = ToTensorTransformer(self.root) + to_tensor_transformer.transform() self.visit(self.root) + return self.wrapper_root def visit_Assign(self, node): @@ -62,11 +63,6 @@ class BasicApiTransformer(gast.NodeTransformer): def _visit_Call(self, node): assert isinstance(node, gast.Call) - # Replace API `to_variable` with `fluid.layers.assign` - if is_to_variable(node): - node = to_assign_node(node) - return node - func_name = astor.to_source(gast.gast_to_ast(node.func)) if self._is_dygraph_forward(func_name): @@ -102,6 +98,29 @@ class BasicApiTransformer(gast.NodeTransformer): return False +class ToTensorTransformer(gast.NodeTransformer): + """ + Class to transform paddle.to_tensor and paddle.to_variable to paddle.assign + """ + + def __init__(self, node): + assert isinstance( + node, gast.AST + ), "Input non-gast.AST node for the initialization of ToTensorTransformer." + self.root = node + + def transform(self): + self.visit(self.root) + return self.root + + def visit_Call(self, node): + assert isinstance(node, gast.Call) + if is_to_variable(node): + node = to_assign_node(node) + self.generic_visit(node) + return node + + def is_to_variable(node): assert isinstance(node, gast.Call) api_name = utils.ast_to_source_code(node.func).strip() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py index 630b804f9a2..ea745ad6614 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py @@ -64,13 +64,11 @@ def dyfunc_int_to_tensor(x): def dyfunc_float_to_tensor(x): - res = paddle.to_tensor(2.0) - return res + return paddle.to_tensor(2.0) def dyfunc_bool_to_tensor(x): - res = paddle.to_tensor(True) - return res + return paddle.to_tensor(True) class TestDygraphBasicApi_ToVariable(unittest.TestCase): -- GitLab