提交 c70ea1cc 编写于 作者: T typhoonzero

add splitter

上级 ed55f1b9
...@@ -80,7 +80,6 @@ class DistributeTranspiler: ...@@ -80,7 +80,6 @@ class DistributeTranspiler:
# step3 # step3
send_inputs = [] send_inputs = []
send_outputs = []
for _, splited in var2splited.iteritems(): for _, splited in var2splited.iteritems():
send_inputs.extend(splited) send_inputs.extend(splited)
send_outputs = self._create_vars_from_blocklist(program, param_blocks) send_outputs = self._create_vars_from_blocklist(program, param_blocks)
......
def hash_name(varblocks, pserver_endpoints):
"""
:param varblocks: a list of VarBlock string indicating
sub blocks of variables
:return: a map of pserver endpoint -> varblock_str
"""
def _hash_block(block_str, total):
return hash(block_str) % total
ep2block = dict()
for varblock_str in varblocks:
if param.trainable is True and grad is not None:
server_id = _hash_block(varblock_str, len(pserver_endpoints))
server_for_param = pserver_endpoints[server_id]
if not ep2block.has_key(server_for_param):
ep2block[server_for_param] = []
ep2block[server_for_param].append(varblock_str)
return ep2block
def round_robin(varblocks, pserver_endpoints):
assert (len(varblocks) > len(pserver_endpoints))
ep2block = dict()
pserver_idx = 0
for varblock_str in varblocks:
if param.trainable is True:
server_for_param = pserver_endpoints[pserver_idx]
if not ep2block.has_key(server_for_param):
ep2block[server_for_param] = []
ep2block[server_for_param].append(varblock_str)
pserver_idx += 1
if pserver_idx >= len(pserver_endpoints):
pserver_idx = 0
return ep2block
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册