提交 6ecd82a4 编写于 作者: W wjj19950828

fixed scale and crop op

上级 fef1d0bb
......@@ -346,7 +346,15 @@ def shape_argmax(layer, input_shape):
def shape_crop(layer, input_shape):
assert len(input_shape) == 2, "the number of crop's inputs must be 2"
return [input_shape[1]]
params = layer.crop_param
axis = params.axis
if axis < 0:
axis += len(input_shape[0])
if axis > 0:
crop_shape = input_shape[0][:axis] + input_shape[1][axis:]
else:
crop_shape = input_shape[1]
return [crop_shape]
def shape_flatten(layer, input_shape):
......
......@@ -768,7 +768,6 @@ class CaffeOpMapper():
node.data[1]).astype("float32")
params = node.layer.scale_param
axis = params.axis
inputs = []
if len(node.inputs) == 2:
input0 = self.graph.get_input_node(node, idx=0, copy=True)
input1 = self.graph.get_input_node(node, idx=1, copy=True)
......@@ -778,7 +777,7 @@ class CaffeOpMapper():
inputs_dict['x'] = input0_name
inputs_dict['y'] = input1_name
self.paddle_graph.add_layer(
"paddle.multiply",
"paddle.fluid.layers.elementwise_mul",
inputs=inputs_dict,
outputs=[node.layer_name + "_mul"],
axis=1)
......@@ -800,11 +799,17 @@ class CaffeOpMapper():
inputs=inputs_dict,
outputs=[node.layer_name + "_mul"])
else:
self.paddle_graph.add_layer(
"paddle.multiply",
inputs=inputs_dict,
outputs=[node.layer_name + "_mul"],
axis=axis)
if axis == -1 or axis == len(node.in_shapes[0]) - 1:
self.paddle_graph.add_layer(
"paddle.multiply",
inputs=inputs_dict,
outputs=[node.layer_name + "_mul"])
else:
self.paddle_graph.add_layer(
"paddle.fluid.layers.elementwise_mul",
inputs=inputs_dict,
outputs=[node.layer_name + "_mul"],
axis=axis)
self.paddle_graph.add_layer(
"self.create_parameter",
inputs={},
......@@ -933,11 +938,15 @@ class CaffeOpMapper():
) == len(offset), "invalid offset[%s] in crop layer" % (
str(offset))
offset_real = [0] * axis + offset
if axis > 0:
crop_shape = node.in_shapes[0][:axis] + node.in_shapes[1][axis:]
else:
crop_shape = node.in_shapes[1]
self.paddle_graph.add_layer(
"paddle.crop",
inputs={"x": input.name},
outputs=[node.layer_name],
shape=node.in_shapes[1],
shape=crop_shape,
offsets=list(offset_real))
def Flatten(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册