提交 4bd64781 编写于 作者: T tangwei12

fix distribued transpile when slice_var_up=False

上级 cebf7c60
...@@ -349,7 +349,10 @@ class Trainer(object): ...@@ -349,7 +349,10 @@ class Trainer(object):
with self._prog_and_scope_guard(): with self._prog_and_scope_guard():
t = distribute_transpiler.DistributeTranspiler() t = distribute_transpiler.DistributeTranspiler()
t.transpile( t.transpile(
self.trainer_id, pservers=pserver_endpoints, trainers=trainers) self.trainer_id,
pservers=pserver_endpoints,
trainers=trainers,
slice_var_up=False)
if training_role == "PSERVER": if training_role == "PSERVER":
if self.checkpoint_cfg: if self.checkpoint_cfg:
pserver_id = eplist.index(current_endpoint) pserver_id = eplist.index(current_endpoint)
......
...@@ -196,8 +196,6 @@ class DistributeTranspiler(object): ...@@ -196,8 +196,6 @@ class DistributeTranspiler(object):
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2 # fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
# shuffle the map will avoid the uneven distribution above # shuffle the map will avoid the uneven distribution above
grad_var_mapping_items = self.grad_var_mapping.items() grad_var_mapping_items = self.grad_var_mapping.items()
if not slice_var_up:
np.random.shuffle(grad_var_mapping_items)
for orig_varname, splited_vars in grad_var_mapping_items: for orig_varname, splited_vars in grad_var_mapping_items:
eplist = ps_dispatcher.dispatch(splited_vars) eplist = ps_dispatcher.dispatch(splited_vars)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册