From 6e0aa776013bc35e9c770e5dd9bfedf906a5cd80 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 28 Jun 2022 20:43:02 +0800 Subject: [PATCH] [Dy2Stat]Enhance Python if-else by pruning usless no_return variable (#43880) --- .../dygraph_to_static/convert_operators.py | 76 ++++++++++++------- .../dygraph_to_static/test_return.py | 28 +++++++ 2 files changed, 75 insertions(+), 29 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index bf97362ab7c..cbb4655f354 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -209,9 +209,44 @@ def convert_ifelse(pred, true_fn, false_fn, get_args, set_args, out = _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args, return_name_ids) else: - out = _run_py_ifelse(pred, true_fn, false_fn) + out = _run_py_ifelse(pred, true_fn, false_fn, get_args, set_args, + return_name_ids) - return _remove_no_value_return_var(out) + return out + + +def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args, + return_name_ids): + """ + Paddle cond API will evaluate both ture_fn and false_fn codes. + """ + pred = cast_bool_if_necessary(pred) + init_args = get_args() + + def new_true_fn(): + set_args(init_args) + outs = true_fn() + _check_no_undefined_var(outs, return_name_ids, 'if_body') + return outs + + def new_false_fn(): + set_args(init_args) + outs = false_fn() + _check_no_undefined_var(outs, return_name_ids, 'else_body') + return outs + + cond_outs = control_flow.cond(pred, new_true_fn, new_false_fn) + return _recover_args_state(cond_outs, get_args, set_args, return_name_ids) + + +def _run_py_ifelse(pred, true_fn, false_fn, get_args, set_args, + return_name_ids): + """ + Evaluate python original branch function if-else. + """ + py_outs = true_fn() if pred else false_fn() + py_outs = _remove_no_value_return_var(py_outs) + return _recover_args_state(py_outs, get_args, set_args, return_name_ids) def _remove_no_value_return_var(out): @@ -258,50 +293,33 @@ def _check_no_undefined_var(outs, names, branch_name): .format(name, branch_name)) -def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args, - return_name_ids): +def _recover_args_state(outs, get_args, set_args, return_name_ids): """ - Paddle cond API will evaluate both ture_fn and false_fn codes. - """ - pred = cast_bool_if_necessary(pred) - init_args = get_args() - - def new_true_fn(): - set_args(init_args) - outs = true_fn() - _check_no_undefined_var(outs, return_name_ids, 'if_body') - return outs - - def new_false_fn(): - set_args(init_args) - outs = false_fn() - _check_no_undefined_var(outs, return_name_ids, 'else_body') - return outs + Currently we support variant length of early return statement by padding + _no_return_value. - cond_outs = control_flow.cond(pred, new_true_fn, new_false_fn) + # TODO(dev): We shall consider to evaluate whether should support this for Python if-else? + """ # IfExpr's return_name_ids maybe None if return_name_ids is None: - return cond_outs + return outs + init_args = get_args() # recover args state num_outs = len(return_name_ids) num_args = 1 if not isinstance(init_args, tuple) else len(init_args) assert num_outs <= num_args if num_args == 1: - final_outs = cond_outs + final_outs = outs else: - cond_outs = (cond_outs, ) if num_outs == 1 else cond_outs - final_outs = cond_outs + init_args[num_outs:] + outs = (outs, ) if num_outs == 1 else outs + final_outs = outs + init_args[num_outs:] set_args(final_outs) return final_outs -def _run_py_ifelse(pred, true_fn, false_fn): - return true_fn() if pred else false_fn() - - def convert_len(var): """ Returns variable(length) from shape ops based on var.type diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py index 07e3fe518c2..a5a6b146769 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_return.py @@ -201,6 +201,28 @@ def test_return_without_paddle_cond(x): return y +def two_value(x): + return x * 2, x + 1 + + +def diff_return_hepler(x): + if False: + y = x + 1 + z = x - 1 + return y, z + else: + return two_value(x) + + +@to_static +def test_diff_return(x): + x = paddle.to_tensor(x) + y, z = diff_return_hepler(x) + if y.shape[0] > 1: + y = y + 1 + return y, z + + class TestReturnBase(unittest.TestCase): def setUp(self): @@ -255,6 +277,12 @@ class TestReturnIf(TestReturnBase): self.dygraph_func = test_return_if +class TestReturnIfDiff(TestReturnBase): + + def init_dygraph_func(self): + self.dygraph_func = test_diff_return + + class TestReturnIfElse(TestReturnBase): def init_dygraph_func(self): -- GitLab