提交 ac001819 编写于 作者: J jiangjiajun

modify caffe

上级 c0dabbac
...@@ -27,34 +27,12 @@ class CaffeOpMapper(OpMapper): ...@@ -27,34 +27,12 @@ class CaffeOpMapper(OpMapper):
self.weights = dict() self.weights = dict()
resolver = decoder.resolver resolver = decoder.resolver
self.used_custom_layers = {} self.used_custom_layers = {}
self.inputs = self.graph.input_nodes
self.outputs = self.graph.output_nodes
if resolver.has_pycaffe(): if resolver.has_pycaffe():
self.did_use_pb = False self.did_use_pb = False
else: else:
self.did_use_pb = True self.did_use_pb = True
def op_checker(self):
unsupported_ops = set()
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
op = node.layer_type
if not hasattr(self, op) and op not in custom_layers:
unsupported_ops.add(op)
if len(unsupported_ops) == 0:
return True
else:
print("There are {} ops not supported yet, list as below".format(
len(unsupported_ops)))
for op in unsupported_ops:
print(op)
return False
def run(self):
print("Total nodes: {}".format(len(self.graph.topo_sort))) print("Total nodes: {}".format(len(self.graph.topo_sort)))
# check if ops in model are all supported
if not self.op_checker():
raise Exception("Model are not supported yet.")
for node_name in self.graph.topo_sort: for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
op = node.layer_type op = node.layer_type
...@@ -67,13 +45,22 @@ class CaffeOpMapper(OpMapper): ...@@ -67,13 +45,22 @@ class CaffeOpMapper(OpMapper):
self.deal_custom_layer(node) self.deal_custom_layer(node)
else: else:
raise Exception("Model are not supported yet.") raise Exception("Model are not supported yet.")
for key in self.used_custom_layers:
self.net_code.append(self.used_custom_layers[key])
for i in range(len(self.graph.topo_sort)): def op_checker(self):
node_name = self.graph.topo_sort[i] unsupported_ops = set()
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
self.net_code += node.fluid_code.gen_codes() op = node.layer_type
if not hasattr(self, op) and op not in custom_layers:
unsupported_ops.add(op)
if len(unsupported_ops) == 0:
return True
else:
print("There are {} ops not supported yet, list as below".format(
len(unsupported_ops)))
for op in unsupported_ops:
print(op)
return False
def set_shape(self, node, is_fluid_op=True): def set_shape(self, node, is_fluid_op=True):
inputs = node.inputs inputs = node.inputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册