diff --git a/x2paddle/convert.py b/x2paddle/convert.py index f897dc9bda711a3936c055f5a3822320f7f21674..8d0e838018e8b6a5337d8d0bd70613e7b7944f23 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from x2paddle.parser.tf_parser import TFParser +from x2paddle.optimizer.tf_optimizer import TFGraphOptimizer -parser = TFParser('/ssd2/Jason/github/X2Paddle/x2paddle/tests/frozen_darknet_yolov3_model.pb', +parser = TFParser('/ssd3/dltpsz/frozen_darknet_yolov3_model.pb', in_nodes=['inputs'], out_nodes=['output_boxes'], in_shapes=[[-1, 416, 416, 3]]) +optimizer = TFGraphOptimizer() +optimizer.remove_useless_node(parser.tf_graph) +parser.tf_graph.print() + diff --git a/x2paddle/core/graph.py b/x2paddle/core/graph.py index 18aa6831bf313ea19d629c2941357d607d468021..a2178ec1f837c59a009a2999dbc94ccf58061839 100644 --- a/x2paddle/core/graph.py +++ b/x2paddle/core/graph.py @@ -72,9 +72,6 @@ class Graph(object): self.topo_sort.append(node) idx += 1 - for i, tmp in enumerate(self.topo_sort): - print(tmp, self.node_map[tmp].layer_type, self.node_map[tmp].inputs) - def get_node(self, name): if name not in self.node_map: raise Exception("Graph doesn't have node [%s]." % name) @@ -86,3 +83,24 @@ class Graph(object): raise Exception("node[{}] not in graph".format(dst)) self.node_map[dst].inputs.append(src) self.node_map[src].outputs.append(dst) + + def remove_node(self, node_name): + if node_name not in self.node_map: + raise Exception("Node[{}] not in graph".format(node_name)) + inputs = self.node_map[node_name].inputs + outputs = self.node_map[node_name].outputs + for input in inputs: + idx = self.node_map[input].outputs.index(node_name) + del self.node_map[input].outputs[idx] + for output in outputs: + idx = self.node_map[input].inputs.index(node_name) + del self.node_map[input].inputs[idx] + del self.node_map[node_name] + + idx = self.topo_sort.index(node_name) + del self.topo_sort[idx] + + def print(self): + for i, tmp in enumerate(self.topo_sort): + print(tmp, self.node_map[tmp].layer_type, self.node_map[tmp].inputs) + diff --git a/x2paddle/optimizer/tf_optimizer.py b/x2paddle/optimizer/tf_optimizer.py index 9d2775ce239e8d6516410409f2cc5a15e0a6fec9..3db266b61df4cc8e664b7796f1039756eb85b61b 100644 --- a/x2paddle/optimizer/tf_optimizer.py +++ b/x2paddle/optimizer/tf_optimizer.py @@ -23,8 +23,9 @@ class TFGraphOptimizer(object): 'NoOp'] def remove_useless_node(self, graph): - for name, node in graph.node_map.items(): + for node_name, node in graph.node_map.items(): if node.layer_type in self.useless_op: + graph.remove_node(node_name) # TODO identity node remove diff --git a/x2paddle/parser/tf_parser.py b/x2paddle/parser/tf_parser.py index ef7b9fc201792c6623dacc3bc0936bbd333138ba..87503ffbcf7e46f84d90f1acea79f94122654fcd 100644 --- a/x2paddle/parser/tf_parser.py +++ b/x2paddle/parser/tf_parser.py @@ -44,7 +44,8 @@ class TFGraph(Graph): else: self.connect(in_node, layer_name) - super(TFGraph, self).build() + super(TFGraph, self).build() + class TFParser(object): def __init__(self, pb_model, in_nodes=None, out_nodes=None, in_shapes=None):