graph.py 14.3 KB
Newer Older
Q
qiaolongfei 已提交
1
import json
2
import os
S
superjom 已提交
3

Q
qiaolongfei 已提交
4
from google.protobuf.json_format import MessageToJson
5
from PIL import Image
Q
qiaolongfei 已提交
6

Y
Yan Chunwei 已提交
7
import graphviz_graph as gg
8
import onnx
S
superjom 已提交
9

Q
qiaolongfei 已提交
10

Q
qiaolongfei 已提交
11
def debug_print(json_obj):
T
Thuan Nguyen 已提交
12 13
    print(
        json.dumps(json_obj, sort_keys=True, indent=4, separators=(',', ': ')))
Q
qiaolongfei 已提交
14 15


Q
qiaolongfei 已提交
16 17 18 19 20 21 22
def reorganize_inout(json_obj, key):
    """
    :param json_obj: the model's json obj
    :param key: "input or output"
    :return:
    """
    for index in range(len(json_obj[key])):
Q
qiaolongfei 已提交
23 24
        var = json_obj[key][index]
        var_new = dict()
Q
qiaolongfei 已提交
25 26

        # set name
Q
qiaolongfei 已提交
27
        var_new['name'] = var['name']
Q
qiaolongfei 已提交
28

Q
qiaolongfei 已提交
29
        tensor_type = var['type']['tensorType']
Q
qiaolongfei 已提交
30 31

        # set data_type
Q
qiaolongfei 已提交
32
        var_new['data_type'] = tensor_type['elemType']
Q
qiaolongfei 已提交
33 34 35

        # set shape
        shape = [dim['dimValue'] for dim in tensor_type['shape']['dim']]
Q
qiaolongfei 已提交
36
        var_new['shape'] = shape
Q
qiaolongfei 已提交
37

Q
qiaolongfei 已提交
38
        json_obj[key][index] = var_new
Q
qiaolongfei 已提交
39 40


Q
qiaolongfei 已提交
41 42 43 44 45 46 47 48 49 50
def rename_model(model_json):
    def rename_edge(model_json, old_name, new_name):
        for node in model_json['node']:
            inputs = node['input']
            for idx in range(len(inputs)):
                if inputs[idx] == old_name:
                    inputs[idx] = new_name
            outputs = node['output']
            for idx in range(len(outputs)):
                if outputs[idx] == old_name:
Q
qiaolongfei 已提交
51
                    outputs[idx] = new_name
Q
qiaolongfei 已提交
52 53 54 55 56

    def rename_variables(model, variables):
        for variable in variables:
            old_name = variable['name']
            new_shape = [int(dim) for dim in variable['shape']]
T
Thuan Nguyen 已提交
57 58
            new_name = old_name + '\ndata_type=' + str(
                variable['data_type']) + '\nshape=' + str(new_shape)
Q
qiaolongfei 已提交
59 60
            variable['name'] = new_name
            rename_edge(model, old_name, new_name)
Q
qiaolongfei 已提交
61

Q
qiaolongfei 已提交
62 63 64 65 66 67
    rename_variables(model_json, model_json['input'])
    rename_variables(model_json, model_json['output'])

    # rename
    all_nodes = model_json['node']
    for idx in range(len(all_nodes)):
Q
qiaolongfei 已提交
68 69 70
        name = ""
        if "name" in all_nodes[idx]:
            name = all_nodes[idx]['name']
Q
qiaolongfei 已提交
71 72 73 74 75 76 77
        op_type = all_nodes[idx]['opType']
        new_name = str(idx) + '\n' + str(op_type)
        if name != "":
            new_name = new_name + "\n" + name
        all_nodes[idx]['name'] = new_name


Q
qiaolongfei 已提交
78
def get_links(model_json):
Q
qiaolongfei 已提交
79 80 81 82 83 84
    links = []

    for input in model_json['input']:
        name = input['name']
        for node in model_json['node']:
            if name in node['input']:
Y
Yan Chunwei 已提交
85
                links.append({'source': name, "target": node['name']})
Q
qiaolongfei 已提交
86 87 88 89 90

    for source_node in model_json['node']:
        for output in source_node['output']:
            for target_node in model_json['node']:
                if output in target_node['input']:
Y
Yan Chunwei 已提交
91 92 93 94
                    links.append({
                        'source': source_node['name'],
                        'target': target_node['name']
                    })
Q
qiaolongfei 已提交
95

Q
qiaolongfei 已提交
96
    return links
Q
qiaolongfei 已提交
97 98


