未验证 提交 1a8c9692 编写于 作者: F feifei-111 提交者: GitHub

[warning] added warning message in cond block when one branch returns variable...

[warning] added warning message in cond block when one branch returns variable and another returns None (#46031)

* [cherry-pick] Allow manaully set py_reader name in standalone executor (#45898) (#45931)

* Allow manaully set py_reader name in standalone executor
上级 97cdc7c4
...@@ -2653,15 +2653,22 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): ...@@ -2653,15 +2653,22 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
# Merge ture and false output if they are not None # Merge ture and false output if they are not None
if return_names is None: if return_names is None:
is_dy2staic = False
return_names = ["no name"] * len(to_sequence(true_output)) return_names = ["no name"] * len(to_sequence(true_output))
else: else:
""" """
dy2static will set the return_names and expand the return values to UndefinedVar. dy2static will set the return_names and expand the return values to UndefinedVar.
""" """
is_dy2staic = True
# TODO: expand_undefined_var will replace None to Undefinedvar(), to fix cases like:
# a = None
# if condition:
# a = 1
# Because we can not use variable to express 'None'
true_output, false_output = expand_undefined_var( true_output, false_output = expand_undefined_var(
true_output, false_output, return_names) true_output, false_output, return_names)
true_output, false_output = change_none_to_undefinedvar(
true_output, false_output)
if len(to_sequence(true_output)) != len(to_sequence(false_output)): if len(to_sequence(true_output)) != len(to_sequence(false_output)):
raise ValueError( raise ValueError(
"true fn returns {} vars, but false fn returns {} vars, which is not equals" "true fn returns {} vars, but false fn returns {} vars, which is not equals"
...@@ -2677,6 +2684,28 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): ...@@ -2677,6 +2684,28 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
"Incompatible return values of `{}` in true_fn and false_fn in cond: {}" "Incompatible return values of `{}` in true_fn and false_fn in cond: {}"
.format(return_name, e)) .format(return_name, e))
def check_ret_none(seq_true, seq_false, seq_names):
length = len(seq_true)
for i in range(length):
f_true = flatten(seq_true[i])
f_false = flatten(seq_false[i])
for idx in range(len(f_true)):
if f_true[idx] is None and f_false[idx] is not None or f_false[
idx] is None and f_true[idx] is not None:
warnings.warn(
"In cond : Var '{}' or part of it is set differently in ifelse branchs, "
"<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
"'None' in ifelse block might lead to error.".format(
seq_names[i], type(f_true[idx]), f_true[idx],
type(f_false[idx]), f_false[idx]))
check_ret_none(to_sequence(true_output), to_sequence(false_output),
to_sequence(return_names))
if is_dy2staic:
true_output, false_output = change_none_to_undefinedvar(
true_output, false_output)
mask = cast(pred, dtype='int32') mask = cast(pred, dtype='int32')
merge_func = lambda name, false_var, true_var: select_input_with_buildin_type( merge_func = lambda name, false_var, true_var: select_input_with_buildin_type(
[false_var, true_var], mask, name) [false_var, true_var], mask, name)
...@@ -2716,16 +2745,31 @@ def expand_undefined_var(nest1, nest2, names): ...@@ -2716,16 +2745,31 @@ def expand_undefined_var(nest1, nest2, names):
return pack_sequence_as(seq, return pack_sequence_as(seq,
[UndefinedVar("padding") for i in flatten(seq)]) [UndefinedVar("padding") for i in flatten(seq)])
def map_fn(n1, n2, name): def map_fn(n1, n2, name, order):
if not name.startswith(RETURN_VALUE_PREFIX) and (isinstance( if not name.startswith(RETURN_VALUE_PREFIX) and (isinstance(
n1, UndefinedVar) or n1 is None): n1, UndefinedVar) or n1 is None):
if n1 is None and n2 is not None:
if order == 0:
warnings.warn(
"In cond : Var '{}' or part of it is set differently in ifelse branchs, "
"<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
"'None' in ifelse block might lead to error.".format(
name, type(n1), n1, type(n2), n2))
else:
warnings.warn(
"In cond : Var '{}' or part of it is set differently in ifelse branchs, "
"<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
"'None' in ifelse block might lead to error.".format(
name, type(n2), n2, type(n1), n1))
return pack_undefined_var_as(n2) return pack_undefined_var_as(n2)
return n1 return n1
nest1_out = list( nest1_out = list(
map(map_fn, to_sequence(nest1), to_sequence(nest2), to_sequence(names))) map(map_fn, to_sequence(nest1), to_sequence(nest2), to_sequence(names),
[0 for i in to_sequence(names)]))
nest2_out = list( nest2_out = list(
map(map_fn, to_sequence(nest2), to_sequence(nest1), to_sequence(names))) map(map_fn, to_sequence(nest2), to_sequence(nest1), to_sequence(names),
[1 for i in to_sequence(names)]))
if not is_sequence(nest1): nest1_out = nest1_out[0] if not is_sequence(nest1): nest1_out = nest1_out[0]
if not is_sequence(nest2): nest2_out = nest2_out[0] if not is_sequence(nest2): nest2_out = nest2_out[0]
return nest1_out, nest2_out return nest1_out, nest2_out
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import warnings
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
from paddle.fluid.layers.control_flow import cond
@paddle.jit.to_static
def fun1():
a = paddle.to_tensor(1)
b = paddle.to_tensor(2)
if a > b:
b = paddle.to_tensor(3)
else:
b = None
def true_fn():
return [paddle.to_tensor(1), [paddle.to_tensor(2), paddle.to_tensor(3)]]
def false_fn():
return [paddle.to_tensor(3), [None, paddle.to_tensor(4)]]
class TestReturnNoneInIfelse(unittest.TestCase):
def test_dy2static_warning(self):
paddle.disable_static()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
fun1()
flag = False
for warn in w:
if (
issubclass(warn.category, UserWarning)
) and "Set var to 'None' in ifelse block might lead to error." in str(
warn.message):
flag = True
break
self.assertTrue(flag)
def test_cond_warning(self):
paddle.enable_static()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
a = paddle.to_tensor(1)
b = paddle.to_tensor(2)
cond(a < b, true_fn, false_fn, return_names=['ret1', 'ret2'])
flag = False
for warn in w:
if (
issubclass(warn.category, UserWarning)
) and "Set var to 'None' in ifelse block might lead to error." in str(
warn.message):
flag = True
break
self.assertTrue(flag)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册