未验证 提交 6e0aa776 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Enhance Python if-else by pruning usless no_return variable (#43880)

上级 6cb24967
......@@ -209,9 +209,44 @@ def convert_ifelse(pred, true_fn, false_fn, get_args, set_args,
out = _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
return_name_ids)
else:
out = _run_py_ifelse(pred, true_fn, false_fn)
out = _run_py_ifelse(pred, true_fn, false_fn, get_args, set_args,
return_name_ids)
return _remove_no_value_return_var(out)
return out
def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
return_name_ids):
"""
Paddle cond API will evaluate both ture_fn and false_fn codes.
"""
pred = cast_bool_if_necessary(pred)
init_args = get_args()
def new_true_fn():
set_args(init_args)
outs = true_fn()
_check_no_undefined_var(outs, return_name_ids, 'if_body')
return outs
def new_false_fn():
set_args(init_args)
outs = false_fn()
_check_no_undefined_var(outs, return_name_ids, 'else_body')
return outs
cond_outs = control_flow.cond(pred, new_true_fn, new_false_fn)
return _recover_args_state(cond_outs, get_args, set_args, return_name_ids)
def _run_py_ifelse(pred, true_fn, false_fn, get_args, set_args,
return_name_ids):
"""
Evaluate python original branch function if-else.
"""
py_outs = true_fn() if pred else false_fn()
py_outs = _remove_no_value_return_var(py_outs)
return _recover_args_state(py_outs, get_args, set_args, return_name_ids)
def _remove_no_value_return_var(out):
......@@ -258,50 +293,33 @@ def _check_no_undefined_var(outs, names, branch_name):
.format(name, branch_name))
def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
return_name_ids):
def _recover_args_state(outs, get_args, set_args, return_name_ids):
"""
Paddle cond API will evaluate both ture_fn and false_fn codes.
"""
pred = cast_bool_if_necessary(pred)
init_args = get_args()
def new_true_fn():
set_args(init_args)
outs = true_fn()
_check_no_undefined_var(outs, return_name_ids, 'if_body')
return outs
def new_false_fn():
set_args(init_args)
outs = false_fn()
_check_no_undefined_var(outs, return_name_ids, 'else_body')
return outs
Currently we support variant length of early return statement by padding
_no_return_value.
cond_outs = control_flow.cond(pred, new_true_fn, new_false_fn)
# TODO(dev): We shall consider to evaluate whether should support this for Python if-else?
"""
# IfExpr's return_name_ids maybe None
if return_name_ids is None:
return cond_outs
return outs
init_args = get_args()
# recover args state
num_outs = len(return_name_ids)
num_args = 1 if not isinstance(init_args, tuple) else len(init_args)
assert num_outs <= num_args
if num_args == 1:
final_outs = cond_outs
final_outs = outs
else:
cond_outs = (cond_outs, ) if num_outs == 1 else cond_outs
final_outs = cond_outs + init_args[num_outs:]
outs = (outs, ) if num_outs == 1 else outs
final_outs = outs + init_args[num_outs:]
set_args(final_outs)
return final_outs
def _run_py_ifelse(pred, true_fn, false_fn):
return true_fn() if pred else false_fn()
def convert_len(var):
"""
Returns variable(length) from shape ops based on var.type
......
......@@ -201,6 +201,28 @@ def test_return_without_paddle_cond(x):
return y
def two_value(x):
return x * 2, x + 1
def diff_return_hepler(x):
if False:
y = x + 1
z = x - 1
return y, z
else:
return two_value(x)
@to_static
def test_diff_return(x):
x = paddle.to_tensor(x)
y, z = diff_return_hepler(x)
if y.shape[0] > 1:
y = y + 1
return y, z
class TestReturnBase(unittest.TestCase):
def setUp(self):
......@@ -255,6 +277,12 @@ class TestReturnIf(TestReturnBase):
self.dygraph_func = test_return_if
class TestReturnIfDiff(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_diff_return
class TestReturnIfElse(TestReturnBase):
def init_dygraph_func(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册