diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 74f946acedb27f396296417286bc827d2fb5adea..de53a56468485a47e7d764a03b0b4398a4b76f25 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -18,6 +18,7 @@ from __future__ import print_function # It provides a compatibility layer between the AST of various Python versions, # as produced by ast.parse from the standard ast module. # See details in https://github.com/serge-sans-paille/gast/ +import os from paddle.utils import gast from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer @@ -44,6 +45,18 @@ __all__ = ['DygraphToStaticAst'] DECORATOR_NAMES = ['declarative', 'to_static', 'dygraph_to_static_func'] +def apply_optimization(transformers): + """ + Judge wheter to apply optimized transformation, such as BreakTransformOptimizer. + And not all optimized transformations are applied by default. It's controlled by + 'export FLAGS_optim_transformation=1' + """ + flag = str( + os.environ.get('FLAGS_optim_transformation')) in ['1', 'True', 'true'] + if flag: + transformers.insert(3, BreakTransformOptimizer) + + class DygraphToStaticAst(gast.NodeTransformer): """ Main class to transform Dygraph to Static Graph @@ -77,7 +90,6 @@ class DygraphToStaticAst(gast.NodeTransformer): BasicApiTransformer, # Basic Api TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor) ListTransformer, # List used in control flow - BreakTransformOptimizer, # optimize transfromation of break in loops BreakContinueTransformer, # break/continue in loops ReturnTransformer, # return in functions LogicalTransformer, # logical and/or/not @@ -90,6 +102,8 @@ class DygraphToStaticAst(gast.NodeTransformer): GradTransformer, # transform paddle.grad to paddle.gradients ] + apply_optimization(transformers) + for index, transformer in enumerate(transformers): self._apply(transformer, node_wrapper, log_level=index + 1) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index 56e9cabbef485dad17908327801b3e6a79db0ede..78d97a3884aedf79dccaa40099a25c202d38fcd4 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -122,10 +122,12 @@ def for_loop_dyfunc_not_support(max_len): def for_break_single_return(max_len): + x = 0 for i in range(3): if i == 2: break - return i + x += 1 + return x def while_loop_bool_op(x): @@ -324,6 +326,7 @@ class TestTransformWhileLoop(unittest.TestCase): def test_ast_to_func(self): static_numpy = self._run_static() dygraph_numpy = self._run_dygraph() + print(static_numpy, dygraph_numpy) self.assertTrue(np.allclose(dygraph_numpy, static_numpy))