提交 902bdefd 编写于 作者: W wjj19950828

fixed aten::index

上级 f5f9799c
......@@ -2806,13 +2806,13 @@ def aten_hardswish(mapper, graph, node):
def aten_index(mapper, graph, node):
""" 构造选择元素的PaddleLayer。
TorchScript示例:
"""
TorchScript Code:
%1681 : Float = aten::index(%1653, %1680)
参数含义:
%1681 (Tensor): 输出,选择后的Tensor。
%1653 (Tensor): 需要选择的Tensor。
%1680 (int): 选择的索引。
Parameter meaning:
%1681 (Tensor): Output Tensor
%1653 (Tensor): Input Tensor
%1680 (int): Index
"""
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0]
......@@ -2820,31 +2820,25 @@ def aten_index(mapper, graph, node):
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
# output list
current_outputs = [output_name]
# 处理输入0,即%1653
# process Input Tensor
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%1680
# process Index
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["index"] = inputs_name[1]
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
graph.add_layer(
"prim.getitem",
inputs={"list": layer_inputs["index"]},
outputs=[layer_inputs["index"]],
scope_name=scope_name,
index=0)
graph.add_layer(
"paddle.index_select",
inputs=layer_inputs,
inputs={"list": layer_inputs["x"]},
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
index=layer_inputs["index"])
return current_inputs, current_outputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册