distribute_planner.py 1.7 KB
Newer Older
T
typhoonzero 已提交
1 2 3 4 5 6 7
import framework
from backward import append_backward_ops
from regularizer import append_regularization_ops
import optimizer
from layer_helper import LayerHelper


T
update  
typhoonzero 已提交
8 9 10 11 12 13 14
def hash_name_to_server(params_grads, pserver_endpoints):
    """
    :param param_grads:
    :return: a map of pserver endpoint -> 
                    params -> [param list]
                    grads  -> [grad list]
    """
T
typhoonzero 已提交
15 16 17 18

    def _hash_param(param_name, total):
        return hash(param_name) % total

T
update  
typhoonzero 已提交
19 20 21
    param_grad_map = dict()
    for param, grad in params_grads:
        if param.trainable is True and grad is not None:
T
wip  
typhoonzero 已提交
22
            server_id = _hash_param(param.name, len(pserver_endpoints))
T
typhoonzero 已提交
23
            server_for_param = pserver_endpoints[server_id]
T
update  
typhoonzero 已提交
24 25 26 27
            if not param_grad_map.has_key(server_for_param):
                param_grad_map[server_for_param] = {"params": [], "grads": []}
            param_grad_map[server_for_param]["params"].append(param)
            param_grad_map[server_for_param]["grads"].append(grad)
T
typhoonzero 已提交
28

T
update  
typhoonzero 已提交
29
    return param_grad_map
T
typhoonzero 已提交
30 31


T
wip  
typhoonzero 已提交
32
def round_robin(parameters, pserver_endpoints):
T
typhoonzero 已提交
33
    assert (len(parameters) > len(pserver_endpoints))
T
typhoonzero 已提交
34

T
update  
typhoonzero 已提交
35
    param_grad_map = dict()
T
typhoonzero 已提交
36
    pserver_idx = 0
T
wip  
typhoonzero 已提交
37 38
    for param in parameters:
        if param.trainable is True:
T
typhoonzero 已提交
39
            server_for_param = pserver_endpoints[pserver_idx]
T
update  
typhoonzero 已提交
40 41 42 43 44
            if not param_grad_map.has_key(server_for_param):
                param_grad_map[server_for_param] = {"params": [], "grads": []}

            param_grad_map[server_for_param]["params"].append(param)
            param_grad_map[server_for_param]["grads"].append(param)
T
typhoonzero 已提交
45 46

            pserver_idx += 1
T
typhoonzero 已提交
47
            if pserver_idx >= len(pserver_endpoints):
T
typhoonzero 已提交
48
                pserver_idx = 0
T
update  
typhoonzero 已提交
49
    return param_grad_map