提交 82d92498 编写于 作者: T typhoonzero

update dist transpiler

上级 dc7073de
......@@ -221,7 +221,7 @@ class DistributeTranspiler:
if len(splited_vars) <= 1:
continue
orig_var = program.global_block().vars[varname]
if orig_var == core.VarDesc.VarType.SELECTED_ROWS:
if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS:
height_sections = []
for v in splited_vars:
height_sections.append(v.shape[0])
......@@ -230,7 +230,7 @@ class DistributeTranspiler:
inputs={"X": orig_var},
outputs={"Out": splited_vars},
attrs={"height_sections": height_sections})
elif orig_var == core.VarDesc.VarType.LOD_TENSOR:
elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR:
sections = []
for v in splited_vars:
sections.append(v.shape[0])
......@@ -470,8 +470,7 @@ class DistributeTranspiler:
# Append the recv op
pserver_program.global_block().append_op(
type="recv",
inputs={"RX": self.param_grad_ep_mapping[endpoint]["grads"]
}, # grads to recv
inputs={},
outputs={},
attrs={
"OptimizeBlock": optimize_sub_program.global_block(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册