node.py 2.4 KB
Newer Older
D
dongdaxiang 已提交
1
import ps_pb2 as pslib
D
dongdaxiang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

class Server(object):
    def __init__(self):
        pass


class Worker(object):
    def __init__(self):
        pass


class DownpourServer(Server):
    def __init__(self):
D
dongdaxiang 已提交
15
        self.server_ = pslib.ServerParameter()
D
dongdaxiang 已提交
16 17

    def add_sparse_table(self, table_id, learning_rate,
18
                         slot_key_vars, slot_value_var):
D
dongdaxiang 已提交
19
        table = self.server_.downpour_server_param.downpour_table_param.add()
D
dongdaxiang 已提交
20
        table.table_id = table_id
21
        table.type = pslib.PS_SPARSE_TABLE
D
dongdaxiang 已提交
22 23
        table.accessor.accessor_class = "DownpourFeatureValueAccessor"
        table.accessor.dense_sgd_param.adam.learning_rate = learning_rate
24 25
        table.accessor.fea_dim = abs(reduce(lambda x, y: x * y, 
                                            slot_value_var[0].shape, 1))
D
dongdaxiang 已提交
26 27 28

    def add_dense_table(self, table_id, learning_rate, 
                        param_var, grad_var):
D
dongdaxiang 已提交
29
        table = self.server_.downpour_server_param.downpour_table_param.add()
D
dongdaxiang 已提交
30
        table.table_id = table_id
31
        table.type = pslib.PS_DENSE_TABLE
D
dongdaxiang 已提交
32 33
        table.accessor.accessor_class = "DownpourDenseValueAccessor"
        table.accessor.sparse_sgd_param.learning_rate = learning_rate
34 35 36 37
        fea_dim = 0
        for param in param_var:
            fea_dim += reduce(lambda x, y: x * y, param.shape, 1)
        table.accessor.fea_dim = fea_dim
D
dongdaxiang 已提交
38 39 40 41 42 43 44 45

    def get_desc(self):
        return self.server_


class DownpourWorker(Worker):
    def __init__(self, window):
        self.window = window
D
dongdaxiang 已提交
46
        self.worker_ = pslib.DownpourTrainerParameter()
D
dongdaxiang 已提交
47 48 49
        self.worker_.pull_dense_per_batch = window
        self.worker_.push_dense_per_batch = window

50 51 52
    def add_sparse_table(self, table_id, learning_rate,
                         slot_key_vars, slot_value_vars):
        table = self.worker_.sparse_table.add()
D
dongdaxiang 已提交
53
        table.table_id = table_id
54 55 56 57 58 59
        table.slot_key.extend(
            [var.name for var in slot_key_vars])
        table.slot_value.extend(
            [var.name for var in slot_value_vars])
        table.slot_gradient.extend(
            [var.name + "@GRAD" for var in slot_value_vars])
D
dongdaxiang 已提交
60

61 62 63
    def add_dense_table(self, table_id, learning_rate, 
                        param_vars, grad_vars):
        table = self.worker_.dense_table.add()
D
dongdaxiang 已提交
64 65 66 67 68 69
        table.table_id = table_id
        table.dense_variable_name.extend([p.name for p in param_vars])
        table.dense_gradient_variable_name.extend([g.name for g in grad_vars])

    def get_desc(self):
        return self.worker_