提交 ac5dd3f5 编写于 作者: S SunAhong1993

add caffe parser v1

上级 7be4d883
...@@ -26,9 +26,3 @@ optimizer.run(parser.tf_graph) ...@@ -26,9 +26,3 @@ optimizer.run(parser.tf_graph)
emitter = TFEmitter(parser) emitter = TFEmitter(parser)
emitter.run() emitter.run()
from x2paddle.parser.caffe_parser import CaffeParser
parser = CaffeParser(
'/home/sunyanfang01/X2Paddle/x2paddle/alexnet.prototxt',
'/home/sunyanfang01/X2Paddle/x2paddle/bvlc_alexnet.caffemodel')
...@@ -27,7 +27,7 @@ class CaffeResolver(object): ...@@ -27,7 +27,7 @@ class CaffeResolver(object):
def import_caffepb(self): def import_caffepb(self):
p = os.path.realpath(__file__) p = os.path.realpath(__file__)
p = os.path.dirname(p) p = os.path.dirname(p)
p = os.path.join(p, '../proto') p = os.path.join(p, './proto')
sys.path.insert(0, p) sys.path.insert(0, p)
import caffe_pb2 import caffe_pb2
return caffe_pb2 return caffe_pb2
...@@ -68,12 +68,8 @@ class CaffeGraphNode(GraphNode): ...@@ -68,12 +68,8 @@ class CaffeGraphNode(GraphNode):
class CaffeGraph(Graph): class CaffeGraph(Graph):
def __init__(self, resolver, model, params): def __init__(self, model, params):
self.params = params self.params = params
if resolver.has_pycaffe():
self.did_use_pb = False
else:
self.did_use_pb = True
super(CaffeGraph, self).__init__(model) super(CaffeGraph, self).__init__(model)
def filter_layers(self, layers): def filter_layers(self, layers):
...@@ -103,44 +99,6 @@ class CaffeGraph(Graph): ...@@ -103,44 +99,6 @@ class CaffeGraph(Graph):
print(layer.name) print(layer.name)
return filtered_layers return filtered_layers
def adjust_parameters(self, node, data):
if not self.did_use_pb:
return data
# When using the protobuf-backend, each parameter initially has four dimensions.
# In certain cases (like FC layers), we want to eliminate the singleton dimensions.
# This implementation takes care of the common cases. However, it does leave the
# potential for future issues.
# The Caffe-backend does not suffer from this problem.
data = list(data)
squeeze_indices = [1] # Squeeze biases.
if node.layer_type == 'InnerProduct':
squeeze_indices.append(0) # Squeeze FC.
for idx in squeeze_indices:
if idx >= len(data):
continue
d = data[idx]
assert len(
d.shape
) == 4, 'invalid shape[%s] from caffe when adjust_parameters' % (
str(d.shape))
shape_old = d.shape
sq_axis = None
if idx == 0:
sq_axis = (0, 1)
elif idx == 1:
sq_axis = (0, 1, 2)
else:
continue
data[idx] = np.squeeze(d, axis=sq_axis)
shape_new = data[idx].shape
return data
def build(self): def build(self):
layers = self.model.layers or self.model.layer layers = self.model.layers or self.model.layer
layers = self.filter_layers(layers) layers = self.filter_layers(layers)
...@@ -167,22 +125,26 @@ class CaffeGraph(Graph): ...@@ -167,22 +125,26 @@ class CaffeGraph(Graph):
data.name = self.model.input[0] data.name = self.model.input[0]
data.top[0] = self.model.input[0] data.top[0] = self.model.input[0]
top_layer = {}
for layer in layers: for layer in layers:
self.node_map[layer.name] = CaffeGraphNode(layer) self.node_map[layer.name] = CaffeGraphNode(layer)
for in_name in layer.bottom:
for layer_name, node in self.node_map.items(): if in_name in top_layer:
for in_node in node.layer.bottom: self.connect(top_layer[in_name][-1], layer.name)
if in_node in self.node_map:
self.connect(in_node, layer_name)
else: else:
raise Exception( raise Exception(
'input[{}] of node[{}] does not exist in node_map'. 'input[{}] of node[{}] does not exist in node_map'.
format(in_node, layer_name)) format(in_name, layer.name))
for out_name in layer.top:
if out_name not in top_layer:
top_layer[out_name] = [layer.name]
else:
top_layer[out_name].append(layer.name)
for layer_name, data in self.params: for layer_name, data in self.params:
if layer_name in self.node_map: if layer_name in self.node_map:
node = self.node_map[layer_name] node = self.node_map[layer_name]
node.set_params(self.adjust_parameters(node, data)) node.set_params(data)
else: else:
notice('Ignoring parameters for non-existent layer: %s' % \ notice('Ignoring parameters for non-existent layer: %s' % \
layer_name) layer_name)
...@@ -201,7 +163,7 @@ class CaffeParser(object): ...@@ -201,7 +163,7 @@ class CaffeParser(object):
text_format.Merge(proto_str, self.net) text_format.Merge(proto_str, self.net)
self.load() self.load()
self.caffe_graph = CaffeGraph(self.resolver, self.net, self.params) self.caffe_graph = CaffeGraph(self.net, self.params)
self.caffe_graph.build() self.caffe_graph.build()
def load(self): def load(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册