未验证 提交 8f762790 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Reduce Exception Type for Better Error Message (#29268)

Reduce exception type so that if covert_to_static failed, it reports right error message.
上级 61a8f287
......@@ -226,7 +226,7 @@ def convert_call(func):
# So descriptor mechanism is used to bound `self` instance on function to
# keep it as bound method.
setattr(func, 'forward', forward_func.__get__(func))
except Exception:
except (IOError, OSError, TypeError):
# NOTE: func.forward may have been decorated.
func_self = None if func_self else func_self
converted_call = func
......@@ -235,7 +235,7 @@ def convert_call(func):
call_func = func.__class__.__call__
converted_call = convert_to_static(call_func)
func_self = func
except Exception:
except (IOError, OSError, TypeError):
# NOTE:
# If `func` is a class which is being initialized, for example `convert_call(Foo)()`,
# it doesn't need to be transformed
......
......@@ -17,6 +17,39 @@ import paddle
import unittest
class CallNotExist(paddle.nn.Layer):
def __call__(self):
# call a non-exist API to trigger exception
return paddle.nn.not_exist_api
class ForwardNotExist(paddle.nn.Layer):
def forward(self):
return 0
net = ForwardNotExist()
setattr(net, "forward", "A string so that convert forward will fail")
class TestConvertCall(unittest.TestCase):
def test_class_exception(self):
@paddle.jit.to_static
def call_not_exist():
net = CallNotExist()
return net()
with self.assertRaises(AttributeError):
call_not_exist()
@paddle.jit.to_static
def forward_not_exist():
return net()
with self.assertRaises(TypeError):
forward_not_exist()
class TestConvertShapeCompare(unittest.TestCase):
def test_non_variable(self):
self.assertEqual(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册