From 902bdefdbf7e3573fe2f77116e432361d7e017e7 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Sat, 7 May 2022 15:51:03 +0800 Subject: [PATCH] fixed aten::index --- x2paddle/op_mapper/pytorch2paddle/aten.py | 30 +++++++++-------------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/x2paddle/op_mapper/pytorch2paddle/aten.py b/x2paddle/op_mapper/pytorch2paddle/aten.py index dce8581..1b7c5e7 100755 --- a/x2paddle/op_mapper/pytorch2paddle/aten.py +++ b/x2paddle/op_mapper/pytorch2paddle/aten.py @@ -2806,13 +2806,13 @@ def aten_hardswish(mapper, graph, node): def aten_index(mapper, graph, node): - """ 构造选择元素的PaddleLayer。 - TorchScript示例: + """ + TorchScript Code: %1681 : Float = aten::index(%1653, %1680) - 参数含义: - %1681 (Tensor): 输出,选择后的Tensor。 - %1653 (Tensor): 需要选择的Tensor。 - %1680 (int): 选择的索引。 + Parameter meaning: + %1681 (Tensor): Output Tensor + %1653 (Tensor): Input Tensor + %1680 (int): Index """ scope_name = mapper.normalize_scope_name(node) output_name = mapper._get_outputs_name(node)[0] @@ -2820,31 +2820,25 @@ def aten_index(mapper, graph, node): layer_inputs = {} layer_attrs = {} inputs_name, inputs_node = mapper._get_inputs_name(node) - # 获取当前节点输出的list + # output list current_outputs = [output_name] - # 处理输入0,即%1653 + # process Input Tensor mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) layer_inputs["x"] = inputs_name[0] - # 处理输入1,即%1680 + # process Index mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) layer_inputs["index"] = inputs_name[1] - # 获取当前节点输入的list + current_inputs = list(layer_inputs.values()) graph.add_layer( "prim.getitem", - inputs={"list": layer_inputs["index"]}, - outputs=[layer_inputs["index"]], - scope_name=scope_name, - index=0) - graph.add_layer( - "paddle.index_select", - inputs=layer_inputs, + inputs={"list": layer_inputs["x"]}, outputs=layer_outputs, scope_name=scope_name, - **layer_attrs) + index=layer_inputs["index"]) return current_inputs, current_outputs -- GitLab