From b39b9ccfa68b3ddb5dd572a1daabe866bceadf7a Mon Sep 17 00:00:00 2001 From: WJJ1995 Date: Mon, 9 May 2022 12:36:05 +0800 Subject: [PATCH] Fixed aten index (#787) * add replication pad * update op_list.md * re-lint * fixed aten::index * rm useless code --- x2paddle/op_mapper/pytorch2paddle/aten.py | 31 +++++++++-------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/x2paddle/op_mapper/pytorch2paddle/aten.py b/x2paddle/op_mapper/pytorch2paddle/aten.py index e9259de..0dbad61 100755 --- a/x2paddle/op_mapper/pytorch2paddle/aten.py +++ b/x2paddle/op_mapper/pytorch2paddle/aten.py @@ -2853,45 +2853,38 @@ def aten_hardswish(mapper, graph, node): def aten_index(mapper, graph, node): - """ 构造选择元素的PaddleLayer。 - TorchScript示例: + """ + TorchScript Code: %1681 : Float = aten::index(%1653, %1680) - 参数含义: - %1681 (Tensor): 输出,选择后的Tensor。 - %1653 (Tensor): 需要选择的Tensor。 - %1680 (int): 选择的索引。 + Parameter meaning: + %1681 (Tensor): Output Tensor + %1653 (Tensor): Input Tensor + %1680 (int): Index """ scope_name = mapper.normalize_scope_name(node) output_name = mapper._get_outputs_name(node)[0] layer_outputs = [output_name] layer_inputs = {} - layer_attrs = {} inputs_name, inputs_node = mapper._get_inputs_name(node) - # 获取当前节点输出的list + # output list current_outputs = [output_name] - # 处理输入0,即%1653 + # process Input Tensor mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) layer_inputs["x"] = inputs_name[0] - # 处理输入1,即%1680 + # process Index mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) layer_inputs["index"] = inputs_name[1] - # 获取当前节点输入的list + current_inputs = list(layer_inputs.values()) graph.add_layer( "prim.getitem", - inputs={"list": layer_inputs["index"]}, - outputs=[layer_inputs["index"]], - scope_name=scope_name, - index=0) - graph.add_layer( - "paddle.index_select", - inputs=layer_inputs, + inputs={"list": layer_inputs["x"]}, outputs=layer_outputs, scope_name=scope_name, - **layer_attrs) + index=layer_inputs["index"]) return current_inputs, current_outputs -- GitLab