未验证 提交 00ce09e6 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Add apply_optimization for @to_static and remove...

[Dy2Stat]Add apply_optimization for @to_static and remove BreakTransformOptimizer by default (#43320)

* [Dy2Stat]Add apply_optimization for @to_static and remove BreakTransformOptimizer by default

* fix unittest
上级 04294f80
......@@ -18,6 +18,7 @@ from __future__ import print_function
# It provides a compatibility layer between the AST of various Python versions,
# as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/
import os
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer
......@@ -44,6 +45,18 @@ __all__ = ['DygraphToStaticAst']
DECORATOR_NAMES = ['declarative', 'to_static', 'dygraph_to_static_func']
def apply_optimization(transformers):
"""
Judge wheter to apply optimized transformation, such as BreakTransformOptimizer.
And not all optimized transformations are applied by default. It's controlled by
'export FLAGS_optim_transformation=1'
"""
flag = str(
os.environ.get('FLAGS_optim_transformation')) in ['1', 'True', 'true']
if flag:
transformers.insert(3, BreakTransformOptimizer)
class DygraphToStaticAst(gast.NodeTransformer):
"""
Main class to transform Dygraph to Static Graph
......@@ -77,7 +90,6 @@ class DygraphToStaticAst(gast.NodeTransformer):
BasicApiTransformer, # Basic Api
TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor)
ListTransformer, # List used in control flow
BreakTransformOptimizer, # optimize transfromation of break in loops
BreakContinueTransformer, # break/continue in loops
ReturnTransformer, # return in functions
LogicalTransformer, # logical and/or/not
......@@ -90,6 +102,8 @@ class DygraphToStaticAst(gast.NodeTransformer):
GradTransformer, # transform paddle.grad to paddle.gradients
]
apply_optimization(transformers)
for index, transformer in enumerate(transformers):
self._apply(transformer, node_wrapper, log_level=index + 1)
......
......@@ -122,10 +122,12 @@ def for_loop_dyfunc_not_support(max_len):
def for_break_single_return(max_len):
x = 0
for i in range(3):
if i == 2:
break
return i
x += 1
return x
def while_loop_bool_op(x):
......@@ -324,6 +326,7 @@ class TestTransformWhileLoop(unittest.TestCase):
def test_ast_to_func(self):
static_numpy = self._run_static()
dygraph_numpy = self._run_dygraph()
print(static_numpy, dygraph_numpy)
self.assertTrue(np.allclose(dygraph_numpy, static_numpy))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册