未验证 提交 9881738e 编写于 作者: L liym27 提交者: GitHub

[Dynamic-to-Static ErrorMessage] Support dy2stat error message when call...

[Dynamic-to-Static ErrorMessage] Support dy2stat error message when call jit.save and polish error message (#28713)

* Support dy2stat error message when call jit.save;

* Polish dy2stat error message:
  (1) the original dygraph code is marked with (* user code *) ; 
  (2) "In user code:" -> "In transformed code:"
上级 f1074e3b
...@@ -79,6 +79,11 @@ class TraceBackFrame(OriginInfo): ...@@ -79,6 +79,11 @@ class TraceBackFrame(OriginInfo):
self.function_name = function_name self.function_name = function_name
self.source_code = source_code self.source_code = source_code
def formated_message(self):
return ' File "{}", line {}, in {}\n\t{}'.format(
self.location.filepath, self.location.lineno, self.function_name,
self.source_code.lstrip())
class ErrorData(object): class ErrorData(object):
""" """
...@@ -106,7 +111,7 @@ class ErrorData(object): ...@@ -106,7 +111,7 @@ class ErrorData(object):
message_lines = [] message_lines = []
# Step1: Adds header message to prompt users that the following is the original information. # Step1: Adds header message to prompt users that the following is the original information.
header_message = "In user code:" header_message = "In transformed code:"
message_lines.append(header_message) message_lines.append(header_message)
message_lines.append("") message_lines.append("")
......
...@@ -69,9 +69,10 @@ class OriginInfo(object): ...@@ -69,9 +69,10 @@ class OriginInfo(object):
self.location, self.source_code, self.function_name) self.location, self.source_code, self.function_name)
def formated_message(self): def formated_message(self):
return ' File "{}", line {}, in {}\n\t{}'.format( flag_for_origin_info = "(* user code *)"
return ' File "{}", line {}, in {} {}\n\t{}'.format(
self.location.filepath, self.location.lineno, self.function_name, self.location.filepath, self.location.lineno, self.function_name,
self.source_code.lstrip()) flag_for_origin_info, self.source_code.lstrip())
def as_frame(self): def as_frame(self):
return (self.location.filepath, self.location.lineno, return (self.location.filepath, self.location.lineno,
......
...@@ -609,6 +609,9 @@ class ConcreteProgram(object): ...@@ -609,6 +609,9 @@ class ConcreteProgram(object):
except BaseException as e: except BaseException as e:
# NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here. # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here.
error.attach_error_data(e) error.attach_error_data(e)
error_data = getattr(e, error.ERROR_DATA, None)
if error_data:
error_data.raise_new_exception()
raise raise
if outputs is not None: if outputs is not None:
......
...@@ -48,83 +48,174 @@ def func_error_in_compile_time_2(x): ...@@ -48,83 +48,174 @@ def func_error_in_compile_time_2(x):
@paddle.jit.to_static @paddle.jit.to_static
def func_error_in_runtime(x, iter_num=3): def func_error_in_runtime(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32") two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")
x = fluid.layers.reshape(x, shape=[1, two]) x = fluid.layers.reshape(x, shape=[1, two])
return x return x
class TestErrorInCompileTime(unittest.TestCase): @unwrap
@paddle.jit.to_static()
def func_decorated_by_other_1():
return 1
@paddle.jit.to_static()
@unwrap
def func_decorated_by_other_2():
return 1
class LayerErrorInCompiletime(fluid.dygraph.Layer):
def __init__(self, fc_size=20):
super(LayerErrorInCompiletime, self).__init__()
self._linear = fluid.dygraph.Linear(fc_size, fc_size)
@paddle.jit.to_static(
input_spec=[paddle.static.InputSpec(
shape=[20, 20], dtype='float32')])
def forward(self, x):
y = self._linear(x)
z = fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")
out = fluid.layers.mean(y[z])
return out
class TestFlags(unittest.TestCase):
def setUp(self):
self.reset_flags_to_default()
def reset_flags_to_default(self):
# Reset flags to use defaut value
# 1. A flag to set whether to open the dygraph2static error reporting module
os.environ[error.DISABLE_ERROR_ENV_NAME] = str(
error.DEFAULT_DISABLE_NEW_ERROR)
disable_error = int(os.getenv(error.DISABLE_ERROR_ENV_NAME, 999))
self.assertEqual(disable_error, 0)
# 2. A flag to set whether to display the simplified error stack
os.environ[error.SIMPLIFY_ERROR_ENV_NAME] = str(
error.DEFAULT_SIMPLIFY_NEW_ERROR)
simplify_error = int(os.getenv(error.SIMPLIFY_ERROR_ENV_NAME, 999))
self.assertEqual(simplify_error, 1)
def _test_set_flag(self, flag_name, set_value):
os.environ[flag_name] = str(set_value)
new_value = int(os.getenv(error.DISABLE_ERROR_ENV_NAME, 999))
self.assertEqual(new_value, set_value)
def test_translator_disable_new_error(self):
self._test_set_flag(error.DISABLE_ERROR_ENV_NAME, 1)
def test_translator_simplify_new_error(self):
self._test_set_flag(error.SIMPLIFY_ERROR_ENV_NAME, 0)
class TestErrorBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.set_func()
self.set_input() self.set_input()
self.set_func()
self.set_func_call()
self.filepath = inspect.getfile(unwrap(self.func_call))
self.set_exception_type() self.set_exception_type()
self.set_message()
self.prog_trans = paddle.jit.ProgramTranslator() self.prog_trans = paddle.jit.ProgramTranslator()
self.simplify_error = 1
self.disable_error = 0
def set_func(self):
self.func = func_error_in_compile_time
def set_exception_type(self):
self.exception_type = TypeError
def set_input(self): def set_input(self):
self.input = np.ones([3, 2]) self.input = np.ones([3, 2])
def set_message(self): def set_func(self):
self.expected_message = \ raise NotImplementedError("Error test should implement set_func")
['File "{}", line 35, in func_error_in_compile_time'.format(self.filepath),
'inner_func()',
'File "{}", line 28, in inner_func'.format(self.filepath),
'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")',
]
def _test_create_message(self, error_data): def set_func_call(self):
self.filepath = inspect.getfile(unwrap(self.func)) raise NotImplementedError("Error test should implement set_func_call")
self.set_message()
error_message = error_data.create_message()
self.assertIn('In user code:', error_message) def set_exception_type(self):
for m in self.expected_message: raise NotImplementedError(
self.assertIn(m, error_message) "Error test should implement set_exception_type")
def _test_attach_and_raise_new_exception(self, func_call): def set_message(self):
raise NotImplementedError("Error test should implement set_message")
def reset_flags_to_default(self):
os.environ[error.DISABLE_ERROR_ENV_NAME] = str(
error.DEFAULT_DISABLE_NEW_ERROR)
os.environ[error.SIMPLIFY_ERROR_ENV_NAME] = str(
error.DEFAULT_SIMPLIFY_NEW_ERROR)
def disable_new_error(self):
os.environ[error.DISABLE_ERROR_ENV_NAME] = str(
1 - error.DEFAULT_DISABLE_NEW_ERROR)
def _test_new_error_message(self, new_exception, disable_new_error=0):
error_message = str(new_exception)
if disable_new_error:
# If disable new error, 'In user code:' should not in error_message.
self.assertNotIn('In transformed code:', error_message)
else:
# 1. 'In user code:' must be in error_message because it indicates that
# this is an optimized error message
self.assertIn('In transformed code:', error_message)
# 2. Check whether the converted static graph code is mapped to the dygraph code.
for m in self.expected_message:
self.assertIn(m, error_message)
def _test_raise_new_exception(self, disable_new_error=0):
paddle.disable_static() paddle.disable_static()
with self.assertRaises(self.exception_type) as cm:
func_call()
exception = cm.exception
error_data = getattr(exception, error.ERROR_DATA, None) if disable_new_error:
self.disable_new_error()
else:
self.reset_flags_to_default()
# 1. Check whether the new exception type is the same as the old one
with self.assertRaises(self.exception_type) as new_cm:
self.func_call()
new_exception = new_cm.exception
# 2. Check whether the new_exception is attached ErrorData to indicate that this is a new exception
error_data = getattr(new_exception, error.ERROR_DATA, None)
self.assertIsInstance(error_data, error.ErrorData) self.assertIsInstance(error_data, error.ErrorData)
self._test_create_message(error_data)
def test_static_layer_call(self): # 3. Check whether the error message is optimized
# NOTE: self.func(self.input) is the StaticLayer().__call__(self.input) self._test_new_error_message(new_exception, disable_new_error)
call_dy2static = lambda: self.func(self.input)
self.set_flags(0) # Situation 1: Call StaticLayer.__call__ to use Dynamic-to-Static
self._test_attach_and_raise_new_exception(call_dy2static) class TestErrorStaticLayerCallInCompiletime(TestErrorBase):
def set_func(self):
self.func = func_error_in_compile_time
def set_input(self):
self.input = np.ones([3, 2])
def test_program_translator_get_output(self): def set_exception_type(self):
call_dy2static = lambda : self.prog_trans.get_output(unwrap(self.func), self.input) self.exception_type = TypeError
self.set_flags(0) def set_message(self):
self._test_attach_and_raise_new_exception(call_dy2static) self.expected_message = \
['File "{}", line 35, in func_error_in_compile_time'.format(self.filepath),
'inner_func()',
'File "{}", line 28, in inner_func'.format(self.filepath),
'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")',
]
def set_flags(self, disable_error=0, simplify_error=1): def set_func_call(self):
os.environ[error.DISABLE_ERROR_ENV_NAME] = str(disable_error) # NOTE: self.func(self.input) is the StaticLayer().__call__(self.input)
self.disable_error = int(os.getenv(error.DISABLE_ERROR_ENV_NAME, 0)) self.func_call = lambda: self.func(self.input)
self.assertEqual(self.disable_error, disable_error)
os.environ[error.SIMPLIFY_ERROR_ENV_NAME] = str(simplify_error) def test_error(self):
self.simplify_error = int(os.getenv(error.SIMPLIFY_ERROR_ENV_NAME, 1)) for disable_new_error in [0, 1]:
self.assertEqual(self.simplify_error, simplify_error) self._test_raise_new_exception(disable_new_error)
class TestErrorInCompileTime2(TestErrorInCompileTime): class TestErrorStaticLayerCallInCompiletime_2(
TestErrorStaticLayerCallInCompiletime):
def set_func(self): def set_func(self):
self.func = func_error_in_compile_time_2 self.func = func_error_in_compile_time_2
...@@ -132,7 +223,6 @@ class TestErrorInCompileTime2(TestErrorInCompileTime): ...@@ -132,7 +223,6 @@ class TestErrorInCompileTime2(TestErrorInCompileTime):
self.exception_type = ValueError self.exception_type = ValueError
def set_message(self): def set_message(self):
self.expected_message = \ self.expected_message = \
[ [
'File "{}", line 46, in func_error_in_compile_time_2'.format(self.filepath), 'File "{}", line 46, in func_error_in_compile_time_2'.format(self.filepath),
...@@ -140,7 +230,7 @@ class TestErrorInCompileTime2(TestErrorInCompileTime): ...@@ -140,7 +230,7 @@ class TestErrorInCompileTime2(TestErrorInCompileTime):
] ]
class TestErrorInRuntime(TestErrorInCompileTime): class TestErrorStaticLayerCallInRuntime(TestErrorStaticLayerCallInCompiletime):
def set_func(self): def set_func(self):
self.func = func_error_in_runtime self.func = func_error_in_runtime
...@@ -154,33 +244,50 @@ class TestErrorInRuntime(TestErrorInCompileTime): ...@@ -154,33 +244,50 @@ class TestErrorInRuntime(TestErrorInCompileTime):
'x = fluid.layers.reshape(x, shape=[1, two])' 'x = fluid.layers.reshape(x, shape=[1, two])'
] ]
def _test_create_message(self, error_data):
self.filepath = inspect.getfile(unwrap(self.func))
self.set_message()
with self.assertRaises(ValueError): # Situation 2: Call ProgramTranslator().get_output(...) to use Dynamic-to-Static
error_data.create_message() class TestErrorGetOutputInCompiletime(TestErrorStaticLayerCallInCompiletime):
def set_func_call(self):
self.func_call = lambda : self.prog_trans.get_output(unwrap(self.func), self.input)
error_data.in_runtime = False
error_message = error_data.create_message()
self.assertIn('In user code:', error_message) class TestErrorGetOutputInCompiletime_2(
for m in self.expected_message: TestErrorStaticLayerCallInCompiletime_2):
self.assertIn(m, error_message) def set_func_call(self):
self.func_call = lambda : self.prog_trans.get_output(unwrap(self.func), self.input)
@unwrap class TestErrorGetOutputInRuntime(TestErrorStaticLayerCallInRuntime):
@paddle.jit.to_static() def set_func_call(self):
def func_decorated_by_other_1(): self.func_call = lambda : self.prog_trans.get_output(unwrap(self.func), self.input)
return 1
@paddle.jit.to_static() class TestJitSaveInCompiletime(TestErrorBase):
@unwrap def setUp(self):
def func_decorated_by_other_2(): self.reset_flags_to_default()
return 1 self.set_func_call()
self.filepath = inspect.getfile(unwrap(self.func_call))
self.set_exception_type()
self.set_message()
def set_exception_type(self):
self.exception_type = TypeError
def set_message(self):
self.expected_message = \
['File "{}", line 80, in forward'.format(self.filepath),
'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")',
]
def set_func_call(self):
layer = LayerErrorInCompiletime()
self.func_call = lambda : paddle.jit.save(layer, path="./test_dy2stat_error/model")
def test_error(self):
self._test_raise_new_exception()
# Situation 4: NotImplementedError
class TestErrorInOther(unittest.TestCase): class TestErrorInOther(unittest.TestCase):
def test(self): def test(self):
paddle.disable_static() paddle.disable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册