Q
qiaolongfei 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
def get_node_links(model_json):
    """
    :return:
    {
        "0": {
            "input": [],
            "output": [
                1
            ]
        },
        "1": {
            "input": [
                0
            ],
            "output": [
                2
            ]
        }
    }
    """
    node_links = dict()
    nodes = model_json['node']

    # init all nodes
    for idx in range(len(nodes)):
Q
qiaolongfei 已提交
124
        node_links[idx] = {'input': list(), 'output': list()}
Q
qiaolongfei 已提交
125 126 127 128 129

    for src_idx in range(len(nodes)):
        for out_name in nodes[src_idx]['output']:
            for dst_idx in range(len(nodes)):
                if out_name in nodes[dst_idx]['input']:
Q
qiaolongfei 已提交
130 131
                    node_links[src_idx]['output'].append(dst_idx)
                    node_links[dst_idx]['input'].append(src_idx)
Q
qiaolongfei 已提交
132

Q
qiaolongfei 已提交
133
    return node_links
Q
qiaolongfei 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176


def add_level_to_node_links(node_links):
    """
    :return:
    {
        "0": {
            "input": [],
            "output": [
                1
            ],
            "level": 1
        },
        "1": {
            "input": [
                0
            ],
            "output": [
                2
            ],
            "level": 2
        }
    }
    """
    # init level
    for key in node_links:
        node_links[key]['level'] = None
    for idx in range(len(node_links)):
        # the start up op's level is 1
        if len(node_links[idx]['input']) == 0:
            node_links[idx]['level'] = 1
        else:
            cur_level = node_links[idx]['level']
            for in_idx in node_links[idx]['input']:
                in_level = node_links[in_idx]['level']
                assert in_level is not None
                if cur_level is None or in_level >= cur_level:
                    node_links[idx]['level'] = in_level + 1


def get_level_to_all(node_links, model_json):
    """
    level_to_nodes {level -> [node_1, node_2]}
Q
qiaolongfei 已提交
177 178 179 180 181 182 183 184 185 186 187 188
    output:
    {
        "35": {
        "inputs": [
            38,
            39
        ],
        "nodes": [
            46
        ],
        "outputs": []
    }, {}
Q
qiaolongfei 已提交
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
    """
    level_to_nodes = dict()
    for idx in node_links:
        level = node_links[idx]['level']
        if level not in level_to_nodes:
            level_to_nodes[level] = list()
        level_to_nodes[level].append(idx)
    # debug_print(level_to_nodes)
    """
    input_to_level {idx -> level}
    level_to_inputs {level -> [input1, input2]}
    """
    nodes = model_json['node']

    input_to_level = dict()
    inputs = model_json['input']
    for in_idx in range(len(inputs)):
        in_name = inputs[in_idx]['name']
        for node_idx in range(len(nodes)):
            if in_name in nodes[node_idx]['input']:
                node_level = node_links[node_idx]['level']
                in_level = node_level - 1
                if in_idx not in input_to_level:
                    input_to_level[in_idx] = in_level
                elif input_to_level[in_idx] > in_level:
                    input_to_level[in_idx] = in_level

    level_to_inputs = dict()
    for in_idx in input_to_level:
        level = input_to_level[in_idx]
        if level not in level_to_inputs:
            level_to_inputs[level] = list()
        level_to_inputs[level].append(in_idx)

    # debug_print(level_to_inputs)

    # get output level
    output_to_level = dict()
    outputs = model_json['output']
    for out_idx in range(len(outputs)):
        out_name = outputs[out_idx]['name']
        for node_idx in range(len(nodes)):
            if out_name in nodes[node_idx]['output']:
                node_level = node_links[node_idx]['level']
                out_level = node_level + 1
                if out_level not in output_to_level:
                    output_to_level[out_idx] = out_level
                else:
T
Thuan Nguyen 已提交
237 238
                    raise Exception("output " + out_name +
                                    "have multiple source")
Q
qiaolongfei 已提交
239 240 241 242 243 244 245 246 247 248 249
    level_to_outputs = dict()
    for out_idx in output_to_level:
        level = output_to_level[out_idx]
        if level not in level_to_outputs:
            level_to_outputs[level] = list()
        level_to_outputs[level].append(out_idx)

    level_to_all = dict()

    def init_level(level):
        if level not in level_to_all:
Y
Yan Chunwei 已提交
250 251 252 253 254 255
            level_to_all[level] = {
                'nodes': list(),
                'inputs': list(),
                'outputs': list()
            }

