提交 92313a99 编写于 作者: T typhoonzero

update

上级 d02b17e5
...@@ -278,6 +278,12 @@ class DistributeTranspiler: ...@@ -278,6 +278,12 @@ class DistributeTranspiler:
# we don't need to create them when grad arrives. # we don't need to create them when grad arrives.
# change client side var name to origin name by # change client side var name to origin name by
# removing ".trainer_%d" suffix # removing ".trainer_%d" suffix
suff_idx = v.name.find(".trainer_")
if suff_idx >= 0:
orig_var_name = v.name[:suff_idx]
else:
orig_var_name = v.name
# NOTE: single_trainer_var must be created for multi-trainer # NOTE: single_trainer_var must be created for multi-trainer
# case to merge grads from multiple trainers # case to merge grads from multiple trainers
single_trainer_var = \ single_trainer_var = \
...@@ -287,11 +293,6 @@ class DistributeTranspiler: ...@@ -287,11 +293,6 @@ class DistributeTranspiler:
type=v.type, type=v.type,
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
suff_idx = v.name.find(".trainer_")
if suff_idx >= 0:
orig_var_name = v.name[:suff_idx]
else:
orig_var_name = v.name
if self.trainers > 1: if self.trainers > 1:
for trainer_id in xrange(self.trainers): for trainer_id in xrange(self.trainers):
var = pserver_program.global_block().create_var( var = pserver_program.global_block().create_var(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册