提交 0a209d49 编写于 作者: S SunAhong1993

fix the satic bug

上级 45bcf759
......@@ -190,7 +190,7 @@ class CaffeOpMapper(OpMapper):
def get_input_name(self, node):
if hasattr(node, "index"):
return node.layer_name + "[{}]".format(node.index)
return "{}_{}".format(node.layer_name, node.index)
else:
return node.layer_name
......@@ -423,9 +423,11 @@ class CaffeOpMapper(OpMapper):
if slice_dim != 1 and axis == 1:
axis = slice_dim
output_shape = node.output_shape
sections_list = []
for s in output_shape:
sections_list = list()
outputs_list = list()
for i, s in enumerate(output_shape):
sections_list.append(s[axis])
outputs_list.append("{}_{}".format(node.layer_name, i))
layer_attrs = {
'num_or_sections': sections_list,
'dim': axis,
......@@ -434,7 +436,7 @@ class CaffeOpMapper(OpMapper):
self.paddle_graph.add_layer(
kernel="fluid.layers.split",
inputs={"input": self.get_input_name(input)},
outputs=[node.layer_name],
outputs=outputs_list,
**layer_attrs)
def Concat(self, node):
......@@ -958,7 +960,7 @@ class CaffeOpMapper(OpMapper):
kwargs[k]["top_k"] = v.top_k
kwargs[k]["eta"] = v.eta
self.paddle_graph.add_layer(
kernel="combination_layer:{}".format(op),
kernel="custom_layer:{}".format(op),
inputs={"inputs": inputs_list},
outputs=[node.layer_name],
**kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册