Q
qiaolongfei 已提交
256 257 258 259 260 261 262 263 264 265 266
    # merge all levels
    for level in level_to_nodes:
        init_level(level)
        level_to_all[level]['nodes'] = level_to_nodes[level]
    for level in level_to_inputs:
        init_level(level)
        level_to_all[level]['inputs'] = level_to_inputs[level]
    for level in level_to_outputs:
        init_level(level)
        level_to_all[level]['outputs'] = level_to_outputs[level]

Q
qiaolongfei 已提交
267
    # debug_print(level_to_all)
Q
qiaolongfei 已提交
268 269 270 271

    return level_to_all


Q
qiaolongfei 已提交
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
def level_to_coordinate(level_to_all):
    default_x = 100
    x_step = 100
    default_y = 10
    y_step = 100

    node_to_coordinate = dict()
    input_to_coordinate = dict()
    output_to_coordinate = dict()

    def get_coordinate(x_idx, y_idx):
        x = default_x + x_idx * x_step
        y = default_y + y_idx * y_step
        return {"x": int(x), "y": int(y)}

    for level in level_to_all:
        nodes = level_to_all[level]['nodes']
        inputs = level_to_all[level]['inputs']
        outputs = level_to_all[level]['outputs']
        x_idx = 0
        for node_idx in nodes:
            node_to_coordinate[node_idx] = get_coordinate(x_idx, level)
            x_idx += 1
        for in_idx in inputs:
            input_to_coordinate[in_idx] = get_coordinate(x_idx, level)
            x_idx += 1
        for out_idx in outputs:
            output_to_coordinate[out_idx] = get_coordinate(x_idx, level)
            x_idx += 1

    return node_to_coordinate, input_to_coordinate, output_to_coordinate


305
def add_edges(json_obj):
D
daminglu 已提交
306 307
    # TODO(daming-lu): should try to de-duplicate node's out-edge
    # Currently it is counted twice: 1 as out-edge, 1 as in-edge
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
    json_obj['edges'] = []
    label_incrementer = 0

    for node_index in range(0, len(json_obj['node'])):
        cur_node = json_obj['node'][node_index]

        # input edges
        for source in cur_node['input']:
            json_obj['edges'].append({
                'source': source,
                'target': 'node_' + str(node_index),
                'label': 'label_' + str(label_incrementer)
            })
            label_incrementer += 1

        # output edge
        json_obj['edges'].append({
            'source': 'node_' + str(node_index),
            'target': cur_node['output'][0],
            'label': 'label_' + str(label_incrementer)
        })
        label_incrementer += 1
Q
qiaolongfei 已提交
330
    return json_obj
Q
qiaolongfei 已提交
331

332

Q
qiaolongfei 已提交
333
def to_IR_json(model_pb_path):
Q
qiaolongfei 已提交
334 335 336 337 338 339
    model = onnx.load(model_pb_path)
    graph = model.graph
    del graph.initializer[:]

    # to json string
    json_str = MessageToJson(model.graph)
Q
qiaolongfei 已提交
340 341 342
    model_json = json.loads(json_str)
    reorganize_inout(model_json, 'input')
    reorganize_inout(model_json, 'output')
Q
qiaolongfei 已提交
343 344 345 346 347
    return model_json


def load_model(model_pb_path):
    model_json = to_IR_json(model_pb_path)
Y
Yan Chunwei 已提交
348 349 350 351 352
    model_json = add_edges(model_json)
    return model_json


class GraphPreviewGenerator(object):
Y
Yan Chunwei 已提交
353 354 355
    '''
    Generate a graph image for ONNX proto.
    '''
T
Thuan Nguyen 已提交
356

Y
Yan Chunwei 已提交
357 358 359 360 361 362 363
    def __init__(self, model_json):
        self.model = model_json
        # init graphviz graph
        self.graph = gg.Graph(
            self.model['name'],
            layout="dot",
            concentrate="true",
T
Thuan Nguyen 已提交
364
            rankdir="TB", )
Y
Yan Chunwei 已提交
365 366 367 368 369

        self.op_rank = self.graph.rank_group('same', 2)
        self.param_rank = self.graph.rank_group('same', 1)
        self.arg_rank = self.graph.rank_group('same', 0)

Y
Yan Chunwei 已提交
370
    def __call__(self, path='temp.dot', show=False):
Y
Yan Chunwei 已提交
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398
        self.nodes = {}
        self.params = set()
        self.ops = set()
        self.args = set()

        for item in self.model['input'] + self.model['output']:
            node = self.add_param(**item)
            self.nodes[item['name']] = node
            self.params.add(item['name'])

        for id, item in enumerate(self.model['node']):
            node = self.add_op(**item)
            name = "node_" + str(id)
            self.nodes[name] = node
            self.ops.add(name)

        for item in self.model['edges']:
            source = item['source']
            target = item['target']

            if source not in self.nodes:
                self.nodes[source] = self.add_arg(source)
                self.args.add(source)
            if target not in self.nodes:
                self.nodes[target] = self.add_arg(target)
                self.args.add(target)

            if source in self.args or target in self.args:
