未验证 提交 72e1eb6b 编写于 作者: X xiongkun 提交者: GitHub

[CherryPick] Cherry pick #45916 #46031 #47299 (#47610)

* [ Dy2Static ] Fix bugs when select inputs meeting different shape or undefined-var (#45916)

* fix select_input with different shape errors:
1. select_input_with_buildin_type directly return non-undefinedvar branch when meeting undefined var
2. the output shape of select_input is inferred from inputs.

* reverse the logic in select_input

* [warning] added warning message in cond block when one branch returns variable and another returns None (#46031)

* [cherry-pick] Allow manaully set py_reader name in standalone executor (#45898) (#45931)

* Allow manaully set py_reader name in standalone executor

* [BugFix] while cond receives dict as input (#47299)

* fix bugs while cond receives dict as input

* add unittest

* change flatten -> _is_sequence_except_dict

* code format
Co-authored-by: Nfeifei-111 <wuzhanfei@baidu.com>
上级 cfee9c13
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import warnings
from paddle.fluid.dygraph.dygraph_to_static.program_translator import (
convert_to_static,
)
from paddle.fluid.layers.control_flow import cond
@paddle.jit.to_static
def fun1():
a = paddle.to_tensor(1)
b = paddle.to_tensor(2)
if a > b:
b = paddle.to_tensor(3)
else:
b = None
def true_fn():
return [paddle.to_tensor(1), [paddle.to_tensor(2), paddle.to_tensor(3)]]
def false_fn():
return [paddle.to_tensor(3), [None, paddle.to_tensor(4)]]
class TestReturnNoneInIfelse(unittest.TestCase):
def test_dy2static_warning(self):
paddle.disable_static()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
fun1()
flag = False
for warn in w:
if (
issubclass(warn.category, UserWarning)
) and "Set var to 'None' in ifelse block might lead to error." in str(
warn.message
):
flag = True
break
self.assertTrue(flag)
def test_cond_warning(self):
paddle.enable_static()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
a = paddle.to_tensor(1)
b = paddle.to_tensor(2)
cond(a < b, true_fn, false_fn, return_names=['ret1', 'ret2'])
flag = False
for warn in w:
if (
issubclass(warn.category, UserWarning)
) and "Set var to 'None' in ifelse block might lead to error." in str(
warn.message
):
flag = True
break
self.assertTrue(flag)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册