未验证 提交 135b62a4 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat] Refine code of DygraphToStaticAst (#28103)

* refine code of DygraphToStaticAst

* add __init__ function
上级 6dd64b0a
...@@ -47,6 +47,9 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -47,6 +47,9 @@ class DygraphToStaticAst(gast.NodeTransformer):
Main class to transform Dygraph to Static Graph Main class to transform Dygraph to Static Graph
""" """
def __init__(self):
self.translator_logger = logging_utils.TranslatorLogger()
def get_static_ast(self, root): def get_static_ast(self, root):
# save root for some analysis may need global AST # save root for some analysis may need global AST
self.root = root self.root = root
...@@ -57,71 +60,37 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -57,71 +60,37 @@ class DygraphToStaticAst(gast.NodeTransformer):
self.transfer_from_node_type(self.static_analysis_root) self.transfer_from_node_type(self.static_analysis_root)
return self.static_analysis_root return self.static_analysis_root
def _apply(self, transformer, node_wrapper, log_level):
transformer(node_wrapper).transform()
self.translator_logger.log_transformed_code(log_level, self.root,
transformer.__name__)
def transfer_from_node_type(self, node_wrapper): def transfer_from_node_type(self, node_wrapper):
translator_logger = logging_utils.TranslatorLogger() self.translator_logger.log(
translator_logger.log(
1, "Source code: \n{}".format(ast_to_source_code(self.root))) 1, "Source code: \n{}".format(ast_to_source_code(self.root)))
# Generic transformation # Generic transformation
self.visit(node_wrapper.node) self.visit(node_wrapper.node)
# Transform basic api of dygraph to static graph and get feed_name_to_arg_name transformers = [
BasicApiTransformer(node_wrapper).transform() BasicApiTransformer, # Basic Api
translator_logger.log_transformed_code(1, self.root, TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor)
"BasicApiTransformer") ListTransformer, # List used in control flow
BreakContinueTransformer, # break/continue in loops
# Transform Tensor.shape into fluid.layers.shape(Tensor) ReturnTransformer, # return in functions
TensorShapeTransformer(node_wrapper).transform() LogicalTransformer, # logical and/or/not
translator_logger.log_transformed_code(2, self.root, LoopTransformer, # for/while -> while_op
"TensorShapeTransformer") IfElseTransformer, # if/else -> cond_op
AssertTransformer, # assert statement
# Transform list used in control flow PrintTransformer, # print statement
ListTransformer(node_wrapper).transform() CallTransformer, # transform call recursively
translator_logger.log_transformed_code(3, self.root, "ListTransformer") CastTransformer, # type casting statement
]
# Transform break/continue in loops
BreakContinueTransformer(node_wrapper).transform() for index, transformer in enumerate(transformers):
translator_logger.log_transformed_code(4, self.root, self._apply(transformer, node_wrapper, log_level=index + 1)
"BreakContinueTransformer")
self.translator_logger.log_transformed_code(
# Transform return in functions logging_utils.LOG_AllTransformer, self.root, "All Transformers")
ReturnTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(5, self.root,
"ReturnTransformer")
# Transform logical and/or/not
LogicalTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(6, self.root,
"LogicalTransformer")
# Transform for loop and while loop
LoopTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(7, self.root, "LoopTransformer")
# Transform all if/else statement of Dygraph into Static Graph.
IfElseTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(8, self.root,
"IfElseTransformer")
# Transform python assert statement
AssertTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(9, self.root,
"AssertTransformer")
# Transform all python print statement
PrintTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(10, self.root,
"PrintTransformer")
# Transform call recursively
CallTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(11, self.root, "CallTransformer")
# Transform python type casting statement
CastTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(12, self.root, "CastTransformer")
translator_logger.log_transformed_code(logging_utils.LOG_AllTransformer,
self.root, "All Transformers")
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
if self.decorate_func_name is None: if self.decorate_func_name is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册