提交 7a6000a0 编写于 作者: T typhoonzero

follow comments

上级 c7444501
...@@ -385,7 +385,7 @@ class DistributeTranspiler: ...@@ -385,7 +385,7 @@ class DistributeTranspiler:
# param is already created on global program # param is already created on global program
param_block = None param_block = None
for p in self.param_grad_ep_mapping[endpoint]["params"]: for p in self.param_grad_ep_mapping[endpoint]["params"]:
if same_or_split_var(p.name, opt_op.input(key)): if same_or_split_var(p.name, opt_op.input(key)[0]):
param_block = p param_block = p
break break
if not param_block: if not param_block:
...@@ -403,7 +403,7 @@ class DistributeTranspiler: ...@@ -403,7 +403,7 @@ class DistributeTranspiler:
continue continue
# update accumulator variable shape # update accumulator variable shape
param_shape = new_inputs["Param"].shape param_shape = new_inputs["Param"].shape
var = program.global_block().vars[opt_op.input(key)] var = program.global_block().vars[opt_op.input(key)[0]]
new_shape = self._get_optimizer_input_shape(opt_op.type, key, new_shape = self._get_optimizer_input_shape(opt_op.type, key,
var.shape, param_shape) var.shape, param_shape)
tmpvar = program.global_block().create_var( tmpvar = program.global_block().create_var(
...@@ -440,20 +440,18 @@ class DistributeTranspiler: ...@@ -440,20 +440,18 @@ class DistributeTranspiler:
else: else:
varlist = [var] varlist = [var]
for var in varlist: for var in varlist:
# TODO(typhoonzero): will remove below line later.
program.global_block().create_var( program.global_block().create_var(
name=var.name, name=var.name,
persistable=var.persistable, persistable=var.persistable,
dtype=var.dtype, dtype=var.dtype,
shape=var.shape) shape=var.shape)
try: if not pserver_program.global_block().vars.has_key(var.name):
pserver_program.global_block().create_var( pserver_program.global_block().create_var(
name=var.name, name=var.name,
persistable=var.persistable, persistable=var.persistable,
dtype=var.dtype, dtype=var.dtype,
shape=var.shape) shape=var.shape)
except ValueError:
# create var if not created yet.
pass
outputs = self._get_output_map_from_op(self.program.global_block().vars, outputs = self._get_output_map_from_op(self.program.global_block().vars,
opt_op) opt_op)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册