未验证 提交 269470d6 编写于 作者: L liym27 提交者: GitHub

[Dynamic-to-Static] Remove unnecessary variables of the arguments in true_func/false_func (#28722)

上级 7d32e100
......@@ -245,23 +245,51 @@ def get_name_ids(nodes, end_node=None):
return name_visitor.name_ids
def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load):
def parse_cond_args(parent_ids_dict,
var_ids_dict,
modified_ids_dict=None,
ctx=gast.Load):
"""
Find out the ast.Name.id list of input by analyzing node's AST information.
"""
name_ids = [
# 1. filter the var fit the ctx
arg_name_ids = [
var_id for var_id, var_ctx in six.iteritems(var_ids_dict)
if isinstance(var_ctx[0], ctx)
]
if return_ids:
new_args = set(return_ids) - set(name_ids)
name_ids.extend(list(new_args))
name_ids.sort()
# 2. args should contain modified var ids in if-body or else-body
# case:
#
# ```
# if b < 1:
# z = y
# else:
# z = x
# ```
#
# In the above case, `z` should be in the args of cond()
if modified_ids_dict:
arg_name_ids = set(arg_name_ids) | set(modified_ids_dict)
# 3. args should not contain the vars not in parent ids
# case :
#
# ```
# x = 1
# if x > y:
# z = [v for v in range(i)]
# ```
#
# In the above case, `v` should not be in the args of cond()
arg_name_ids = list(set(arg_name_ids) & set(parent_ids_dict))
arg_name_ids.sort()
args = [
gast.Name(
id=name_id, ctx=gast.Load(), annotation=None, type_comment=None)
for name_id in name_ids
for name_id in arg_name_ids
]
arguments = gast.arguments(
args=args,
......@@ -412,7 +440,7 @@ def transform_if_else(node, root):
all_name_ids = get_name_ids([root])
for name in all_name_ids:
before_var_names_ids = parent_name_ids.get(name, []) + \
body_name_ids.get(name, []) + orelse_name_ids.get(name, [])
body_name_ids.get(name, []) + orelse_name_ids.get(name, [])
# Note: context of node.Name like gast.Load is a concrete object which has unique id different from other gast.Load
# E.g. ctx of `x` can be [<gast.Load object at 0x142a33c90>, <gast.Load object at 0x142a51950>, <gast.Param object at 0x1407d8250>]
after_var_names_ids = [
......@@ -444,12 +472,14 @@ def transform_if_else(node, root):
true_func_node = create_funcDef_node(
node.body,
name=unique_name.generate(TRUE_FUNC_PREFIX),
input_args=parse_cond_args(body_name_ids, modified_name_ids),
input_args=parse_cond_args(parent_name_ids, body_name_ids,
modified_name_ids),
return_name_ids=return_name_ids)
false_func_node = create_funcDef_node(
node.orelse,
name=unique_name.generate(FALSE_FUNC_PREFIX),
input_args=parse_cond_args(orelse_name_ids, modified_name_ids),
input_args=parse_cond_args(parent_name_ids, orelse_name_ids,
modified_name_ids),
return_name_ids=return_name_ids)
return create_new_vars_in_parent_stmts, true_func_node, false_func_node, return_name_ids
......
......@@ -14,6 +14,7 @@
from __future__ import print_function
import paddle
import paddle.fluid as fluid
......@@ -99,6 +100,16 @@ def dyfunc_with_if_else3(x):
return x
def dyfunc_with_if_else_with_list_geneator(x):
if 10 > 5:
y = paddle.add_n(
[paddle.full(
shape=[2], fill_value=v) for v in range(5)])
else:
y = x
return y
def nested_if_else(x_v):
batch_size = 16
feat_size = x_v.shape[-1]
......
......@@ -69,6 +69,12 @@ class TestDygraphIfElse3(TestDygraphIfElse):
self.dyfunc = dyfunc_with_if_else3
class TestDygraphIfElseWithListGenerator(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = dyfunc_with_if_else_with_list_geneator
class TestDygraphNestedIfElse(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册