提交 d02b17e5 编写于 作者: T typhoonzero

fix dist transpiler bug

上级 ddff83ff
......@@ -278,6 +278,15 @@ class DistributeTranspiler:
# we don't need to create them when grad arrives.
# change client side var name to origin name by
# removing ".trainer_%d" suffix
# NOTE: single_trainer_var must be created for multi-trainer
# case to merge grads from multiple trainers
single_trainer_var = \
pserver_program.global_block().create_var(
name=orig_var_name,
persistable=True,
type=v.type,
dtype=v.dtype,
shape=v.shape)
suff_idx = v.name.find(".trainer_")
if suff_idx >= 0:
orig_var_name = v.name[:suff_idx]
......@@ -293,12 +302,6 @@ class DistributeTranspiler:
shape=v.shape)
recv_inputs.append(var)
else:
single_trainer_var = pserver_program.global_block().create_var(
name=orig_var_name,
persistable=True,
type=v.type,
dtype=v.dtype,
shape=v.shape)
recv_inputs.append(single_trainer_var)
# step3
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册