未验证 提交 dab5e5d8 编写于 作者: A Aurelius84 提交者: GitHub

Add IsControlFlowIfVisitor in ast_transformer (#22709)

* add is_control_flow_if  test=develop
上级 cdf5f6fb
...@@ -52,8 +52,9 @@ class IfElseTransformer(gast.NodeTransformer): ...@@ -52,8 +52,9 @@ class IfElseTransformer(gast.NodeTransformer):
def visit_If(self, node): def visit_If(self, node):
assert isinstance(node, gast.If) assert isinstance(node, gast.If)
need_transform = is_control_flow_if(node.test)
self.generic_visit(node) self.generic_visit(node)
if is_control_flow_if(node.test): if need_transform:
pred_node = node.test pred_node = node.test
true_func_node, false_func_node, return_name_ids = transform_if_else( true_func_node, false_func_node, return_name_ids = transform_if_else(
node, self.root) node, self.root)
......
...@@ -27,8 +27,53 @@ from collections import defaultdict ...@@ -27,8 +27,53 @@ from collections import defaultdict
from paddle.fluid import unique_name from paddle.fluid import unique_name
TRUE_FUNC_PRFIX = 'true_fn' TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PRFIX = 'false_fn' FALSE_FUNC_PREFIX = 'false_fn'
class IsControlFlowIfVisitor(gast.NodeTransformer):
"""
Judge whether the node.test from Dygraph code dependent on paddle Tensor.
If does, it should satisfy:
1. must involve at least one var whose type is Tensor.
2. the Tensor var should call `.numpy()[]` interface or Tensor.shape is [1].
3. involve Tensor.shape[i] and the shape[i] is unknown in compile time.
The following examples should not be considered as control_flow_if:
1. `if Tensor_var` or `if Tensor_var is None`
2. if Tensor.shape[i] is determined with fixed value (not -1 or None)
Note: pred in ConditionalBlock require variable, which means all vars should be Tensor
or transformed into Tensor, like fill_constant(shape=[1], dtype='int32', value=Tensor.shape[i]).
TODO: 1. need to deal with `tensor.shape[i]` which need to eval the data of shape[i],
because reshape_op may be called before this statement.
"""
def __init__(self, node):
self.node = node
self.is_control_flow = False
def ast_visit(self):
self.visit(self.node)
return self.is_control_flow
def visit_Compare(self, node):
for child in gast.walk(node):
if isinstance(child, gast.Subscript):
self._visit_Subscript(child)
return node
def _visit_Subscript(self, node):
self.generic_visit(node)
if isinstance(node.value, gast.Call):
self._visit_Call(node.value)
return node
def _visit_Call(self, node):
assert isinstance(node, gast.Call)
if isinstance(node.func, gast.Attribute):
attr_node = node.func
self.is_control_flow = (attr_node.attr == 'numpy')
def is_control_flow_if(node): def is_control_flow_if(node):
...@@ -36,7 +81,10 @@ def is_control_flow_if(node): ...@@ -36,7 +81,10 @@ def is_control_flow_if(node):
Determine whether the node is a plain python `if statement` or Determine whether the node is a plain python `if statement` or
control flow in Paddle. control flow in Paddle.
""" """
return True assert isinstance(
node, gast.AST
), "Type of input node should be gast.AST, but received %s." % type(node)
return IsControlFlowIfVisitor(node).ast_visit()
def get_name_ids(nodes, not_name_set=None, node_black_list=None): def get_name_ids(nodes, not_name_set=None, node_black_list=None):
...@@ -228,12 +276,12 @@ def transform_if_else(node, root): ...@@ -228,12 +276,12 @@ def transform_if_else(node, root):
true_func_node = create_funcDef_node( true_func_node = create_funcDef_node(
node.body, node.body,
name=unique_name.generate(TRUE_FUNC_PRFIX), name=unique_name.generate(TRUE_FUNC_PREFIX),
input_args=parse_cond_args(if_name_ids, modified_name_ids), input_args=parse_cond_args(if_name_ids, modified_name_ids),
return_name_ids=return_name_ids) return_name_ids=return_name_ids)
false_func_node = create_funcDef_node( false_func_node = create_funcDef_node(
node.orelse, node.orelse,
name=unique_name.generate(FALSE_FUNC_PRFIX), name=unique_name.generate(FALSE_FUNC_PREFIX),
input_args=parse_cond_args(else_name_ids, modified_name_ids), input_args=parse_cond_args(else_name_ids, modified_name_ids),
return_name_ids=return_name_ids) return_name_ids=return_name_ids)
...@@ -309,7 +357,7 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True): ...@@ -309,7 +357,7 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True):
f = tempfile.NamedTemporaryFile( f = tempfile.NamedTemporaryFile(
mode='w', suffix='.py', delete=False, encoding='utf-8') mode='w', suffix='.py', delete=False, encoding='utf-8')
# TODO(Aurelius84): more elegent way to transform ast into callable object # TODO(Aurelius84): more elegant way to transform ast into callable object
import_str = "import paddle\n" \ import_str = "import paddle\n" \
"import paddle.fluid as fluid\n" \ "import paddle.fluid as fluid\n" \
"import paddle.fluid.layers as layers\n" "import paddle.fluid.layers as layers\n"
......
...@@ -187,7 +187,7 @@ class AstVarEnv(object): ...@@ -187,7 +187,7 @@ class AstVarEnv(object):
def exit_scope(self): def exit_scope(self):
assert self.cur_scope.parent_scope is not None, "Call exit_scope in "\ assert self.cur_scope.parent_scope is not None, "Call exit_scope in "\
"AstVarEnv when current scope doens't have parent scope." "AstVarEnv when current scope doesn't have parent scope."
self.cur_scope = self.cur_scope.parent_scope self.cur_scope = self.cur_scope.parent_scope
return self.cur_scope return self.cur_scope
......
...@@ -20,7 +20,9 @@ import gast ...@@ -20,7 +20,9 @@ import gast
import inspect import inspect
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static.ast_utils import get_name_ids, ast_to_func from paddle.fluid.dygraph.dygraph_to_static.ast_utils import get_name_ids, ast_to_func, is_control_flow_if
from test_dygraph_to_static_basic import dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else
class TestGetNameIds(unittest.TestCase): class TestGetNameIds(unittest.TestCase):
...@@ -92,22 +94,46 @@ class TestGetNameIds3(TestGetNameIds): ...@@ -92,22 +94,46 @@ class TestGetNameIds3(TestGetNameIds):
} }
def dyfunc_with_if_else(x_v): class TestIsControlFlowIf(unittest.TestCase):
if fluid.layers.mean(x_v).numpy()[0] > 5: def test_expr(self):
x_v = x_v - 1 # node is not ast.Compare
else: node = gast.parse("a + b")
x_v = x_v + 1 self.assertFalse(is_control_flow_if(node))
return x_v
def test_expr2(self):
node = gast.parse("a + x.numpy()[1]")
self.assertFalse(is_control_flow_if(node))
def test_is_None(self):
node = gast.parse("x is None")
self.assertFalse(is_control_flow_if(node))
def test_is_None2(self):
node = gast.parse("fluid.layers.sum(x) is None")
self.assertFalse(is_control_flow_if(node))
def test_is_None3(self):
node = gast.parse("fluid.layers.sum(x).numpy() != None")
self.assertFalse(is_control_flow_if(node))
def dyfunc_with_if_else2(x): def test_if(self):
i, j = 0, 0 node = gast.parse("x.numpy()[1] > 1")
if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[i][j]: self.assertTrue(is_control_flow_if(node))
y = fluid.layers.relu(x)
else: def test_if_with_and(self):
x_pow = fluid.layers.pow(x, 2) node = gast.parse("x is not None and 1 < x.numpy()[1]")
y = fluid.layers.tanh(x_pow) self.assertTrue(is_control_flow_if(node))
return y
def test_if_with_or(self):
node = gast.parse("1 < fluid.layers.sum(x).numpy()[2] or x+y < 0")
self.assertTrue(is_control_flow_if(node))
def test_raise_error(self):
node = "a + b"
with self.assertRaises(Exception) as e:
self.assertRaises(TypeError, is_control_flow_if(node))
self.assertTrue(
"Type of input node should be gast.AST" in str(e.exception))
class TestAST2Func(unittest.TestCase): class TestAST2Func(unittest.TestCase):
...@@ -130,13 +156,14 @@ class TestAST2Func(unittest.TestCase): ...@@ -130,13 +156,14 @@ class TestAST2Func(unittest.TestCase):
self.assertEqual(func(x, y), self._ast2func(func)(x, y)) self.assertEqual(func(x, y), self._ast2func(func)(x, y))
def test_ast2func_dygraph(self): def test_ast2func_dygraph(self):
func = dyfunc_with_if_else funcs = [dyfunc_with_if_else, dyfunc_with_if_else, nested_if_else]
x_data = np.random.random([10, 16]).astype('float32') x_data = np.random.random([10, 16]).astype('float32')
with fluid.dygraph.guard(): for func in funcs:
x_v = fluid.dygraph.to_variable(x_data) with fluid.dygraph.guard():
true_ret = func(x_v).numpy() x_v = fluid.dygraph.to_variable(x_data)
test_ret = self._ast2func(func)(x_v).numpy() true_ret = func(x_v).numpy()
self.assertTrue((true_ret == test_ret).all()) test_ret = self._ast2func(func)(x_v).numpy()
self.assertTrue((true_ret == test_ret).all())
def test_ast2func_static(self): def test_ast2func_static(self):
def func(x): def func(x):
......
...@@ -22,18 +22,30 @@ from paddle.fluid.dygraph.jit import dygraph_to_static_output ...@@ -22,18 +22,30 @@ from paddle.fluid.dygraph.jit import dygraph_to_static_output
np.random.seed(1) np.random.seed(1)
if fluid.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
def dyfunc_with_if_else(x_v):
def dyfunc_with_if_else(x_v, label=None):
if fluid.layers.mean(x_v).numpy()[0] > 5: if fluid.layers.mean(x_v).numpy()[0] > 5:
x_v = x_v - 1 x_v = x_v - 1
else: else:
x_v = x_v + 1 x_v = x_v + 1
# plain if in python
if label is not None:
loss = fluid.layers.cross_entropy(x_v, label)
return loss
return x_v return x_v
def dyfunc_with_if_else2(x): def dyfunc_with_if_else2(x, col=100):
i, j = 0, 0 row = 0
if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[i][j]: # plain if in python
if abs(col) > x.shape[-1]:
col = -1
if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[row][col]:
y = fluid.layers.relu(x) y = fluid.layers.relu(x)
else: else:
x_pow = fluid.layers.pow(x, 2) x_pow = fluid.layers.pow(x, 2)
...@@ -42,9 +54,12 @@ def dyfunc_with_if_else2(x): ...@@ -42,9 +54,12 @@ def dyfunc_with_if_else2(x):
def nested_if_else(x_v): def nested_if_else(x_v):
batch_size = x_v.shape[0] batch_size = 16
feat_size = x_v.shape[-1] feat_size = x_v.shape[-1]
bias = fluid.layers.fill_constant([feat_size], dtype='float32', value=1) bias = fluid.layers.fill_constant([feat_size], dtype='float32', value=1)
# plain if in python
if x_v.shape[0] != batch_size:
batch_size = x_v.shape[0]
if fluid.layers.mean(x_v).numpy()[0] < 0: if fluid.layers.mean(x_v).numpy()[0] < 0:
y = x_v + bias y = x_v + bias
w = fluid.layers.fill_constant([feat_size], dtype='float32', value=10) w = fluid.layers.fill_constant([feat_size], dtype='float32', value=10)
...@@ -78,12 +93,12 @@ class TestDygraphIfElse(unittest.TestCase): ...@@ -78,12 +93,12 @@ class TestDygraphIfElse(unittest.TestCase):
x_v = fluid.layers.assign(self.x) x_v = fluid.layers.assign(self.x)
# Transform into static graph # Transform into static graph
out = dygraph_to_static_output(self.dyfunc)(x_v) out = dygraph_to_static_output(self.dyfunc)(x_v)
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(place)
ret = exe.run(main_program, fetch_list=out) ret = exe.run(main_program, fetch_list=out)
return ret return ret
def _run_dygraph(self): def _run_dygraph(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard(place):
x_v = fluid.dygraph.to_variable(self.x) x_v = fluid.dygraph.to_variable(self.x)
ret = self.dyfunc(x_v) ret = self.dyfunc(x_v)
return ret.numpy() return ret.numpy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册