From 1b41cdfbb40f663185556d0e4c3bfb1fb0326e84 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Tue, 16 Jul 2019 15:24:20 +0800 Subject: [PATCH] add topo demo for tf --- x2paddle/convert.py | 7 ++++++- x2paddle/core/graph.py | 24 +++++++++++++++++++++--- x2paddle/optimizer/tf_optimizer.py | 3 ++- x2paddle/parser/tf_parser.py | 3 ++- 4 files changed, 31 insertions(+), 6 deletions(-) diff --git a/x2paddle/convert.py b/x2paddle/convert.py index f897dc9..8d0e838 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from x2paddle.parser.tf_parser import TFParser +from x2paddle.optimizer.tf_optimizer import TFGraphOptimizer -parser = TFParser('/ssd2/Jason/github/X2Paddle/x2paddle/tests/frozen_darknet_yolov3_model.pb', +parser = TFParser('/ssd3/dltpsz/frozen_darknet_yolov3_model.pb', in_nodes=['inputs'], out_nodes=['output_boxes'], in_shapes=[[-1, 416, 416, 3]]) +optimizer = TFGraphOptimizer() +optimizer.remove_useless_node(parser.tf_graph) +parser.tf_graph.print() + diff --git a/x2paddle/core/graph.py b/x2paddle/core/graph.py index 18aa683..a2178ec 100644 --- a/x2paddle/core/graph.py +++ b/x2paddle/core/graph.py @@ -72,9 +72,6 @@ class Graph(object): self.topo_sort.append(node) idx += 1 - for i, tmp in enumerate(self.topo_sort): - print(tmp, self.node_map[tmp].layer_type, self.node_map[tmp].inputs) - def get_node(self, name): if name not in self.node_map: raise Exception("Graph doesn't have node [%s]." % name) @@ -86,3 +83,24 @@ class Graph(object): raise Exception("node[{}] not in graph".format(dst)) self.node_map[dst].inputs.append(src) self.node_map[src].outputs.append(dst) + + def remove_node(self, node_name): + if node_name not in self.node_map: + raise Exception("Node[{}] not in graph".format(node_name)) + inputs = self.node_map[node_name].inputs + outputs = self.node_map[node_name].outputs + for input in inputs: + idx = self.node_map[input].outputs.index(node_name) + del self.node_map[input].outputs[idx] + for output in outputs: + idx = self.node_map[input].inputs.index(node_name) + del self.node_map[input].inputs[idx] + del self.node_map[node_name] + + idx = self.topo_sort.index(node_name) + del self.topo_sort[idx] + + def print(self): + for i, tmp in enumerate(self.topo_sort): + print(tmp, self.node_map[tmp].layer_type, self.node_map[tmp].inputs) + diff --git a/x2paddle/optimizer/tf_optimizer.py b/x2paddle/optimizer/tf_optimizer.py index 9d2775c..3db266b 100644 --- a/x2paddle/optimizer/tf_optimizer.py +++ b/x2paddle/optimizer/tf_optimizer.py @@ -23,8 +23,9 @@ class TFGraphOptimizer(object): 'NoOp'] def remove_useless_node(self, graph): - for name, node in graph.node_map.items(): + for node_name, node in graph.node_map.items(): if node.layer_type in self.useless_op: + graph.remove_node(node_name) # TODO identity node remove diff --git a/x2paddle/parser/tf_parser.py b/x2paddle/parser/tf_parser.py index ef7b9fc..87503ff 100644 --- a/x2paddle/parser/tf_parser.py +++ b/x2paddle/parser/tf_parser.py @@ -44,7 +44,8 @@ class TFGraph(Graph): else: self.connect(in_node, layer_name) - super(TFGraph, self).build() + super(TFGraph, self).build() + class TFParser(object): def __init__(self, pb_model, in_nodes=None, out_nodes=None, in_shapes=None): -- GitLab