distributed_spliter.py 1.3 KB
Newer Older
T
typhoonzero 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
def hash_name(varblocks, pserver_endpoints):
    """
    :param varblocks: a list of VarBlock string indicating 
                      sub blocks of variables
    :return: a map of pserver endpoint -> varblock_str
    """

    def _hash_block(block_str, total):
        return hash(block_str) % total

    ep2block = dict()
    for varblock_str in varblocks:
        if param.trainable is True and grad is not None:
            server_id = _hash_block(varblock_str, len(pserver_endpoints))
            server_for_param = pserver_endpoints[server_id]
            if not ep2block.has_key(server_for_param):
                ep2block[server_for_param] = []
            ep2block[server_for_param].append(varblock_str)

    return ep2block


def round_robin(varblocks, pserver_endpoints):
    assert (len(varblocks) > len(pserver_endpoints))

    ep2block = dict()
    pserver_idx = 0
    for varblock_str in varblocks:
        if param.trainable is True:
            server_for_param = pserver_endpoints[pserver_idx]
            if not ep2block.has_key(server_for_param):
                ep2block[server_for_param] = []
            ep2block[server_for_param].append(varblock_str)

            pserver_idx += 1
            if pserver_idx >= len(pserver_endpoints):
                pserver_idx = 0
    return ep2block