From 4ff2915d1f0dd2b902443d7828834ef56bdb2b60 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 11 Mar 2020 10:02:54 +0800 Subject: [PATCH] Refine code of IfElseTransformer and rename unittest files (#22930) + Refine code structure and move code related with control flow `if` into `ifelse_transformer.py` + Merge code of `ast_utils.py` into `utils.py` --- .../dygraph_to_static/ast_transformer.py | 90 +-------- .../{ast_utils.py => ifelse_transformer.py} | 184 ++++++++--------- .../dygraph_to_static/loop_transformer.py | 3 +- .../dygraph_to_static/static_analysis.py | 5 +- .../fluid/dygraph/dygraph_to_static/utils.py | 90 ++++++++- .../dygraph_to_static/test_ifelse_basic.py | 190 ++++++++++++++++++ .../fluid/tests/unittests/test_ast_util.py | 170 +--------------- ...raph_to_static_basic_api_transformation.py | 48 ++--- 8 files changed, 407 insertions(+), 373 deletions(-) rename python/paddle/fluid/dygraph/dygraph_to_static/{ast_utils.py => ifelse_transformer.py} (75%) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index c0f7b30ca9a..24381ce245a 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -26,96 +26,20 @@ import astor import gast from paddle.fluid import unique_name +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func +from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable +from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func +from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api, create_api_shape_node from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer -from .ast_utils import is_control_flow_if, create_cond_node, transform_if_else, ast_to_func -from .static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor -from .utils import * +from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor __all__ = ['DygraphToStaticAst', 'convert_to_static'] DECORATOR_NAMES = ['dygraph_to_static_output', 'dygraph_to_static_graph'] -class IfElseTransformer(gast.NodeTransformer): - """ - Transform if/else statement of Dygraph into Static Graph. - """ - - def __init__(self, wrapper_root): - assert isinstance( - wrapper_root, AstNodeWrapper - ), "Type of input node should be AstNodeWrapper, but received %s ." % type( - wrapper_root) - self.root = wrapper_root.node - self.static_analysis_visitor = StaticAnalysisVisitor(self.root) - self.new_func_nodes = {} - - def transform(self): - """ - Main function to transform AST. - """ - self.visit(self.root) - self.after_visit(self.root) - - def visit_If(self, node): - assert isinstance(node, gast.If) - need_transform = is_control_flow_if(node.test, - self.static_analysis_visitor) - self.generic_visit(node) - if need_transform: - pred_node = node.test - true_func_node, false_func_node, return_name_ids = transform_if_else( - node, self.root) - # create layers.cond - new_node = create_cond_node(return_name_ids, pred_node, - true_func_node, false_func_node) - self.new_func_nodes[new_node] = [true_func_node, false_func_node] - return new_node - else: - return node - - def visit_Call(self, node): - # Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]` - # TODO: should be removed. it may be considered as basic api transformation. - if isinstance(node.func, gast.Attribute): - attribute = node.func - if attribute.attr == 'numpy': - node = attribute.value - return node - - def after_visit(self, node): - """ - This function will add some postprocessing operations with node. - It can be used to add the created `true_fn/false_fn` in front of - the node.body before they are called in cond layer. - """ - self._insert_func_nodes(node) - - def _insert_func_nodes(self, parent_node): - """ - Defined `true_func` and `false_func` will be inserted in front of corresponding - `layers.cond` statement instead of inserting them all into body of parent node. - Because private variables of class or other external scope will be modified. - For example, `self.var_dict["key"]`. In this case, nested structure of newly - defined functions is easier to understand. - """ - if not (self.new_func_nodes and hasattr(parent_node, 'body')): - return - idx = len(parent_node.body) - 1 - while idx >= 0: - child_node = parent_node.body[idx] - if child_node in self.new_func_nodes: - parent_node.body[idx:idx] = self.new_func_nodes[child_node] - idx = idx + len(self.new_func_nodes[child_node]) - 1 - del self.new_func_nodes[child_node] - else: - self._insert_func_nodes(child_node) - idx = idx - 1 - - def get_new_func_nodes(self): - return self.new_func_nodes - - class DygraphToStaticAst(gast.NodeTransformer): """ Main class to transform Dygraph to Static Graph diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py similarity index 75% rename from python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py rename to python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index 42321101c30..c3918bf042a 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -14,25 +14,106 @@ from __future__ import print_function -import ast -import astor -import gast -import six import copy -import tempfile -import imp -import os -import atexit from collections import defaultdict +# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST). +# It provides a compatibility layer between the AST of various Python versions, +# as produced by ast.parse from the standard ast module. +# See details in https://github.com/serge-sans-paille/gast/ +import gast from paddle.fluid import unique_name -from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType, StaticAnalysisVisitor + from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api +from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node +from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType TRUE_FUNC_PREFIX = 'true_fn' FALSE_FUNC_PREFIX = 'false_fn' +class IfElseTransformer(gast.NodeTransformer): + """ + Transform if/else statement of Dygraph into Static Graph. + """ + + def __init__(self, wrapper_root): + assert isinstance( + wrapper_root, AstNodeWrapper + ), "Type of input node should be AstNodeWrapper, but received %s ." % type( + wrapper_root) + self.root = wrapper_root.node + self.static_analysis_visitor = StaticAnalysisVisitor(self.root) + self.new_func_nodes = {} + + def transform(self): + """ + Main function to transform AST. + """ + self.visit(self.root) + self.after_visit(self.root) + + def visit_If(self, node): + assert isinstance(node, gast.If) + need_transform = is_control_flow_if(node.test, + self.static_analysis_visitor) + self.generic_visit(node) + if need_transform: + pred_node = node.test + true_func_node, false_func_node, return_name_ids = transform_if_else( + node, self.root) + # create layers.cond + new_node = create_cond_node(return_name_ids, pred_node, + true_func_node, false_func_node) + self.new_func_nodes[new_node] = [true_func_node, false_func_node] + return new_node + else: + return node + + def visit_Call(self, node): + # Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]` + # TODO: should be removed. it may be considered as basic api transformation. + if isinstance(node.func, gast.Attribute): + attribute = node.func + if attribute.attr == 'numpy': + node = attribute.value + return node + + def after_visit(self, node): + """ + This function will add some postprocessing operations with node. + It can be used to add the created `true_fn/false_fn` in front of + the node.body before they are called in cond layer. + """ + self._insert_func_nodes(node) + + def _insert_func_nodes(self, parent_node): + """ + Defined `true_func` and `false_func` will be inserted in front of corresponding + `layers.cond` statement instead of inserting them all into body of parent node. + Because private variables of class or other external scope will be modified. + For example, `self.var_dict["key"]`. In this case, nested structure of newly + defined functions is easier to understand. + """ + if not (self.new_func_nodes and hasattr(parent_node, 'body')): + return + idx = len(parent_node.body) - 1 + while idx >= 0: + child_node = parent_node.body[idx] + if child_node in self.new_func_nodes: + parent_node.body[idx:idx] = self.new_func_nodes[child_node] + idx = idx + len(self.new_func_nodes[child_node]) - 1 + del self.new_func_nodes[child_node] + else: + self._insert_func_nodes(child_node) + idx = idx - 1 + + def get_new_func_nodes(self): + return self.new_func_nodes + + class IsControlFlowIfVisitor(gast.NodeTransformer): """ Judge whether the node.test from Dygraph code dependent on paddle Tensor. @@ -112,7 +193,8 @@ class IsControlFlowIfVisitor(gast.NodeTransformer): if isinstance(node, gast.Compare): for child in [node.left, node.comparators]: # node.comparators is a list. - if isinstance(child, list): child = child[0] + if isinstance(child, list): + child = child[0] if (isinstance(child, gast.Constant) and child.value is None) or ( isinstance(child, gast.Name) and @@ -151,7 +233,8 @@ def get_name_ids(nodes, not_name_set=None, node_black_list=None): name_ids = defaultdict(list) for node in nodes: - if node_black_list and node in node_black_list: break + if node_black_list and node in node_black_list: + break if isinstance(node, gast.AST): # In two case, the ast.Name should be filtered. # 1. Function name like `my_func` of my_func(x) @@ -271,45 +354,6 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict): return return_ids, list(modified_vars - new_vars) -def generate_name_node(name_ids, ctx=gast.Load()): - """ - Generate list or gast.Tuple of ast.Name for Return statement. - """ - if isinstance(name_ids, six.string_types): - name_ids = [name_ids] - if not isinstance(name_ids, (list, tuple, set)): - raise TypeError('name_ids must be list or tuple or set, but received %s' - % type(type(name_ids))) - gast_names = [ - gast.Name( - id=name_id, ctx=ctx, annotation=None, type_comment=None) - for name_id in name_ids - ] - if len(gast_names) == 1: - name_node = gast_names[0] - else: - name_node = gast.Tuple(elts=gast_names, ctx=ctx) - return name_node - - -def create_funcDef_node(nodes, name, input_args, return_name_ids): - """ - Wrapper all statements of nodes into one ast.FunctionDef, which can be - called by ast.Call. - """ - nodes = copy.copy(nodes) - # add return statement - nodes.append(gast.Return(value=generate_name_node(return_name_ids))) - func_def_node = gast.FunctionDef( - name=name, - args=input_args, - body=nodes, - decorator_list=[], - returns=None, - type_comment=None) - return func_def_node - - def transform_if_else(node, root): """ Transform ast.If into control flow statement of Paddle static graph. @@ -384,43 +428,3 @@ def create_cond_node(return_name_ids, pred, true_func, false_func): assign_node = gast.Assign(targets=targets, value=cond_layer) return assign_node - - -def ast_to_func(ast_root, func_name, delete_on_exit=True): - """ - Transform modified AST of decorated function into python callable object. - """ - if not isinstance(ast_root, (gast.AST, ast.AST)): - raise TypeError( - "Type of ast_root should be gast.AST or ast.AST, but received %s." % - type(ast_root)) - if isinstance(ast_root, gast.AST): - ast_root = gast.gast_to_ast(ast_root) - source = astor.to_source(ast_root) - if six.PY2: - source = source.encode('utf-8') - f = tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) - else: - f = tempfile.NamedTemporaryFile( - mode='w', suffix='.py', delete=False, encoding='utf-8') - - # TODO(Aurelius84): more elegant way to transform ast into callable object - import_str = "import paddle\n" \ - "import paddle.fluid as fluid\n" \ - "import paddle.fluid.layers as layers\n" \ - "import numpy as np\n" \ - "import numpy\n" - with f: - module_name = os.path.basename(f.name[:-3]) - f.write(import_str) - f.write(source) - - if delete_on_exit: - atexit.register(lambda: os.remove(f.name)) - module = imp.load_source(module_name, f.name) - if not hasattr(module, func_name): - raise ValueError( - 'Function: %s doesn\'t exist in the Module transformed from AST.' % - func_name) - - return getattr(module, func_name), f.name diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index 2314d3c6c95..10f170aaefc 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -19,8 +19,7 @@ import gast from collections import defaultdict from paddle.fluid import unique_name -from paddle.fluid.dygraph.dygraph_to_static.ast_utils import create_funcDef_node -from paddle.fluid.dygraph.dygraph_to_static.ast_utils import generate_name_node +from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py index 34d05bed32d..a0d8e3a5881 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py @@ -14,10 +14,7 @@ from __future__ import print_function -import astor import gast -import inspect -import six import warnings from .utils import is_paddle_api, is_dygraph_api, is_numpy_api @@ -109,7 +106,7 @@ class AstNodeWrapper(object): class AstVarScope(object): """ AstVarScope is a class holding the map from current scope variable to its - type. + type. """ SCOPE_TYPE_SCRIPT = 0 SCOPE_TYPE_FUNCTION = 1 diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 0a72881c2c4..fba46f16ee0 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -14,9 +14,16 @@ from __future__ import print_function -import inspect -import gast +import ast import astor +import atexit +import copy +import gast +import imp +import inspect +import os +import six +import tempfile dygraph_class_to_static_api = { "CosineDecay": "cosine_decay", @@ -206,3 +213,82 @@ def create_api_shape_node(tensor_shape_node): args=[tensor_shape_node.value], keywords=[]) return api_shape_node + + +def generate_name_node(name_ids, ctx=gast.Load()): + """ + Generate list or gast.Tuple of ast.Name for Return statement. + """ + if isinstance(name_ids, six.string_types): + name_ids = [name_ids] + if not isinstance(name_ids, (list, tuple, set)): + raise TypeError('name_ids must be list or tuple or set, but received %s' + % type(type(name_ids))) + gast_names = [ + gast.Name( + id=name_id, ctx=ctx, annotation=None, type_comment=None) + for name_id in name_ids + ] + if len(gast_names) == 1: + name_node = gast_names[0] + else: + name_node = gast.Tuple(elts=gast_names, ctx=ctx) + return name_node + + +def create_funcDef_node(nodes, name, input_args, return_name_ids): + """ + Wrapper all statements of nodes into one ast.FunctionDef, which can be + called by ast.Call. + """ + nodes = copy.copy(nodes) + # add return statement + nodes.append(gast.Return(value=generate_name_node(return_name_ids))) + func_def_node = gast.FunctionDef( + name=name, + args=input_args, + body=nodes, + decorator_list=[], + returns=None, + type_comment=None) + return func_def_node + + +def ast_to_func(ast_root, func_name, delete_on_exit=True): + """ + Transform modified AST of decorated function into python callable object. + """ + if not isinstance(ast_root, (gast.AST, ast.AST)): + raise TypeError( + "Type of ast_root should be gast.AST or ast.AST, but received %s." % + type(ast_root)) + if isinstance(ast_root, gast.AST): + ast_root = gast.gast_to_ast(ast_root) + source = astor.to_source(ast_root) + if six.PY2: + source = source.encode('utf-8') + f = tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) + else: + f = tempfile.NamedTemporaryFile( + mode='w', suffix='.py', delete=False, encoding='utf-8') + + # TODO(Aurelius84): more elegant way to transform ast into callable object + import_str = "import paddle\n" \ + "import paddle.fluid as fluid\n" \ + "import paddle.fluid.layers as layers\n" \ + "import numpy as np\n" \ + "import numpy\n" + with f: + module_name = os.path.basename(f.name[:-3]) + f.write(import_str) + f.write(source) + + if delete_on_exit: + atexit.register(lambda: os.remove(f.name)) + module = imp.load_source(module_name, f.name) + if not hasattr(module, func_name): + raise ValueError( + 'Function: %s doesn\'t exist in the Module transformed from AST.' % + func_name) + + return getattr(module, func_name), f.name diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py new file mode 100644 index 00000000000..8609af49aa6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py @@ -0,0 +1,190 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import textwrap +import gast +from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import get_name_ids, is_control_flow_if +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor + + +class TestGetNameIds(unittest.TestCase): + """ + Test for parsing the ast.Name list from the ast.Nodes + """ + + def setUp(self): + self.source = """ + def test_fn(x): + return x+1 + """ + self.all_name_ids = {'x': [gast.Param()]} + + def test_get_name_ids(self): + source = textwrap.dedent(self.source) + root = gast.parse(source) + all_name_ids = get_name_ids([root]) + self.assertDictEqual( + self.transfer_dict(self.all_name_ids), + self.transfer_dict(all_name_ids)) + + def transfer_dict(self, name_ids_dict): + new_dict = {} + for name, ctxs in name_ids_dict.items(): + new_dict[name] = [type(ctx) for ctx in ctxs] + return new_dict + + +class TestGetNameIds2(TestGetNameIds): + def setUp(self): + self.source = """ + def test_fn(x, y): + a = 1 + x = y + a + if x > y: + z = x * x + z = z + a + else: + z = y * y + return z + """ + self.all_name_ids = { + 'x': [ + gast.Param(), gast.Store(), gast.Load(), gast.Load(), + gast.Load() + ], + 'a': [gast.Store(), gast.Load(), gast.Load()], + 'y': + [gast.Param(), gast.Load(), gast.Load(), gast.Load(), gast.Load()], + 'z': [gast.Store(), gast.Load(), gast.Store(), gast.Store()] + } + + +class TestGetNameIds3(TestGetNameIds): + def setUp(self): + self.source = """ + def test_fn(x, y): + z = 1 + if x > y: + z = x * x + z = z + y + return z + """ + self.all_name_ids = { + 'x': [gast.Param(), gast.Load(), gast.Load(), gast.Load()], + 'y': [gast.Param(), gast.Load(), gast.Load()], + 'z': [gast.Store(), gast.Store(), gast.Load(), gast.Store()] + } + + +class TestIsControlFlowIf(unittest.TestCase): + def test_expr(self): + # node is not ast.Compare + node = gast.parse("a + b") + self.assertFalse(is_control_flow_if(node.body[0].value)) + + def test_expr2(self): + node = gast.parse("a + x.numpy()[1]") + self.assertFalse(is_control_flow_if(node.body[0].value)) + + def test_is_None(self): + node = gast.parse("x is None") + self.assertFalse(is_control_flow_if(node.body[0].value)) + + def test_is_None2(self): + node = gast.parse("fluid.layers.sum(x) is None") + self.assertFalse(is_control_flow_if(node.body[0].value)) + + def test_is_None3(self): + node = gast.parse("fluid.layers.sum(x).numpy() != None") + self.assertFalse(is_control_flow_if(node.body[0].value)) + + def test_if(self): + node = gast.parse("x.numpy()[1] > 1") + self.assertTrue(is_control_flow_if(node.body[0].value)) + + def test_if_with_and(self): + node = gast.parse("x is not None and 1 < x.numpy()[1]") + self.assertTrue(is_control_flow_if(node.body[0].value)) + + def test_if_with_or(self): + node = gast.parse("1 < fluid.layers.sum(x).numpy()[2] or x+y < 0") + self.assertTrue(is_control_flow_if(node.body[0].value)) + + def test_shape(self): + code = """ + def foo(x): + batch_size = fluid.layers.shape(x) + if batch_size[0] > 16: + x = x + 1 + return x + """ + code = textwrap.dedent(code) + node = gast.parse(code) + visitor = StaticAnalysisVisitor(node) + test_node = node.body[0].body[1].test + self.assertTrue(is_control_flow_if(test_node, visitor)) + + def test_shape_with_andOr(self): + code = """ + def foo(x): + batch_size = fluid.layers.shape(x) + if x is not None and batch_size[0] > 16 or 2 > 1: + x = x + 1 + return x + """ + code = textwrap.dedent(code) + node = gast.parse(code) + visitor = StaticAnalysisVisitor(node) + test_node = node.body[0].body[1].test + self.assertTrue(is_control_flow_if(test_node, visitor)) + + def test_paddle_api(self): + code = """ + def foo(x): + if fluid.layers.shape(x)[0] > 16: + x = x + 1 + return x + """ + code = textwrap.dedent(code) + node = gast.parse(code) + visitor = StaticAnalysisVisitor(node) + test_node = node.body[0].body[0].test + self.assertTrue(is_control_flow_if(test_node, visitor)) + + def test_paddle_api_with_andOr(self): + code = """ + def foo(x): + if 2 > 1 and fluid.layers.shape(x)[0] > 16 or x is not None : + x = x + 1 + return x + """ + code = textwrap.dedent(code) + node = gast.parse(code) + visitor = StaticAnalysisVisitor(node) + test_node = node.body[0].body[0].test + self.assertTrue(is_control_flow_if(test_node, visitor)) + + def test_raise_error(self): + node = "a + b" + with self.assertRaises(Exception) as e: + self.assertRaises(TypeError, is_control_flow_if(node)) + self.assertTrue( + "Type of input node should be gast.AST" in str(e.exception)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_ast_util.py b/python/paddle/fluid/tests/unittests/test_ast_util.py index 2d984a0a521..5c8a77c2779 100644 --- a/python/paddle/fluid/tests/unittests/test_ast_util.py +++ b/python/paddle/fluid/tests/unittests/test_ast_util.py @@ -20,177 +20,11 @@ import gast import inspect import numpy as np import paddle.fluid as fluid -from paddle.fluid.dygraph.dygraph_to_static.ast_utils import get_name_ids, ast_to_func, is_control_flow_if -from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func from test_dygraph_to_static_basic import dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else -class TestGetNameIds(unittest.TestCase): - """ - Test for parsing the ast.Name list from the ast.Nodes - """ - - def setUp(self): - self.source = """ - def test_fn(x): - return x+1 - """ - self.all_name_ids = {'x': [gast.Param()]} - - def test_get_name_ids(self): - source = textwrap.dedent(self.source) - root = gast.parse(source) - all_name_ids = get_name_ids([root]) - self.assertDictEqual( - self.transfer_dict(self.all_name_ids), - self.transfer_dict(all_name_ids)) - - def transfer_dict(self, name_ids_dict): - new_dict = {} - for name, ctxs in name_ids_dict.items(): - new_dict[name] = [type(ctx) for ctx in ctxs] - return new_dict - - -class TestGetNameIds2(TestGetNameIds): - def setUp(self): - self.source = """ - def test_fn(x, y): - a = 1 - x = y + a - if x > y: - z = x * x - z = z + a - else: - z = y * y - return z - """ - self.all_name_ids = { - 'x': [ - gast.Param(), gast.Store(), gast.Load(), gast.Load(), - gast.Load() - ], - 'a': [gast.Store(), gast.Load(), gast.Load()], - 'y': - [gast.Param(), gast.Load(), gast.Load(), gast.Load(), gast.Load()], - 'z': [gast.Store(), gast.Load(), gast.Store(), gast.Store()] - } - - -class TestGetNameIds3(TestGetNameIds): - def setUp(self): - self.source = """ - def test_fn(x, y): - z = 1 - if x > y: - z = x * x - z = z + y - return z - """ - self.all_name_ids = { - 'x': [gast.Param(), gast.Load(), gast.Load(), gast.Load()], - 'y': [gast.Param(), gast.Load(), gast.Load()], - 'z': [gast.Store(), gast.Store(), gast.Load(), gast.Store()] - } - - -class TestIsControlFlowIf(unittest.TestCase): - def test_expr(self): - # node is not ast.Compare - node = gast.parse("a + b") - self.assertFalse(is_control_flow_if(node.body[0].value)) - - def test_expr2(self): - node = gast.parse("a + x.numpy()[1]") - self.assertFalse(is_control_flow_if(node.body[0].value)) - - def test_is_None(self): - node = gast.parse("x is None") - self.assertFalse(is_control_flow_if(node.body[0].value)) - - def test_is_None2(self): - node = gast.parse("fluid.layers.sum(x) is None") - self.assertFalse(is_control_flow_if(node.body[0].value)) - - def test_is_None3(self): - node = gast.parse("fluid.layers.sum(x).numpy() != None") - self.assertFalse(is_control_flow_if(node.body[0].value)) - - def test_if(self): - node = gast.parse("x.numpy()[1] > 1") - self.assertTrue(is_control_flow_if(node.body[0].value)) - - def test_if_with_and(self): - node = gast.parse("x is not None and 1 < x.numpy()[1]") - self.assertTrue(is_control_flow_if(node.body[0].value)) - - def test_if_with_or(self): - node = gast.parse("1 < fluid.layers.sum(x).numpy()[2] or x+y < 0") - self.assertTrue(is_control_flow_if(node.body[0].value)) - - def test_shape(self): - code = """ - def foo(x): - batch_size = fluid.layers.shape(x) - if batch_size[0] > 16: - x = x + 1 - return x - """ - code = textwrap.dedent(code) - node = gast.parse(code) - visitor = StaticAnalysisVisitor(node) - test_node = node.body[0].body[1].test - self.assertTrue(is_control_flow_if(test_node, visitor)) - - def test_shape_with_andOr(self): - code = """ - def foo(x): - batch_size = fluid.layers.shape(x) - if x is not None and batch_size[0] > 16 or 2 > 1: - x = x + 1 - return x - """ - code = textwrap.dedent(code) - node = gast.parse(code) - visitor = StaticAnalysisVisitor(node) - test_node = node.body[0].body[1].test - self.assertTrue(is_control_flow_if(test_node, visitor)) - - def test_paddle_api(self): - code = """ - def foo(x): - if fluid.layers.shape(x)[0] > 16: - x = x + 1 - return x - """ - code = textwrap.dedent(code) - node = gast.parse(code) - visitor = StaticAnalysisVisitor(node) - test_node = node.body[0].body[0].test - self.assertTrue(is_control_flow_if(test_node, visitor)) - - def test_paddle_api_with_andOr(self): - code = """ - def foo(x): - if 2 > 1 and fluid.layers.shape(x)[0] > 16 or x is not None : - x = x + 1 - return x - """ - code = textwrap.dedent(code) - node = gast.parse(code) - visitor = StaticAnalysisVisitor(node) - test_node = node.body[0].body[0].test - self.assertTrue(is_control_flow_if(test_node, visitor)) - - def test_raise_error(self): - node = "a + b" - with self.assertRaises(Exception) as e: - self.assertRaises(TypeError, is_control_flow_if(node)) - self.assertTrue( - "Type of input node should be gast.AST" in str(e.exception)) - - class TestAST2Func(unittest.TestCase): """ TestCase for the transformation from ast.AST into python callable function. @@ -211,7 +45,7 @@ class TestAST2Func(unittest.TestCase): self.assertEqual(func(x, y), self._ast2func(func)(x, y)) def test_ast2func_dygraph(self): - funcs = [dyfunc_with_if_else, dyfunc_with_if_else, nested_if_else] + funcs = [dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else] x_data = np.random.random([10, 16]).astype('float32') for func in funcs: with fluid.dygraph.guard(): diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic_api_transformation.py b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic_api_transformation.py index 38a10ecb5dc..e490513468c 100644 --- a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic_api_transformation.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic_api_transformation.py @@ -79,10 +79,10 @@ def dyfunc_BilinearTensorProduct(layer1, layer2): input1_dim=5, input2_dim=4, output_dim=1000, - param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.99)), - bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.5))) + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.99)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.5))) res = bilinearTensorProduct( fluid.dygraph.base.to_variable(layer1), @@ -95,10 +95,10 @@ def dyfunc_Conv2D(input): num_channels=3, num_filters=2, filter_size=3, - param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.99)), - bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.5)), ) + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.99)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.5)), ) res = conv2d(input) return res @@ -108,10 +108,10 @@ def dyfunc_Conv3D(input): num_channels=3, num_filters=2, filter_size=3, - param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.99)), - bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.5)), ) + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.99)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.5)), ) res = conv3d(input) return res @@ -122,10 +122,10 @@ def dyfunc_Conv2DTranspose(input): num_filters=12, filter_size=12, use_cudnn=False, - param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.99)), - bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.5)), ) + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.99)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.5)), ) ret = conv2dTranspose(input) return ret @@ -136,10 +136,10 @@ def dyfunc_Conv3DTranspose(input): num_filters=12, filter_size=12, use_cudnn=False, - param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.99)), - bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.5)), ) + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.99)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.5)), ) ret = conv3dTranspose(input) return ret @@ -149,10 +149,10 @@ def dyfunc_Linear(input): input_dim=10, output_dim=5, act='relu', - param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.99)), - bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( - value=0.5)), ) + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.99)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.5)), ) res = fc(input) return res -- GitLab