diff --git a/server/visualdl/graph.py b/server/visualdl/graph.py index 555d2ee407b9955aaed3ef7493b181b1b7659a1a..66c542ebe067835b97ae62f509210243611c0c9b 100644 --- a/server/visualdl/graph.py +++ b/server/visualdl/graph.py @@ -5,119 +5,8 @@ from google.protobuf.json_format import MessageToJson import onnx -class Node(object): - def __init__(self): - pass - - def to_json(self): - raise NotImplementedError - - -class Operator(Node): - def __init__(self, json_obj): - self.json_obj = json_obj - self.renamed = False - self.in_nodes = [] - self.out_nodes = [] - - @property - def name(self): - return self.json_obj['name'] - - @property - def inputs(self): - return self.json_obj['input'] - - @property - def outputs(self): - return self.json_obj['output'] - - def sync_inout_name(self): - pass - - def rename(self, node_id): - if not self.renamed: - self.renamed = True - self.json_obj['name'] = node_id + '\n' + self.name - else: - raise Exception("Operator " + self.name + " has already been renamed") - - def to_json(self): - return self.json_obj - - -class Variable(Node): - def __init__(self, json_obj): - """ - :param json_obj: - { - "data_type": "FLOAT", - "name": "conv1_w_0", - "shape": [ - "64", - "3", - "3", - "3" - ] - } - """ - self.renamed = False - self.json_obj = json_obj - self.in_nodes = [] - self.out_nodes = [] - - @property - def name(self): - return self.json_obj['name'] - - @property - def data_type(self): - return self.json_obj['data_type'] - - @property - def shape(self): - return [int(dim) for dim in self.json_obj['shape']] - - def sync_inout_name(self): - pass - - def rename(self): - if not self.renamed: - self.renamed = True - new_name = self.name + '\ndata_type=' + str(self.data_type) + '\nshape=' + str(self.shape) - self.json_obj['name'] = new_name - else: - raise Exception("Variable " + self.name + " has already been renamed") - - def to_json(self): - return self.json_obj - - -class Edge(object): - def __init__(self, name): - self.__name = name - self.__from_node = None - self.__to_node = None - - @property - def name(self): - return self.__name - - @property - def from_node(self): - return self.from_node - - @property - def to_node(self): - return self.to_node - - def set_from_node(self, node): - assert self.__from_node is None - self.__from_node = node - - def set_to_node(self, node): - assert self.__to_node is None - self.__to_node +def debug_print(json_obj): + print(json.dumps(json_obj, sort_keys=True, indent=4, separators=(',', ': '))) def reorganize_inout(json_obj, key): @@ -145,43 +34,6 @@ def reorganize_inout(json_obj, key): json_obj[key][index] = var_new -def to_structure_data(model_json): - operators = [Operator(node) for node in model_json['node']] - inputs = [Variable(input) for input in model_json['input']] - outputs = [Variable(output for output in model_json['output'])] - - edges = dict() - - # consturct all edges - def get_edge(edges, name): - assert isinstance(edges, dict) - if name not in edges: - edges[name] = Edge(name) - return edges[name] - - for input in inputs: - edge = get_edge(edges, input.name) - assert edge.from_node is None - edge.from_node = input - - for output in outputs: - edge = get_edge(edges, output.name) - assert edge.to_node is None - edge.to_node = output - - for operator in operators: - for input_name in operator.inputs: - edge = get_edge(edges, input_name) - assert edge.to_node is None - edge.to_node = operator - for output_name in operator.outputs: - edge = get_edge(edges, output_name) - assert edge.from_node is None - edge.from_node = operator - - # rename node - - def rename_model(model_json): def rename_edge(model_json, old_name, new_name): for node in model_json['node']: @@ -192,7 +44,7 @@ def rename_model(model_json): outputs = node['output'] for idx in range(len(outputs)): if outputs[idx] == old_name: - outputs[idx] == new_name + outputs[idx] = new_name def rename_variables(model, variables): for variable in variables: @@ -238,6 +90,170 @@ def add_links(model_json): model_json['links'] = links +def get_node_links(model_json): + """ + :return: + { + "0": { + "input": [], + "output": [ + 1 + ] + }, + "1": { + "input": [ + 0 + ], + "output": [ + 2 + ] + } + } + """ + node_links = dict() + nodes = model_json['node'] + + # init all nodes + for idx in range(len(nodes)): + node_links[idx] = {'input': set(), 'output': set()} + + for src_idx in range(len(nodes)): + for out_name in nodes[src_idx]['output']: + for dst_idx in range(len(nodes)): + if out_name in nodes[dst_idx]['input']: + node_links[src_idx]['output'].add(dst_idx) + node_links[dst_idx]['input'].add(src_idx) + + # change set to list for json can not serialize set + new_node_links = dict() + for key in node_links: + new_node_links[key] = {'input': list(node_links[key]['input']), + 'output': list(node_links[key]['output'])} + return new_node_links + + +def add_level_to_node_links(node_links): + """ + :return: + { + "0": { + "input": [], + "output": [ + 1 + ], + "level": 1 + }, + "1": { + "input": [ + 0 + ], + "output": [ + 2 + ], + "level": 2 + } + } + """ + # init level + for key in node_links: + node_links[key]['level'] = None + for idx in range(len(node_links)): + # the start up op's level is 1 + if len(node_links[idx]['input']) == 0: + node_links[idx]['level'] = 1 + else: + cur_level = node_links[idx]['level'] + for in_idx in node_links[idx]['input']: + in_level = node_links[in_idx]['level'] + assert in_level is not None + if cur_level is None or in_level >= cur_level: + node_links[idx]['level'] = in_level + 1 + # debug_print(node_links) + + +def get_level_to_all(node_links, model_json): + """ + level_to_nodes {level -> [node_1, node_2]} + """ + level_to_nodes = dict() + for idx in node_links: + level = node_links[idx]['level'] + if level not in level_to_nodes: + level_to_nodes[level] = list() + level_to_nodes[level].append(idx) + # debug_print(level_to_nodes) + + + """ + input_to_level {idx -> level} + level_to_inputs {level -> [input1, input2]} + """ + nodes = model_json['node'] + + input_to_level = dict() + inputs = model_json['input'] + for in_idx in range(len(inputs)): + in_name = inputs[in_idx]['name'] + for node_idx in range(len(nodes)): + if in_name in nodes[node_idx]['input']: + node_level = node_links[node_idx]['level'] + in_level = node_level - 1 + if in_idx not in input_to_level: + input_to_level[in_idx] = in_level + elif input_to_level[in_idx] > in_level: + input_to_level[in_idx] = in_level + + level_to_inputs = dict() + for in_idx in input_to_level: + level = input_to_level[in_idx] + if level not in level_to_inputs: + level_to_inputs[level] = list() + level_to_inputs[level].append(in_idx) + + # debug_print(level_to_inputs) + + # get output level + output_to_level = dict() + outputs = model_json['output'] + for out_idx in range(len(outputs)): + out_name = outputs[out_idx]['name'] + for node_idx in range(len(nodes)): + if out_name in nodes[node_idx]['output']: + node_level = node_links[node_idx]['level'] + out_level = node_level + 1 + if out_level not in output_to_level: + output_to_level[out_idx] = out_level + else: + raise Exception("output " + out_name + "have multiple source") + level_to_outputs = dict() + for out_idx in output_to_level: + level = output_to_level[out_idx] + if level not in level_to_outputs: + level_to_outputs[level] = list() + level_to_outputs[level].append(out_idx) + # debug_print(level_to_outputs) + + level_to_all = dict() + + def init_level(level): + if level not in level_to_all: + level_to_all[level] = {'nodes': list(), 'inputs': list(), 'outputs': list()} + # merge all levels + for level in level_to_nodes: + init_level(level) + level_to_all[level]['nodes'] = level_to_nodes[level] + for level in level_to_inputs: + init_level(level) + level_to_all[level]['inputs'] = level_to_inputs[level] + for level in level_to_outputs: + init_level(level) + level_to_all[level]['outputs'] = level_to_outputs[level] + + debug_print(level_to_all) + + return level_to_all + + def add_edges(json_obj): # TODO(daming-lu): should try to de-duplicate node's out-edge # Currently it is counted twice: 1 as out-edge, 1 as in-edge @@ -266,9 +282,6 @@ def add_edges(json_obj): return json_obj -def add_coordinate(model_json): - - def transform_for_echars(model_json): opItemStyle = { "normal": { @@ -343,12 +356,16 @@ def load_model(model_pb_path): # to json string json_str = MessageToJson(model.graph) - json_obj = json.loads(json_str) - reorganize_inout(json_obj, 'input') - reorganize_inout(json_obj, 'output') - rename_model(json_obj) - add_links(json_obj) - return json.dumps(json_obj, sort_keys=True, indent=4, separators=(',', ': ')) + model_json = json.loads(json_str) + reorganize_inout(model_json, 'input') + reorganize_inout(model_json, 'output') + rename_model(model_json) + add_links(model_json) + debug_print(model_json) + node_links = get_node_links(model_json) + add_level_to_node_links(node_links) + get_level_to_all(node_links, model_json) + return json.dumps(model_json, sort_keys=True, indent=4, separators=(',', ': ')) if __name__ == '__main__': @@ -357,4 +374,4 @@ if __name__ == '__main__': current_path = os.path.abspath(os.path.dirname(sys.argv[0])) # json_str = load_model(current_path + "/mock/inception_v1_model.pb") json_str = load_model(current_path + "/mock/squeezenet_model.pb") - print(json_str) + # print(json_str)