From 213427e47a7fa5551864f461cff8c680d82174fe Mon Sep 17 00:00:00 2001 From: feifei-111 Date: Thu, 22 Sep 2022 16:42:54 +0800 Subject: [PATCH] [error msg] update error msg of user call numpy api and set_state_dict in dy2static (#46128) * add numpy api err msg * fix bug * fix unittest * add set_state_dict err * rewrite numpy_api_check * add define * change err msg * fix test * move import statement --- .../fluid/dygraph/dygraph_to_static/error.py | 41 +++++++++++++ .../fluid/dygraph/dygraph_to_static/utils.py | 4 +- python/paddle/fluid/dygraph/layers.py | 19 +++--- .../unittests/dygraph_to_static/test_error.py | 58 +++++++++++++++++++ 4 files changed, 113 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/error.py b/python/paddle/fluid/dygraph/dygraph_to_static/error.py index 93670758dae..3ff66bd2ee3 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/error.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/error.py @@ -18,8 +18,10 @@ import sys import traceback import linecache import re +import numpy as np from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginInfo, global_origin_info_map +from paddle.fluid.dygraph.dygraph_to_static.utils import _is_api_in_module_helper, RE_PYMODULE ERROR_DATA = "Error data about original source code information and traceback." @@ -66,6 +68,7 @@ class TraceBackFrame(OriginInfo): self.location = location self.function_name = function_name self.source_code = source_code + self.error_line = '' def formated_message(self): # self.source_code may be empty in some functions. @@ -85,6 +88,7 @@ class TraceBackFrameRange(OriginInfo): self.location = location self.function_name = function_name self.source_code = [] + self.error_line = '' blank_count = [] begin_lineno = max(1, self.location.lineno - int(SOURCE_CODE_RANGE / 2)) @@ -98,6 +102,7 @@ class TraceBackFrameRange(OriginInfo): blank_count.append(len(line) - len(line_lstrip)) if i == self.location.lineno: + self.error_line = self.source_code[-1] hint_msg = '~' * len(self.source_code[-1]) + ' <--- HERE' self.source_code.append(hint_msg) blank_count.append(blank_count[-1]) @@ -170,6 +175,36 @@ class ErrorData(object): setattr(new_exception, ERROR_DATA, self) return new_exception + def numpy_api_check(self, format_exception, error_line): + if self.error_type is not TypeError: + return format_exception + + tb = self.origin_traceback + func_str = None + for frame in tb: + searched_name = re.search( + r'({module})*{name}'.format(module=RE_PYMODULE, + name=frame.name), error_line) + if searched_name: + func_str = searched_name.group(0) + break + try: + module_result = eval("_is_api_in_module_helper({}, '{}')".format( + func_str, "numpy")) + is_numpy_api_err = module_result or (func_str.startswith("numpy.") + or func_str.startswith("np.")) + except Exception: + is_numpy_api_err = False + + if is_numpy_api_err and func_str: + return [ + "TypeError: Code '{}' called numpy API {}, please use Paddle API to replace it." + .format(error_line, func_str), + " values will be changed to variables by dy2static, numpy api can not handle variables" + ] + else: + return format_exception + def create_message(self): """ Creates a custom error message which includes trace stack with source code information of dygraph from user. @@ -180,6 +215,7 @@ class ErrorData(object): header_message = "In transformed code:" message_lines.append(header_message) message_lines.append("") + error_line = None # Simplify error value to improve readability if error is raised in runtime if self.in_runtime: @@ -213,6 +249,7 @@ class ErrorData(object): dygraph_func_info.source_code) message_lines.append(traceback_frame.formated_message()) + error_line = traceback_frame.error_line message_lines.append("") # Add paddle traceback after user code traceback @@ -230,6 +267,10 @@ class ErrorData(object): # is gather than 1, for example, the error_type is IndentationError. format_exception = traceback.format_exception_only( self.error_type, self.error_value) + if error_line is not None: + format_exception = self.numpy_api_check(format_exception, + error_line) + error_message = [ " " * BLANK_COUNT_BEFORE_FILE_STR + line for line in format_exception diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 05938aa4b0f..531f9724bb4 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -324,8 +324,8 @@ def is_numpy_api(node): func_str, "numpy")) # BUG: np.random.uniform doesn't have module and cannot be analyzed # TODO: find a better way - if not module_result: - return func_str.startswith("numpy.") or func_str.startswith("np.") + return module_result or (func_str.startswith("numpy.") + or func_str.startswith("np.")) except Exception: return False diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 9f36f1cd37e..9b3ebdd7b9a 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -1538,13 +1538,18 @@ class Layer(object): place = core.CUDAPlace(p.gpu_device_id()) t.set(ndarray, place) - executor = Executor(_get_device())._default_executor - # restore parameter states - core._create_loaded_parameter( - [param for param, state in matched_param_state], global_scope(), - executor) - for param, state in matched_param_state: - _set_var(param, state) + try: + executor = Executor(_get_device())._default_executor + # restore parameter states + core._create_loaded_parameter( + [param for param, state in matched_param_state], + global_scope(), executor) + for param, state in matched_param_state: + _set_var(param, state) + except ValueError as e: + raise ValueError( + "This error might happens in dy2static, while calling 'set_state_dict' dynamicly in 'forward', which is not supported. If you only need call 'set_state_dict' once, move it to '__init__'." + ) def to(self, device=None, dtype=None, blocking=None): ''' diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py index 97f0cf99b5f..c0ca1c1af02 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py @@ -448,5 +448,63 @@ class TestKeyError(unittest.TestCase): x = paddle.to_tensor([1]) func_ker_error(x) + +@paddle.jit.to_static +def NpApiErr(): + a = paddle.to_tensor([1,2]) + b = np.sum(a.numpy()) + print(b) + +class TestNumpyApiErr(unittest.TestCase): + def test_numpy_api_err(self): + with self.assertRaises(TypeError) as e: + NpApiErr() + + new_exception = e.exception + + error_data = getattr(new_exception, error.ERROR_DATA, None) + self.assertIsInstance(error_data, error.ErrorData) + + error_message = str(new_exception) + + self.assertIn("values will be changed to variables by dy2static, numpy api can not handle variables", error_message) + + +class test_set_state_dict_err_layer(paddle.nn.Layer): + def __init__(self): + super(test_set_state_dict_err_layer, self).__init__() + self.linear = paddle.nn.Linear(5, 2) + + @paddle.jit.to_static + def forward(self, x): + old_dict = self.state_dict() + wgt = old_dict['linear.weight'] + drop_w = paddle.nn.functional.dropout(wgt) + old_dict['linear.weight'] = drop_w + # old_dict['linear.weight'][0][0] = 0.01 + self.set_state_dict(old_dict) + + y = self.linear(x) + + return y + + +class TestSetStateDictErr(unittest.TestCase): + def test_set_state_dict_err(self): + with self.assertRaises(ValueError) as e: + layer = test_set_state_dict_err_layer() + x = paddle.to_tensor([1.,2.,3.,4.,5.]) + y = layer(x) + + new_exception = e.exception + + error_data = getattr(new_exception, error.ERROR_DATA, None) + self.assertIsInstance(error_data, error.ErrorData) + + error_message = str(new_exception) + + self.assertIn("This error might happens in dy2static, while calling 'set_state_dict' dynamicly in 'forward', which is not supported. If you only need call 'set_state_dict' once, move it to '__init__'.", error_message) + + if __name__ == '__main__': unittest.main() -- GitLab