提交 20f391d3 编写于 作者: Q qiaolongfei

add level_to_all

上级 8d223360
......@@ -5,119 +5,8 @@ from google.protobuf.json_format import MessageToJson
import onnx
class Node(object):
def __init__(self):
pass
def to_json(self):
raise NotImplementedError
class Operator(Node):
def __init__(self, json_obj):
self.json_obj = json_obj
self.renamed = False
self.in_nodes = []
self.out_nodes = []
@property
def name(self):
return self.json_obj['name']
@property
def inputs(self):
return self.json_obj['input']
@property
def outputs(self):
return self.json_obj['output']
def sync_inout_name(self):
pass
def rename(self, node_id):
if not self.renamed:
self.renamed = True
self.json_obj['name'] = node_id + '\n' + self.name
else:
raise Exception("Operator " + self.name + " has already been renamed")
def to_json(self):
return self.json_obj
class Variable(Node):
def __init__(self, json_obj):
"""
:param json_obj:
{
"data_type": "FLOAT",
"name": "conv1_w_0",
"shape": [
"64",
"3",
"3",
"3"
]
}
"""
self.renamed = False
self.json_obj = json_obj
self.in_nodes = []
self.out_nodes = []
@property
def name(self):
return self.json_obj['name']
@property
def data_type(self):
return self.json_obj['data_type']
@property
def shape(self):
return [int(dim) for dim in self.json_obj['shape']]
def sync_inout_name(self):
pass
def rename(self):
if not self.renamed:
self.renamed = True
new_name = self.name + '\ndata_type=' + str(self.data_type) + '\nshape=' + str(self.shape)
self.json_obj['name'] = new_name
else:
raise Exception("Variable " + self.name + " has already been renamed")
def to_json(self):
return self.json_obj
class Edge(object):
def __init__(self, name):
self.__name = name
self.__from_node = None
self.__to_node = None
@property
def name(self):
return self.__name
@property
def from_node(self):
return self.from_node
@property
def to_node(self):
return self.to_node
def set_from_node(self, node):
assert self.__from_node is None
self.__from_node = node
def set_to_node(self, node):
assert self.__to_node is None
self.__to_node
def debug_print(json_obj):
print(json.dumps(json_obj, sort_keys=True, indent=4, separators=(',', ': ')))
def reorganize_inout(json_obj, key):
......@@ -145,43 +34,6 @@ def reorganize_inout(json_obj, key):
json_obj[key][index] = var_new
def to_structure_data(model_json):
operators = [Operator(node) for node in model_json['node']]
inputs = [Variable(input) for input in model_json['input']]
outputs = [Variable(output for output in model_json['output'])]
edges = dict()
# consturct all edges
def get_edge(edges, name):
assert isinstance(edges, dict)
if name not in edges:
edges[name] = Edge(name)
return edges[name]
for input in inputs:
edge = get_edge(edges, input.name)
assert edge.from_node is None
edge.from_node = input
for output in outputs:
edge = get_edge(edges, output.name)
assert edge.to_node is None
edge.to_node = output
for operator in operators:
for input_name in operator.inputs:
edge = get_edge(edges, input_name)
assert edge.to_node is None
edge.to_node = operator
for output_name in operator.outputs:
edge = get_edge(edges, output_name)
assert edge.from_node is None
edge.from_node = operator
# rename node
def rename_model(model_json):
def rename_edge(model_json, old_name, new_name):
for node in model_json['node']:
......@@ -192,7 +44,7 @@ def rename_model(model_json):
outputs = node['output']
for idx in range(len(outputs)):
if outputs[idx] == old_name:
outputs[idx] == new_name
outputs[idx] = new_name
def rename_variables(model, variables):
for variable in variables:
......@@ -238,6 +90,170 @@ def add_links(model_json):
model_json['links'] = links
def get_node_links(model_json):
"""
:return:
{
"0": {
"input": [],
"output": [
1
]
},
"1": {
"input": [
0
],
"output": [
2
]
}
}
"""
node_links = dict()
nodes = model_json['node']
# init all nodes
for idx in range(len(nodes)):
node_links[idx] = {'input': set(), 'output': set()}
for src_idx in range(len(nodes)):
for out_name in nodes[src_idx]['output']:
for dst_idx in range(len(nodes)):
if out_name in nodes[dst_idx]['input']:
node_links[src_idx]['output'].add(dst_idx)
node_links[dst_idx]['input'].add(src_idx)
# change set to list for json can not serialize set
new_node_links = dict()
for key in node_links:
new_node_links[key] = {'input': list(node_links[key]['input']),
'output': list(node_links[key]['output'])}
return new_node_links
def add_level_to_node_links(node_links):
"""
:return:
{
"0": {
"input": [],
"output": [
1
],
"level": 1
},
"1": {
"input": [
0
],
"output": [
2
],
"level": 2
}
}
"""
# init level
for key in node_links:
node_links[key]['level'] = None
for idx in range(len(node_links)):
# the start up op's level is 1
if len(node_links[idx]['input']) == 0:
node_links[idx]['level'] = 1
else:
cur_level = node_links[idx]['level']
for in_idx in node_links[idx]['input']:
in_level = node_links[in_idx]['level']
assert in_level is not None
if cur_level is None or in_level >= cur_level:
node_links[idx]['level'] = in_level + 1
# debug_print(node_links)
def get_level_to_all(node_links, model_json):
"""
level_to_nodes {level -> [node_1, node_2]}
"""
level_to_nodes = dict()
for idx in node_links:
level = node_links[idx]['level']
if level not in level_to_nodes:
level_to_nodes[level] = list()
level_to_nodes[level].append(idx)
# debug_print(level_to_nodes)
"""
input_to_level {idx -> level}
level_to_inputs {level -> [input1, input2]}
"""
nodes = model_json['node']
input_to_level = dict()
inputs = model_json['input']
for in_idx in range(len(inputs)):
in_name = inputs[in_idx]['name']
for node_idx in range(len(nodes)):
if in_name in nodes[node_idx]['input']:
node_level = node_links[node_idx]['level']
in_level = node_level - 1
if in_idx not in input_to_level:
input_to_level[in_idx] = in_level
elif input_to_level[in_idx] > in_level:
input_to_level[in_idx] = in_level
level_to_inputs = dict()
for in_idx in input_to_level:
level = input_to_level[in_idx]
if level not in level_to_inputs:
level_to_inputs[level] = list()
level_to_inputs[level].append(in_idx)
# debug_print(level_to_inputs)
# get output level
output_to_level = dict()
outputs = model_json['output']
for out_idx in range(len(outputs)):
out_name = outputs[out_idx]['name']
for node_idx in range(len(nodes)):
if out_name in nodes[node_idx]['output']:
node_level = node_links[node_idx]['level']
out_level = node_level + 1
if out_level not in output_to_level:
output_to_level[out_idx] = out_level
else:
raise Exception("output " + out_name + "have multiple source")
level_to_outputs = dict()
for out_idx in output_to_level:
level = output_to_level[out_idx]
if level not in level_to_outputs:
level_to_outputs[level] = list()
level_to_outputs[level].append(out_idx)
# debug_print(level_to_outputs)
level_to_all = dict()
def init_level(level):
if level not in level_to_all:
level_to_all[level] = {'nodes': list(), 'inputs': list(), 'outputs': list()}
# merge all levels
for level in level_to_nodes:
init_level(level)
level_to_all[level]['nodes'] = level_to_nodes[level]
for level in level_to_inputs:
init_level(level)
level_to_all[level]['inputs'] = level_to_inputs[level]
for level in level_to_outputs:
init_level(level)
level_to_all[level]['outputs'] = level_to_outputs[level]
debug_print(level_to_all)
return level_to_all
def add_edges(json_obj):
# TODO(daming-lu): should try to de-duplicate node's out-edge
# Currently it is counted twice: 1 as out-edge, 1 as in-edge
......@@ -266,9 +282,6 @@ def add_edges(json_obj):
return json_obj
def add_coordinate(model_json):
def transform_for_echars(model_json):
opItemStyle = {
"normal": {
......@@ -343,12 +356,16 @@ def load_model(model_pb_path):
# to json string
json_str = MessageToJson(model.graph)
json_obj = json.loads(json_str)
reorganize_inout(json_obj, 'input')
reorganize_inout(json_obj, 'output')
rename_model(json_obj)
add_links(json_obj)
return json.dumps(json_obj, sort_keys=True, indent=4, separators=(',', ': '))
model_json = json.loads(json_str)
reorganize_inout(model_json, 'input')
reorganize_inout(model_json, 'output')
rename_model(model_json)
add_links(model_json)
debug_print(model_json)
node_links = get_node_links(model_json)
add_level_to_node_links(node_links)
get_level_to_all(node_links, model_json)
return json.dumps(model_json, sort_keys=True, indent=4, separators=(',', ': '))
if __name__ == '__main__':
......@@ -357,4 +374,4 @@ if __name__ == '__main__':
current_path = os.path.abspath(os.path.dirname(sys.argv[0]))
# json_str = load_model(current_path + "/mock/inception_v1_model.pb")
json_str = load_model(current_path + "/mock/squeezenet_model.pb")
print(json_str)
# print(json_str)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册