未验证 提交 d5c2ceb3 编写于 作者: W WJJ1995 提交者: GitHub

fixed pytorch split op (#648)

上级 489ba078
......@@ -5130,7 +5130,7 @@ def aten_split(mapper, graph, node):
%160 (Tensor): 输出,分割后的矩阵。
%159 (Tensor): 需要分割的Tensor。
%135 (int): 分割的数量。
%723 (int): 轴。
%123 (int): 轴。
"""
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0]
......@@ -5155,7 +5155,25 @@ def aten_split(mapper, graph, node):
if "[]" in str(input_type):
layer_inputs["num_or_sections"] = inputs_name[1]
else:
layer_attrs["num_or_sections"] = mapper.attrs[inputs_name[1]] + 1
index = mapper.attrs[inputs_name[2]]
graph.add_layer(
"prim.shape",
inputs={"input": inputs_name[0]},
outputs=[inputs_name[0] + '_shape'],
scope_name=scope_name)
graph.add_layer(
"prim.getitem",
inputs={"list": inputs_name[0] + '_shape'},
outputs=[inputs_name[0] + '_dim'],
scope_name=scope_name,
index=index)
graph.add_layer(
"prim.floordiv",
inputs={'x': inputs_name[0] + '_dim',
'y': inputs_name[1]},
outputs=[inputs_name[1] + '_div'],
scope_name=scope_name)
layer_attrs["num_or_sections"] = inputs_name[1] + '_div'
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册