提交 8cd5e764 编写于 作者: S SunAhong1993

reduce code

上级 983669f5
......@@ -99,6 +99,23 @@ class CaffeGraph(Graph):
print('The filter layer:' + layer.name)
return filtered_layers
def generate_input_layer(self, dims, index):
dim_str = ''
for dim in dims:
dim_str += 'dim: {}\n'.format(str(dim))
input_str = 'layer {\n'
input_str += 'name: \"{}\"\n '.format(str(self.model.input[index]))
input_str += 'type: "Input"\n'
input_str += 'top: \"{}\"\n'.format(str(self.model.input[index]))
input_str += 'input_param {\n'
input_str += 'shape {\n'
input_str += dim_str
input_str += '}}}'
input_str = str.encode(input_str)
net = self.caffe_pb.NetParameter()
text_format.Merge(input_str, net)
return net.layers or net.layer
def input2layers(self, input_layers=[]):
inputs_num = len(self.model.input)
if inputs_num != 0:
......@@ -109,44 +126,12 @@ class CaffeGraph(Graph):
(input_dims_num))
for i in range(inputs_num):
dims = self.model.input_dim[i * 4:(i + 1) * 4]
dim_str = ''
for dim in dims:
dim_str += 'dim: {}\n'.format(str(dim))
input_str = 'layer {\n'
input_str += 'name: \"{}\"\n '.format(
str(self.model.input[i]))
input_str += 'type: "Input"\n'
input_str += 'top: \"{}\"\n'.format(str(
self.model.input[i]))
input_str += 'input_param {\n'
input_str += 'shape {\n'
input_str += dim_str
input_str += '}}}'
input_str = str.encode(input_str)
net = self.caffe_pb.NetParameter()
text_format.Merge(input_str, net)
l = net.layers or net.layer
l = self.generate_input_layer(dims, i)
input_layers.append(l[0])
else:
for i in range(inputs_num):
dims = self.model.input_shape[i].dim[0:4]
dim_str = ''
for dim in dims:
dim_str += 'dim: {}\n'.format(str(dim))
input_str = 'layer {\n'
input_str += 'name: \"{}\"\n '.format(
str(self.model.input[i]))
input_str += 'type: "Input"\n'
input_str += 'top: \"{}\"\n'.format(str(
self.model.input[i]))
input_str += 'input_param {\n'
input_str += 'shape {\n'
input_str += dim_str
input_str += '}}}'
input_str = str.encode(input_str)
net = self.caffe_pb.NetParameter()
text_format.Merge(input_str, net)
l = net.layers or net.layer
l = self.generate_input_layer(dims, i)
input_layers.append(l[0])
def transform_input_layers(self, layers, input_layers=[]):
......@@ -157,21 +142,7 @@ class CaffeGraph(Graph):
raise Error('invalid input_dim[%d] param in prototxt' %
(input_dims_num))
dims = self.model.input_dim[0:4]
dim_str = ''
for dim in dims:
dim_str += 'dim: {}\n'.format(str(dim))
input_str = 'layer {\n'
input_str += 'name: \"{}\"\n '.format(str(self.model.input[i]))
input_str += 'type: "Input"\n'
input_str += 'top: \"{}\"\n'.format(str(self.model.input[i]))
input_str += 'input_param {\n'
input_str += 'shape {\n'
input_str += dim_str
input_str += '}}}'
input_str = str.encode(input_str)
net = self.caffe_pb.NetParameter()
text_format.Merge(input_str, net)
l = net.layers or net.layer
l = self.generate_input_layer(dims, i)
input_layers.append(l[0])
def get_layer_type(self, layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册