未验证 提交 4ff2915d 编写于 作者: A Aurelius84 提交者: GitHub

Refine code of IfElseTransformer and rename unittest files (#22930)

+ Refine code structure and move code related with control flow `if` into `ifelse_transformer.py`
+ Merge code of `ast_utils.py`  into `utils.py` 
上级 3d8571e8
......@@ -26,96 +26,20 @@ import astor
import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api, create_api_shape_node
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from .ast_utils import is_control_flow_if, create_cond_node, transform_if_else, ast_to_func
from .static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from .utils import *
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
__all__ = ['DygraphToStaticAst', 'convert_to_static']
DECORATOR_NAMES = ['dygraph_to_static_output', 'dygraph_to_static_graph']
class IfElseTransformer(gast.NodeTransformer):
"""
Transform if/else statement of Dygraph into Static Graph.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Type of input node should be AstNodeWrapper, but received %s ." % type(
wrapper_root)
self.root = wrapper_root.node
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
self.new_func_nodes = {}
def transform(self):
"""
Main function to transform AST.
"""
self.visit(self.root)
self.after_visit(self.root)
def visit_If(self, node):
assert isinstance(node, gast.If)
need_transform = is_control_flow_if(node.test,
self.static_analysis_visitor)
self.generic_visit(node)
if need_transform:
pred_node = node.test
true_func_node, false_func_node, return_name_ids = transform_if_else(
node, self.root)
# create layers.cond
new_node = create_cond_node(return_name_ids, pred_node,
true_func_node, false_func_node)
self.new_func_nodes[new_node] = [true_func_node, false_func_node]
return new_node
else:
return node
def visit_Call(self, node):
# Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]`
# TODO: should be removed. it may be considered as basic api transformation.
if isinstance(node.func, gast.Attribute):
attribute = node.func
if attribute.attr == 'numpy':
node = attribute.value
return node
def after_visit(self, node):
"""
This function will add some postprocessing operations with node.
It can be used to add the created `true_fn/false_fn` in front of
the node.body before they are called in cond layer.
"""
self._insert_func_nodes(node)
def _insert_func_nodes(self, parent_node):
"""
Defined `true_func` and `false_func` will be inserted in front of corresponding
`layers.cond` statement instead of inserting them all into body of parent node.
Because private variables of class or other external scope will be modified.
For example, `self.var_dict["key"]`. In this case, nested structure of newly
defined functions is easier to understand.
"""
if not (self.new_func_nodes and hasattr(parent_node, 'body')):
return
idx = len(parent_node.body) - 1
while idx >= 0:
child_node = parent_node.body[idx]
if child_node in self.new_func_nodes:
parent_node.body[idx:idx] = self.new_func_nodes[child_node]
idx = idx + len(self.new_func_nodes[child_node]) - 1
del self.new_func_nodes[child_node]
else:
self._insert_func_nodes(child_node)
idx = idx - 1
def get_new_func_nodes(self):
return self.new_func_nodes
class DygraphToStaticAst(gast.NodeTransformer):
"""
Main class to transform Dygraph to Static Graph
......
......@@ -14,25 +14,106 @@
from __future__ import print_function
import ast
import astor
import gast
import six
import copy
import tempfile
import imp
import os
import atexit
from collections import defaultdict
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
# It provides a compatibility layer between the AST of various Python versions,
# as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/
import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType
TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
class IfElseTransformer(gast.NodeTransformer):
"""
Transform if/else statement of Dygraph into Static Graph.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Type of input node should be AstNodeWrapper, but received %s ." % type(
wrapper_root)
self.root = wrapper_root.node
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
self.new_func_nodes = {}
def transform(self):
"""
Main function to transform AST.
"""
self.visit(self.root)
self.after_visit(self.root)
def visit_If(self, node):
assert isinstance(node, gast.If)
need_transform = is_control_flow_if(node.test,
self.static_analysis_visitor)
self.generic_visit(node)
if need_transform:
pred_node = node.test
true_func_node, false_func_node, return_name_ids = transform_if_else(
node, self.root)
# create layers.cond
new_node = create_cond_node(return_name_ids, pred_node,
true_func_node, false_func_node)
self.new_func_nodes[new_node] = [true_func_node, false_func_node]
return new_node
else:
return node
def visit_Call(self, node):
# Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]`
# TODO: should be removed. it may be considered as basic api transformation.
if isinstance(node.func, gast.Attribute):
attribute = node.func
if attribute.attr == 'numpy':
node = attribute.value
return node
def after_visit(self, node):
"""
This function will add some postprocessing operations with node.
It can be used to add the created `true_fn/false_fn` in front of
the node.body before they are called in cond layer.
"""
self._insert_func_nodes(node)
def _insert_func_nodes(self, parent_node):
"""
Defined `true_func` and `false_func` will be inserted in front of corresponding
`layers.cond` statement instead of inserting them all into body of parent node.
Because private variables of class or other external scope will be modified.
For example, `self.var_dict["key"]`. In this case, nested structure of newly
defined functions is easier to understand.
"""
if not (self.new_func_nodes and hasattr(parent_node, 'body')):
return
idx = len(parent_node.body) - 1
while idx >= 0:
child_node = parent_node.body[idx]
if child_node in self.new_func_nodes:
parent_node.body[idx:idx] = self.new_func_nodes[child_node]
idx = idx + len(self.new_func_nodes[child_node]) - 1
del self.new_func_nodes[child_node]
else:
self._insert_func_nodes(child_node)
idx = idx - 1
def get_new_func_nodes(self):
return self.new_func_nodes
class IsControlFlowIfVisitor(gast.NodeTransformer):
"""
Judge whether the node.test from Dygraph code dependent on paddle Tensor.
......@@ -112,7 +193,8 @@ class IsControlFlowIfVisitor(gast.NodeTransformer):
if isinstance(node, gast.Compare):
for child in [node.left, node.comparators]:
# node.comparators is a list.
if isinstance(child, list): child = child[0]
if isinstance(child, list):
child = child[0]
if (isinstance(child, gast.Constant) and
child.value is None) or (
isinstance(child, gast.Name) and
......@@ -151,7 +233,8 @@ def get_name_ids(nodes, not_name_set=None, node_black_list=None):
name_ids = defaultdict(list)
for node in nodes:
if node_black_list and node in node_black_list: break
if node_black_list and node in node_black_list:
break
if isinstance(node, gast.AST):
# In two case, the ast.Name should be filtered.
# 1. Function name like `my_func` of my_func(x)
......@@ -271,45 +354,6 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict):
return return_ids, list(modified_vars - new_vars)
def generate_name_node(name_ids, ctx=gast.Load()):
"""
Generate list or gast.Tuple of ast.Name for Return statement.
"""
if isinstance(name_ids, six.string_types):
name_ids = [name_ids]
if not isinstance(name_ids, (list, tuple, set)):
raise TypeError('name_ids must be list or tuple or set, but received %s'
% type(type(name_ids)))
gast_names = [
gast.Name(
id=name_id, ctx=ctx, annotation=None, type_comment=None)
for name_id in name_ids
]
if len(gast_names) == 1:
name_node = gast_names[0]
else:
name_node = gast.Tuple(elts=gast_names, ctx=ctx)
return name_node
def create_funcDef_node(nodes, name, input_args, return_name_ids):
"""
Wrapper all statements of nodes into one ast.FunctionDef, which can be
called by ast.Call.
"""
nodes = copy.copy(nodes)
# add return statement
nodes.append(gast.Return(value=generate_name_node(return_name_ids)))
func_def_node = gast.FunctionDef(
name=name,
args=input_args,
body=nodes,
decorator_list=[],
returns=None,
type_comment=None)
return func_def_node
def transform_if_else(node, root):
"""
Transform ast.If into control flow statement of Paddle static graph.
......@@ -384,43 +428,3 @@ def create_cond_node(return_name_ids, pred, true_func, false_func):
assign_node = gast.Assign(targets=targets, value=cond_layer)
return assign_node
def ast_to_func(ast_root, func_name, delete_on_exit=True):
"""
Transform modified AST of decorated function into python callable object.
"""
if not isinstance(ast_root, (gast.AST, ast.AST)):
raise TypeError(
"Type of ast_root should be gast.AST or ast.AST, but received %s." %
type(ast_root))
if isinstance(ast_root, gast.AST):
ast_root = gast.gast_to_ast(ast_root)
source = astor.to_source(ast_root)
if six.PY2:
source = source.encode('utf-8')
f = tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False)
else:
f = tempfile.NamedTemporaryFile(
mode='w', suffix='.py', delete=False, encoding='utf-8')
# TODO(Aurelius84): more elegant way to transform ast into callable object
import_str = "import paddle\n" \
"import paddle.fluid as fluid\n" \
"import paddle.fluid.layers as layers\n" \
"import numpy as np\n" \
"import numpy\n"
with f:
module_name = os.path.basename(f.name[:-3])
f.write(import_str)
f.write(source)
if delete_on_exit:
atexit.register(lambda: os.remove(f.name))
module = imp.load_source(module_name, f.name)
if not hasattr(module, func_name):
raise ValueError(
'Function: %s doesn\'t exist in the Module transformed from AST.' %
func_name)
return getattr(module, func_name), f.name
......@@ -19,8 +19,7 @@ import gast
from collections import defaultdict
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.ast_utils import create_funcDef_node
from paddle.fluid.dygraph.dygraph_to_static.ast_utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node
......
......@@ -14,10 +14,7 @@
from __future__ import print_function
import astor
import gast
import inspect
import six
import warnings
from .utils import is_paddle_api, is_dygraph_api, is_numpy_api
......
......@@ -14,9 +14,16 @@
from __future__ import print_function
import inspect
import gast
import ast
import astor
import atexit
import copy
import gast
import imp
import inspect
import os
import six
import tempfile
dygraph_class_to_static_api = {
"CosineDecay": "cosine_decay",
......@@ -206,3 +213,82 @@ def create_api_shape_node(tensor_shape_node):
args=[tensor_shape_node.value],
keywords=[])
return api_shape_node
def generate_name_node(name_ids, ctx=gast.Load()):
"""
Generate list or gast.Tuple of ast.Name for Return statement.
"""
if isinstance(name_ids, six.string_types):
name_ids = [name_ids]
if not isinstance(name_ids, (list, tuple, set)):
raise TypeError('name_ids must be list or tuple or set, but received %s'
% type(type(name_ids)))
gast_names = [
gast.Name(
id=name_id, ctx=ctx, annotation=None, type_comment=None)
for name_id in name_ids
]
if len(gast_names) == 1:
name_node = gast_names[0]
else:
name_node = gast.Tuple(elts=gast_names, ctx=ctx)
return name_node
def create_funcDef_node(nodes, name, input_args, return_name_ids):
"""
Wrapper all statements of nodes into one ast.FunctionDef, which can be
called by ast.Call.
"""
nodes = copy.copy(nodes)
# add return statement
nodes.append(gast.Return(value=generate_name_node(return_name_ids)))
func_def_node = gast.FunctionDef(
name=name,
args=input_args,
body=nodes,
decorator_list=[],
returns=None,
type_comment=None)
return func_def_node
def ast_to_func(ast_root, func_name, delete_on_exit=True):
"""
Transform modified AST of decorated function into python callable object.
"""
if not isinstance(ast_root, (gast.AST, ast.AST)):
raise TypeError(
"Type of ast_root should be gast.AST or ast.AST, but received %s." %
type(ast_root))
if isinstance(ast_root, gast.AST):
ast_root = gast.gast_to_ast(ast_root)
source = astor.to_source(ast_root)
if six.PY2:
source = source.encode('utf-8')
f = tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False)
else:
f = tempfile.NamedTemporaryFile(
mode='w', suffix='.py', delete=False, encoding='utf-8')
# TODO(Aurelius84): more elegant way to transform ast into callable object
import_str = "import paddle\n" \
"import paddle.fluid as fluid\n" \
"import paddle.fluid.layers as layers\n" \
"import numpy as np\n" \
"import numpy\n"
with f:
module_name = os.path.basename(f.name[:-3])
f.write(import_str)
f.write(source)
if delete_on_exit:
atexit.register(lambda: os.remove(f.name))
module = imp.load_source(module_name, f.name)
if not hasattr(module, func_name):
raise ValueError(
'Function: %s doesn\'t exist in the Module transformed from AST.' %
func_name)
return getattr(module, func_name), f.name
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import textwrap
import gast
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import get_name_ids, is_control_flow_if
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
class TestGetNameIds(unittest.TestCase):
"""
Test for parsing the ast.Name list from the ast.Nodes
"""
def setUp(self):
self.source = """
def test_fn(x):
return x+1
"""
self.all_name_ids = {'x': [gast.Param()]}
def test_get_name_ids(self):
source = textwrap.dedent(self.source)
root = gast.parse(source)
all_name_ids = get_name_ids([root])
self.assertDictEqual(
self.transfer_dict(self.all_name_ids),
self.transfer_dict(all_name_ids))
def transfer_dict(self, name_ids_dict):
new_dict = {}
for name, ctxs in name_ids_dict.items():
new_dict[name] = [type(ctx) for ctx in ctxs]
return new_dict
class TestGetNameIds2(TestGetNameIds):
def setUp(self):
self.source = """
def test_fn(x, y):
a = 1
x = y + a
if x > y:
z = x * x
z = z + a
else:
z = y * y
return z
"""
self.all_name_ids = {
'x': [
gast.Param(), gast.Store(), gast.Load(), gast.Load(),
gast.Load()
],
'a': [gast.Store(), gast.Load(), gast.Load()],
'y':
[gast.Param(), gast.Load(), gast.Load(), gast.Load(), gast.Load()],
'z': [gast.Store(), gast.Load(), gast.Store(), gast.Store()]
}
class TestGetNameIds3(TestGetNameIds):
def setUp(self):
self.source = """
def test_fn(x, y):
z = 1
if x > y:
z = x * x
z = z + y
return z
"""
self.all_name_ids = {
'x': [gast.Param(), gast.Load(), gast.Load(), gast.Load()],
'y': [gast.Param(), gast.Load(), gast.Load()],
'z': [gast.Store(), gast.Store(), gast.Load(), gast.Store()]
}
class TestIsControlFlowIf(unittest.TestCase):
def test_expr(self):
# node is not ast.Compare
node = gast.parse("a + b")
self.assertFalse(is_control_flow_if(node.body[0].value))
def test_expr2(self):
node = gast.parse("a + x.numpy()[1]")
self.assertFalse(is_control_flow_if(node.body[0].value))
def test_is_None(self):
node = gast.parse("x is None")
self.assertFalse(is_control_flow_if(node.body[0].value))
def test_is_None2(self):
node = gast.parse("fluid.layers.sum(x) is None")
self.assertFalse(is_control_flow_if(node.body[0].value))
def test_is_None3(self):
node = gast.parse("fluid.layers.sum(x).numpy() != None")
self.assertFalse(is_control_flow_if(node.body[0].value))
def test_if(self):
node = gast.parse("x.numpy()[1] > 1")
self.assertTrue(is_control_flow_if(node.body[0].value))
def test_if_with_and(self):
node = gast.parse("x is not None and 1 < x.numpy()[1]")
self.assertTrue(is_control_flow_if(node.body[0].value))
def test_if_with_or(self):
node = gast.parse("1 < fluid.layers.sum(x).numpy()[2] or x+y < 0")
self.assertTrue(is_control_flow_if(node.body[0].value))
def test_shape(self):
code = """
def foo(x):
batch_size = fluid.layers.shape(x)
if batch_size[0] > 16:
x = x + 1
return x
"""
code = textwrap.dedent(code)
node = gast.parse(code)
visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[1].test
self.assertTrue(is_control_flow_if(test_node, visitor))
def test_shape_with_andOr(self):
code = """
def foo(x):
batch_size = fluid.layers.shape(x)
if x is not None and batch_size[0] > 16 or 2 > 1:
x = x + 1
return x
"""
code = textwrap.dedent(code)
node = gast.parse(code)
visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[1].test
self.assertTrue(is_control_flow_if(test_node, visitor))
def test_paddle_api(self):
code = """
def foo(x):
if fluid.layers.shape(x)[0] > 16:
x = x + 1
return x
"""
code = textwrap.dedent(code)
node = gast.parse(code)
visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[0].test
self.assertTrue(is_control_flow_if(test_node, visitor))
def test_paddle_api_with_andOr(self):
code = """
def foo(x):
if 2 > 1 and fluid.layers.shape(x)[0] > 16 or x is not None :
x = x + 1
return x
"""
code = textwrap.dedent(code)
node = gast.parse(code)
visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[0].test
self.assertTrue(is_control_flow_if(test_node, visitor))
def test_raise_error(self):
node = "a + b"
with self.assertRaises(Exception) as e:
self.assertRaises(TypeError, is_control_flow_if(node))
self.assertTrue(
"Type of input node should be gast.AST" in str(e.exception))
if __name__ == '__main__':
unittest.main()
......@@ -20,177 +20,11 @@ import gast
import inspect
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static.ast_utils import get_name_ids, ast_to_func, is_control_flow_if
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from test_dygraph_to_static_basic import dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else
class TestGetNameIds(unittest.TestCase):
"""
Test for parsing the ast.Name list from the ast.Nodes
"""
def setUp(self):
self.source = """
def test_fn(x):
return x+1
"""
self.all_name_ids = {'x': [gast.Param()]}
def test_get_name_ids(self):
source = textwrap.dedent(self.source)
root = gast.parse(source)
all_name_ids = get_name_ids([root])
self.assertDictEqual(
self.transfer_dict(self.all_name_ids),
self.transfer_dict(all_name_ids))
def transfer_dict(self, name_ids_dict):
new_dict = {}
for name, ctxs in name_ids_dict.items():
new_dict[name] = [type(ctx) for ctx in ctxs]
return new_dict
class TestGetNameIds2(TestGetNameIds):
def setUp(self):
self.source = """
def test_fn(x, y):
a = 1
x = y + a
if x > y:
z = x * x
z = z + a
else:
z = y * y
return z
"""
self.all_name_ids = {
'x': [
gast.Param(), gast.Store(), gast.Load(), gast.Load(),
gast.Load()
],
'a': [gast.Store(), gast.Load(), gast.Load()],
'y':
[gast.Param(), gast.Load(), gast.Load(), gast.Load(), gast.Load()],
'z': [gast.Store(), gast.Load(), gast.Store(), gast.Store()]
}
class TestGetNameIds3(TestGetNameIds):
def setUp(self):
self.source = """
def test_fn(x, y):
z = 1
if x > y:
z = x * x
z = z + y
return z
"""
self.all_name_ids = {
'x': [gast.Param(), gast.Load(), gast.Load(), gast.Load()],
'y': [gast.Param(), gast.Load(), gast.Load()],
'z': [gast.Store(), gast.Store(), gast.Load(), gast.Store()]
}
class TestIsControlFlowIf(unittest.TestCase):
def test_expr(self):
# node is not ast.Compare
node = gast.parse("a + b")
self.assertFalse(is_control_flow_if(node.body[0].value))
def test_expr2(self):
node = gast.parse("a + x.numpy()[1]")
self.assertFalse(is_control_flow_if(node.body[0].value))
def test_is_None(self):
node = gast.parse("x is None")
self.assertFalse(is_control_flow_if(node.body[0].value))
def test_is_None2(self):
node = gast.parse("fluid.layers.sum(x) is None")
self.assertFalse(is_control_flow_if(node.body[0].value))
def test_is_None3(self):
node = gast.parse("fluid.layers.sum(x).numpy() != None")
self.assertFalse(is_control_flow_if(node.body[0].value))
def test_if(self):
node = gast.parse("x.numpy()[1] > 1")
self.assertTrue(is_control_flow_if(node.body[0].value))
def test_if_with_and(self):
node = gast.parse("x is not None and 1 < x.numpy()[1]")
self.assertTrue(is_control_flow_if(node.body[0].value))
def test_if_with_or(self):
node = gast.parse("1 < fluid.layers.sum(x).numpy()[2] or x+y < 0")
self.assertTrue(is_control_flow_if(node.body[0].value))
def test_shape(self):
code = """
def foo(x):
batch_size = fluid.layers.shape(x)
if batch_size[0] > 16:
x = x + 1
return x
"""
code = textwrap.dedent(code)
node = gast.parse(code)
visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[1].test
self.assertTrue(is_control_flow_if(test_node, visitor))
def test_shape_with_andOr(self):
code = """
def foo(x):
batch_size = fluid.layers.shape(x)
if x is not None and batch_size[0] > 16 or 2 > 1:
x = x + 1
return x
"""
code = textwrap.dedent(code)
node = gast.parse(code)
visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[1].test
self.assertTrue(is_control_flow_if(test_node, visitor))
def test_paddle_api(self):
code = """
def foo(x):
if fluid.layers.shape(x)[0] > 16:
x = x + 1
return x
"""
code = textwrap.dedent(code)
node = gast.parse(code)
visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[0].test
self.assertTrue(is_control_flow_if(test_node, visitor))
def test_paddle_api_with_andOr(self):
code = """
def foo(x):
if 2 > 1 and fluid.layers.shape(x)[0] > 16 or x is not None :
x = x + 1
return x
"""
code = textwrap.dedent(code)
node = gast.parse(code)
visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[0].test
self.assertTrue(is_control_flow_if(test_node, visitor))
def test_raise_error(self):
node = "a + b"
with self.assertRaises(Exception) as e:
self.assertRaises(TypeError, is_control_flow_if(node))
self.assertTrue(
"Type of input node should be gast.AST" in str(e.exception))
class TestAST2Func(unittest.TestCase):
"""
TestCase for the transformation from ast.AST into python callable function.
......@@ -211,7 +45,7 @@ class TestAST2Func(unittest.TestCase):
self.assertEqual(func(x, y), self._ast2func(func)(x, y))
def test_ast2func_dygraph(self):
funcs = [dyfunc_with_if_else, dyfunc_with_if_else, nested_if_else]
funcs = [dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else]
x_data = np.random.random([10, 16]).astype('float32')
for func in funcs:
with fluid.dygraph.guard():
......
......@@ -79,10 +79,10 @@ def dyfunc_BilinearTensorProduct(layer1, layer2):
input1_dim=5,
input2_dim=4,
output_dim=1000,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.99)),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.5)))
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.99)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.5)))
res = bilinearTensorProduct(
fluid.dygraph.base.to_variable(layer1),
......@@ -95,10 +95,10 @@ def dyfunc_Conv2D(input):
num_channels=3,
num_filters=2,
filter_size=3,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.99)),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.5)), )
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.99)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.5)), )
res = conv2d(input)
return res
......@@ -108,10 +108,10 @@ def dyfunc_Conv3D(input):
num_channels=3,
num_filters=2,
filter_size=3,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.99)),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.5)), )
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.99)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.5)), )
res = conv3d(input)
return res
......@@ -122,10 +122,10 @@ def dyfunc_Conv2DTranspose(input):
num_filters=12,
filter_size=12,
use_cudnn=False,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.99)),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.5)), )
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.99)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.5)), )
ret = conv2dTranspose(input)
return ret
......@@ -136,10 +136,10 @@ def dyfunc_Conv3DTranspose(input):
num_filters=12,
filter_size=12,
use_cudnn=False,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.99)),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.5)), )
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.99)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.5)), )
ret = conv3dTranspose(input)
return ret
......@@ -149,10 +149,10 @@ def dyfunc_Linear(input):
input_dim=10,
output_dim=5,
act='relu',
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.99)),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.5)), )
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.99)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.5)), )
res = fc(input)
return res
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册