未验证 提交 b9d8bbe4 编写于 作者: 武毅 提交者: GitHub

Merge pull request #9558 from typhoonzero/fix_dist_transpile_one_trainer

fix single pserver single trainer error
...@@ -276,20 +276,25 @@ class DistributeTranspiler: ...@@ -276,20 +276,25 @@ class DistributeTranspiler:
suff_idx = v.name.find(".trainer_") suff_idx = v.name.find(".trainer_")
if suff_idx >= 0: if suff_idx >= 0:
orig_var_name = v.name[:suff_idx] orig_var_name = v.name[:suff_idx]
pserver_program.global_block().create_var( else:
orig_var_name = v.name
single_trainer_var = pserver_program.global_block().create_var(
name=orig_var_name, name=orig_var_name,
persistable=True, persistable=True,
type=v.type, type=v.type,
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
for trainer_id in xrange(self.trainers): if self.trainers > 1:
var = pserver_program.global_block().create_var( for trainer_id in xrange(self.trainers):
name="%s.trainer_%d" % (orig_var_name, trainer_id), var = pserver_program.global_block().create_var(
persistable=False, name="%s.trainer_%d" % (orig_var_name, trainer_id),
type=v.type, persistable=False,
dtype=v.dtype, type=v.type,
shape=v.shape) dtype=v.dtype,
recv_inputs.append(var) shape=v.shape)
recv_inputs.append(var)
else:
recv_inputs.append(single_trainer_var)
# step3 # step3
optimize_block = pserver_program.create_block(0) optimize_block = pserver_program.create_block(0)
...@@ -511,8 +516,11 @@ class DistributeTranspiler: ...@@ -511,8 +516,11 @@ class DistributeTranspiler:
def _append_split_op(self, program, gradblocks): def _append_split_op(self, program, gradblocks):
# Split variables that need to be split and append respective ops # Split variables that need to be split and append respective ops
add_suffix = False
if self.trainers > 1:
add_suffix = True
var_mapping = self._create_vars_from_blocklist( var_mapping = self._create_vars_from_blocklist(
program, gradblocks, add_trainer_suffix=True) program, gradblocks, add_trainer_suffix=add_suffix)
for varname, splited_vars in var_mapping.iteritems(): for varname, splited_vars in var_mapping.iteritems():
# variable that don't need to split have empty splited_vars # variable that don't need to split have empty splited_vars
if len(splited_vars) <= 1: if len(splited_vars) <= 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册