提交 d02b17e5 编写于 作者: T typhoonzero

fix dist transpiler bug

上级 ddff83ff
...@@ -278,6 +278,15 @@ class DistributeTranspiler: ...@@ -278,6 +278,15 @@ class DistributeTranspiler:
# we don't need to create them when grad arrives. # we don't need to create them when grad arrives.
# change client side var name to origin name by # change client side var name to origin name by
# removing ".trainer_%d" suffix # removing ".trainer_%d" suffix
# NOTE: single_trainer_var must be created for multi-trainer
# case to merge grads from multiple trainers
single_trainer_var = \
pserver_program.global_block().create_var(
name=orig_var_name,
persistable=True,
type=v.type,
dtype=v.dtype,
shape=v.shape)
suff_idx = v.name.find(".trainer_") suff_idx = v.name.find(".trainer_")
if suff_idx >= 0: if suff_idx >= 0:
orig_var_name = v.name[:suff_idx] orig_var_name = v.name[:suff_idx]
...@@ -293,12 +302,6 @@ class DistributeTranspiler: ...@@ -293,12 +302,6 @@ class DistributeTranspiler:
shape=v.shape) shape=v.shape)
recv_inputs.append(var) recv_inputs.append(var)
else: else:
single_trainer_var = pserver_program.global_block().create_var(
name=orig_var_name,
persistable=True,
type=v.type,
dtype=v.dtype,
shape=v.shape)
recv_inputs.append(single_trainer_var) recv_inputs.append(single_trainer_var)
# step3 # step3
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册