device_worker.py 4.3 KB
Newer Older
1
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
D
dongdaxiang 已提交
14
import sys
15

16 17
__all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD']

18 19 20

class DeviceWorker(object):
    def __init__(self):
D
dongdaxiang 已提交
21 22 23 24 25 26 27
        self.program_ = None

    def set_fleet_desc(self, fleet_desc):
        self.fleet_desc_ = fleet_desc

    def set_program(self, program):
        self.program_ = program
28

29
    def gen_worker_desc(self, trainer_desc):
30 31 32 33 34 35 36
        pass


class Hogwild(DeviceWorker):
    def __init__(self):
        super(Hogwild, self).__init__()

37
    def gen_worker_desc(self, trainer_desc):
38 39 40
        trainer_desc.device_worker_name = "HogwildWorker"


D
dongdaxiang 已提交
41
class DownpourSGD(DeviceWorker):
42
    def __init__(self):
D
dongdaxiang 已提交
43
        super(DownpourSGD, self).__init__()
44

45
    def gen_worker_desc(self, trainer_desc):
46 47 48 49 50
        trainer_desc.device_worker_name = "DownpourWorker"
        pull_thread = trainer_desc.pull_dense_param
        pull_thread.device_num = trainer_desc.thread_num
        dense_table = pull_thread.dense_table.add()
        dense_table.dense_value_name.extend(
D
dongdaxiang 已提交
51
            self.fleet_desc_.trainer_param.dense_table[0].dense_variable_name)
52
        dense_table.table_id = \
D
dongdaxiang 已提交
53
            self.fleet_desc_.trainer_param.dense_table[0].table_id
54 55 56
        downpour = trainer_desc.downpour_param
        sparse_table = downpour.sparse_table.add()
        sparse_table.table_id = \
D
dongdaxiang 已提交
57
                    self.fleet_desc_.trainer_param.sparse_table[0].table_id
58
        sparse_table.sparse_key_name.extend(
D
dongdaxiang 已提交
59
            self.fleet_desc_.trainer_param.sparse_table[0].slot_key)
60
        sparse_table.sparse_value_name.extend(
D
dongdaxiang 已提交
61
            self.fleet_desc_.trainer_param.sparse_table[0].slot_value)
62
        sparse_table.sparse_grad_name.extend(
D
dongdaxiang 已提交
63 64 65 66
            self.fleet_desc_.trainer_param.sparse_table[0].slot_gradient)
        sparse_table.emb_dim = \
                    self.fleet_desc_.server_param.downpour_server_param.downpour_table_param[
                        0].accessor.fea_dim - 2
67
        sparse_table.fea_dim = sparse_table.emb_dim + 2
D
dongdaxiang 已提交
68
        # TODO(guru4elephant): hard code here, need to improve
69 70
        sparse_table.label_var_name = "click"

X
fix bug  
xjqbest 已提交
71
        dense_table_set = set()
D
dongdaxiang 已提交
72 73 74 75 76 77 78
        program_id = str(id(self.program_))
        if self.program_ == None:
            print("program of current device worker is not configured")
            sys.exit(-1)
        opt_info = self.program_._fleet_opt
        program_configs = opt_info["program_configs"]

D
dongdaxiang 已提交
79 80
        for pid in program_configs:
            if pid == program_id:
D
dongdaxiang 已提交
81 82 83 84 85 86
                pc = downpour.program_config.add()
                pc.program_id = program_id
                for i in program_configs[program_id]["push_sparse"]:
                    pc.push_sparse_table_id.extend([i])
                for i in program_configs[program_id]["push_dense"]:
                    pc.push_dense_table_id.extend([i])
X
xjqbest 已提交
87
                    dense_table_set.add(i)
D
dongdaxiang 已提交
88 89 90 91
                for i in program_configs[program_id]["pull_sparse"]:
                    pc.pull_sparse_table_id.extend([i])
                for i in program_configs[program_id]["pull_dense"]:
                    pc.pull_dense_table_id.extend([i])
X
fix bug  
xjqbest 已提交
92
                    dense_table_set.add(i)
D
dongdaxiang 已提交
93
                break
94

X
fix bug  
xjqbest 已提交
95 96 97 98 99 100 101 102 103 104
        for i in self.fleet_desc_.trainer_param.dense_table:
            if i.table_id in dense_table_set:
                dense_table = downpour.dense_table.add()
                dense_table.table_id = i.table_id
                dense_table.dense_value_name.extend(
                    i.dense_variable_name)
                dense_table.dense_grad_name.extend(
                    i.dense_gradient_variable_name)
                downpour.skip_ops.extend(self.fleet_desc_.trainer_param.skip_op)

105 106 107 108 109

class DeviceWorkerFactory(object):
    def create_device_worker(self, worker_type):
        classname = worker_type.capitalize()
        return globals()[classname]()