未验证 提交 16e74f11 编写于 作者: A Aurelius84 提交者: GitHub

fix is_controw_flow bug with `if Tensor.numpy()` (#23251)

上级 be2ac9cc
...@@ -62,7 +62,6 @@ class IfElseTransformer(gast.NodeTransformer): ...@@ -62,7 +62,6 @@ class IfElseTransformer(gast.NodeTransformer):
self.after_visit(self.root) self.after_visit(self.root)
def visit_If(self, node): def visit_If(self, node):
assert isinstance(node, gast.If)
if_condition_visitor = IfConditionVisitor(node.test, if_condition_visitor = IfConditionVisitor(node.test,
self.static_analysis_visitor) self.static_analysis_visitor)
need_transform = if_condition_visitor.is_control_flow() need_transform = if_condition_visitor.is_control_flow()
...@@ -88,6 +87,22 @@ class IfElseTransformer(gast.NodeTransformer): ...@@ -88,6 +87,22 @@ class IfElseTransformer(gast.NodeTransformer):
node = attribute.value node = attribute.value
return node return node
def visit_IfExp(self, node):
"""
Transformation with `true_fn(x) if Tensor > 0 else false_fn(x)`
"""
if_condition_visitor = IfConditionVisitor(node.test,
self.static_analysis_visitor)
need_transform = if_condition_visitor.is_control_flow()
self.generic_visit(node)
if need_transform:
pred_node, new_assign_nodes = if_condition_visitor.transform()
new_node = create_cond_node(None, pred_node, node.body, node.orelse,
True)
return new_node
else:
return node
def after_visit(self, node): def after_visit(self, node):
""" """
This function will add some postprocessing operations with node. This function will add some postprocessing operations with node.
...@@ -130,7 +145,12 @@ def is_candidate_node(node): ...@@ -130,7 +145,12 @@ def is_candidate_node(node):
""" """
Nodes with specified type will be dependent on tensor. Nodes with specified type will be dependent on tensor.
""" """
return isinstance(node, (gast.Compare, gast.BoolOp, gast.UnaryOp)) is_compare_node = isinstance(node,
(gast.Compare, gast.BoolOp, gast.UnaryOp))
# TODO(Aurelius84): `.numpy()` may be an customized function,
# and should consider a more elegant way to solve this problem.
has_numpy_attr = ".numpy()" in ast_to_source_code(node)
return is_compare_node or has_numpy_attr
def compare_with_none(node): def compare_with_none(node):
...@@ -223,6 +243,7 @@ class IsControlFlowVisitor(gast.NodeVisitor): ...@@ -223,6 +243,7 @@ class IsControlFlowVisitor(gast.NodeVisitor):
self.is_control_flow_num += 1 self.is_control_flow_num += 1
def visit_Call(self, node): def visit_Call(self, node):
self._visit_Call(node)
if is_paddle_api(node): if is_paddle_api(node):
self.is_control_flow_num += 1 self.is_control_flow_num += 1
return node return node
...@@ -238,8 +259,7 @@ class IsControlFlowVisitor(gast.NodeVisitor): ...@@ -238,8 +259,7 @@ class IsControlFlowVisitor(gast.NodeVisitor):
return node return node
def _is_node_with_tensor(self, node, name_id): def _is_node_with_tensor(self, node, name_id):
tensor_types = set( tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES}
[NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES])
# Look up the node_var_type_map by name_id. # Look up the node_var_type_map by name_id.
if self.node_var_type_map: if self.node_var_type_map:
if name_id and isinstance(name_id, six.string_types): if name_id and isinstance(name_id, six.string_types):
...@@ -261,7 +281,9 @@ class IsControlFlowVisitor(gast.NodeVisitor): ...@@ -261,7 +281,9 @@ class IsControlFlowVisitor(gast.NodeVisitor):
class NodeTestTransformer(gast.NodeTransformer): class NodeTestTransformer(gast.NodeTransformer):
def __init__(self, ast_node, compare_nodes_with_tensor=set()): def __init__(self, ast_node, compare_nodes_with_tensor=None):
if compare_nodes_with_tensor is None:
compare_nodes_with_tensor = set()
self.ast_root = ast_node self.ast_root = ast_node
self._compare_nodes_with_tensor = compare_nodes_with_tensor self._compare_nodes_with_tensor = compare_nodes_with_tensor
self._new_assign_nodes = [] self._new_assign_nodes = []
...@@ -269,6 +291,15 @@ class NodeTestTransformer(gast.NodeTransformer): ...@@ -269,6 +291,15 @@ class NodeTestTransformer(gast.NodeTransformer):
def transform(self): def transform(self):
return self.visit(self.ast_root) return self.visit(self.ast_root)
def visit_Call(self, node):
# self.generic_visit(node)
# Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]`
if isinstance(node.func, gast.Attribute):
attribute = node.func
if attribute.attr == 'numpy':
node = attribute.value
return node
def visit_UnaryOp(self, node): def visit_UnaryOp(self, node):
self.generic_visit(node) self.generic_visit(node)
if isinstance(node.op, gast.Not): if isinstance(node.op, gast.Not):
...@@ -297,6 +328,7 @@ class NodeTestTransformer(gast.NodeTransformer): ...@@ -297,6 +328,7 @@ class NodeTestTransformer(gast.NodeTransformer):
if compare_with_none( if compare_with_none(
node) or node not in self._compare_nodes_with_tensor: node) or node not in self._compare_nodes_with_tensor:
return self._create_bool_node(node) return self._create_bool_node(node)
self.generic_visit(node)
return node return node
def _create_bool_node(self, node): def _create_bool_node(self, node):
...@@ -656,46 +688,43 @@ def transform_if_else(node, root): ...@@ -656,46 +688,43 @@ def transform_if_else(node, root):
return true_func_node, false_func_node, return_name_ids return true_func_node, false_func_node, return_name_ids
def create_cond_node(return_name_ids, pred, true_func, false_func): def create_cond_node(return_name_ids,
pred,
true_func,
false_func,
is_if_expr=False):
""" """
Create `fluid.layers.cond(pred, true_fn, false_fn)` to replace Create `fluid.layers.cond(pred, true_fn, false_fn)` to replace
original `python if/else` statement. original `python if/else` statement.
""" """
def create_lambda_node(func_or_expr_node, is_if_expr=False):
body = func_or_expr_node
if not is_if_expr:
body = gast.Call(
func=gast.Name(
id=func_or_expr_node.name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
args=[func_or_expr_node.args],
keywords=[])
lambda_node = gast.Lambda(
args=gast.arguments(
args=[],
posonlyargs=[],
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
defaults=[]),
body=body)
return lambda_node
cond_api = gast.parse('fluid.layers.cond').body[0].value cond_api = gast.parse('fluid.layers.cond').body[0].value
true_func_lambda = gast.Lambda( true_func_lambda = create_lambda_node(true_func, is_if_expr)
args=gast.arguments( false_func_lambda = create_lambda_node(false_func, is_if_expr)
args=[],
posonlyargs=[],
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
defaults=[]),
body=gast.Call(
func=gast.Name(
id=true_func.name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
args=[true_func.args],
keywords=[]))
false_func_lambda = gast.Lambda(
args=gast.arguments(
args=[],
posonlyargs=[],
vararg=None,
kwonlyargs=[],
kw_defaults=None,
kwarg=None,
defaults=[]),
body=gast.Call(
func=gast.Name(
id=false_func.name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
args=[false_func.args],
keywords=[]))
cond_layer = gast.Call( cond_layer = gast.Call(
func=cond_api, func=cond_api,
args=[pred, true_func_lambda, false_func_lambda], args=[pred, true_func_lambda, false_func_lambda],
......
...@@ -58,7 +58,8 @@ def nested_if_else(x_v): ...@@ -58,7 +58,8 @@ def nested_if_else(x_v):
bias = fluid.layers.fill_constant([feat_size], dtype='float32', value=1) bias = fluid.layers.fill_constant([feat_size], dtype='float32', value=1)
if x_v.shape[0] != batch_size: if x_v.shape[0] != batch_size:
batch_size = x_v.shape[0] batch_size = x_v.shape[0]
if fluid.layers.mean(x_v).numpy()[0] < 0: # if tensor.shape is [1], now support to compare with numpy.
if fluid.layers.mean(x_v).numpy() < 0:
y = x_v + bias y = x_v + bias
w = fluid.layers.fill_constant([feat_size], dtype='float32', value=10) w = fluid.layers.fill_constant([feat_size], dtype='float32', value=10)
if y.numpy()[0] < 10: if y.numpy()[0] < 10:
......
...@@ -89,7 +89,13 @@ class MainNetWithDict(fluid.dygraph.Layer): ...@@ -89,7 +89,13 @@ class MainNetWithDict(fluid.dygraph.Layer):
dtype='float32', dtype='float32',
value=0), value=0),
} }
max_len = input.shape[0] if input.shape[0] != max_len else max_len # TODO(Aurelius84): The following code will be converted into:
# max_len = layers.cond(layers.shape(input)[0] != max_len,
# lambda: layers.shape(input)[0], lambda: max_len)
# But max_len should be wrapped into tensor, which is not supported.
# Comment out this line of code for now.
# max_len = input.shape[0] if input.shape[0] != max_len else max_len
out = input out = input
for i in range(max_len): for i in range(max_len):
out = self.sub_net(out, cache) out = self.sub_net(out, cache)
......
...@@ -84,6 +84,38 @@ class TestDygraphIfElse5(TestDygraphIfElse): ...@@ -84,6 +84,38 @@ class TestDygraphIfElse5(TestDygraphIfElse):
self.dyfunc = nested_if_else_3 self.dyfunc = nested_if_else_3
def dyfunc_ifExp_with_while(x):
y = [x]
def add_fn(x):
x = x + 1
return x
def cond(i, ten, y):
return i < ten
def map_func(func, tensor_list):
return [func(x) for x in tensor_list]
def body(i, ten, y):
# It will be converted into `layers.cond` as followed.
# map_func(lambda x: fluid.layers.cond(i==0, lambda: x, lambda: add_fn(x), y)
y = map_func(lambda x: x if i == 0 else add_fn(x), y)
i += 1
return [i, ten, y]
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
ten = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10)
i, ten, y = fluid.layers.while_loop(cond, body, [i, ten, y])
return y[0]
class TestDygraphIfElse6(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = dyfunc_ifExp_with_while
class TestDygraphIfElseWithAndOr(TestDygraphIfElse): class TestDygraphIfElseWithAndOr(TestDygraphIfElse):
def setUp(self): def setUp(self):
self.x = np.random.random([10, 16]).astype('float32') self.x = np.random.random([10, 16]).astype('float32')
......
...@@ -135,7 +135,15 @@ class TestIsControlFlowIf(unittest.TestCase): ...@@ -135,7 +135,15 @@ class TestIsControlFlowIf(unittest.TestCase):
self.check_false_case("a+b") self.check_false_case("a+b")
def test_expr2(self): def test_expr2(self):
self.check_false_case("a + x.numpy()[1]") # x is a Tensor.
node = gast.parse("a + x.numpy()")
node_test = node.body[0].value
if_visitor = IfConditionVisitor(node_test)
self.assertTrue(if_visitor.is_control_flow())
# No transformation will be applied.
new_node, assign_nodes = if_visitor.transform()
self.assertTrue(len(assign_nodes) == 0)
def test_is_None(self): def test_is_None(self):
self.check_false_case("x is None") self.check_false_case("x is None")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册