pretty_printers.py 6.4 KB
Newer Older
1
import sys
M
Megvii Engine Team 已提交
2

M
Megvii Engine Team 已提交
3 4 5 6 7
import gdb
import gdb.printing
import gdb.types


8 9 10 11 12 13
def dynamic_cast(val):
    assert val.type.code == gdb.TYPE_CODE_REF
    val = val.cast(val.dynamic_type)
    return val


M
Megvii Engine Team 已提交
14
def eval_on_val(val, eval_str):
15 16 17 18
    if val.type.code == gdb.TYPE_CODE_REF:
        val = val.referenced_value()
    address = val.address
    eval_str = "(*({}){}){}".format(address.type, int(address), eval_str)
M
Megvii Engine Team 已提交
19 20 21 22 23 24
    return gdb.parse_and_eval(eval_str)


class SmallVectorPrinter:
    def __init__(self, val):
        t = val.type.template_argument(0)
M
Megvii Engine Team 已提交
25 26
        self.begin = val["m_begin_ptr"].cast(t.pointer())
        self.end = val["m_end_ptr"].cast(t.pointer())
M
Megvii Engine Team 已提交
27
        self.size = self.end - self.begin
M
Megvii Engine Team 已提交
28 29 30
        self.capacity = val["m_capacity_ptr"].cast(t.pointer()) - val[
            "m_begin_ptr"
        ].cast(t.pointer())
M
Megvii Engine Team 已提交
31 32

    def to_string(self):
M
Megvii Engine Team 已提交
33
        return "SmallVector of Size {}".format(self.size)
M
Megvii Engine Team 已提交
34 35

    def display_hint(self):
M
Megvii Engine Team 已提交
36
        return "array"
M
Megvii Engine Team 已提交
37 38 39

    def children(self):
        for i in range(self.size):
M
Megvii Engine Team 已提交
40
            yield "[{}]".format(i), (self.begin + i).dereference()
M
Megvii Engine Team 已提交
41 42 43 44


class MaybePrinter:
    def __init__(self, val):
M
Megvii Engine Team 已提交
45
        self.val = val["m_ptr"]
M
Megvii Engine Team 已提交
46 47 48

    def to_string(self):
        if self.val:
M
Megvii Engine Team 已提交
49
            return "Some {}".format(self.val)
M
Megvii Engine Team 已提交
50
        else:
M
Megvii Engine Team 已提交
51
            return "None"
M
Megvii Engine Team 已提交
52 53

    def display_hint(self):
M
Megvii Engine Team 已提交
54
        return "array"
M
Megvii Engine Team 已提交
55 56 57

    def children(self):
        if self.val:
M
Megvii Engine Team 已提交
58
            yield "[0]", self.val.dereference()
M
Megvii Engine Team 已提交
59 60 61 62 63 64 65


class ToStringPrinter:
    def __init__(self, val):
        self.val = val

    def to_string(self):
66
        return eval_on_val(self.val, ".to_string().c_str()").string()
M
Megvii Engine Team 已提交
67 68 69 70 71 72 73


class ReprPrinter:
    def __init__(self, val):
        self.val = val

    def to_string(self):
74 75 76 77
        val = self.val
        if val.type.code == gdb.TYPE_CODE_REF:
            val = val.referenced_value()
        return eval_on_val(val, ".repr().c_str()").string()
M
Megvii Engine Team 已提交
78 79 80 81 82 83 84 85 86


class HandlePrinter:
    def __init__(self, val):
        impl = gdb.lookup_type("mgb::imperative::interpreter::intl::TensorInfo")
        self.val = val.cast(impl.pointer())

    def to_string(self):
        if self.val:
M
Megvii Engine Team 已提交
87
            return "Handle of TensorInfo at {}".format(self.val)
M
Megvii Engine Team 已提交
88
        else:
M
Megvii Engine Team 已提交
89
            return "Empty Handle"
M
Megvii Engine Team 已提交
90 91

    def display_hint(self):
M
Megvii Engine Team 已提交
92
        return "array"
M
Megvii Engine Team 已提交
93 94 95

    def children(self):
        if self.val:
M
Megvii Engine Team 已提交
96
            yield "[0]", self.val.dereference()
M
Megvii Engine Team 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129


def print_small_tensor(device_nd):
    size = device_nd["m_storage"]["m_size"]
    ndim = device_nd["m_layout"]["ndim"]
    dim0 = device_nd["m_layout"]["shape"][0]
    stride0 = device_nd["m_layout"]["stride"][0]
    dtype = device_nd["m_layout"]["dtype"]
    if size == 0:
        return "<empty>"
    if ndim > 1:
        return "<ndim > 1>"
    if dim0 > 64:
        return "<size tool large>"
    raw_ptr = device_nd["m_storage"]["m_data"]["_M_ptr"]
    dtype_name = dtype["m_trait"]["name"].string()
    dtype_map = {
        "Float32": (gdb.lookup_type("float"), float),
        "Int32": (gdb.lookup_type("int"), int),
    }
    if dtype_name not in dtype_map:
        return "<dtype unsupported>"
    else:
        ctype, pytype = dtype_map[dtype_name]
    ptr = raw_ptr.cast(ctype.pointer())
    array = []
    for i in range(dim0):
        array.append((pytype)((ptr + i * int(stride0)).dereference()))
    return str(array)


