提交 2827607f 编写于 作者: Y yi.wu

fix startup program shape

上级 6fa56b9d
......@@ -365,7 +365,6 @@ class DistributeTranspiler:
else:
self._append_pserver_non_opt_ops(optimize_sub_program, opt_op)
print("####", optimize_sub_program)
pserver_program.global_block().append_op(
type="recv",
inputs={"RX": self.param_grad_ep_mapping[endpoint]["grads"]
......@@ -407,7 +406,6 @@ class DistributeTranspiler:
# 1. create vars
created_var_map = dict()
for var in params:
print("%%%% append var", var.name, var.shape)
tmpvar = s_prog.global_block().create_var(
name=var.name,
persistable=True,
......@@ -430,6 +428,8 @@ class DistributeTranspiler:
if var.name in created_var_map:
var_on_pserver = True
if var_on_pserver:
# gaussian_random use attr to determine tensor shape
op.attrs["shape"] = new_outputs["Out"].shape
s_prog.global_block().append_op(
type=op.type,
inputs=op.inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册