From 135b62a4ecef1ef4ce8e4dc910d2e801af91abe9 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 20 Oct 2020 16:34:54 +0800 Subject: [PATCH] [Dy2stat] Refine code of DygraphToStaticAst (#28103) * refine code of DygraphToStaticAst * add __init__ function --- .../dygraph_to_static/ast_transformer.py | 89 ++++++------------- 1 file changed, 29 insertions(+), 60 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 5050067e48..2c59a66f22 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -47,6 +47,9 @@ class DygraphToStaticAst(gast.NodeTransformer): Main class to transform Dygraph to Static Graph """ + def __init__(self): + self.translator_logger = logging_utils.TranslatorLogger() + def get_static_ast(self, root): # save root for some analysis may need global AST self.root = root @@ -57,71 +60,37 @@ class DygraphToStaticAst(gast.NodeTransformer): self.transfer_from_node_type(self.static_analysis_root) return self.static_analysis_root + def _apply(self, transformer, node_wrapper, log_level): + transformer(node_wrapper).transform() + self.translator_logger.log_transformed_code(log_level, self.root, + transformer.__name__) + def transfer_from_node_type(self, node_wrapper): - translator_logger = logging_utils.TranslatorLogger() - translator_logger.log( + self.translator_logger.log( 1, "Source code: \n{}".format(ast_to_source_code(self.root))) # Generic transformation self.visit(node_wrapper.node) - # Transform basic api of dygraph to static graph and get feed_name_to_arg_name - BasicApiTransformer(node_wrapper).transform() - translator_logger.log_transformed_code(1, self.root, - "BasicApiTransformer") - - # Transform Tensor.shape into fluid.layers.shape(Tensor) - TensorShapeTransformer(node_wrapper).transform() - translator_logger.log_transformed_code(2, self.root, - "TensorShapeTransformer") - - # Transform list used in control flow - ListTransformer(node_wrapper).transform() - translator_logger.log_transformed_code(3, self.root, "ListTransformer") - - # Transform break/continue in loops - BreakContinueTransformer(node_wrapper).transform() - translator_logger.log_transformed_code(4, self.root, - "BreakContinueTransformer") - - # Transform return in functions - ReturnTransformer(node_wrapper).transform() - translator_logger.log_transformed_code(5, self.root, - "ReturnTransformer") - - # Transform logical and/or/not - LogicalTransformer(node_wrapper).transform() - translator_logger.log_transformed_code(6, self.root, - "LogicalTransformer") - - # Transform for loop and while loop - LoopTransformer(node_wrapper).transform() - translator_logger.log_transformed_code(7, self.root, "LoopTransformer") - - # Transform all if/else statement of Dygraph into Static Graph. - IfElseTransformer(node_wrapper).transform() - translator_logger.log_transformed_code(8, self.root, - "IfElseTransformer") - - # Transform python assert statement - AssertTransformer(node_wrapper).transform() - translator_logger.log_transformed_code(9, self.root, - "AssertTransformer") - - # Transform all python print statement - PrintTransformer(node_wrapper).transform() - translator_logger.log_transformed_code(10, self.root, - "PrintTransformer") - - # Transform call recursively - CallTransformer(node_wrapper).transform() - translator_logger.log_transformed_code(11, self.root, "CallTransformer") - - # Transform python type casting statement - CastTransformer(node_wrapper).transform() - translator_logger.log_transformed_code(12, self.root, "CastTransformer") - - translator_logger.log_transformed_code(logging_utils.LOG_AllTransformer, - self.root, "All Transformers") + transformers = [ + BasicApiTransformer, # Basic Api + TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor) + ListTransformer, # List used in control flow + BreakContinueTransformer, # break/continue in loops + ReturnTransformer, # return in functions + LogicalTransformer, # logical and/or/not + LoopTransformer, # for/while -> while_op + IfElseTransformer, # if/else -> cond_op + AssertTransformer, # assert statement + PrintTransformer, # print statement + CallTransformer, # transform call recursively + CastTransformer, # type casting statement + ] + + for index, transformer in enumerate(transformers): + self._apply(transformer, node_wrapper, log_level=index + 1) + + self.translator_logger.log_transformed_code( + logging_utils.LOG_AllTransformer, self.root, "All Transformers") def visit_FunctionDef(self, node): if self.decorate_func_name is None: -- GitLab