class LogicalTensorDescPrinter:
    def __init__(self, val):
M
Megvii Engine Team 已提交
130 131 132
        self.layout = val["layout"]
        self.comp_node = val["comp_node"]
        self.value = val["value"]
M
Megvii Engine Team 已提交
133 134

    def to_string(self):
M
Megvii Engine Team 已提交
135
        return "LogicalTensorDesc"
M
Megvii Engine Team 已提交
136 137

    def children(self):
M
Megvii Engine Team 已提交
138 139 140
        yield "layout", self.layout
        yield "comp_node", self.comp_node
        yield "value", print_small_tensor(self.value)
M
Megvii Engine Team 已提交
141 142 143 144 145 146 147 148 149 150


class OpDefPrinter:
    def __init__(self, val):
        self.val = val

    def to_string(self):
        return self.val.dynamic_type.name

    def children(self):
M
Megvii Engine Team 已提交
151 152 153
        concrete_val = self.val.address.cast(
            self.val.dynamic_type.pointer()
        ).dereference()
M
Megvii Engine Team 已提交
154 155 156
        for field in concrete_val.type.fields():
            if field.is_base_class or field.artificial:
                continue
M
Megvii Engine Team 已提交
157
            if field.name == "sm_typeinfo":
M
Megvii Engine Team 已提交
158 159 160 161
                continue
            yield field.name, concrete_val[field.name]


162 163
class SpanPrinter:
    def __init__(self, val):
M
Megvii Engine Team 已提交
164 165
        self.begin = val["m_begin"]
        self.end = val["m_end"]
166 167 168
        self.size = self.end - self.begin

    def to_string(self):
M
Megvii Engine Team 已提交
169
        return "Span of Size {}".format(self.size)
170 171

    def display_hint(self):
M
Megvii Engine Team 已提交
172
        return "array"
173 174 175

    def children(self):
        for i in range(self.size):
M
Megvii Engine Team 已提交
176
            yield "[{}]".format(i), (self.begin + i).dereference()
177 178


179 180
if sys.version_info.major > 2:
    pp = gdb.printing.RegexpCollectionPrettyPrinter("MegEngine")
M
Megvii Engine Team 已提交
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
    # megdnn
    pp.add_printer(
        "megdnn::SmallVectorImpl",
        "^megdnn::SmallVector(Impl)?<.*>$",
        SmallVectorPrinter,
    )
    pp.add_printer("megdnn::TensorLayout", "^megdnn::TensorLayout$", ToStringPrinter)
    pp.add_printer("megdnn::TensorShape", "^megdnn::TensorShape$", ToStringPrinter)
    # megbrain
    pp.add_printer("mgb::CompNode", "^mgb::CompNode$", ToStringPrinter)
    pp.add_printer("mgb::Maybe", "^mgb::Maybe<.*>$", MaybePrinter)
    # imperative
    pp.add_printer(
        "mgb::imperative::LogicalTensorDesc",
        "^mgb::imperative::LogicalTensorDesc$",
        LogicalTensorDescPrinter,
    )
    pp.add_printer("mgb::imperative::OpDef", "^mgb::imperative::OpDef$", OpDefPrinter)
    pp.add_printer(
        "mgb::imperative::Subgraph", "^mgb::imperative::Subgraph$", ReprPrinter
    )
    pp.add_printer(
        "mgb::imperative::EncodedSubgraph",
        "^mgb::imperative::EncodedSubgraph$",
        ReprPrinter,
    )
    # imperative dispatch
    pp.add_printer(
        "mgb::imperative::ValueRef", "^mgb::imperative::ValueRef$", ToStringPrinter
    )
    pp.add_printer("mgb::imperative::Span", "^mgb::imperative::Span<.*>$", SpanPrinter)
212 213 214
    gdb.printing.register_pretty_printer(gdb.current_objfile(), pp)
else:
    print("skip import pretty printers")
M
Megvii Engine Team 已提交
215 216 217 218 219 220 221 222 223 224 225 226 227 228


def override_pretty_printer_for(val):
    type = val.type.strip_typedefs()
    if type.code == gdb.TYPE_CODE_PTR:
        if not val:
            return None
        target_typename = str(type.target().strip_typedefs())
        if target_typename == "mgb::imperative::OpDef":
            return OpDefPrinter(val.dereference())
        if target_typename == "mgb::imperative::interpreter::Interpreter::HandleImpl":
            return HandlePrinter(val)


229 230
if sys.version_info.major > 2:
    gdb.pretty_printers.append(override_pretty_printer_for)