提交 fc6042a3 编写于 作者: S SunAhong1993

modify caffe

上级 b5281541
...@@ -527,10 +527,10 @@ class CaffeOpMapper(OpMapper): ...@@ -527,10 +527,10 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node.inputs node.inputs
) >= 1, "The count of Concat node\'s input is not more than 1." ) >= 1, "The count of Concat node\'s input is not more than 1."
inputs_list = dict() inputs_list = list()
for i in range(len(node.inputs)): for i in range(len(node.inputs)):
input = self.graph.get_bottom_node(node, idx=i, copy=True) input = self.graph.get_bottom_node(node, idx=i, copy=True)
inputs_list[i] = self.get_input_name(input) inputs_list.append(self.get_input_name(input))
params = node.layer.concat_param params = node.layer.concat_param
axis = params.axis axis = params.axis
layer_attrs = {'axis': axis} layer_attrs = {'axis': axis}
...@@ -574,11 +574,9 @@ class CaffeOpMapper(OpMapper): ...@@ -574,11 +574,9 @@ class CaffeOpMapper(OpMapper):
mode_bool = params.channel_shared mode_bool = params.channel_shared
output_shape = node.output_shape[0] output_shape = node.output_shape[0]
if mode_bool: if mode_bool:
mode = 'all' num_parameters = 1
channel = None
else: else:
mode = 'channel' num_parameters = output_shape[1]
channel = output_shape[1]
data = node.data data = node.data
self.params[prelu_name + '._weight'] = np.squeeze(data[0]) self.params[prelu_name + '._weight'] = np.squeeze(data[0])
assert data is not None, "The parameter of {} (type is {}) is not set. You need to use python package of caffe to set the default value.".format( assert data is not None, "The parameter of {} (type is {}) is not set. You need to use python package of caffe to set the default value.".format(
...@@ -587,8 +585,7 @@ class CaffeOpMapper(OpMapper): ...@@ -587,8 +585,7 @@ class CaffeOpMapper(OpMapper):
"paddle.nn.PReLU", "paddle.nn.PReLU",
inputs={"input": self.get_input_name(input)}, inputs={"input": self.get_input_name(input)},
outputs=layer_outputs, outputs=layer_outputs,
channel=channel, num_parameters=num_parameters)
mode=string(mode))
def Eltwise(self, node): def Eltwise(self, node):
assert len( assert len(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册