提交 1b41cdfb 编写于 作者: J jiangjiajun

add topo demo for tf

上级 70945df5
......@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.parser.tf_parser import TFParser
from x2paddle.optimizer.tf_optimizer import TFGraphOptimizer
parser = TFParser('/ssd2/Jason/github/X2Paddle/x2paddle/tests/frozen_darknet_yolov3_model.pb',
parser = TFParser('/ssd3/dltpsz/frozen_darknet_yolov3_model.pb',
in_nodes=['inputs'], out_nodes=['output_boxes'],
in_shapes=[[-1, 416, 416, 3]])
optimizer = TFGraphOptimizer()
optimizer.remove_useless_node(parser.tf_graph)
parser.tf_graph.print()
......@@ -72,9 +72,6 @@ class Graph(object):
self.topo_sort.append(node)
idx += 1
for i, tmp in enumerate(self.topo_sort):
print(tmp, self.node_map[tmp].layer_type, self.node_map[tmp].inputs)
def get_node(self, name):
if name not in self.node_map:
raise Exception("Graph doesn't have node [%s]." % name)
......@@ -86,3 +83,24 @@ class Graph(object):
raise Exception("node[{}] not in graph".format(dst))
self.node_map[dst].inputs.append(src)
self.node_map[src].outputs.append(dst)
def remove_node(self, node_name):
if node_name not in self.node_map:
raise Exception("Node[{}] not in graph".format(node_name))
inputs = self.node_map[node_name].inputs
outputs = self.node_map[node_name].outputs
for input in inputs:
idx = self.node_map[input].outputs.index(node_name)
del self.node_map[input].outputs[idx]
for output in outputs:
idx = self.node_map[input].inputs.index(node_name)
del self.node_map[input].inputs[idx]
del self.node_map[node_name]
idx = self.topo_sort.index(node_name)
del self.topo_sort[idx]
def print(self):
for i, tmp in enumerate(self.topo_sort):
print(tmp, self.node_map[tmp].layer_type, self.node_map[tmp].inputs)
......@@ -23,8 +23,9 @@ class TFGraphOptimizer(object):
'NoOp']
def remove_useless_node(self, graph):
for name, node in graph.node_map.items():
for node_name, node in graph.node_map.items():
if node.layer_type in self.useless_op:
graph.remove_node(node_name)
# TODO identity node remove
......
......@@ -44,7 +44,8 @@ class TFGraph(Graph):
else:
self.connect(in_node, layer_name)
super(TFGraph, self).build()
super(TFGraph, self).build()
class TFParser(object):
def __init__(self, pb_model, in_nodes=None, out_nodes=None, in_shapes=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册