提交 5faebab3 编写于 作者: T typhoonzero

Done, need support selectedrows

上级 5d901d00
...@@ -148,7 +148,7 @@ class DistributeTranspiler: ...@@ -148,7 +148,7 @@ class DistributeTranspiler:
concat = program.global_block().append_op( concat = program.global_block().append_op(
type="concat", type="concat",
inputs={"X": splited_var}, inputs={"X": splited_var},
outputs={"Out": orig_param}, outputs={"Out": [orig_param]},
attrs={"axis": 0}) attrs={"axis": 0})
def _create_vars_from_blocklist(self, program, block_list): def _create_vars_from_blocklist(self, program, block_list):
...@@ -420,7 +420,6 @@ class DistributeTranspiler: ...@@ -420,7 +420,6 @@ class DistributeTranspiler:
else: else:
self._append_pserver_non_opt_ops(optimize_sub_program, self._append_pserver_non_opt_ops(optimize_sub_program,
pserver_program, opt_op) pserver_program, opt_op)
print("****subprogram", optimize_sub_program)
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
type="recv", type="recv",
inputs={"RX": self.param_grad_ep_mapping[endpoint]["grads"] inputs={"RX": self.param_grad_ep_mapping[endpoint]["grads"]
...@@ -463,7 +462,6 @@ class DistributeTranspiler: ...@@ -463,7 +462,6 @@ class DistributeTranspiler:
pserver_vars = pserver_program.global_block().vars pserver_vars = pserver_program.global_block().vars
created_var_map = dict() created_var_map = dict()
for _, var in pserver_vars.iteritems(): for _, var in pserver_vars.iteritems():
print("create var for startup", var.name, var.shape)
tmpvar = s_prog.global_block().create_var( tmpvar = s_prog.global_block().create_var(
name=var.name, name=var.name,
persistable=var.persistable, persistable=var.persistable,
...@@ -485,21 +483,11 @@ class DistributeTranspiler: ...@@ -485,21 +483,11 @@ class DistributeTranspiler:
op_on_pserver = True op_on_pserver = True
new_outputs[key] = pserver_vars[var.name] new_outputs[key] = pserver_vars[var.name]
# newname, _ = _get_splited_name_and_shape(var.name)
# if newname:
# print("updating output", newname, created_var_map[newname])
# new_outputs[key] = created_var_map[newname]
# else:
# print("no update output", key, var)
# new_outputs[key] = var
# if var.name in created_var_map or \
# newname:
# op_on_pserver = True
if op_on_pserver: if op_on_pserver:
if op.type in ["gaussian_random", "fill_constant"]: if op.type in [
"gaussian_random", "fill_constant", "uniform_random"
]:
op.attrs["shape"] = new_outputs["Out"].shape op.attrs["shape"] = new_outputs["Out"].shape
print("updated shape", op.attrs["shape"])
s_prog.global_block().append_op( s_prog.global_block().append_op(
type=op.type, type=op.type,
inputs=op.inputs, inputs=op.inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册