提交 82d92498 编写于 作者: T typhoonzero

update dist transpiler

上级 dc7073de
...@@ -221,7 +221,7 @@ class DistributeTranspiler: ...@@ -221,7 +221,7 @@ class DistributeTranspiler:
if len(splited_vars) <= 1: if len(splited_vars) <= 1:
continue continue
orig_var = program.global_block().vars[varname] orig_var = program.global_block().vars[varname]
if orig_var == core.VarDesc.VarType.SELECTED_ROWS: if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS:
height_sections = [] height_sections = []
for v in splited_vars: for v in splited_vars:
height_sections.append(v.shape[0]) height_sections.append(v.shape[0])
...@@ -230,7 +230,7 @@ class DistributeTranspiler: ...@@ -230,7 +230,7 @@ class DistributeTranspiler:
inputs={"X": orig_var}, inputs={"X": orig_var},
outputs={"Out": splited_vars}, outputs={"Out": splited_vars},
attrs={"height_sections": height_sections}) attrs={"height_sections": height_sections})
elif orig_var == core.VarDesc.VarType.LOD_TENSOR: elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR:
sections = [] sections = []
for v in splited_vars: for v in splited_vars:
sections.append(v.shape[0]) sections.append(v.shape[0])
...@@ -470,8 +470,7 @@ class DistributeTranspiler: ...@@ -470,8 +470,7 @@ class DistributeTranspiler:
# Append the recv op # Append the recv op
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
type="recv", type="recv",
inputs={"RX": self.param_grad_ep_mapping[endpoint]["grads"] inputs={},
}, # grads to recv
outputs={}, outputs={},
attrs={ attrs={
"OptimizeBlock": optimize_sub_program.global_block(), "OptimizeBlock": optimize_sub_program.global_block(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册