diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py index 198c2920eec7fdc81c6a06b27c9ed64f9754ec75..5ea1fdfac0928ad465fc7e29813fe42182047c6a 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py @@ -33,10 +33,11 @@ class BasicApiTransformer(gast.NodeTransformer): self.root = wrapper_root.node self.class_node_dict = {} - self.name_to_tensor_shape = {} - def transform(self): + to_tensor_transformer = ToTensorTransformer(self.root) + to_tensor_transformer.transform() self.visit(self.root) + return self.wrapper_root def visit_Assign(self, node): @@ -62,11 +63,6 @@ class BasicApiTransformer(gast.NodeTransformer): def _visit_Call(self, node): assert isinstance(node, gast.Call) - # Replace API `to_variable` with `fluid.layers.assign` - if is_to_variable(node): - node = to_assign_node(node) - return node - func_name = astor.to_source(gast.gast_to_ast(node.func)) if self._is_dygraph_forward(func_name): @@ -102,6 +98,29 @@ class BasicApiTransformer(gast.NodeTransformer): return False +class ToTensorTransformer(gast.NodeTransformer): + """ + Class to transform paddle.to_tensor and paddle.to_variable to paddle.assign + """ + + def __init__(self, node): + assert isinstance( + node, gast.AST + ), "Input non-gast.AST node for the initialization of ToTensorTransformer." + self.root = node + + def transform(self): + self.visit(self.root) + return self.root + + def visit_Call(self, node): + assert isinstance(node, gast.Call) + if is_to_variable(node): + node = to_assign_node(node) + self.generic_visit(node) + return node + + def is_to_variable(node): assert isinstance(node, gast.Call) api_name = utils.ast_to_source_code(node.func).strip() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py index 630b804f9a2fbe3326a7b9c7b9757f1cba8c444c..ea745ad661425381811b2405362ce254b0403fe1 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py @@ -64,13 +64,11 @@ def dyfunc_int_to_tensor(x): def dyfunc_float_to_tensor(x): - res = paddle.to_tensor(2.0) - return res + return paddle.to_tensor(2.0) def dyfunc_bool_to_tensor(x): - res = paddle.to_tensor(True) - return res + return paddle.to_tensor(True) class TestDygraphBasicApi_ToVariable(unittest.TestCase):