diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index e5314cf27295ee9e65ce645d5c9f234e5c5dcb13..4c90b4a8535aa5a10f1f8405abea327834c9fe1a 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -80,7 +80,6 @@ class DistributeTranspiler: # step3 send_inputs = [] - send_outputs = [] for _, splited in var2splited.iteritems(): send_inputs.extend(splited) send_outputs = self._create_vars_from_blocklist(program, param_blocks) diff --git a/python/paddle/v2/fluid/distributed_spliter.py b/python/paddle/v2/fluid/distributed_spliter.py new file mode 100644 index 0000000000000000000000000000000000000000..e7ba53390d48bb797a08e261e5615462cb9e6019 --- /dev/null +++ b/python/paddle/v2/fluid/distributed_spliter.py @@ -0,0 +1,38 @@ +def hash_name(varblocks, pserver_endpoints): + """ + :param varblocks: a list of VarBlock string indicating + sub blocks of variables + :return: a map of pserver endpoint -> varblock_str + """ + + def _hash_block(block_str, total): + return hash(block_str) % total + + ep2block = dict() + for varblock_str in varblocks: + if param.trainable is True and grad is not None: + server_id = _hash_block(varblock_str, len(pserver_endpoints)) + server_for_param = pserver_endpoints[server_id] + if not ep2block.has_key(server_for_param): + ep2block[server_for_param] = [] + ep2block[server_for_param].append(varblock_str) + + return ep2block + + +def round_robin(varblocks, pserver_endpoints): + assert (len(varblocks) > len(pserver_endpoints)) + + ep2block = dict() + pserver_idx = 0 + for varblock_str in varblocks: + if param.trainable is True: + server_for_param = pserver_endpoints[pserver_idx] + if not ep2block.has_key(server_for_param): + ep2block[server_for_param] = [] + ep2block[server_for_param].append(varblock_str) + + pserver_idx += 1 + if pserver_idx >= len(pserver_endpoints): + pserver_idx = 0 + return ep2block