From 6153f98e787305413ffe21c1453be483bd859f6e Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Tue, 2 Feb 2021 16:14:03 +0800 Subject: [PATCH] add pytorch aten_split_with_sizes --- .../op_mapper/dygraph/pytorch2paddle/aten.py | 50 +++++++++++++++++++ .../pytorch2paddle/pytorch_op_mapper.py | 2 +- .../layer_code_generator.py | 4 +- 3 files changed, 53 insertions(+), 3 deletions(-) diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py index 9e37289..d6c6493 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py @@ -3976,6 +3976,56 @@ def aten_softplus(mapper, graph, node): return current_inputs, current_outputs +def aten_split_with_sizes(mapper, graph, node): + """ 构构造split的PaddleLayer。 + + TorchScript示例: + %1450 : Tensor[] = aten::split_with_sizes(%1446, %1750, %41) + 参数含义: + %1450 (Tensor): 输出,split后的Tensor。 + %1446 (Tensor): 需要获取split的Tensor。 + %1750 (list): 子Tensor的数量列表。 + %41 (int): 需要分割的维度。 + """ + scope_name = mapper.normalize_scope_name(node) + output_name = mapper._get_outputs_name(node)[0] + layer_outputs = [output_name] + layer_inputs = {} + layer_attrs = {} + inputs_name, inputs_node = mapper._get_inputs_name(node) + # 获取当前节点输出的list + current_outputs = [output_name] + # 处理输入0,即%1446 + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) + layer_inputs["x"] = inputs_name[0] + # 处理输入1,即%1750 + if inputs_name[1] in mapper.attrs: + layer_attrs["num_or_sections"] = mapper.attrs[inputs_name[1]] + else: + mapper._check_input(graph, inputs_node[1], inputs_name[1], + current_outputs, scope_name) + layer_inputs["num_or_sections"] = inputs_name[1] + current_inputs.append(inputs_name[1]) + # 处理输入2,即%135 + if inputs_name[2] in mapper.attrs: + layer_attrs["axis"] = mapper.attrs[inputs_name[2]] + else: + mapper._check_input(graph, inputs_node[2], inputs_name[2], + current_outputs, scope_name) + layer_inputs["axis"] = inputs_name[2] + current_inputs.append(inputs_name[2]) + # 获取当前节点输入的list + current_inputs = list(layer_inputs.values()) + + graph.add_layer( + "paddle.split", + inputs=layer_inputs, + outputs=layer_outputs, + scope_name=scope_name, + **layer_attrs) + return current_inputs, current_outputs + + def aten_sqrt(mapper, graph, node): """ 构构造sqrt的PaddleLayer。 diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py index 5f43e33..6893f7b 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py @@ -137,7 +137,7 @@ class PyTorchOpMapper(OpMapper): graph.outputs = inputs_name # 更新split参数 for layer in graph.layers.values(): - if layer.kernel == "paddle.split" and "num_or_sections" in layer.attrs: + if layer.kernel == "paddle.split" and "num_or_sections" in layer.attrs and len(set(layer.attrs["num_or_sections"])) == 1: layer.attrs["num_or_sections"] = self.split_len[layer.outputs[ 0]] return graph, graph_inputs diff --git a/x2paddle/optimizer/pytorch_code_optimizer/layer_code_generator.py b/x2paddle/optimizer/pytorch_code_optimizer/layer_code_generator.py index dc89e7f..a4c368c 100644 --- a/x2paddle/optimizer/pytorch_code_optimizer/layer_code_generator.py +++ b/x2paddle/optimizer/pytorch_code_optimizer/layer_code_generator.py @@ -32,8 +32,8 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn", "paddle.nn.Softmax": "softmax", "paddle.nn.Softplus": "softplus", "paddle.nn.Tanh": "tanh", - "paddle.nn.AvgPool2D": "pool", - "paddle.nn.MaxPool2D": "pool", + "paddle.nn.AvgPool2D": "avgpool", + "paddle.nn.MaxPool2D": "maxpool", "paddle.nn.Pad1D": "pad", "paddle.nn.Pad2D": "pad", "paddle.nn.Pad3D": "pad", -- GitLab