提交 900f411d 编写于 作者: T typhoonzero

fix dist transpiler bug

上级 d380ad06
...@@ -225,7 +225,7 @@ class DistributeTranspiler: ...@@ -225,7 +225,7 @@ class DistributeTranspiler:
if len(splited_vars) <= 1: if len(splited_vars) <= 1:
continue continue
orig_var = program.global_block().vars[varname] orig_var = program.global_block().vars[varname]
if orig_var == core.VarDesc.VarType.SELECTED_ROWS: if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS:
height_sections = [] height_sections = []
for v in splited_vars: for v in splited_vars:
height_sections.append(v.shape[0]) height_sections.append(v.shape[0])
...@@ -234,7 +234,7 @@ class DistributeTranspiler: ...@@ -234,7 +234,7 @@ class DistributeTranspiler:
inputs={"X": orig_var}, inputs={"X": orig_var},
outputs={"Out": splited_vars}, outputs={"Out": splited_vars},
attrs={"height_sections": height_sections}) attrs={"height_sections": height_sections})
elif orig_var == core.VarDesc.VarType.LOD_TENSOR: elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR:
sections = [] sections = []
for v in splited_vars: for v in splited_vars:
sections.append(v.shape[0]) sections.append(v.shape[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册