提交 ac001819 编写于 作者: J jiangjiajun

modify caffe

上级 c0dabbac
......@@ -27,34 +27,12 @@ class CaffeOpMapper(OpMapper):
self.weights = dict()
resolver = decoder.resolver
self.used_custom_layers = {}
self.inputs = self.graph.input_nodes
self.outputs = self.graph.output_nodes
if resolver.has_pycaffe():
self.did_use_pb = False
else:
self.did_use_pb = True
def op_checker(self):
unsupported_ops = set()
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
op = node.layer_type
if not hasattr(self, op) and op not in custom_layers:
unsupported_ops.add(op)
if len(unsupported_ops) == 0:
return True
else:
print("There are {} ops not supported yet, list as below".format(
len(unsupported_ops)))
for op in unsupported_ops:
print(op)
return False
def run(self):
print("Total nodes: {}".format(len(self.graph.topo_sort)))
# check if ops in model are all supported
if not self.op_checker():
raise Exception("Model are not supported yet.")
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
op = node.layer_type
......@@ -67,13 +45,22 @@ class CaffeOpMapper(OpMapper):
self.deal_custom_layer(node)
else:
raise Exception("Model are not supported yet.")
for key in self.used_custom_layers:
self.net_code.append(self.used_custom_layers[key])
for i in range(len(self.graph.topo_sort)):
node_name = self.graph.topo_sort[i]
def op_checker(self):
unsupported_ops = set()
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
self.net_code += node.fluid_code.gen_codes()
op = node.layer_type
if not hasattr(self, op) and op not in custom_layers:
unsupported_ops.add(op)
if len(unsupported_ops) == 0:
return True
else:
print("There are {} ops not supported yet, list as below".format(
len(unsupported_ops)))
for op in unsupported_ops:
print(op)
return False
def set_shape(self, node, is_fluid_op=True):
inputs = node.inputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册