T
Thuan Nguyen 已提交
399
                self.add_edge(style="dashed,bold", color="#aaaaaa", **item)
Y
Yan Chunwei 已提交
400
            else:
T
Thuan Nguyen 已提交
401
                self.add_edge(style="bold", color="#aaaaaa", **item)
Y
Yan Chunwei 已提交
402

Y
Yan Chunwei 已提交
403 404 405 406
        if not show:
            self.graph.display(path)
        else:
            self.graph.show(path)
Y
Yan Chunwei 已提交
407 408 409 410 411

    def add_param(self, name, data_type, shape):
        label = '\n'.join([
            '<<table cellpadding="5">',
            '  <tr>',
Y
Yan Chunwei 已提交
412 413
            '    <td bgcolor="#2b787e">',
            '    <b>',
Y
Yan Chunwei 已提交
414
            name,
Y
Yan Chunwei 已提交
415 416
            '    </b>',
            '    </td>',
Y
Yan Chunwei 已提交
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435
            '  </tr>',
            '  <tr>',
            '    <td>',
            data_type,
            '    </td>'
            '  </tr>',
            '  <tr>',
            '    <td>',
            '[%s]' % 'x'.join(shape),
            '    </td>'
            '  </tr>',
            '</table>>',
        ])
        return self.graph.node(
            label,
            prefix="param",
            shape="none",
            style="rounded,filled,bold",
            width="1.3",
Y
Yan Chunwei 已提交
436 437
            color="#148b97",
            fontcolor="#ffffff",
Y
Yan Chunwei 已提交
438 439 440 441
            fontname="Arial")

    def add_op(self, opType, **kwargs):
        return self.graph.node(
Y
Yan Chunwei 已提交
442
            "<<B>%s</B>>" % opType,
Y
Yan Chunwei 已提交
443 444 445
            prefix="op",
            shape="box",
            style="rounded, filled, bold",
Y
Yan Chunwei 已提交
446
            color="#303A3A",
Y
Yan Chunwei 已提交
447
            fontname="Arial",
Y
Yan Chunwei 已提交
448
            fontcolor="#ffffff",
Y
Yan Chunwei 已提交
449
            width="1.3",
T
Thuan Nguyen 已提交
450
            height="0.84", )
Y
Yan Chunwei 已提交
451 452 453 454 455 456 457 458

    def add_arg(self, name):
        return self.graph.node(
            gg.crepr(name),
            prefix="arg",
            shape="box",
            style="rounded,filled,bold",
            fontname="Arial",
Y
Yan Chunwei 已提交
459 460
            fontcolor="#999999",
            color="#dddddd")
Y
Yan Chunwei 已提交
461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483

    def add_edge(self, source, target, label, **kwargs):
        source = self.nodes[source]
        target = self.nodes[target]
        return self.graph.edge(source, target, **kwargs)


def draw_graph(model_pb_path, image_dir):
    json_str = load_model(model_pb_path)
    best_image = None
    min_width = None
    for i in range(10):
        # randomly generate dot images and select the one with minimum width.
        g = GraphPreviewGenerator(json_str)
        dot_path = os.path.join(image_dir, "temp-%d.dot" % i)
        image_path = os.path.join(image_dir, "temp-%d.jpg" % i)
        g(dot_path)

        try:
            im = Image.open(image_path)
            if min_width is None or im.size[0] < min_width:
                min_width = im.size
                best_image = image_path
T
Thuan Nguyen 已提交
484
        except Exception:
Y
Yan Chunwei 已提交
485 486
            pass
    return best_image
Q
qiaolongfei 已提交
487 488


Q
qiaolongfei 已提交
489 490 491
if __name__ == '__main__':
    import sys
    current_path = os.path.abspath(os.path.dirname(sys.argv[0]))
Y
Yan Chunwei 已提交
492
    json_str = load_model(current_path + "/mock/inception_v1_model.pb")
T
Thuan Nguyen 已提交
493
    # json_str = load_model(current_path + "/mock/squeezenet_model.pb")
Y
Yan Chunwei 已提交
494 495 496 497 498
    # json_str = load_model('./mock/shufflenet/model.pb')
    debug_print(json_str)
    assert json_str

    g = GraphPreviewGenerator(json_str)
Y
Yan Chunwei 已提交
499
    g('./temp.dot', show=False)