提交 39534800 编写于 作者: W wjj19950828

deal with comments

上级 6ecd82a4
...@@ -768,6 +768,8 @@ class CaffeOpMapper(): ...@@ -768,6 +768,8 @@ class CaffeOpMapper():
node.data[1]).astype("float32") node.data[1]).astype("float32")
params = node.layer.scale_param params = node.layer.scale_param
axis = params.axis axis = params.axis
if axis < 0:
axis += len(node.in_shapes[0])
if len(node.inputs) == 2: if len(node.inputs) == 2:
input0 = self.graph.get_input_node(node, idx=0, copy=True) input0 = self.graph.get_input_node(node, idx=0, copy=True)
input1 = self.graph.get_input_node(node, idx=1, copy=True) input1 = self.graph.get_input_node(node, idx=1, copy=True)
...@@ -776,11 +778,6 @@ class CaffeOpMapper(): ...@@ -776,11 +778,6 @@ class CaffeOpMapper():
inputs_dict = {} inputs_dict = {}
inputs_dict['x'] = input0_name inputs_dict['x'] = input0_name
inputs_dict['y'] = input1_name inputs_dict['y'] = input1_name
self.paddle_graph.add_layer(
"paddle.fluid.layers.elementwise_mul",
inputs=inputs_dict,
outputs=[node.layer_name + "_mul"],
axis=1)
else: else:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"self.create_parameter", "self.create_parameter",
...@@ -793,13 +790,7 @@ class CaffeOpMapper(): ...@@ -793,13 +790,7 @@ class CaffeOpMapper():
inputs_dict = {} inputs_dict = {}
inputs_dict['x'] = input0_name inputs_dict['x'] = input0_name
inputs_dict['y'] = node.layer_name + "_cparam1" inputs_dict['y'] = node.layer_name + "_cparam1"
if len(node.in_shapes[0]) == 2: if axis == len(node.in_shapes[0]) - 1:
self.paddle_graph.add_layer(
"paddle.multiply",
inputs=inputs_dict,
outputs=[node.layer_name + "_mul"])
else:
if axis == -1 or axis == len(node.in_shapes[0]) - 1:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.multiply", "paddle.multiply",
inputs=inputs_dict, inputs=inputs_dict,
...@@ -820,12 +811,10 @@ class CaffeOpMapper(): ...@@ -820,12 +811,10 @@ class CaffeOpMapper():
inputs_dict['x'] = node.layer_name + "_mul" inputs_dict['x'] = node.layer_name + "_mul"
inputs_dict['y'] = node.layer_name + "_cparam2" inputs_dict['y'] = node.layer_name + "_cparam2"
output_shape = node.out_shapes[0] output_shape = node.out_shapes[0]
if axis == -1: if axis == len(output_shape) - 1:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.add", inputs=inputs_dict, outputs=[node.layer_name]) "paddle.add", inputs=inputs_dict, outputs=[node.layer_name])
else: else:
if axis < 0:
axis = axis + len(output_shape)
param2_shape = self.params[node.layer_name + "_cparam2"].shape param2_shape = self.params[node.layer_name + "_cparam2"].shape
param2_shape_len = len(param2_shape) param2_shape_len = len(param2_shape)
diff_len = len(output_shape) - axis - param2_shape_len diff_len = len(output_shape) - axis - param2_shape_len
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册