提交 ca2844fa 编写于 作者: S SunAhong1993

add pytorch aten_split_with_sizes

上级 6153f98e
...@@ -137,7 +137,8 @@ class PyTorchOpMapper(OpMapper): ...@@ -137,7 +137,8 @@ class PyTorchOpMapper(OpMapper):
graph.outputs = inputs_name graph.outputs = inputs_name
# 更新split参数 # 更新split参数
for layer in graph.layers.values(): for layer in graph.layers.values():
if layer.kernel == "paddle.split" and "num_or_sections" in layer.attrs and len(set(layer.attrs["num_or_sections"])) == 1: if layer.kernel == "paddle.split" and "num_or_sections" in layer.attrs \
and len(set(layer.attrs["num_or_sections"])) == 1:
layer.attrs["num_or_sections"] = self.split_len[layer.outputs[ layer.attrs["num_or_sections"] = self.split_len[layer.outputs[
0]] 0]]
return graph, graph_inputs return graph, graph_inputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册