未验证 提交 00ce09e6 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Add apply_optimization for @to_static and remove...

[Dy2Stat]Add apply_optimization for @to_static and remove BreakTransformOptimizer by default (#43320)

* [Dy2Stat]Add apply_optimization for @to_static and remove BreakTransformOptimizer by default

* fix unittest
上级 04294f80
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
# It provides a compatibility layer between the AST of various Python versions, # It provides a compatibility layer between the AST of various Python versions,
# as produced by ast.parse from the standard ast module. # as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/ # See details in https://github.com/serge-sans-paille/gast/
import os
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer
...@@ -44,6 +45,18 @@ __all__ = ['DygraphToStaticAst'] ...@@ -44,6 +45,18 @@ __all__ = ['DygraphToStaticAst']
DECORATOR_NAMES = ['declarative', 'to_static', 'dygraph_to_static_func'] DECORATOR_NAMES = ['declarative', 'to_static', 'dygraph_to_static_func']
def apply_optimization(transformers):
"""
Judge wheter to apply optimized transformation, such as BreakTransformOptimizer.
And not all optimized transformations are applied by default. It's controlled by
'export FLAGS_optim_transformation=1'
"""
flag = str(
os.environ.get('FLAGS_optim_transformation')) in ['1', 'True', 'true']
if flag:
transformers.insert(3, BreakTransformOptimizer)
class DygraphToStaticAst(gast.NodeTransformer): class DygraphToStaticAst(gast.NodeTransformer):
""" """
Main class to transform Dygraph to Static Graph Main class to transform Dygraph to Static Graph
...@@ -77,7 +90,6 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -77,7 +90,6 @@ class DygraphToStaticAst(gast.NodeTransformer):
BasicApiTransformer, # Basic Api BasicApiTransformer, # Basic Api
TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor) TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor)
ListTransformer, # List used in control flow ListTransformer, # List used in control flow
BreakTransformOptimizer, # optimize transfromation of break in loops
BreakContinueTransformer, # break/continue in loops BreakContinueTransformer, # break/continue in loops
ReturnTransformer, # return in functions ReturnTransformer, # return in functions
LogicalTransformer, # logical and/or/not LogicalTransformer, # logical and/or/not
...@@ -90,6 +102,8 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -90,6 +102,8 @@ class DygraphToStaticAst(gast.NodeTransformer):
GradTransformer, # transform paddle.grad to paddle.gradients GradTransformer, # transform paddle.grad to paddle.gradients
] ]
apply_optimization(transformers)
for index, transformer in enumerate(transformers): for index, transformer in enumerate(transformers):
self._apply(transformer, node_wrapper, log_level=index + 1) self._apply(transformer, node_wrapper, log_level=index + 1)
......
...@@ -122,10 +122,12 @@ def for_loop_dyfunc_not_support(max_len): ...@@ -122,10 +122,12 @@ def for_loop_dyfunc_not_support(max_len):
def for_break_single_return(max_len): def for_break_single_return(max_len):
x = 0
for i in range(3): for i in range(3):
if i == 2: if i == 2:
break break
return i x += 1
return x
def while_loop_bool_op(x): def while_loop_bool_op(x):
...@@ -324,6 +326,7 @@ class TestTransformWhileLoop(unittest.TestCase): ...@@ -324,6 +326,7 @@ class TestTransformWhileLoop(unittest.TestCase):
def test_ast_to_func(self): def test_ast_to_func(self):
static_numpy = self._run_static() static_numpy = self._run_static()
dygraph_numpy = self._run_dygraph() dygraph_numpy = self._run_dygraph()
print(static_numpy, dygraph_numpy)
self.assertTrue(np.allclose(dygraph_numpy, static_numpy)) self.assertTrue(np.allclose(dygraph_numpy, static_numpy))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册