提交 c236c3c2 编写于 作者: S SunAhong1993

fix the caffe name bug

上级 f1999b34
...@@ -104,6 +104,8 @@ class CaffeGraph(Graph): ...@@ -104,6 +104,8 @@ class CaffeGraph(Graph):
if not exclude: if not exclude:
filtered_layers.append(layer) filtered_layers.append(layer)
# Guard against dupes. # Guard against dupes.
if layer.name in filtered_layer_names:
layer.name += "_0"
assert layer.name not in filtered_layer_names assert layer.name not in filtered_layer_names
filtered_layer_names.add(layer.name) filtered_layer_names.add(layer.name)
else: else:
...@@ -224,7 +226,7 @@ class CaffeGraph(Graph): ...@@ -224,7 +226,7 @@ class CaffeGraph(Graph):
assert input_node_name in self.node_map, 'The {} isn\'t a valid node'.format( assert input_node_name in self.node_map, 'The {} isn\'t a valid node'.format(
name) name)
input_node = self.node_map[input_node_name] input_node = self.node_map[input_node_name]
if len(input_node.layer.top) > 1: if len(input_node.layer.top) > 1 and input_node.layer_type != "Input":
need_idx = list(input_node.layer.top).index(node.layer.bottom[idx]) need_idx = list(input_node.layer.top).index(node.layer.bottom[idx])
name = input_node_name + ':' + str(need_idx) name = input_node_name + ':' + str(need_idx)
else: else:
......
...@@ -81,7 +81,6 @@ class CaffeOpMapper(OpMapper): ...@@ -81,7 +81,6 @@ class CaffeOpMapper(OpMapper):
input_shape.append(last_node.output_shape[idx]) input_shape.append(last_node.output_shape[idx])
node.input_shape = input_shape node.input_shape = input_shape
func_name = 'shape_' + node.layer_type.lower() func_name = 'shape_' + node.layer_type.lower()
if is_fluid_op: if is_fluid_op:
node.output_shape = getattr(caffe_shape, func_name)(node.layer, node.output_shape = getattr(caffe_shape, func_name)(node.layer,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册