未验证 提交 489ba078 编写于 作者: Y yeliang2258 提交者: GitHub

fix directly_map, Eltwise and crop (#663)

上级 80f23a85
......@@ -168,12 +168,21 @@ class CaffeOpMapper():
return False
def directly_map(self, node):
inputs = node.layer.input
assert len(inputs) == 1, 'directly_map error with multi inputs'
assert len(node.layer.bottom) == 1, 'directly_map error with multi inputs'
op_info = self.directly_map_ops[node.layer_type]
input = self.graph.get_input_node(node, 0)
paddle_op = op_info[0]
if paddle_op.startswith("paddle.nn"):
if paddle_op.startswith("paddle.nn.layer"):
op_name = paddle_op[16:].lower()
op_name = name_generator(op_name, self.nn_name2id)
output_name = node.name
layer_outputs = [op_name, output_name]
self.paddle_graph.add_layer(
kernel=paddle_op,
inputs={"x": input.name},
outputs=layer_outputs)
else:
if paddle_op.startswith("paddle.nn") and "layer" not in paddle_op:
op_name = paddle_op[10:].lower()
op_name = name_generator(op_name, self.nn_name2id)
output_name = node.name
......@@ -610,8 +619,29 @@ class CaffeOpMapper():
num_parameters=num_parameters)
def Eltwise(self, node):
if len(node.layer.bottom) == 3 and node.layer.eltwise_param.operation == 1:
inputs_dict = {}
input0 = self.graph.get_input_node(node, idx=0, copy=True)
input1 = self.graph.get_input_node(node, idx=1, copy=True)
input2 = self.graph.get_input_node(node, idx=2, copy=True)
input0_name = input0.name
input1_name = input1.name
input2_name = input2.name
inputs_dict['x'] = input0_name
inputs_dict['y'] = input1_name
self.paddle_graph.add_layer(
"paddle.add", inputs=inputs_dict,
outputs=[node.layer_name+"_1"])
inputs_dict = {}
inputs_dict['x'] = node.layer_name+"_1"
inputs_dict['y'] = input2_name
self.paddle_graph.add_layer(
"paddle.add", inputs=inputs_dict,
outputs=[node.layer_name])
return
assert len(
node.inputs) == 2, "The count of Eltwise node\'s input is not 2."
node.layer.bottom) == 2, "The count of Eltwise node\'s input is not 2."
params = node.layer.eltwise_param
mode = params.operation
inputs = []
......@@ -894,7 +924,9 @@ class CaffeOpMapper():
axis += len(input_shape)
offset_real = [0] * len(input_shape)
if hasattr(params, "offset") and len(params.offset) > 0:
offset = list(params.offset)
offset_origin = list(params.offset)
if len(offset_origin)==1 :
offset = offset_origin * (len(input_shape) - axis)
assert (len(input_shape) - axis
) == len(offset), "invalid offset[%s] in crop layer" % (
str(offset))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册