diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py index 6893f7b1a07006860a05454469d72d9b6ae46d9e..118d998992ec49176e40f5f4b74b20a0a6abd9fe 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py @@ -137,7 +137,8 @@ class PyTorchOpMapper(OpMapper): graph.outputs = inputs_name # 更新split参数 for layer in graph.layers.values(): - if layer.kernel == "paddle.split" and "num_or_sections" in layer.attrs and len(set(layer.attrs["num_or_sections"])) == 1: + if layer.kernel == "paddle.split" and "num_or_sections" in layer.attrs \ + and len(set(layer.attrs["num_or_sections"])) == 1: layer.attrs["num_or_sections"] = self.split_len[layer.outputs[ 0]] return graph, graph_inputs