未验证 提交 9fbe2e36 编写于 作者: 武毅 提交者: GitHub

Merge pull request #8676 from typhoonzero/fix_renamevar_again

Fix rename de-ref bug
...@@ -279,7 +279,6 @@ class DistributeTranspiler: ...@@ -279,7 +279,6 @@ class DistributeTranspiler:
type=v.type, type=v.type,
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
print("create origin var: ", orig_var_name)
for trainer_id in xrange(self.trainers): for trainer_id in xrange(self.trainers):
var = pserver_program.global_block().create_var( var = pserver_program.global_block().create_var(
name="%s.trainer_%d" % (orig_var_name, trainer_id), name="%s.trainer_%d" % (orig_var_name, trainer_id),
...@@ -288,7 +287,6 @@ class DistributeTranspiler: ...@@ -288,7 +287,6 @@ class DistributeTranspiler:
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
recv_inputs.append(var) recv_inputs.append(var)
print("create per trainer var: ", var.name)
# step3 # step3
optimize_block = pserver_program.create_block(0) optimize_block = pserver_program.create_block(0)
# step 4 # step 4
......
...@@ -773,7 +773,7 @@ class Block(object): ...@@ -773,7 +773,7 @@ class Block(object):
stop_gradient = v.stop_gradient stop_gradient = v.stop_gradient
else: else:
raise ValueError("unsupported var type: %s", type(v)) raise ValueError("unsupported var type: %s", type(v))
orig_var_type = v.type
self.desc.rename_var(name, new_name) self.desc.rename_var(name, new_name)
# NOTE: v is destroyed by C++ after calling rename_var. # NOTE: v is destroyed by C++ after calling rename_var.
d = self.desc.find_var(new_name) d = self.desc.find_var(new_name)
...@@ -782,6 +782,7 @@ class Block(object): ...@@ -782,6 +782,7 @@ class Block(object):
self, self,
d.shape(), d.shape(),
d.dtype(), d.dtype(),
type=orig_var_type,
name=new_name, name=new_name,
stop_gradient=stop_gradient, stop_gradient=stop_gradient,
trainable=trainable, trainable=trainable,
...@@ -792,7 +793,7 @@ class Block(object): ...@@ -792,7 +793,7 @@ class Block(object):
elif var_type == "Variable": elif var_type == "Variable":
var = Variable( var = Variable(
self, self,
type=v.type, type=orig_var_type,
name=new_name, name=new_name,
error_clip=error_clip, error_clip=error_clip,
stop_gradient=stop_gradient) stop_gradient=stop_gradient)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册