提交 7a6000a0 编写于 作者: T typhoonzero

follow comments

上级 c7444501
......@@ -385,7 +385,7 @@ class DistributeTranspiler:
# param is already created on global program
param_block = None
for p in self.param_grad_ep_mapping[endpoint]["params"]:
if same_or_split_var(p.name, opt_op.input(key)):
if same_or_split_var(p.name, opt_op.input(key)[0]):
param_block = p
break
if not param_block:
......@@ -403,7 +403,7 @@ class DistributeTranspiler:
continue
# update accumulator variable shape
param_shape = new_inputs["Param"].shape
var = program.global_block().vars[opt_op.input(key)]
var = program.global_block().vars[opt_op.input(key)[0]]
new_shape = self._get_optimizer_input_shape(opt_op.type, key,
var.shape, param_shape)
tmpvar = program.global_block().create_var(
......@@ -440,20 +440,18 @@ class DistributeTranspiler:
else:
varlist = [var]
for var in varlist:
# TODO(typhoonzero): will remove below line later.
program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
try:
if not pserver_program.global_block().vars.has_key(var.name):
pserver_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
except ValueError:
# create var if not created yet.
pass
outputs = self._get_output_map_from_op(self.program.global_block().vars,
opt_op)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册