提交 8248159a 编写于 作者: S SunAhong1993

fix the bug

上级 ea037db6
......@@ -49,11 +49,13 @@ class Graph(object):
def _make_input_nodes(self):
for name, node in self.node_map.items():
name = name.replace('/', '_').replace('-', '_')
if len(node.inputs) == 0:
self.input_nodes.append(name)
def _make_output_nodes(self):
for name, node in self.node_map.items():
name = name.replace('/', '_').replace('-', '_')
if len(node.outputs) == 0:
self.output_nodes.append(name)
......
......@@ -97,7 +97,6 @@ class CaffeGraph(Graph):
phase_map = {0: 'train', 1: 'test'}
filtered_layer_names = set()
filtered_layers = []
print('The filter layer:')
for layer in layers:
phase = 'test'
if len(layer.include):
......@@ -116,7 +115,7 @@ class CaffeGraph(Graph):
assert layer.name not in filtered_layer_names
filtered_layer_names.add(layer.name)
else:
print(layer.name)
print('The filter layer:' + layer.name)
return filtered_layers
def build(self):
......
......@@ -18,19 +18,19 @@ def convolutiondepthwise_shape(input_shape,
[k_h, k_w] = [1, 1]
if isinstance(kernel_size, numbers.Number):
[k_h, k_w] = [kernel_size] * 2
elif isinstance(kernel_size, list) and len(kernel_size) > 0:
elif len(kernel_size) > 0:
k_h = kernel_h if kernel_h else kernel_size[0]
k_w = kernel_w if kernel_w else kernel_size[len(kernel_size) - 1]
[s_h, s_w] = [1, 1]
if isinstance(stride, numbers.Number):
[s_h, s_w] = [stride] * 2
elif isinstance(stride, list) and len(stride) > 0:
elif len(stride) > 0:
s_h = stride_h if stride_h else stride[0]
s_w = stride_w if stride_w else stride[len(stride) - 1]
[p_h, p_w] = [0, 0]
if isinstance(pad, numbers.Number):
[p_h, p_w] = [pad] * 2
elif isinstance(pad, list) and len(pad) > 0:
elif len(pad) > 0:
p_h = pad_h if pad_h else pad[0]
p_w = pad_w if pad_w else pad[len(pad) - 1]
dila_len = len(dilation)
......@@ -69,22 +69,23 @@ def convolutiondepthwise_layer(inputs,
stride_w=None,
input_shape=None,
name=None):
import numbers
[k_h, k_w] = [1, 1]
if isinstance(kernel_size, numbers.Number):
[k_h, k_w] = [kernel_size] * 2
elif isinstance(kernel_size, list) and len(kernel_size) > 0:
elif len(kernel_size) > 0:
k_h = kernel_h if kernel_h else kernel_size[0]
k_w = kernel_w if kernel_w else kernel_size[len(kernel_size) - 1]
[s_h, s_w] = [1, 1]
if isinstance(stride, numbers.Number):
[s_h, s_w] = [stride] * 2
elif isinstance(stride, list) and len(stride) > 0:
elif len(stride) > 0:
s_h = stride_h if stride_h else stride[0]
s_w = stride_w if stride_w else stride[len(stride) - 1]
[p_h, p_w] = [0, 0]
if isinstance(pad, numbers.Number):
[p_h, p_w] = [pad] * 2
elif isinstance(pad, list) and len(pad) > 0:
elif len(pad) > 0:
p_h = pad_h if pad_h else pad[0]
p_w = pad_w if pad_w else pad[len(pad) - 1]
input = inputs[0]
......
......@@ -123,22 +123,21 @@ class CaffeOpMapper(OpMapper):
[k_h, k_w] = [1, 1]
if isinstance(params.kernel_size, numbers.Number):
[k_h, k_w] = [params.kernel_size] * 2
elif isinstance(params.kernel_size,
list) and len(params.kernel_size) > 0:
elif len(params.kernel_size) > 0:
k_h = params.kernel_h if params.kernel_h else params.kernel_size[0]
k_w = params.kernel_w if params.kernel_w else params.kernel_size[
len(params.kernel_size) - 1]
[s_h, s_w] = [1, 1]
if isinstance(params.stride, numbers.Number):
[s_h, s_w] = [params.stride] * 2
elif isinstance(params.stride, list) and len(params.stride) > 0:
elif len(params.stride) > 0:
s_h = params.stride_h if params.stride_h else params.stride[0]
s_w = params.stride_w if params.stride_w else params.stride[
len(params.stride) - 1]
[p_h, p_w] = [0, 0]
if isinstance(params.pad, numbers.Number):
[p_h, p_w] = [params.pad] * 2
elif isinstance(params.pad, list) and len(params.pad) > 0:
elif len(params.pad) > 0:
p_h = params.pad_h if params.pad_h else params.pad[0]
p_w = params.pad_w if params.pad_w else params.pad[len(params.pad) -
1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册