From 11e6c207b40da9364a1d848cbf9552a9481cb2a9 Mon Sep 17 00:00:00 2001 From: yeliang2258 <30516196+yeliang2258@users.noreply.github.com> Date: Wed, 1 Sep 2021 22:16:13 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0aten::sum=20op=20(#669)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix directly_map, Eltwise and crop * add aten::sum op * delet useless info --- x2paddle/op_mapper/pytorch2paddle/aten.py | 39 +++++++++++++++++++++++ 1 file changed, 39 insertions(+) mode change 100644 => 100755 x2paddle/op_mapper/pytorch2paddle/aten.py diff --git a/x2paddle/op_mapper/pytorch2paddle/aten.py b/x2paddle/op_mapper/pytorch2paddle/aten.py old mode 100644 new mode 100755 index 84c0600..4da7ec5 --- a/x2paddle/op_mapper/pytorch2paddle/aten.py +++ b/x2paddle/op_mapper/pytorch2paddle/aten.py @@ -32,6 +32,45 @@ dtype_dict = { } +def aten_sum(mapper, graph, node): + """ 构造获取元素求和的paddlelayer。 + TorchScript示例: + %x_gap.15 : Tensor = aten::sum(%x.58, %2166, %1450, %1453) + 参数含义: + %x_gap.15 (Tensor): 求和后的Tensor。 + %n.3 (Tensor): 求和前的Tensor。 + %2166:axis + %1450:keepdim + %1453:dtype + """ + scope_name = mapper.normalize_scope_name(node) + output_name = mapper._get_outputs_name(node)[0] + layer_outputs = [output_name] + layer_inputs = {} + layer_attrs = {} + inputs_name, inputs_node = mapper._get_inputs_name(node) + # 获取当前节点输出的list + current_outputs = [output_name] + # 处理输入0,即%n.3 + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, + scope_name) + layer_inputs["x"] = inputs_name[0] + # 获取当前节点输入的list + current_inputs = list(layer_inputs.values()) + if inputs_name[1] in mapper.attrs: + layer_attrs["axis"] = mapper.attrs[inputs_name[1]] + if inputs_name[2] in mapper.attrs: + layer_attrs["keepdim"] = mapper.attrs[inputs_name[2]] + if inputs_name[3] in mapper.attrs: + layer_attrs["dtype"] = mapper.attrs[inputs_name[3]] + graph.add_layer( + "paddle.sum", + inputs=layer_inputs, + outputs=layer_outputs, + scope_name=scope_name, + **layer_attrs) + return current_inputs, current_outputs + def aten_abs(mapper, graph, node): """ 构造获取绝对值的PaddleLayer。 TorchScript示例: -- GitLab