未验证 提交 16e74f11 编写于 作者: A Aurelius84 提交者: GitHub

fix is_controw_flow bug with `if Tensor.numpy()` (#23251)

上级 be2ac9cc
......@@ -62,7 +62,6 @@ class IfElseTransformer(gast.NodeTransformer):
self.after_visit(self.root)
def visit_If(self, node):
assert isinstance(node, gast.If)
if_condition_visitor = IfConditionVisitor(node.test,
self.static_analysis_visitor)
need_transform = if_condition_visitor.is_control_flow()
......@@ -88,6 +87,22 @@ class IfElseTransformer(gast.NodeTransformer):
node = attribute.value
return node
def visit_IfExp(self, node):
"""
Transformation with `true_fn(x) if Tensor > 0 else false_fn(x)`
"""
if_condition_visitor = IfConditionVisitor(node.test,
self.static_analysis_visitor)
need_transform = if_condition_visitor.is_control_flow()
self.generic_visit(node)
if need_transform:
pred_node, new_assign_nodes = if_condition_visitor.transform()
new_node = create_cond_node(None, pred_node, node.body, node.orelse,
True)
return new_node
else:
return node
def after_visit(self, node):
"""
This function will add some postprocessing operations with node.
......@@ -130,7 +145,12 @@ def is_candidate_node(node):
"""
Nodes with specified type will be dependent on tensor.
"""
return isinstance(node, (gast.Compare, gast.BoolOp, gast.UnaryOp))
is_compare_node = isinstance(node,
(gast.Compare, gast.BoolOp, gast.UnaryOp))
# TODO(Aurelius84): `.numpy()` may be an customized function,
# and should consider a more elegant way to solve this problem.
has_numpy_attr = ".numpy()" in ast_to_source_code(node)
return is_compare_node or has_numpy_attr
def compare_with_none(node):
......@@ -223,6 +243,7 @@ class IsControlFlowVisitor(gast.NodeVisitor):
self.is_control_flow_num += 1
def visit_Call(self, node):
self._visit_Call(node)
if is_paddle_api(node):
self.is_control_flow_num += 1
return node
......@@ -238,8 +259,7 @@ class IsControlFlowVisitor(gast.NodeVisitor):
return node
def _is_node_with_tensor(self, node, name_id):
tensor_types = set(
[NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES])
tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES}
# Look up the node_var_type_map by name_id.
if self.node_var_type_map:
if name_id and isinstance(name_id, six.string_types):
......@@ -261,7 +281,9 @@ class IsControlFlowVisitor(gast.NodeVisitor):
class NodeTestTransformer(gast.NodeTransformer):
def __init__(self, ast_node, compare_nodes_with_tensor=set()):
def __init__(self, ast_node, compare_nodes_with_tensor=None):
if compare_nodes_with_tensor is None:
compare_nodes_with_tensor = set()
self.ast_root = ast_node
self._compare_nodes_with_tensor = compare_nodes_with_tensor
self._new_assign_nodes = []
......@@ -269,6 +291,15 @@ class NodeTestTransformer(gast.NodeTransformer):
def transform(self):
return self.visit(self.ast_root)
def visit_Call(self, node):
# self.generic_visit(node)
# Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]`
if isinstance(node.func, gast.Attribute):
attribute = node.func
if attribute.attr == 'numpy':
node = attribute.value
return node
def visit_UnaryOp(self, node):
self.generic_visit(node)
if isinstance(node.op, gast.Not):
......@@ -297,6 +328,7 @@ class NodeTestTransformer(gast.NodeTransformer):
if compare_with_none(
node) or node not in self._compare_nodes_with_tensor:
return self._create_bool_node(node)
self.generic_visit(node)
return node
def _create_bool_node(self, node):
......@@ -656,30 +688,29 @@ def transform_if_else(node, root):
return true_func_node, false_func_node, return_name_ids
def create_cond_node(return_name_ids, pred, true_func, false_func):
def create_cond_node(return_name_ids,
pred,
true_func,
false_func,
is_if_expr=False):
"""
Create `fluid.layers.cond(pred, true_fn, false_fn)` to replace
original `python if/else` statement.
"""
cond_api = gast.parse('fluid.layers.cond').body[0].value
true_func_lambda = gast.Lambda(
args=gast.arguments(
args=[],
posonlyargs=[],
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
defaults=[]),
body=gast.Call(
def create_lambda_node(func_or_expr_node, is_if_expr=False):
body = func_or_expr_node
if not is_if_expr:
body = gast.Call(
func=gast.Name(
id=true_func.name,
id=func_or_expr_node.name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
args=[true_func.args],
keywords=[]))
false_func_lambda = gast.Lambda(
args=[func_or_expr_node.args],
keywords=[])
lambda_node = gast.Lambda(
args=gast.arguments(
args=[],
posonlyargs=[],
......@@ -688,14 +719,12 @@ def create_cond_node(return_name_ids, pred, true_func, false_func):
kw_defaults=None,
kwarg=None,
defaults=[]),
body=gast.Call(
func=gast.Name(
id=false_func.name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
args=[false_func.args],
keywords=[]))
body=body)
return lambda_node
cond_api = gast.parse('fluid.layers.cond').body[0].value
true_func_lambda = create_lambda_node(true_func, is_if_expr)
false_func_lambda = create_lambda_node(false_func, is_if_expr)
cond_layer = gast.Call(
func=cond_api,
args=[pred, true_func_lambda, false_func_lambda],
......
......@@ -58,7 +58,8 @@ def nested_if_else(x_v):
bias = fluid.layers.fill_constant([feat_size], dtype='float32', value=1)
if x_v.shape[0] != batch_size:
batch_size = x_v.shape[0]
if fluid.layers.mean(x_v).numpy()[0] < 0:
# if tensor.shape is [1], now support to compare with numpy.
if fluid.layers.mean(x_v).numpy() < 0:
y = x_v + bias
w = fluid.layers.fill_constant([feat_size], dtype='float32', value=10)
if y.numpy()[0] < 10:
......
......@@ -89,7 +89,13 @@ class MainNetWithDict(fluid.dygraph.Layer):
dtype='float32',
value=0),
}
max_len = input.shape[0] if input.shape[0] != max_len else max_len
# TODO(Aurelius84): The following code will be converted into:
# max_len = layers.cond(layers.shape(input)[0] != max_len,
# lambda: layers.shape(input)[0], lambda: max_len)
# But max_len should be wrapped into tensor, which is not supported.
# Comment out this line of code for now.
# max_len = input.shape[0] if input.shape[0] != max_len else max_len
out = input
for i in range(max_len):
out = self.sub_net(out, cache)
......
......@@ -84,6 +84,38 @@ class TestDygraphIfElse5(TestDygraphIfElse):
self.dyfunc = nested_if_else_3
def dyfunc_ifExp_with_while(x):
y = [x]
def add_fn(x):
x = x + 1
return x
def cond(i, ten, y):
return i < ten
def map_func(func, tensor_list):
return [func(x) for x in tensor_list]
def body(i, ten, y):
# It will be converted into `layers.cond` as followed.
# map_func(lambda x: fluid.layers.cond(i==0, lambda: x, lambda: add_fn(x), y)
y = map_func(lambda x: x if i == 0 else add_fn(x), y)
i += 1
return [i, ten, y]
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
ten = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10)
i, ten, y = fluid.layers.while_loop(cond, body, [i, ten, y])
return y[0]
class TestDygraphIfElse6(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = dyfunc_ifExp_with_while
class TestDygraphIfElseWithAndOr(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
......
......@@ -135,7 +135,15 @@ class TestIsControlFlowIf(unittest.TestCase):
self.check_false_case("a+b")
def test_expr2(self):
self.check_false_case("a + x.numpy()[1]")
# x is a Tensor.
node = gast.parse("a + x.numpy()")
node_test = node.body[0].value
if_visitor = IfConditionVisitor(node_test)
self.assertTrue(if_visitor.is_control_flow())
# No transformation will be applied.
new_node, assign_nodes = if_visitor.transform()
self.assertTrue(len(assign_nodes) == 0)
def test_is_None(self):
self.check_false_case("x is None")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册