未验证 提交 dab5e5d8 编写于 作者: A Aurelius84 提交者: GitHub

Add IsControlFlowIfVisitor in ast_transformer (#22709)

* add is_control_flow_if  test=develop
上级 cdf5f6fb
......@@ -52,8 +52,9 @@ class IfElseTransformer(gast.NodeTransformer):
def visit_If(self, node):
assert isinstance(node, gast.If)
need_transform = is_control_flow_if(node.test)
self.generic_visit(node)
if is_control_flow_if(node.test):
if need_transform:
pred_node = node.test
true_func_node, false_func_node, return_name_ids = transform_if_else(
node, self.root)
......
......@@ -27,8 +27,53 @@ from collections import defaultdict
from paddle.fluid import unique_name
TRUE_FUNC_PRFIX = 'true_fn'
FALSE_FUNC_PRFIX = 'false_fn'
TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
class IsControlFlowIfVisitor(gast.NodeTransformer):
"""
Judge whether the node.test from Dygraph code dependent on paddle Tensor.
If does, it should satisfy:
1. must involve at least one var whose type is Tensor.
2. the Tensor var should call `.numpy()[]` interface or Tensor.shape is [1].
3. involve Tensor.shape[i] and the shape[i] is unknown in compile time.
The following examples should not be considered as control_flow_if:
1. `if Tensor_var` or `if Tensor_var is None`
2. if Tensor.shape[i] is determined with fixed value (not -1 or None)
Note: pred in ConditionalBlock require variable, which means all vars should be Tensor
or transformed into Tensor, like fill_constant(shape=[1], dtype='int32', value=Tensor.shape[i]).
TODO: 1. need to deal with `tensor.shape[i]` which need to eval the data of shape[i],
because reshape_op may be called before this statement.
"""
def __init__(self, node):
self.node = node
self.is_control_flow = False
def ast_visit(self):
self.visit(self.node)
return self.is_control_flow
def visit_Compare(self, node):
for child in gast.walk(node):
if isinstance(child, gast.Subscript):
self._visit_Subscript(child)
return node
def _visit_Subscript(self, node):
self.generic_visit(node)
if isinstance(node.value, gast.Call):
self._visit_Call(node.value)
return node
def _visit_Call(self, node):
assert isinstance(node, gast.Call)
if isinstance(node.func, gast.Attribute):
attr_node = node.func
self.is_control_flow = (attr_node.attr == 'numpy')
def is_control_flow_if(node):
......@@ -36,7 +81,10 @@ def is_control_flow_if(node):
Determine whether the node is a plain python `if statement` or
control flow in Paddle.
"""
return True
assert isinstance(
node, gast.AST
), "Type of input node should be gast.AST, but received %s." % type(node)
return IsControlFlowIfVisitor(node).ast_visit()
def get_name_ids(nodes, not_name_set=None, node_black_list=None):
......@@ -228,12 +276,12 @@ def transform_if_else(node, root):
true_func_node = create_funcDef_node(
node.body,
name=unique_name.generate(TRUE_FUNC_PRFIX),
name=unique_name.generate(TRUE_FUNC_PREFIX),
input_args=parse_cond_args(if_name_ids, modified_name_ids),
return_name_ids=return_name_ids)
false_func_node = create_funcDef_node(
node.orelse,
name=unique_name.generate(FALSE_FUNC_PRFIX),
name=unique_name.generate(FALSE_FUNC_PREFIX),
input_args=parse_cond_args(else_name_ids, modified_name_ids),
return_name_ids=return_name_ids)
......@@ -309,7 +357,7 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True):
f = tempfile.NamedTemporaryFile(
mode='w', suffix='.py', delete=False, encoding='utf-8')
# TODO(Aurelius84): more elegent way to transform ast into callable object
# TODO(Aurelius84): more elegant way to transform ast into callable object
import_str = "import paddle\n" \
"import paddle.fluid as fluid\n" \
"import paddle.fluid.layers as layers\n"
......
......@@ -187,7 +187,7 @@ class AstVarEnv(object):
def exit_scope(self):
assert self.cur_scope.parent_scope is not None, "Call exit_scope in "\
"AstVarEnv when current scope doens't have parent scope."
"AstVarEnv when current scope doesn't have parent scope."
self.cur_scope = self.cur_scope.parent_scope
return self.cur_scope
......
......@@ -20,7 +20,9 @@ import gast
import inspect
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static.ast_utils import get_name_ids, ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.ast_utils import get_name_ids, ast_to_func, is_control_flow_if
from test_dygraph_to_static_basic import dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else
class TestGetNameIds(unittest.TestCase):
......@@ -92,22 +94,46 @@ class TestGetNameIds3(TestGetNameIds):
}
def dyfunc_with_if_else(x_v):
if fluid.layers.mean(x_v).numpy()[0] > 5:
x_v = x_v - 1
else:
x_v = x_v + 1
return x_v
class TestIsControlFlowIf(unittest.TestCase):
def test_expr(self):
# node is not ast.Compare
node = gast.parse("a + b")
self.assertFalse(is_control_flow_if(node))
def test_expr2(self):
node = gast.parse("a + x.numpy()[1]")
self.assertFalse(is_control_flow_if(node))
def dyfunc_with_if_else2(x):
i, j = 0, 0
if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[i][j]:
y = fluid.layers.relu(x)
else:
x_pow = fluid.layers.pow(x, 2)
y = fluid.layers.tanh(x_pow)
return y
def test_is_None(self):
node = gast.parse("x is None")
self.assertFalse(is_control_flow_if(node))
def test_is_None2(self):
node = gast.parse("fluid.layers.sum(x) is None")
self.assertFalse(is_control_flow_if(node))
def test_is_None3(self):
node = gast.parse("fluid.layers.sum(x).numpy() != None")
self.assertFalse(is_control_flow_if(node))
def test_if(self):
node = gast.parse("x.numpy()[1] > 1")
self.assertTrue(is_control_flow_if(node))
def test_if_with_and(self):
node = gast.parse("x is not None and 1 < x.numpy()[1]")
self.assertTrue(is_control_flow_if(node))
def test_if_with_or(self):
node = gast.parse("1 < fluid.layers.sum(x).numpy()[2] or x+y < 0")
self.assertTrue(is_control_flow_if(node))
def test_raise_error(self):
node = "a + b"
with self.assertRaises(Exception) as e:
self.assertRaises(TypeError, is_control_flow_if(node))
self.assertTrue(
"Type of input node should be gast.AST" in str(e.exception))
class TestAST2Func(unittest.TestCase):
......@@ -130,8 +156,9 @@ class TestAST2Func(unittest.TestCase):
self.assertEqual(func(x, y), self._ast2func(func)(x, y))
def test_ast2func_dygraph(self):
func = dyfunc_with_if_else
funcs = [dyfunc_with_if_else, dyfunc_with_if_else, nested_if_else]
x_data = np.random.random([10, 16]).astype('float32')
for func in funcs:
with fluid.dygraph.guard():
x_v = fluid.dygraph.to_variable(x_data)
true_ret = func(x_v).numpy()
......
......@@ -22,18 +22,30 @@ from paddle.fluid.dygraph.jit import dygraph_to_static_output
np.random.seed(1)
if fluid.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
def dyfunc_with_if_else(x_v):
def dyfunc_with_if_else(x_v, label=None):
if fluid.layers.mean(x_v).numpy()[0] > 5:
x_v = x_v - 1
else:
x_v = x_v + 1
# plain if in python
if label is not None:
loss = fluid.layers.cross_entropy(x_v, label)
return loss
return x_v
def dyfunc_with_if_else2(x):
i, j = 0, 0
if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[i][j]:
def dyfunc_with_if_else2(x, col=100):
row = 0
# plain if in python
if abs(col) > x.shape[-1]:
col = -1
if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[row][col]:
y = fluid.layers.relu(x)
else:
x_pow = fluid.layers.pow(x, 2)
......@@ -42,9 +54,12 @@ def dyfunc_with_if_else2(x):
def nested_if_else(x_v):
batch_size = x_v.shape[0]
batch_size = 16
feat_size = x_v.shape[-1]
bias = fluid.layers.fill_constant([feat_size], dtype='float32', value=1)
# plain if in python
if x_v.shape[0] != batch_size:
batch_size = x_v.shape[0]
if fluid.layers.mean(x_v).numpy()[0] < 0:
y = x_v + bias
w = fluid.layers.fill_constant([feat_size], dtype='float32', value=10)
......@@ -78,12 +93,12 @@ class TestDygraphIfElse(unittest.TestCase):
x_v = fluid.layers.assign(self.x)
# Transform into static graph
out = dygraph_to_static_output(self.dyfunc)(x_v)
exe = fluid.Executor(fluid.CPUPlace())
exe = fluid.Executor(place)
ret = exe.run(main_program, fetch_list=out)
return ret
def _run_dygraph(self):
with fluid.dygraph.guard():
with fluid.dygraph.guard(place):
x_v = fluid.dygraph.to_variable(self.x)
ret = self.dyfunc(x_v)
return ret.numpy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册