提交 39534800 编写于 作者: W wjj19950828

deal with comments

上级 6ecd82a4
......@@ -768,6 +768,8 @@ class CaffeOpMapper():
node.data[1]).astype("float32")
params = node.layer.scale_param
axis = params.axis
if axis < 0:
axis += len(node.in_shapes[0])
if len(node.inputs) == 2:
input0 = self.graph.get_input_node(node, idx=0, copy=True)
input1 = self.graph.get_input_node(node, idx=1, copy=True)
......@@ -776,11 +778,6 @@ class CaffeOpMapper():
inputs_dict = {}
inputs_dict['x'] = input0_name
inputs_dict['y'] = input1_name
self.paddle_graph.add_layer(
"paddle.fluid.layers.elementwise_mul",
inputs=inputs_dict,
outputs=[node.layer_name + "_mul"],
axis=1)
else:
self.paddle_graph.add_layer(
"self.create_parameter",
......@@ -793,23 +790,17 @@ class CaffeOpMapper():
inputs_dict = {}
inputs_dict['x'] = input0_name
inputs_dict['y'] = node.layer_name + "_cparam1"
if len(node.in_shapes[0]) == 2:
self.paddle_graph.add_layer(
"paddle.multiply",
inputs=inputs_dict,
outputs=[node.layer_name + "_mul"])
else:
if axis == -1 or axis == len(node.in_shapes[0]) - 1:
self.paddle_graph.add_layer(
"paddle.multiply",
inputs=inputs_dict,
outputs=[node.layer_name + "_mul"])
else:
self.paddle_graph.add_layer(
"paddle.fluid.layers.elementwise_mul",
inputs=inputs_dict,
outputs=[node.layer_name + "_mul"],
axis=axis)
if axis == len(node.in_shapes[0]) - 1:
self.paddle_graph.add_layer(
"paddle.multiply",
inputs=inputs_dict,
outputs=[node.layer_name + "_mul"])
else:
self.paddle_graph.add_layer(
"paddle.fluid.layers.elementwise_mul",
inputs=inputs_dict,
outputs=[node.layer_name + "_mul"],
axis=axis)
self.paddle_graph.add_layer(
"self.create_parameter",
inputs={},
......@@ -820,12 +811,10 @@ class CaffeOpMapper():
inputs_dict['x'] = node.layer_name + "_mul"
inputs_dict['y'] = node.layer_name + "_cparam2"
output_shape = node.out_shapes[0]
if axis == -1:
if axis == len(output_shape) - 1:
self.paddle_graph.add_layer(
"paddle.add", inputs=inputs_dict, outputs=[node.layer_name])
else:
if axis < 0:
axis = axis + len(output_shape)
param2_shape = self.params[node.layer_name + "_cparam2"].shape
param2_shape_len = len(param2_shape)
diff_len = len(output_shape) - axis - param2_shape_len
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册