未验证 提交 213427e4 编写于 作者: F feifei-111 提交者: GitHub

[error msg] update error msg of user call numpy api and set_state_dict in dy2static (#46128)

* add numpy api err msg

* fix bug

* fix unittest

* add set_state_dict err

* rewrite numpy_api_check

* add define

* change err msg

* fix test

* move import statement
上级 4ae37aee
......@@ -18,8 +18,10 @@ import sys
import traceback
import linecache
import re
import numpy as np
from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginInfo, global_origin_info_map
from paddle.fluid.dygraph.dygraph_to_static.utils import _is_api_in_module_helper, RE_PYMODULE
ERROR_DATA = "Error data about original source code information and traceback."
......@@ -66,6 +68,7 @@ class TraceBackFrame(OriginInfo):
self.location = location
self.function_name = function_name
self.source_code = source_code
self.error_line = ''
def formated_message(self):
# self.source_code may be empty in some functions.
......@@ -85,6 +88,7 @@ class TraceBackFrameRange(OriginInfo):
self.location = location
self.function_name = function_name
self.source_code = []
self.error_line = ''
blank_count = []
begin_lineno = max(1, self.location.lineno - int(SOURCE_CODE_RANGE / 2))
......@@ -98,6 +102,7 @@ class TraceBackFrameRange(OriginInfo):
blank_count.append(len(line) - len(line_lstrip))
if i == self.location.lineno:
self.error_line = self.source_code[-1]
hint_msg = '~' * len(self.source_code[-1]) + ' <--- HERE'
self.source_code.append(hint_msg)
blank_count.append(blank_count[-1])
......@@ -170,6 +175,36 @@ class ErrorData(object):
setattr(new_exception, ERROR_DATA, self)
return new_exception
def numpy_api_check(self, format_exception, error_line):
if self.error_type is not TypeError:
return format_exception
tb = self.origin_traceback
func_str = None
for frame in tb:
searched_name = re.search(
r'({module})*{name}'.format(module=RE_PYMODULE,
name=frame.name), error_line)
if searched_name:
func_str = searched_name.group(0)
break
try:
module_result = eval("_is_api_in_module_helper({}, '{}')".format(
func_str, "numpy"))
is_numpy_api_err = module_result or (func_str.startswith("numpy.")
or func_str.startswith("np."))
except Exception:
is_numpy_api_err = False
if is_numpy_api_err and func_str:
return [
"TypeError: Code '{}' called numpy API {}, please use Paddle API to replace it."
.format(error_line, func_str),
" values will be changed to variables by dy2static, numpy api can not handle variables"
]
else:
return format_exception
def create_message(self):
"""
Creates a custom error message which includes trace stack with source code information of dygraph from user.
......@@ -180,6 +215,7 @@ class ErrorData(object):
header_message = "In transformed code:"
message_lines.append(header_message)
message_lines.append("")
error_line = None
# Simplify error value to improve readability if error is raised in runtime
if self.in_runtime:
......@@ -213,6 +249,7 @@ class ErrorData(object):
dygraph_func_info.source_code)
message_lines.append(traceback_frame.formated_message())
error_line = traceback_frame.error_line
message_lines.append("")
# Add paddle traceback after user code traceback
......@@ -230,6 +267,10 @@ class ErrorData(object):
# is gather than 1, for example, the error_type is IndentationError.
format_exception = traceback.format_exception_only(
self.error_type, self.error_value)
if error_line is not None:
format_exception = self.numpy_api_check(format_exception,
error_line)
error_message = [
" " * BLANK_COUNT_BEFORE_FILE_STR + line
for line in format_exception
......
......@@ -324,8 +324,8 @@ def is_numpy_api(node):
func_str, "numpy"))
# BUG: np.random.uniform doesn't have module and cannot be analyzed
# TODO: find a better way
if not module_result:
return func_str.startswith("numpy.") or func_str.startswith("np.")
return module_result or (func_str.startswith("numpy.")
or func_str.startswith("np."))
except Exception:
return False
......
......@@ -1538,13 +1538,18 @@ class Layer(object):
place = core.CUDAPlace(p.gpu_device_id())
t.set(ndarray, place)
try:
executor = Executor(_get_device())._default_executor
# restore parameter states
core._create_loaded_parameter(
[param for param, state in matched_param_state], global_scope(),
executor)
[param for param, state in matched_param_state],
global_scope(), executor)
for param, state in matched_param_state:
_set_var(param, state)
except ValueError as e:
raise ValueError(
"This error might happens in dy2static, while calling 'set_state_dict' dynamicly in 'forward', which is not supported. If you only need call 'set_state_dict' once, move it to '__init__'."
)
def to(self, device=None, dtype=None, blocking=None):
'''
......
......@@ -448,5 +448,63 @@ class TestKeyError(unittest.TestCase):
x = paddle.to_tensor([1])
func_ker_error(x)
@paddle.jit.to_static
def NpApiErr():
a = paddle.to_tensor([1,2])
b = np.sum(a.numpy())
print(b)
class TestNumpyApiErr(unittest.TestCase):
def test_numpy_api_err(self):
with self.assertRaises(TypeError) as e:
NpApiErr()
new_exception = e.exception
error_data = getattr(new_exception, error.ERROR_DATA, None)
self.assertIsInstance(error_data, error.ErrorData)
error_message = str(new_exception)
self.assertIn("values will be changed to variables by dy2static, numpy api can not handle variables", error_message)
class test_set_state_dict_err_layer(paddle.nn.Layer):
def __init__(self):
super(test_set_state_dict_err_layer, self).__init__()
self.linear = paddle.nn.Linear(5, 2)
@paddle.jit.to_static
def forward(self, x):
old_dict = self.state_dict()
wgt = old_dict['linear.weight']
drop_w = paddle.nn.functional.dropout(wgt)
old_dict['linear.weight'] = drop_w
# old_dict['linear.weight'][0][0] = 0.01
self.set_state_dict(old_dict)
y = self.linear(x)
return y
class TestSetStateDictErr(unittest.TestCase):
def test_set_state_dict_err(self):
with self.assertRaises(ValueError) as e:
layer = test_set_state_dict_err_layer()
x = paddle.to_tensor([1.,2.,3.,4.,5.])
y = layer(x)
new_exception = e.exception
error_data = getattr(new_exception, error.ERROR_DATA, None)
self.assertIsInstance(error_data, error.ErrorData)
error_message = str(new_exception)
self.assertIn("This error might happens in dy2static, while calling 'set_state_dict' dynamicly in 'forward', which is not supported. If you only need call 'set_state_dict' once, move it to '__init__'.", error_message)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册