graph.py 2.2 KB
Newer Older
Q
qiaolongfei 已提交
1
import json
S
superjom 已提交
2

Q
qiaolongfei 已提交
3 4
from google.protobuf.json_format import MessageToJson

S
superjom 已提交
5 6
import onnx

Q
qiaolongfei 已提交
7 8 9 10 11 12 13 14

def reorganize_inout(json_obj, key):
    """
    :param json_obj: the model's json obj
    :param key: "input or output"
    :return:
    """
    for index in range(len(json_obj[key])):
Q
qiaolongfei 已提交
15 16
        var = json_obj[key][index]
        var_new = dict()
Q
qiaolongfei 已提交
17 18

        # set name
Q
qiaolongfei 已提交
19
        var_new['name'] = var['name']
Q
qiaolongfei 已提交
20

Q
qiaolongfei 已提交
21
        tensor_type = var['type']['tensorType']
Q
qiaolongfei 已提交
22 23

        # set data_type
Q
qiaolongfei 已提交
24
        var_new['data_type'] = tensor_type['elemType']
Q
qiaolongfei 已提交
25 26 27

        # set shape
        shape = [dim['dimValue'] for dim in tensor_type['shape']['dim']]
Q
qiaolongfei 已提交
28
        var_new['shape'] = shape
Q
qiaolongfei 已提交
29

Q
qiaolongfei 已提交
30
        json_obj[key][index] = var_new
Q
qiaolongfei 已提交
31 32


33
def add_edges(json_obj):
D
daminglu 已提交
34 35
    # TODO(daming-lu): should try to de-duplicate node's out-edge
    # Currently it is counted twice: 1 as out-edge, 1 as in-edge
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
    json_obj['edges'] = []
    label_incrementer = 0

    for node_index in range(0, len(json_obj['node'])):
        cur_node = json_obj['node'][node_index]

        # input edges
        for source in cur_node['input']:
            json_obj['edges'].append({
                'source': source,
                'target': 'node_' + str(node_index),
                'label': 'label_' + str(label_incrementer)
            })
            label_incrementer += 1

        # output edge
        json_obj['edges'].append({
            'source': 'node_' + str(node_index),
            'target': cur_node['output'][0],
            'label': 'label_' + str(label_incrementer)
        })
        label_incrementer += 1


Q
qiaolongfei 已提交
60 61 62 63 64 65 66 67 68 69
def load_model(model_pb_path):
    model = onnx.load(model_pb_path)
    graph = model.graph
    del graph.initializer[:]

    # to json string
    json_str = MessageToJson(model.graph)
    json_obj = json.loads(json_str)
    reorganize_inout(json_obj, 'input')
    reorganize_inout(json_obj, 'output')
70
    add_edges(json_obj)
Q
qiaolongfei 已提交
71 72 73
    return json.dumps(json_obj, sort_keys=True, indent=4, separators=(',', ': '))


Q
qiaolongfei 已提交
74 75 76 77
if __name__ == '__main__':
    import os
    import sys
    current_path = os.path.abspath(os.path.dirname(sys.argv[0]))
78 79
    # json_str = load_model(current_path + "/mock/inception_v1_model.pb")
    json_str = load_model(current_path + "/mock/squeezenet_model.pb")
Q
qiaolongfei 已提交
80
    print(json_str)