node.py 2.8 KB
Newer Older
D
dongdaxiang 已提交
1
import ps_pb2 as pslib
D
dongdaxiang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

class Server(object):
    def __init__(self):
        pass


class Worker(object):
    def __init__(self):
        pass


class DownpourServer(Server):
    def __init__(self):
D
dongdaxiang 已提交
15 16
        #self.server_ = pslib.ServerParameter().downpour_server_param
        self.server_ = pslib.ServerParameter()
D
dongdaxiang 已提交
17 18 19

    def add_sparse_table(self, table_id, learning_rate,
                         slot_key, slot_value_var, slot_grad_var):
D
dongdaxiang 已提交
20 21
        #table = self.server_.downpour_table_param.add()
        table = self.server_.downpour_server_param.downpour_table_param.add()
D
dongdaxiang 已提交
22 23 24 25 26 27 28 29
        table.table_id = table_id
        table.type = PS_SPARSE_TABLE
        table.accessor.accessor_class = "DownpourFeatureValueAccessor"
        table.accessor.dense_sgd_param.adam.learning_rate = learning_rate
        table.accessor.fea_dim = slot_value_var[0].shape[1]

    def add_dense_table(self, table_id, learning_rate, 
                        param_var, grad_var):
D
dongdaxiang 已提交
30 31
        #table = self.server_.downpour_table_param.add()
        table = self.server_.downpour_server_param.downpour_table_param.add()
D
dongdaxiang 已提交
32 33 34 35
        table.table_id = table_id
        table.type = PS_DENSE_TABLE
        table.accessor.accessor_class = "DownpourDenseValueAccessor"
        table.accessor.sparse_sgd_param.learning_rate = learning_rate
D
dongdaxiang 已提交
36 37
        table.accessor.fea_dim = 1
        #table.accessor.fea_dim = reduce(lambda x, y: x.shape, 1 for x in param_var)
D
dongdaxiang 已提交
38 39 40 41 42 43 44 45

    def get_desc(self):
        return self.server_


class DownpourWorker(Worker):
    def __init__(self, window):
        self.window = window
D
dongdaxiang 已提交
46 47 48 49 50 51 52
        #self.worker_ = pslib.WorkerParameter().downpour_worker_param
        #self.worker_ = pslib.WorkerParameter()
        self.worker_ = pslib.DownpourTrainerParameter()
        #self.worker_.pull_dense_per_batch = window
        #self.worker_.push_dense_per_batch = window
        #self.worker_.downpour_worker_param.pull_dense_per_batch = window
        #self.worker_.downpour_worker_param.push_dense_per_batch = window
D
dongdaxiang 已提交
53 54
        self.worker_.pull_dense_per_batch = window
        self.worker_.push_dense_per_batch = window
D
dongdaxiang 已提交
55
        print(self.worker_)
D
dongdaxiang 已提交
56 57 58

    def add_sparse_table(self, table_id, 
                         slot_keys, slot_value_vars, slot_grad_vars):
D
dongdaxiang 已提交
59 60
        #table = self.worker_.sparse_table.add()
        table = self.worker_.downpour_worker_param.sparse_table.add()
D
dongdaxiang 已提交
61 62 63 64 65
        table.table_id = table_id
        table.slot.extend(slot_keys)
        self.worker_.extend([grad.name for grad in slot_grad_vars])

    def add_dense_table(self, table_id, param_vars, grad_vars):
D
dongdaxiang 已提交
66 67
        #table = self.worker_.dense_table.add()
        table = self.worker_.downpour_worker_param.dense_table.add()
D
dongdaxiang 已提交
68 69 70 71 72 73
        table.table_id = table_id
        table.dense_variable_name.extend([p.name for p in param_vars])
        table.dense_gradient_variable_name.extend([g.name for g in grad_vars])

    def get_desc(self):
        return self.worker_