error.py 5.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import traceback

from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginInfo, global_origin_info_map

ERROR_DATA = "Error data about original source code information and traceback."


23
def attach_error_data(error, in_runtime=False):
24 25 26 27 28
    """
    Attachs error data about original source code information and traceback to an error.

    Args:
        error(Exception): An native error.
29
        in_runtime(bool): `error` is raised in runtime if in_runtime is True, otherwise in compile time
30 31 32 33 34 35 36
    Returns:
        An error attached data about original source code information and traceback.
    """
    e_type, e_value, e_traceback = sys.exc_info()
    tb = traceback.extract_tb(e_traceback)[1:]

    error_data = ErrorData(e_type, e_value, tb, global_origin_info_map)
37 38
    error_data.in_runtime = in_runtime

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    setattr(error, ERROR_DATA, error_data)

    return error


class TraceBackFrame(OriginInfo):
    """
    Traceback frame information.
    """

    def __init__(self, location, function_name, source_code):
        self.location = location
        self.function_name = function_name
        self.source_code = source_code


class ErrorData(object):
    """
    Error data attached to an exception which is raised in un-transformed code.
    """

    def __init__(self, error_type, error_value, origin_traceback,
                 origin_info_map):
        self.error_type = error_type
        self.error_value = error_value
        self.origin_traceback = origin_traceback
        self.origin_info_map = origin_info_map
66
        self.in_runtime = False
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84

    def create_exception(self):
        message = self.create_message()
        new_exception = self.error_type(message)
        setattr(new_exception, ERROR_DATA, self)
        return new_exception

    def create_message(self):
        """
        Creates a custom error message which includes trace stack with source code information of dygraph from user.
        """
        message_lines = []

        # Step1: Adds header message to prompt users that the following is the original information.
        header_message = "In user code:"
        message_lines.append(header_message)
        message_lines.append("")

85 86 87 88 89 90
        # Simplify error value to improve readability if error is raised in runtime
        if self.in_runtime:
            self._simplify_error_value()
            message_lines.append(str(self.error_value))
            return '\n'.join(message_lines)

91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
        # Step2: Optimizes stack information with source code information of dygraph from user.
        for filepath, lineno, funcname, code in self.origin_traceback:
            loc = Location(filepath, lineno)

            dygraph_func_info = self.origin_info_map.get(loc.line_location,
                                                         None)
            if dygraph_func_info:
                # TODO(liym27): more information to prompt users that this is the original information.
                # Replaces trace stack information about transformed static code with original dygraph code.
                traceback_frame = self.origin_info_map[loc.line_location]
            else:
                traceback_frame = TraceBackFrame(loc, funcname, code)

            message_lines.append(traceback_frame.formated_message())

        # Step3: Adds error message like "TypeError: dtype must be int32, but received float32".
        error_message = " " * 4 + traceback.format_exception_only(
            self.error_type, self.error_value)[0].strip("\n")
        message_lines.append(error_message)

        return '\n'.join(message_lines)
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133

    def _simplify_error_value(self):
        """
        Simplifies error value to improve readability if error is raised in runtime.

        NOTE(liym27): The op callstack information about transformed static code has been replaced with original dygraph code.

        TODO(liym27):
            1. Need a more robust way because the code of start_trace may change.
            2. Set the switch to determine whether to simplify error_value
        """
        assert self.in_runtime is True

        error_value_lines = str(self.error_value).split("\n")
        error_value_lines_strip = [mes.lstrip(" ") for mes in error_value_lines]

        start_trace = "outputs = static_func(*inputs)"
        start_idx = error_value_lines_strip.index(start_trace)
        error_value_lines = error_value_lines[start_idx + 1:]

        error_value_str = '\n'.join(error_value_lines)
        self.error_value = self.error_type(error_value_str)