提交 ed89b7b7 编写于 作者: T typhoonzero

dist train use split_by_ref

上级 0c6eef3e
......@@ -824,7 +824,7 @@ class DistributeTranspiler:
for v in splited_vars:
sections.append(v.shape[0])
program.global_block().append_op(
type="split",
type="split_byref",
inputs={"X": orig_var},
outputs={"Out": splited_vars},
attrs={"sections": sections} # assume split evenly
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册