提交 6153f98e 编写于 作者: S SunAhong1993

add pytorch aten_split_with_sizes

上级 edb4bde7
......@@ -3976,6 +3976,56 @@ def aten_softplus(mapper, graph, node):
return current_inputs, current_outputs
def aten_split_with_sizes(mapper, graph, node):
""" 构构造split的PaddleLayer。
TorchScript示例:
%1450 : Tensor[] = aten::split_with_sizes(%1446, %1750, %41)
参数含义:
%1450 (Tensor): 输出,split后的Tensor。
%1446 (Tensor): 需要获取split的Tensor。
%1750 (list): 子Tensor的数量列表。
%41 (int): 需要分割的维度。
"""
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%1446
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%1750
if inputs_name[1] in mapper.attrs:
layer_attrs["num_or_sections"] = mapper.attrs[inputs_name[1]]
else:
mapper._check_input(graph, inputs_node[1], inputs_name[1],
current_outputs, scope_name)
layer_inputs["num_or_sections"] = inputs_name[1]
current_inputs.append(inputs_name[1])
# 处理输入2,即%135
if inputs_name[2] in mapper.attrs:
layer_attrs["axis"] = mapper.attrs[inputs_name[2]]
else:
mapper._check_input(graph, inputs_node[2], inputs_name[2],
current_outputs, scope_name)
layer_inputs["axis"] = inputs_name[2]
current_inputs.append(inputs_name[2])
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
graph.add_layer(
"paddle.split",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs
def aten_sqrt(mapper, graph, node):
""" 构构造sqrt的PaddleLayer。
......
......@@ -137,7 +137,7 @@ class PyTorchOpMapper(OpMapper):
graph.outputs = inputs_name
# 更新split参数
for layer in graph.layers.values():
if layer.kernel == "paddle.split" and "num_or_sections" in layer.attrs:
if layer.kernel == "paddle.split" and "num_or_sections" in layer.attrs and len(set(layer.attrs["num_or_sections"])) == 1:
layer.attrs["num_or_sections"] = self.split_len[layer.outputs[
0]]
return graph, graph_inputs
......
......@@ -32,8 +32,8 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn",
"paddle.nn.Softmax": "softmax",
"paddle.nn.Softplus": "softplus",
"paddle.nn.Tanh": "tanh",
"paddle.nn.AvgPool2D": "pool",
"paddle.nn.MaxPool2D": "pool",
"paddle.nn.AvgPool2D": "avgpool",
"paddle.nn.MaxPool2D": "maxpool",
"paddle.nn.Pad1D": "pad",
"paddle.nn.Pad2D": "pad",
"paddle.nn.Pad3D": "pad",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册