未验证 提交 fb7b008a 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add Support for Break and Continue in Dygraph to Static (#23067)

1. Add support for Break and Continue in Dygraph to Static
2. Also add support for gast.Not in NodeTestTransformer
3. Also add support for logical op transformation in LoopTransformer
上级 853f2e52
......@@ -14,28 +14,31 @@
from __future__ import print_function
import copy
import inspect
import textwrap
import astor
import copy
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
# It provides a compatibility layer between the AST of various Python versions,
# as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/
import gast
import inspect
import textwrap
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api, is_to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api
__all__ = ['DygraphToStaticAst', 'convert_to_static']
......@@ -74,12 +77,15 @@ class DygraphToStaticAst(gast.NodeTransformer):
# Transform list used in control flow
ListTransformer(node_wrapper).transform()
# Transform all if/else statement of Dygraph into Static Graph.
IfElseTransformer(node_wrapper).transform()
# Transform break/continue in loops
BreakContinueTransformer(node_wrapper).transform()
# Transform for loop and while loop
LoopTransformer(node_wrapper).transform()
# Transform all if/else statement of Dygraph into Static Graph.
IfElseTransformer(node_wrapper).transform()
def visit_FunctionDef(self, node):
if self.decorate_func_name is None:
self.decorate_func_name = node.name
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import NodeTestTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
__all__ = ['BreakContinueTransformer']
BREAK_NAME_PREFIX = '__break'
CONTINUE_NAME_PREFIX = '__continue'
class ForToWhileTransformer(gast.NodeTransformer):
"""
Transform python for loop into while loop and add condition node in the
loop test
"""
def __init__(self, parent_node, loop_node, condition_node):
assert isinstance(
loop_node,
gast.For), "loop_node is not gast.For in ForToWhileTransformer"
self.parent_node = parent_node
self.loop_node = loop_node
self.condition_node = condition_node
def transform(self):
if hasattr(self.parent_node, 'body'):
body_list = self.parent_node.body
i = index_in_list(body_list, self.loop_node)
if i != -1:
new_stmts = self.get_for_stmt_nodes(body_list[i])
body_list[i:i + 1] = new_stmts
i += len(new_stmts)
return
if hasattr(self.parent_node, 'orelse'):
body_list = self.parent_node.orelse
i = index_in_list(body_list, self.loop_node)
if i != -1:
new_stmts = self.get_for_stmt_nodes(body_list[i])
body_list[i:i + 1] = new_stmts
i += len(new_stmts)
return
raise ValueError(
"parent_node doesn't contain the loop_node in ForToWhileTransformer")
def get_for_range_node(self, node):
if not isinstance(node.iter, gast.Call):
return None
if not isinstance(node.iter.func, gast.Name):
return None
if node.iter.func.id != "range":
return None
return node.iter
def get_for_args_stmts(self, iter_name, args_list):
'''
Returns 3 gast stmt nodes for argument.
1. Initailize of iterate variable
2. Condition for the loop
3. Statement for changing of iterate variable during the loop
'''
len_range_args = len(args_list)
assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments"
if len_range_args == 1:
init_stmt = get_constant_variable_node(iter_name, 0)
else:
init_stmt = gast.Assign(
targets=[
gast.Name(
id=iter_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=args_list[0])
range_max_node = args_list[0] if len_range_args == 1 else args_list[1]
step_node = args_list[2] if len_range_args == 3 else gast.Constant(
value=1, kind=None)
old_cond_stmt = gast.Compare(
left=gast.BinOp(
left=gast.Name(
id=iter_name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
op=gast.Add(),
right=step_node),
ops=[gast.LtE()],
comparators=[range_max_node])
cond_stmt = gast.BoolOp(
op=gast.And(), values=[old_cond_stmt, self.condition_node])
change_stmt = gast.AugAssign(
target=gast.Name(
id=iter_name,
ctx=gast.Store(),
annotation=None,
type_comment=None),
op=gast.Add(),
value=step_node)
return init_stmt, cond_stmt, change_stmt
def get_for_stmt_nodes(self, node):
assert isinstance(
node, gast.For), "Input node is NOT gast.For in get_for_stmt_nodes"
# TODO: support non-range case
range_call_node = self.get_for_range_node(node)
if range_call_node is None:
return [node]
if not isinstance(node.target, gast.Name):
return [node]
iter_var_name = node.target.id
init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts(
iter_var_name, range_call_node.args)
new_body = node.body
new_body.append(change_stmt)
while_node = gast.While(
test=cond_stmt, body=new_body, orelse=node.orelse)
return [init_stmt, while_node]
class BreakContinueTransformer(gast.NodeTransformer):
"""
Rewrite 'break' and 'continue' key words in a if-else python way to make
it equivalent to original control flow
The main idea of this class is:
1. Map the 'break/continue' stmt with an unique boolean variable V.
2. Find the first ancestor block containing this 'break/continue', a
block can be a node containing stmt list. We should remove all stmts
after the 'break/continue' and set the V to True here.
3. Add 'if V' for stmts in ancestor blocks between the first one
(exclusive) and the ancestor loop (inclusive)
4. For 'break' add break into condition of the loop. For 'continue',
set continue to False at the beginning of each loop
TODO: more details should be summarized as design document
"""
def __init__(self, wrapper_root):
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.ancestor_nodes = []
def transform(self):
self.visit(self.root)
def generic_visit(self, node):
# TODO: because we change ancestor nodes during visit_Break/Continue,
# not current node, so generic_visit of NodeTransformer will visit node
# which may be deleted. To prevent that node being added into
# transformed AST, I have to self-write a generic_visit, but this is
# NOT a good thing. Considering refactorying this whole class.
for field, value in gast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, gast.AST):
self.visit(item)
elif isinstance(value, gast.AST):
self.visit(value)
def visit(self, node):
self.ancestor_nodes.append(node)
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
ret = visitor(node)
self.ancestor_nodes.pop()
return ret
def visit_Break(self, node):
loop_node_index = self._find_ancestor_loop_index(node)
assert loop_node_index != -1, "SyntaxError: 'break' outside loop"
loop_node = self.ancestor_nodes[loop_node_index]
# 1. Map the 'break/continue' stmt with an unique boolean variable V.
variable_name = unique_name.generate(BREAK_NAME_PREFIX)
# 2. Find the first ancestor block containing this 'break/continue', a
# block can be a node containing stmt list. We should remove all stmts
# after the 'break/continue' and set the V to True here.
first_block_index = self._remove_stmts_after_break_continue(
node, variable_name, loop_node_index)
# 3. Add 'if V' for stmts in ancestor blocks between the first one
# (exclusive) and the ancestor loop (inclusive)
self._replace_if_stmt(loop_node_index, first_block_index, variable_name)
# 4. For 'break' add break into condition of the loop.
assign_false_node = create_fill_constant_node(variable_name, False)
self._add_stmt_before_cur_node(loop_node_index, assign_false_node)
cond_var_node = gast.UnaryOp(
op=gast.Not(),
operand=gast.Name(
id=variable_name,
ctx=gast.Load(),
annotation=None,
type_comment=None))
if isinstance(loop_node, gast.While):
loop_node.test = gast.BoolOp(
op=gast.And(), values=[loop_node.test, cond_var_node])
elif isinstance(loop_node, gast.For):
parent_node = self.ancestor_nodes[loop_node_index - 1]
for_to_while = ForToWhileTransformer(parent_node, loop_node,
cond_var_node)
for_to_while.transform()
def visit_Continue(self, node):
loop_node_index = self._find_ancestor_loop_index(node)
assert loop_node_index != -1, "SyntaxError: 'continue' outside loop"
loop_node = self.ancestor_nodes[loop_node_index]
# 1. Map the 'break/continue' stmt with an unique boolean variable V.
variable_name = unique_name.generate(CONTINUE_NAME_PREFIX)
# 2. Find the first ancestor block containing this 'break/continue', a
# block can be a node containing stmt list. We should remove all stmts
# after the 'break/continue' and set the V to True here.
first_block_index = self._remove_stmts_after_break_continue(
node, variable_name, loop_node_index)
# 3. Add 'if V' for stmts in ancestor blocks between the first one
# (exclusive) and the ancestor loop (inclusive)
self._replace_if_stmt(loop_node_index, first_block_index, variable_name)
# 4. For 'continue', set continue to False at the beginning of each loop
assign_false_node = create_fill_constant_node(variable_name, False)
loop_node.body.insert(0, assign_false_node)
def _remove_stmts_after_break_continue(
self, break_continue_node, break_continue_name, loop_node_index):
for first_block_index in range(
len(self.ancestor_nodes) - 1, loop_node_index - 1, -1):
first_block = self.ancestor_nodes[first_block_index]
if hasattr(first_block,
"body") and self._replace_break_continue_in_stmt_list(
first_block.body, break_continue_node,
break_continue_name):
return first_block_index
if hasattr(first_block,
"orelse") and self._replace_break_continue_in_stmt_list(
first_block.orelse, break_continue_node,
break_continue_name):
return first_block_index
return first_block_index
def _replace_break_continue_in_stmt_list(
self, stmt_list, break_continue_node, break_continue_name):
i = index_in_list(stmt_list, break_continue_node)
if i == -1:
return False
assign_true_node = create_fill_constant_node(break_continue_name, True)
stmt_list[i:] = [assign_true_node]
return True
def _replace_if_stmt(self, loop_node_index, first_block_index,
break_continue_name):
for i in range(first_block_index - 1, loop_node_index - 1, -1):
cur_node = self.ancestor_nodes[i]
son_node = self.ancestor_nodes[i + 1]
if hasattr(cur_node,
'body') and self._replace_after_node_to_if_in_stmt_list(
cur_node.body, son_node, break_continue_name):
continue
if hasattr(
cur_node,
'orelse') and self._replace_after_node_to_if_in_stmt_list(
cur_node.orelse, son_node, break_continue_name):
continue
def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node,
break_continue_name):
i = index_in_list(stmt_list, node)
if i == -1:
return False
if i == len(stmt_list) - 1:
# No need to add, we consider this as added successfully
return True
if_stmt = gast.If(test=gast.UnaryOp(
op=gast.Not(),
operand=gast.Name(
id=break_continue_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)),
body=stmt_list[i + 1:],
orelse=[])
stmt_list[i + 1:] = []
stmt_list.append(if_stmt)
return True
def _add_stmt_before_cur_node(self, cur_node_index, stmt_node):
cur_node = self.ancestor_nodes[cur_node_index]
parent_node = self.ancestor_nodes[cur_node_index - 1]
if hasattr(parent_node,
"body") and self._add_stmt_into_list_before_node(
parent_node.body, cur_node, stmt_node):
return True
if hasattr(parent_node,
"orelse") and self._add_stmt_into_list_before_node(
parent_node.orelse, cur_node, stmt_node):
return True
return False
def _add_stmt_into_list_before_node(self, stmt_list, node, stmt_node):
i = index_in_list(stmt_list, node)
if i == -1:
return False
stmt_list.insert(i, stmt_node)
return True
def _find_ancestor_loop_index(self, node):
for i in range(len(self.ancestor_nodes) - 1, -1, -1):
if isinstance(self.ancestor_nodes[i], (gast.For, gast.While)):
return i
return -1
......@@ -310,8 +310,8 @@ class AutoTracer(object):
if not isinstance(loss_name, six.string_types):
raise ValueError(
"Type of input loss_name should type(str), but received {}."
.format(type(loss_name)))
"Type of input loss_name should type(str), but received {}.".
format(type(loss_name)))
self._loss_name = loss_name
def _add_optimizer(self):
......
......@@ -36,6 +36,7 @@ TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
LOGIC_AND_PREFIX = 'logic_and'
LOGIC_OR_PREFIX = 'logic_or'
LOGIC_NOT_PREFIX = 'logic_not'
PLAIN_TENSOR_PREFIX = 'bool_tensor'
......@@ -129,7 +130,7 @@ def is_candidate_node(node):
"""
Nodes with specified type will be dependent on tensor.
"""
return isinstance(node, (gast.Compare, gast.BoolOp))
return isinstance(node, (gast.Compare, gast.BoolOp, gast.UnaryOp))
def compare_with_none(node):
......@@ -268,6 +269,21 @@ class NodeTestTransformer(gast.NodeTransformer):
def transform(self):
return self.visit(self.ast_root)
def visit_UnaryOp(self, node):
self.generic_visit(node)
if isinstance(node.op, gast.Not):
arg = ast_to_source_code(node.operand)
new_node_str = "fluid.layers.logical_not({})".format(arg)
# gast.parse returns Module(body=[expr(value=...)])
new_node = gast.parse(new_node_str).body[0].value
logic_tensor_name = unique_name.generate(LOGIC_NOT_PREFIX)
assign_name, assign_node = create_assign_node(logic_tensor_name,
new_node)
self._new_assign_nodes.append(assign_node)
return assign_name
return node
def visit_BoolOp(self, node):
for i, child in enumerate(node.values):
if not is_candidate_node(child):
......
......@@ -19,8 +19,10 @@ import gast
from collections import defaultdict
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
......@@ -62,6 +64,55 @@ def create_while_node(condition_name, body_name, loop_var_names):
return assign_node
class LogicalOpTransformer(gast.NodeTransformer):
"""
Transform python boolean op into Paddle logical op
"""
def __init__(self, node):
self.root = node
def transform(self):
return self.visit(self.root)
def visit_UnaryOp(self, node):
self.generic_visit(node)
if isinstance(node.op, gast.Not):
arg = ast_to_source_code(node.operand)
new_node_str = "fluid.layers.logical_not({})".format(arg)
# gast.parse returns Module(body=[expr(value=...)])
new_node = gast.parse(new_node_str).body[0].value
return new_node
return node
def visit_BoolOp(self, node):
self.generic_visit(node)
if isinstance(node.op, gast.And):
new_node = self._create_bool_op_node(node.values, 'and')
elif isinstance(node.op, gast.Or):
new_node = self._create_bool_op_node(node.values, 'or')
else:
raise TypeError(
"Only supports and/or syntax in control flow if statement.")
return new_node
def _create_bool_op_node(self, nodes, api_type):
assert len(
nodes
) > 1, "The length of BoolOp should be at least 2, but received {}.".format(
len(nodes))
if len(nodes) > 2:
# Creates logic_and/logic_or node recursively.
pre_assign_node = self._create_bool_op_node(nodes[:2], api_type)
nodes = [pre_assign_node] + nodes[2:]
args = [ast_to_source_code(child) for child in nodes]
new_node_str = "fluid.layers.logical_{}(x={}, y={})".format(
api_type, args[0], args[1])
# gast.parse return Module(body=[expr(...)])
new_node = gast.parse(new_node_str).body[0].value
return new_node
class NameVisitor(gast.NodeVisitor):
'''
Analysis name liveness for loop transformer
......@@ -89,8 +140,8 @@ class NameVisitor(gast.NodeVisitor):
return True
def get_loop_var_names(self, node):
assert isinstance(node, (gast.While,
gast.For)), "Input node is not gast loop node"
assert isinstance(
node, (gast.While, gast.For)), "Input node is not gast loop node"
loop_var_names = set()
create_var_names = set()
read_context = {type(gast.Load()), type(gast.AugLoad())}
......@@ -118,6 +169,9 @@ class NameVisitor(gast.NodeVisitor):
if self._is_call_func_name_node(node):
self.generic_visit(node)
return
if node.id == "False" or node.id == "True":
self.generic_visit(node)
return
self.current_seen_vars.add(node)
for loop_node in self.current_loop:
......@@ -390,6 +444,9 @@ class LoopTransformer(gast.NodeTransformer):
for name in loop_var_names:
new_stmts.append(to_static_variable_gast_node(name))
logical_op_transformer = LogicalOpTransformer(node.test)
cond_value_node = logical_op_transformer.transform()
condition_func_node = gast.FunctionDef(
name=unique_name.generate(WHILE_CONDITION_PREFIX),
args=gast.arguments(
......@@ -406,7 +463,7 @@ class LoopTransformer(gast.NodeTransformer):
kw_defaults=None,
kwarg=None,
defaults=[]),
body=[gast.Return(value=node.test)],
body=[gast.Return(value=cond_value_node)],
decorator_list=[],
returns=None,
type_comment=None)
......
......@@ -299,6 +299,14 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids):
return func_def_node
def index_in_list(array_list, item):
try:
return array_list.index(item)
except ValueError:
# Item not in array_list
return -1
def ast_to_func(ast_root, func_name, delete_on_exit=True):
"""
Transform modified AST of decorated function into python callable object.
......
......@@ -25,13 +25,35 @@ __all__ = ['to_static_variable_gast_node', 'create_static_variable_gast_node']
def to_static_variable_gast_node(name):
func_code = "{} = fluid.dygraph.dygraph_to_static.variable_trans_func.to_static_variable({})".format(
name, name)
return gast.parse(func_code)
return gast.parse(func_code).body[0]
def create_static_variable_gast_node(name):
func_code = "{} = fluid.layers.data(name='{}', shape=[-1], dtype='float32')".format(
name, name)
return gast.parse(func_code)
return gast.parse(func_code).body[0]
def create_fill_constant_node(name, value):
func_code = "{} = fluid.layers.fill_constant(shape=[1], ".format(name)
if isinstance(value, bool):
func_code += "dtype='bool', value={})".format(value)
return gast.parse(func_code).body[0]
if isinstance(value, float):
func_code += "dtype='float64', value={})".format(value)
return gast.parse(func_code).body[0]
if six.PY2:
if isinstance(value, int):
func_code += "dtype='int32', value={})".format(value)
return gast.parse(func_code).body[0]
if isinstance(value, long):
func_code += "dtype='int64', value={})".format(value)
return gast.parse(func_code).body[0]
else:
if isinstance(value, int):
func_code += "dtype='int64', value={})".format(value)
return gast.parse(func_code).body[0]
def to_static_variable(x):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import dygraph_to_static_graph
SEED = 2020
np.random.seed(SEED)
def test_continue_in_for(x):
x = fluid.dygraph.to_variable(x)
for i in range(10):
x += 1
if i > 5:
continue
x += 10086
x += i
return x
def test_continue_in_for_at_end(x):
x = fluid.dygraph.to_variable(x)
for i in range(10):
x += 1
if i > 5:
continue
return x
def test_continue_in_while(x):
x = fluid.dygraph.to_variable(x)
i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0)
while i < 10:
i += 1
if i > 5:
continue
x += 10086
x += i
return x
def test_break_in_for(x):
x = fluid.dygraph.to_variable(x)
for i in range(10):
x += 1
if i > 5:
break
x += 10086
x += i
return x
def test_break_in_for_at_end(x):
x = fluid.dygraph.to_variable(x)
for i in range(10):
x += 1
if i > 5:
break
return x
def test_break_in_while(x):
x = fluid.dygraph.to_variable(x)
i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0)
while i < 10:
i += 1
if i > 5:
break
x += 10086
x += i
return x
def test_break_continue_in_for(x):
x = fluid.dygraph.to_variable(x)
for i in range(1, 10, 1):
if i <= 4:
x += 1
continue
else:
x += 10010
break
x += 10086
return x
def test_for_in_else(x):
x = fluid.dygraph.to_variable(x)
#
# TODO: Huihuang founds that if we put the for range in else body
# the testcase will fail. Enable this test case after fixing it.
#
#if False:
# pass
#else:
# for i in range(0, 10):
# if i > 5:
# x += 1
# break
# x += i
#
if False:
pass
else:
for i in range(0, 10):
x += 1
break
x += i
return x
class TestContinueInFor(unittest.TestCase):
def setUp(self):
self.input = np.zeros((1)).astype('int32')
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.init_dygraph_func()
def init_dygraph_func(self):
self.dygraph_func = test_continue_in_for
def run_dygraph_mode(self):
with fluid.dygraph.guard():
res = self.dygraph_func(self.input)
return res.numpy()
def run_static_mode(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
res = dygraph_to_static_graph(self.dygraph_func)(self.input)
exe = fluid.Executor(self.place)
static_res = exe.run(main_program, fetch_list=[res])
return static_res[0]
def test_transformed_static_result(self):
static_res = self.run_static_mode()
dygraph_res = self.run_dygraph_mode()
self.assertTrue(
np.allclose(dygraph_res, static_res),
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
static_res))
class TestContinueInForAtEnd(TestContinueInFor):
def init_dygraph_func(self):
self.dygraph_func = test_continue_in_for_at_end
class TestBreakInFor(TestContinueInFor):
def init_dygraph_func(self):
self.dygraph_func = test_break_in_for
class TestBreakInForAtEnd(TestContinueInFor):
def init_dygraph_func(self):
self.dygraph_func = test_break_in_for_at_end
class TestBreakContinueInFor(TestContinueInFor):
def init_dygraph_func(self):
self.dygraph_func = test_break_continue_in_for
class TestForInElse(TestContinueInFor):
def init_dygraph_func(self):
self.dygraph_func = test_for_in_else
class TestContinueInWhile(TestContinueInFor):
def init_dygraph_func(self):
self.dygraph_func = test_continue_in_while
def test_transformed_static_result(self):
# TODO: while i < 10 in dygraph will be supported after PR22892
# so currently we just assert static result.
# remove this overrided function after PR22892 is merged
static_res = self.run_static_mode()
self.assertEqual(15, static_res[0])
class TestBreakInWhile(TestContinueInWhile):
def init_dygraph_func(self):
self.dygraph_func = test_break_in_while
def test_transformed_static_result(self):
# TODO: while i < 10 in dygraph will be supported after PR22892
# so currently we just assert static result.
# remove this overrided function after PR22892 is merged
static_res = self.run_static_mode()
self.assertEqual(15, static_res[0])
if __name__ == '__main__':
unittest.main()
......@@ -42,6 +42,14 @@ def for_loop_dyfunc(max_len):
return ret
def while_loop_bool_op(x):
i = fluid.dygraph.to_variable(x)
while (x >= 0 and x < 10) or x <= -1 or x < -3 or (x < -7 or x < -5):
i = i + x
x = x + 1
return i
class TestNameVisitor(unittest.TestCase):
def setUp(self):
self.loop_funcs = [while_loop_dyfunc, for_loop_dyfunc]
......@@ -67,12 +75,16 @@ class TestTransformWhileLoop(unittest.TestCase):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.x = np.zeros(shape=(1), dtype=np.int32)
self._init_dyfunc()
def _init_dyfunc(self):
self.dyfunc = while_loop_dyfunc
def _run_static(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
x_var = fluid.layers.assign(self.x)
static_func = dygraph_to_static_graph(while_loop_dyfunc)
static_func = dygraph_to_static_graph(self.dyfunc)
out = static_func(x_var)
exe = fluid.Executor(self.place)
......@@ -81,7 +93,7 @@ class TestTransformWhileLoop(unittest.TestCase):
def _run_dygraph(self):
with fluid.dygraph.guard(self.place):
ret = while_loop_dyfunc(fluid.dygraph.to_variable(self.x))
ret = self.dyfunc(fluid.dygraph.to_variable(self.x))
return ret.numpy()
def test_ast_to_func(self):
......@@ -97,6 +109,11 @@ class TestTransformWhileLoop(unittest.TestCase):
# self.assertTrue(np.allclose(self._run_dygraph(), self._run_static()))
class TestWhileLoopBoolOp(TestTransformWhileLoop):
def _init_dyfunc(self):
self.dyfunc = while_loop_bool_op
class TestTransformForLoop(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
class TestIndexInList(unittest.TestCase):
def test_index_in_list(self):
list_to_test = [1, 2, 3, 4, 5]
self.assertEqual(index_in_list(list_to_test, 4), 3)
self.assertEqual(index_in_list(list_to_test, 1), 0)
self.assertEqual(index_in_list(list_to_test, 5), 4)
self.assertEqual(index_in_list(list_to_test, 0), -1)
self.assertEqual(index_in_list(list_to_test, 6), -1)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gast
import six
import unittest
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
class TestVariableTransFunc(unittest.TestCase):
def test_create_fill_constant_node(self):
node = create_fill_constant_node("a", 1.0)
source = "a = fluid.layers.fill_constant(shape=[1], dtype='float64', value=1.0)"
self.assertEqual(ast_to_source_code(node).strip(), source)
node = create_fill_constant_node("b", True)
source = "b = fluid.layers.fill_constant(shape=[1], dtype='bool', value=True)"
self.assertEqual(ast_to_source_code(node).strip(), source)
if six.PY2:
node = create_fill_constant_node("c", 214)
source = "c = fluid.layers.fill_constant(shape=[1], dtype='int32', value=214)"
self.assertEqual(ast_to_source_code(node).strip(), source)
node = create_fill_constant_node("d", long(10086))
source = "d = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10086)"
self.assertEqual(ast_to_source_code(node).strip(), source)
else:
node = create_fill_constant_node("c", 4293)
source = "c = fluid.layers.fill_constant(shape=[1], dtype='int64', value=4293)"
self.assertEqual(ast_to_source_code(node).strip(), source)
self.assertIsNone(create_fill_constant_node("e", None))
self.assertIsNone(create_fill_constant_node("e", []))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册