未验证 提交 7aafc2ac 编写于 作者: J Jason 提交者: GitHub

Merge pull request #87 from SunAhong1993/develop

for new style prototxt and accelerate
...@@ -47,7 +47,7 @@ class CaffeResolver(object): ...@@ -47,7 +47,7 @@ class CaffeResolver(object):
class CaffeGraphNode(GraphNode): class CaffeGraphNode(GraphNode):
def __init__(self, layer, layer_name=None): def __init__(self, layer, type_str, layer_name=None):
if layer_name is None: if layer_name is None:
super(CaffeGraphNode, super(CaffeGraphNode,
self).__init__(layer, self).__init__(layer,
...@@ -56,7 +56,7 @@ class CaffeGraphNode(GraphNode): ...@@ -56,7 +56,7 @@ class CaffeGraphNode(GraphNode):
super(CaffeGraphNode, super(CaffeGraphNode,
self).__init__(layer, self).__init__(layer,
layer_name.replace('/', '_').replace('-', '_')) layer_name.replace('/', '_').replace('-', '_'))
self.layer_type = layer.type self.layer_type = type_str
self.fluid_code = FluidCode() self.fluid_code = FluidCode()
self.data = None self.data = None
...@@ -65,8 +65,9 @@ class CaffeGraphNode(GraphNode): ...@@ -65,8 +65,9 @@ class CaffeGraphNode(GraphNode):
class CaffeGraph(Graph): class CaffeGraph(Graph):
def __init__(self, model, params): def __init__(self, model, params, caffe_pb):
self.params = params self.params = params
self.caffe_pb = caffe_pb
super(CaffeGraph, self).__init__(model) super(CaffeGraph, self).__init__(model)
def filter_layers(self, layers): def filter_layers(self, layers):
...@@ -75,6 +76,9 @@ class CaffeGraph(Graph): ...@@ -75,6 +76,9 @@ class CaffeGraph(Graph):
filtered_layer_names = set() filtered_layer_names = set()
filtered_layers = [] filtered_layers = []
for layer in layers: for layer in layers:
if hasattr(layer, 'input'):
continue
type_str = self.get_layer_type(layer)
phase = 'test' phase = 'test'
if len(layer.include): if len(layer.include):
phase = phase_map[layer.include[0].phase] phase = phase_map[layer.include[0].phase]
...@@ -85,7 +89,7 @@ class CaffeGraph(Graph): ...@@ -85,7 +89,7 @@ class CaffeGraph(Graph):
# test-time networks. These are just ignored. We'll # test-time networks. These are just ignored. We'll
# filter them out here. # filter them out here.
if (not exclude) and (phase == 'test'): if (not exclude) and (phase == 'test'):
exclude = (layer.type == 'Dropout') exclude = (type_str == 'Dropout')
if not exclude: if not exclude:
filtered_layers.append(layer) filtered_layers.append(layer)
# Guard against dupes. # Guard against dupes.
...@@ -95,13 +99,85 @@ class CaffeGraph(Graph): ...@@ -95,13 +99,85 @@ class CaffeGraph(Graph):
print('The filter layer:' + layer.name) print('The filter layer:' + layer.name)
return filtered_layers return filtered_layers
def generate_input_layer(self, dims, index):
dim_str = ''
for dim in dims:
dim_str += 'dim: {}\n'.format(str(dim))
input_str = 'layer {\n'
input_str += 'name: \"{}\"\n '.format(str(self.model.input[index]))
input_str += 'type: "Input"\n'
input_str += 'top: \"{}\"\n'.format(str(self.model.input[index]))
input_str += 'input_param {\n'
input_str += 'shape {\n'
input_str += dim_str
input_str += '}}}'
input_str = str.encode(input_str)
net = self.caffe_pb.NetParameter()
text_format.Merge(input_str, net)
return net.layers or net.layer
def input2layers(self, input_layers=[]):
inputs_num = len(self.model.input)
if inputs_num != 0:
input_dims_num = len(self.model.input_dim)
if input_dims_num != 0:
if input_dims_num > 0 and input_dims_num != inputs_num * 4:
raise Error('invalid input_dim[%d] param in prototxt' %
(input_dims_num))
for i in range(inputs_num):
dims = self.model.input_dim[i * 4:(i + 1) * 4]
l = self.generate_input_layer(dims, i)
input_layers.append(l[0])
else:
for i in range(inputs_num):
dims = self.model.input_shape[i].dim[0:4]
l = self.generate_input_layer(dims, i)
input_layers.append(l[0])
def transform_input_layers(self, layers, input_layers=[]):
for layer in layers:
if hasattr(layer, 'input'):
input_dims_num = len(layers.input_dim)
if input_dims_num > 0 and input_dims_num != 4:
raise Error('invalid input_dim[%d] param in prototxt' %
(input_dims_num))
dims = self.model.input_dim[0:4]
l = self.generate_input_layer(dims, i)
input_layers.append(l[0])
def get_layer_type(self, layer):
if isinstance(layer.type, int):
enum_values = self.caffe_pb._V1LAYERPARAMETER_LAYERTYPE.values
vals = [val for val in enum_values if val.number == layer.type]
part = vals[0].name.split('_')
part = [s.capitalize() for s in part]
type_str = ''
type_str = type_str.join(part)
if 'relu' in type_str.lower():
type_str = type_str.replace('elu', 'eLU')
elif type_str.lower() == 'lrn':
type_str = 'LRN'
return type_str
else:
return layer.type
def build(self): def build(self):
layers = self.model.layers or self.model.layer layers = self.model.layers or self.model.layer
layers = self.filter_layers(layers) layers = self.filter_layers(layers)
input_layers = []
self.input2layers(input_layers)
self.transform_input_layers(layers, input_layers)
layers = input_layers + layers
top_layer = {} top_layer = {}
for layer in layers: for layer in layers:
self.node_map[layer.name] = CaffeGraphNode(layer) if hasattr(layer, 'input'):
continue
type_str = self.get_layer_type(layer)
self.node_map[layer.name] = CaffeGraphNode(layer, type_str)
for in_name in layer.bottom: for in_name in layer.bottom:
if in_name in top_layer: if in_name in top_layer:
self.connect(top_layer[in_name][-1], layer.name) self.connect(top_layer[in_name][-1], layer.name)
...@@ -146,19 +222,26 @@ class CaffeDecoder(object): ...@@ -146,19 +222,26 @@ class CaffeDecoder(object):
self.resolver = CaffeResolver(caffe_proto=caffe_proto) self.resolver = CaffeResolver(caffe_proto=caffe_proto)
self.net = self.resolver.NetParameter() self.net = self.resolver.NetParameter()
with open(proto_path, 'rb') as proto_file: with open(proto_path, 'rb') as proto_file:
proto_str = self.old2new(proto_file) proto_str = proto_file.read()
text_format.Merge(proto_str, self.net) text_format.Merge(proto_str, self.net)
self.load_using_pb() self.load_using_pb()
self.caffe_graph = CaffeGraph(self.net, self.params)
self.caffe_graph = CaffeGraph(self.net, self.params,
self.resolver.caffepb)
self.caffe_graph.build() self.caffe_graph.build()
def load_using_pb(self): def load_using_pb(self):
data = self.resolver.NetParameter() data = self.resolver.NetParameter()
data.MergeFromString(open(self.model_path, 'rb').read()) data.MergeFromString(open(self.model_path, 'rb').read())
pair = lambda layer: (layer.name, self.normalize_pb_data(layer)) pair = lambda layer: (layer.name, self.normalize_pb_data(layer))
layers = data.layers or data.layer layers = data.layers or data.layer
import time
start = time.time()
self.params = [pair(layer) for layer in layers if layer.blobs] self.params = [pair(layer) for layer in layers if layer.blobs]
end = time.time()
print('cost:', str(end - start))
def normalize_pb_data(self, layer): def normalize_pb_data(self, layer):
transformed = [] transformed = []
...@@ -171,72 +254,8 @@ class CaffeDecoder(object): ...@@ -171,72 +254,8 @@ class CaffeDecoder(object):
c_i = blob.channels c_i = blob.channels
h = blob.height h = blob.height
w = blob.width w = blob.width
data = np.array(blob.data, dtype=np.float32).reshape(c_o, c_i, h, w) data = np.asarray(list(blob.data),
dtype=np.float32).reshape(c_o, c_i, h, w)
transformed.append(data) transformed.append(data)
return transformed return transformed
def old2new(self, proto_file):
part1_str = ''
part2_str = ''
part3_str = ''
is_input = False
dims = []
line = proto_file.readline()
print('Check if it is a new style of caffe...')
while line:
l_str = bytes.decode(line)
if l_str.replace(' ', '').startswith('input:'):
part2_str += 'layer {\n'
part2_str += (
' name: ' +
l_str.strip().replace(' ', '').split('input:')[-1] + '\n')
part2_str += ' type: \"Input\"\n'
part2_str += (
' top: ' +
l_str.strip().replace(' ', '').split('input:')[-1] + '\n')
is_input = True
line = proto_file.readline()
continue
elif l_str.replace(' ', '').startswith('input_dim:'):
dims.append(
int(l_str.strip().replace(' ', '').split('input_dim:')[-1]))
if len(dims) == 4:
part2_str += ' input_param { shape: { dim: ' + str(dims[0]) + \
' dim: ' + str(dims[1]) + \
' dim: ' + str(dims[2]) + \
' dim: ' + str(dims[3]) + ' } }\n'
dims = []
part2_str += '}\n'
line = proto_file.readline()
if bytes.decode(line).replace(' ', '').startswith('}'):
line = proto_file.readline()
continue
elif l_str.replace(' ', '').startswith('input_shape'):
part2_str += l_str.replace('input_shape',
'input_param { shape: ')
l_str = bytes.decode(proto_file.readline())
while l_str:
if '}' in l_str:
part2_str += l_str + '\n}\n}'
break
else:
part2_str += l_str
l_str = bytes.decode(proto_file.readline())
line = proto_file.readline()
continue
if not is_input:
part1_str += bytes.decode(line)
else:
part3_str += bytes.decode(line)
line = proto_file.readline()
out = part1_str + part2_str + part3_str
layer_str = 'layer{'
part = out.split(layer_str)
if len(part) == 1:
layer_str = 'layer {'
part = out.split(layer_str)
for i in range(len(part)):
if part[i].strip().replace(' ', '') == '' or part[i].count(':') > 1:
continue
out = out.replace(layer_str + part[i], part[i].replace(' ', ''))
return str.encode(out)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册