未验证 提交 9fbe2e36 编写于 作者: 武毅 提交者: GitHub

Merge pull request #8676 from typhoonzero/fix_renamevar_again

Fix rename de-ref bug
......@@ -279,7 +279,6 @@ class DistributeTranspiler:
type=v.type,
dtype=v.dtype,
shape=v.shape)
print("create origin var: ", orig_var_name)
for trainer_id in xrange(self.trainers):
var = pserver_program.global_block().create_var(
name="%s.trainer_%d" % (orig_var_name, trainer_id),
......@@ -288,7 +287,6 @@ class DistributeTranspiler:
dtype=v.dtype,
shape=v.shape)
recv_inputs.append(var)
print("create per trainer var: ", var.name)
# step3
optimize_block = pserver_program.create_block(0)
# step 4
......
......@@ -773,7 +773,7 @@ class Block(object):
stop_gradient = v.stop_gradient
else:
raise ValueError("unsupported var type: %s", type(v))
orig_var_type = v.type
self.desc.rename_var(name, new_name)
# NOTE: v is destroyed by C++ after calling rename_var.
d = self.desc.find_var(new_name)
......@@ -782,6 +782,7 @@ class Block(object):
self,
d.shape(),
d.dtype(),
type=orig_var_type,
name=new_name,
stop_gradient=stop_gradient,
trainable=trainable,
......@@ -792,7 +793,7 @@ class Block(object):
elif var_type == "Variable":
var = Variable(
self,
type=v.type,
type=orig_var_type,
name=new_name,
error_clip=error_clip,
stop_gradient=stop_gradient)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册