未验证 提交 fe01ba6a 编写于 作者: 0 0x45f 提交者: GitHub

remove no_value using var.name (#36513)

* remove no_value using var.name

* fix unit test for CI

* fix unit test

* add test case

* fix test case

* add more test case
上级 51c97d9f
......@@ -20,6 +20,7 @@ from paddle.fluid.layers import array_length, array_read, array_write, create_ar
from paddle.fluid.layers import assign, fill_constant, slice, reduce_all, reduce_any
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME
def convert_while_loop(cond, body, loop_vars):
......@@ -204,10 +205,45 @@ def convert_ifelse(pred, true_fn, false_fn, true_args, false_args, return_vars):
"""
if isinstance(pred, Variable):
return _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,
return_vars)
out = _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,
return_vars)
else:
return _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args)
out = _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args)
return _remove_no_value_return_var(out)
def _remove_no_value_return_var(out):
if out and isinstance(out, tuple):
processed_out = out
align_ret = out[0]
if isinstance(align_ret, tuple):
for index, item in enumerate(align_ret):
if isinstance(item, Variable) and (
RETURN_NO_VALUE_VAR_NAME in item.name):
# return None
if index == 0:
processed_out = (None, ) + out[1:]
elif index == 1:
processed_out = align_ret[:1] + out[1:]
else:
processed_out = (align_ret[:index], ) + out[1:]
break
for index, item in enumerate(processed_out):
if isinstance(item, Variable) and (
RETURN_NO_VALUE_VAR_NAME in item.name):
processed_out = processed_out[:index]
if not processed_out:
return None
elif len(processed_out) == 1:
return processed_out[0]
else:
return processed_out
else:
return out
def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,
......
......@@ -93,14 +93,14 @@ def create_fill_constant_node(name, value):
func_code = "{} = paddle.fluid.layers.fill_constant(shape=[1], ".format(
name)
if isinstance(value, bool):
func_code += "dtype='bool', value={})".format(value)
func_code += "dtype='bool', value={}, name='{}')".format(value, name)
return gast.parse(func_code).body[0]
if isinstance(value, float):
func_code += "dtype='float64', value={})".format(value)
func_code += "dtype='float64', value={}, name='{}')".format(value, name)
return gast.parse(func_code).body[0]
if isinstance(value, int):
func_code += "dtype='int64', value={})".format(value)
func_code += "dtype='int64', value={}, name='{}')".format(value, name)
return gast.parse(func_code).body[0]
......
......@@ -261,5 +261,100 @@ class TestChooseShapeAttrOrApiWithLayer(unittest.TestCase):
self.assertTrue(np.array_equal(out.numpy(), x.numpy()))
class TestIfElseNoValue(unittest.TestCase):
def test_else_ret_none(self):
input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
@paddle.jit.to_static
def with_common_value(x, use_cache=False):
if use_cache:
y = x + 1
z = x + 2
return y, z
else:
c = x + 1
z = x - 1
return None
@paddle.jit.to_static
def without_common_value(x, use_cache=False):
if use_cache:
y = x + 1
z = x + 2
return y, z
else:
c = x + 1
return None
out = with_common_value(input_x, False)
self.assertIsNone(out)
out = without_common_value(input_x, False)
self.assertIsNone(out)
def test_else_ret_c(self):
input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
@paddle.jit.to_static
def with_common_value(x, use_cache=False):
if use_cache:
y = x + 1
z = x + 2
return y, z
else:
c = x + 1
z = x - 1
return c
@paddle.jit.to_static
def without_common_value(x, use_cache=False):
if use_cache:
y = x + 1
z = x + 2
return y, z
else:
c = x + 1
return c
out = with_common_value(input_x, False)
self.assertListEqual(paddle.tolist(out), paddle.tolist(input_x + 1))
out = without_common_value(input_x, False)
self.assertListEqual(paddle.tolist(out), paddle.tolist(input_x + 1))
y, z = with_common_value(input_x, True)
self.assertListEqual(paddle.tolist(y), paddle.tolist(input_x + 1))
self.assertListEqual(paddle.tolist(z), paddle.tolist(input_x + 2))
def test_else_ret_cz(self):
input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
@paddle.jit.to_static
def with_common_value(x, use_cache=False):
if use_cache:
y = x + 1
z = x + 2
return y, z, 1
else:
c = x + 1
z = x - 1
return c, z
@paddle.jit.to_static
def without_common_value(x, use_cache=False):
if use_cache:
y = x + 1
z = x + 2
return y, z, 1
else:
c = x + 1
d = x - 1
return c, d
c, z = with_common_value(input_x, False)
self.assertListEqual(paddle.tolist(c), paddle.tolist(input_x + 1))
self.assertListEqual(paddle.tolist(z), paddle.tolist(input_x - 1))
c, d = without_common_value(input_x, False)
self.assertListEqual(paddle.tolist(c), paddle.tolist(input_x + 1))
self.assertListEqual(paddle.tolist(d), paddle.tolist(input_x - 1))
if __name__ == '__main__':
unittest.main()
......@@ -64,7 +64,7 @@ def get_source_code(func):
class StaticCode1():
def dyfunc_with_if_else(x_v, label=None):
__return_value_init_0 = paddle.fluid.layers.fill_constant(
shape=[1], dtype='float64', value=0.0)
shape=[1], dtype='float64', value=0.0, name='__return_value_init_0')
__return_value_0 = __return_value_init_0
def true_fn_0(x_v):
......@@ -116,7 +116,7 @@ class StaticCode2():
# TODO: Transform return statement
def dyfunc_with_if_else(x_v, label=None):
__return_value_init_1 = paddle.fluid.layers.fill_constant(
shape=[1], dtype='float64', value=0.0)
shape=[1], dtype='float64', value=0.0, name='__return_value_init_1')
__return_value_1 = __return_value_init_1
def true_fn_3(x_v):
......
......@@ -50,16 +50,22 @@ class TestDataLayerNotCheck(unittest.TestCase):
class TestVariableTransFunc(unittest.TestCase):
def test_create_fill_constant_node(self):
node = create_fill_constant_node("a", 1.0)
source = "a = paddle.fluid.layers.fill_constant(shape=[1], dtype='float64', value=1.0)"
self.assertEqual(ast_to_source_code(node).strip(), source)
source = "a = paddle.fluid.layers.fill_constant(shape=[1], dtype='float64', value=1.0, name='a')"
self.assertEqual(
ast_to_source_code(node).replace('\n', '').replace(' ', ''),
source.replace(' ', ''))
node = create_fill_constant_node("b", True)
source = "b = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=True)"
self.assertEqual(ast_to_source_code(node).strip(), source)
source = "b = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=True, name='b')"
self.assertEqual(
ast_to_source_code(node).replace('\n', '').replace(' ', ''),
source.replace(' ', ''))
node = create_fill_constant_node("c", 4293)
source = "c = paddle.fluid.layers.fill_constant(shape=[1], dtype='int64', value=4293)"
self.assertEqual(ast_to_source_code(node).strip(), source)
source = "c = paddle.fluid.layers.fill_constant(shape=[1], dtype='int64', value=4293, name='c')"
self.assertEqual(
ast_to_source_code(node).replace('\n', '').replace(' ', ''),
source.replace(' ', ''))
self.assertIsNone(create_fill_constant_node("e", None))
self.assertIsNone(create_fill_constant_node("e", []))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册