提交 a829051a 编写于 作者: J jiangjiajun

structure modify

上级 86170c37
......@@ -61,7 +61,6 @@ class Graph(object):
num_inputs = dict()
for name, node in self.node_map.items():
num_inputs[name] = len(node.inputs)
print(len(self.node_map))
self.topo_sort = self.input_nodes[:]
idx = 0
......
......@@ -23,6 +23,21 @@ class OpMapper(object):
self.net_code = list()
self.weights = dict()
def op_checker(self):
unsupported_ops = set()
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
op = node.layer_type
if not hasattr(self, op):
unsupported_ops.add(op)
if len(unsupported_ops) == 0:
return True
else:
print("There are {} ops not supported yet, list as below")
for op in unsupported_ops:
print(op)
return False
def add_codes(self, codes, indent=0):
if isinstance(codes, list):
for code in codes:
......
......@@ -31,6 +31,11 @@ class CaffeOpMapper(OpMapper):
def run(self):
print("Total nodes: {}".format(len(self.graph.topo_sort)))
# check if ops in model are all supported
if not self.op_checker():
raise Exception("Model are not supported yet.")
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
op = node.layer_type
......
......@@ -27,6 +27,11 @@ class TFOpMapper(OpMapper):
def run(self):
print("Total nodes: {}".format(len(self.graph.topo_sort)))
# check if ops in model are all supported
if not self.op_checker():
raise Exception("Model are not supported yet.")
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
op = node.layer_type
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册