未验证 提交 9598b19c 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2Stat]Allow ifelse return buildin type in paddle cond (#37888)

* allow ifelse return `int` in paddle cond

* add test and refine code

* polish code, add test

* code format
上级 099cb75a
...@@ -248,20 +248,6 @@ def _remove_no_value_return_var(out): ...@@ -248,20 +248,6 @@ def _remove_no_value_return_var(out):
def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args, def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,
return_vars): return_vars):
return_var_ids = [id(var) for var in return_vars]
# NOTE 1: Returned vars of Paddle op `control_flow.cond` must be Paddle Tensors
# NOTE 2: Here uses id(var) not var, because `if var in return_var` use operator `==`,
# which will call `fluid.layers.equal` and causes error when var in return_vars is not initialized.
true_args = [
to_static_variable(var) if id(var) in return_var_ids else var
for var in true_args
]
false_args = [
to_static_variable(var) if id(var) in return_var_ids else var
for var in false_args
]
pred = cast_bool_if_necessary(pred) pred = cast_bool_if_necessary(pred)
return control_flow.cond(pred, lambda: true_fn(*true_args), return control_flow.cond(pred, lambda: true_fn(*true_args),
lambda: false_fn(*false_args)) lambda: false_fn(*false_args))
......
...@@ -102,6 +102,41 @@ def select_input(inputs, mask): ...@@ -102,6 +102,41 @@ def select_input(inputs, mask):
return out return out
def select_input_with_buildin_type(inputs, mask):
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
support_ret_buildin_type = (bool, float, six.integer_types)
false_var, true_var = inputs
if isinstance(false_var, Variable) and isinstance(true_var, Variable):
return select_input(inputs, mask)
elif (isinstance(false_var, (support_ret_buildin_type)) and
isinstance(false_var, type(true_var))):
if false_var == true_var:
return false_var
else:
inputs = [
to_static_variable(false_var), to_static_variable(true_var)
]
# Deal with the situations like this: false_var is int and true_var is Variable
elif ((isinstance(false_var, support_ret_buildin_type) and
isinstance(true_var, Variable)) or
(isinstance(true_var, support_ret_buildin_type) and
isinstance(false_var, Variable))):
inputs = [to_static_variable(false_var), to_static_variable(true_var)]
warnings.warn(
"Return results from different branches in cond are not same type: "
"false_var returned by fasle_fn is '{}' and true_var of true_fn is "
"'{}'".format(type(false_var), type(true_var)))
else:
raise TypeError(
"Unsupported return type of true_fn and false_fn in cond: false_var "
"returned by fasle_fn is '{}' and true_var of true_fn is '{}'".
format(type(false_var), type(true_var)))
return select_input(inputs, mask)
def split_lod_tensor(input, mask, level=0): def split_lod_tensor(input, mask, level=0):
""" """
This function takes in an input that contains the complete lod information, This function takes in an input that contains the complete lod information,
...@@ -2282,8 +2317,8 @@ class ConditionalBlock(object): ...@@ -2282,8 +2317,8 @@ class ConditionalBlock(object):
def copy_var_to_parent_block(var, layer_helper): def copy_var_to_parent_block(var, layer_helper):
if var is None: if not isinstance(var, Variable):
return None return var
prog = layer_helper.main_program prog = layer_helper.main_program
parent_idx = prog.current_block().parent_idx parent_idx = prog.current_block().parent_idx
assert parent_idx >= 0, "Got wrong parent block index when assigning var to parent scope in control_flow" assert parent_idx >= 0, "Got wrong parent block index when assigning var to parent scope in control_flow"
...@@ -2466,7 +2501,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None): ...@@ -2466,7 +2501,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
format(e)) format(e))
mask = cast(pred, dtype='int32') mask = cast(pred, dtype='int32')
merge_func = lambda false_var, true_var : select_input([false_var, true_var], mask) merge_func = lambda false_var, true_var : select_input_with_buildin_type([false_var, true_var], mask)
merged_output = map_structure(merge_func, false_output, true_output) merged_output = map_structure(merge_func, false_output, true_output)
return merged_output return merged_output
......
...@@ -340,3 +340,53 @@ def if_tensor_case(x): ...@@ -340,3 +340,53 @@ def if_tensor_case(x):
x += 1 x += 1
return x return x
def dyfunc_ifelse_ret_int1(x):
index = 0
pred = paddle.to_tensor([1])
if pred:
y = x[index] + 1
index = index + 1
return y, index
else:
y = x[index] + 2
index = index + 1
return y, index
def dyfunc_ifelse_ret_int2(x):
index = 0
pred = paddle.to_tensor([1])
if pred:
y = x[index] + 1
index = index + 1
return y, index
else:
y = x[index] + 2
index = index + 1
return y
def dyfunc_ifelse_ret_int3(x):
index = 0
pred = paddle.to_tensor([1])
if pred:
y = x[index] + 1
index = index + 1
return index
else:
y = x[index] + 2
return y
def dyfunc_ifelse_ret_int4(x):
index = 0
pred = paddle.to_tensor([1])
if pred:
y = x[index] + 1
index = index + 1
return 'unsupport ret'
else:
y = x[index] + 2
return y
...@@ -365,5 +365,60 @@ class TestNewVarCreateInOneBranch(unittest.TestCase): ...@@ -365,5 +365,60 @@ class TestNewVarCreateInOneBranch(unittest.TestCase):
self.assertEqual(paddle.jit.to_static(case_func)(True), -2) self.assertEqual(paddle.jit.to_static(case_func)(True), -2)
class TestDy2StIfElseRetInt1(unittest.TestCase):
def setUp(self):
self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int1
self.out = self.get_dy2stat_out()
def get_dy2stat_out(self):
ProgramTranslator().enable(True)
static_func = paddle.jit.to_static(self.dyfunc)
out = static_func(self.x)
ProgramTranslator().enable(False)
return out
def test_ast_to_func(self):
self.assertIsInstance(self.out[0], paddle.Tensor)
self.assertIsInstance(self.out[1], int)
class TestDy2StIfElseRetInt2(TestDy2StIfElseRetInt1):
def setUp(self):
self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int2
self.out = self.get_dy2stat_out()
def test_ast_to_func(self):
self.assertIsInstance(self.out[0], paddle.Tensor)
self.assertIsInstance(self.out[1], paddle.Tensor)
class TestDy2StIfElseRetInt3(TestDy2StIfElseRetInt1):
def setUp(self):
self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int3
self.out = self.get_dy2stat_out()
def test_ast_to_func(self):
self.assertIsInstance(self.out, paddle.Tensor)
class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1):
def setUp(self):
self.x = np.random.random([5]).astype('float32')
self.dyfunc = dyfunc_ifelse_ret_int4
def test_ast_to_func(self):
with self.assertRaises(TypeError):
ProgramTranslator().enable(True)
static_func = paddle.jit.to_static(self.dyfunc)
out = static_func(self.x)
def __del__(self):
ProgramTranslator().enable(False)
super(TestDy2StIfElseRetInt4, self).__del__()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册