提交 4db8d282 编写于 作者: S SunAhong1993

fix

上级 ca2844fa
...@@ -138,7 +138,7 @@ class PyTorchOpMapper(OpMapper): ...@@ -138,7 +138,7 @@ class PyTorchOpMapper(OpMapper):
# 更新split参数 # 更新split参数
for layer in graph.layers.values(): for layer in graph.layers.values():
if layer.kernel == "paddle.split" and "num_or_sections" in layer.attrs \ if layer.kernel == "paddle.split" and "num_or_sections" in layer.attrs \
and len(set(layer.attrs["num_or_sections"])) == 1: and not isinstance(layer.attrs["num_or_sections"], int) and len(set(layer.attrs["num_or_sections"])) == 1:
layer.attrs["num_or_sections"] = self.split_len[layer.outputs[ layer.attrs["num_or_sections"] = self.split_len[layer.outputs[
0]] 0]]
return graph, graph_inputs return graph, graph_inputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册