提交 fc6042a3 编写于 作者: S SunAhong1993

modify caffe

上级 b5281541
......@@ -527,10 +527,10 @@ class CaffeOpMapper(OpMapper):
assert len(
node.inputs
) >= 1, "The count of Concat node\'s input is not more than 1."
inputs_list = dict()
inputs_list = list()
for i in range(len(node.inputs)):
input = self.graph.get_bottom_node(node, idx=i, copy=True)
inputs_list[i] = self.get_input_name(input)
inputs_list.append(self.get_input_name(input))
params = node.layer.concat_param
axis = params.axis
layer_attrs = {'axis': axis}
......@@ -574,11 +574,9 @@ class CaffeOpMapper(OpMapper):
mode_bool = params.channel_shared
output_shape = node.output_shape[0]
if mode_bool:
mode = 'all'
channel = None
num_parameters = 1
else:
mode = 'channel'
channel = output_shape[1]
num_parameters = output_shape[1]
data = node.data
self.params[prelu_name + '._weight'] = np.squeeze(data[0])
assert data is not None, "The parameter of {} (type is {}) is not set. You need to use python package of caffe to set the default value.".format(
......@@ -587,8 +585,7 @@ class CaffeOpMapper(OpMapper):
"paddle.nn.PReLU",
inputs={"input": self.get_input_name(input)},
outputs=layer_outputs,
channel=channel,
mode=string(mode))
num_parameters=num_parameters)
def Eltwise(self, node):
assert len(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册