提交 6e83c003 编写于 作者: Y Yancey1989

Registry var type infer in split_selected_rows op

上级 e84615ba
...@@ -276,6 +276,7 @@ class DistributeTranspiler: ...@@ -276,6 +276,7 @@ class DistributeTranspiler:
pserver_program.global_block().create_var( pserver_program.global_block().create_var(
name=orig_var_name, name=orig_var_name,
persistable=True, persistable=True,
type=v.type,
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
print("create origin var: ", orig_var_name) print("create origin var: ", orig_var_name)
...@@ -283,6 +284,7 @@ class DistributeTranspiler: ...@@ -283,6 +284,7 @@ class DistributeTranspiler:
var = pserver_program.global_block().create_var( var = pserver_program.global_block().create_var(
name="%s.trainer_%d" % (orig_var_name, trainer_id), name="%s.trainer_%d" % (orig_var_name, trainer_id),
persistable=False, persistable=False,
type=v.type,
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
recv_inputs.append(var) recv_inputs.append(var)
...@@ -551,11 +553,12 @@ class DistributeTranspiler: ...@@ -551,11 +553,12 @@ class DistributeTranspiler:
type="sum", type="sum",
inputs={"X": vars2merge}, inputs={"X": vars2merge},
outputs={"Out": merged_var}) outputs={"Out": merged_var})
optimize_block.append_op( if not merged_var.type == core.VarDesc.VarType.SELECTED_ROWS:
type="scale", optimize_block.append_op(
inputs={"X": merged_var}, type="scale",
outputs={"Out": merged_var}, inputs={"X": merged_var},
attrs={"scale": 1.0 / float(self.trainers)}) outputs={"Out": merged_var},
attrs={"scale": 1.0 / float(self.trainers)})
new_inputs[key] = merged_var new_inputs[key] = merged_var
elif key == "Param": elif key == "Param":
# param is already created on global program # param is already created on global program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册