未验证 提交 f02a4da6 编写于 作者: T tangwei12 提交者: GitHub

Merge pull request #12152 from seiriosPlus/dis_acc_fix

slice_var_up=False fix
...@@ -31,6 +31,7 @@ Steps to transpile pserver: ...@@ -31,6 +31,7 @@ Steps to transpile pserver:
from __future__ import print_function from __future__ import print_function
import math import math
import random
import numpy as np import numpy as np
from ps_dispatcher import RoundRobin, HashName, PSDispatcher from ps_dispatcher import RoundRobin, HashName, PSDispatcher
...@@ -197,7 +198,8 @@ class DistributeTranspiler(object): ...@@ -197,7 +198,8 @@ class DistributeTranspiler(object):
# shuffle the map will avoid the uneven distribution above # shuffle the map will avoid the uneven distribution above
grad_var_mapping_items = self.grad_var_mapping.items() grad_var_mapping_items = self.grad_var_mapping.items()
if not slice_var_up: if not slice_var_up:
np.random.shuffle(grad_var_mapping_items) random.seed(self.trainer_num)
random.shuffle(grad_var_mapping_items)
for orig_varname, splited_vars in grad_var_mapping_items: for orig_varname, splited_vars in grad_var_mapping_items:
eplist = ps_dispatcher.dispatch(splited_vars) eplist = ps_dispatcher.dispatch(splited_vars)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册