device_worker.py 5.0 KB
Newer Older
1
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
D
dongdaxiang 已提交
14
import sys
15

16 17
__all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD']

18 19 20

class DeviceWorker(object):
    def __init__(self):
D
dongdaxiang 已提交
21 22 23 24 25 26 27
        self.program_ = None

    def set_fleet_desc(self, fleet_desc):
        self.fleet_desc_ = fleet_desc

    def set_program(self, program):
        self.program_ = program
28

29
    def gen_worker_desc(self, trainer_desc):
30 31 32 33 34 35 36
        pass


class Hogwild(DeviceWorker):
    def __init__(self):
        super(Hogwild, self).__init__()

37
    def gen_worker_desc(self, trainer_desc):
38 39 40
        trainer_desc.device_worker_name = "HogwildWorker"


D
dongdaxiang 已提交
41
class DownpourSGD(DeviceWorker):
42
    def __init__(self):
D
dongdaxiang 已提交
43
        super(DownpourSGD, self).__init__()
44

45
    def gen_worker_desc(self, trainer_desc):
46 47 48 49 50
        trainer_desc.device_worker_name = "DownpourWorker"
        pull_thread = trainer_desc.pull_dense_param
        pull_thread.device_num = trainer_desc.thread_num
        dense_table = pull_thread.dense_table.add()
        dense_table.dense_value_name.extend(
D
dongdaxiang 已提交
51
            self.fleet_desc_.trainer_param.dense_table[0].dense_variable_name)
52
        dense_table.table_id = \
D
dongdaxiang 已提交
53
            self.fleet_desc_.trainer_param.dense_table[0].table_id
54 55 56
        downpour = trainer_desc.downpour_param
        sparse_table = downpour.sparse_table.add()
        sparse_table.table_id = \
D
dongdaxiang 已提交
57
                    self.fleet_desc_.trainer_param.sparse_table[0].table_id
58
        sparse_table.sparse_key_name.extend(
D
dongdaxiang 已提交
59
            self.fleet_desc_.trainer_param.sparse_table[0].slot_key)
60
        sparse_table.sparse_value_name.extend(
D
dongdaxiang 已提交
61
            self.fleet_desc_.trainer_param.sparse_table[0].slot_value)
62
        sparse_table.sparse_grad_name.extend(
D
dongdaxiang 已提交
63 64 65 66
            self.fleet_desc_.trainer_param.sparse_table[0].slot_gradient)
        sparse_table.emb_dim = \
                    self.fleet_desc_.server_param.downpour_server_param.downpour_table_param[
                        0].accessor.fea_dim - 2
67
        sparse_table.fea_dim = sparse_table.emb_dim + 2
D
dongdaxiang 已提交
68
        # TODO(guru4elephant): hard code here, need to improve
69 70 71 72
        sparse_table.label_var_name = "click"

        dense_table = downpour.dense_table.add()
        dense_table.table_id = \
D
dongdaxiang 已提交
73
                    self.fleet_desc_.trainer_param.dense_table[0].table_id
74
        dense_table.dense_value_name.extend(
D
dongdaxiang 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
            self.fleet_desc_.trainer_param.dense_table[0].dense_variable_name)
        dense_table.dense_grad_name.extend(
            self.fleet_desc_.trainer_param.dense_table[
                0].dense_gradient_variable_name)
        downpour.skip_ops.extend(self.fleet_desc_.trainer_param.skip_op)

        program_id = str(id(self.program_))
        if self.program_ == None:
            print("program of current device worker is not configured")
            sys.exit(-1)
        opt_info = self.program_._fleet_opt
        program_configs = opt_info["program_configs"]

        for program_id in program_configs:
            if program_configs[program_id] == program_id:
                pc = downpour.program_config.add()
                pc.program_id = program_id
                for i in program_configs[program_id]["push_sparse"]:
                    pc.push_sparse_table_id.extend([i])
                for i in program_configs[program_id]["push_dense"]:
                    pc.push_dense_table_id.extend([i])
                for i in program_configs[program_id]["pull_sparse"]:
                    pc.pull_sparse_table_id.extend([i])
                for i in program_configs[program_id]["pull_dense"]:
                    pc.pull_dense_table_id.extend([i])
                break
        '''
        for program_config in self.fleet_desc_.trainer_param.program_config:
            if program_config.program_id == program_id:
                pc = downpour.program_config.add()
                pc.program_id = program_config.program_id
                for i in program_config.push_sparse_table_id:
                    pc.push_sparse_table_id.extend([i])
                for i in program_config.push_dense_table_id:
                    pc.push_dense_table_id.extend([i])
                for i in program_config.pull_sparse_table_id:
                    pc.pull_sparse_table_id.extend([i])
                for i in program_config.pull_dense_table_id:
                    pc.pull_dense_table_id.extend([i])
                break
        '''
116 117 118 119 120 121


class DeviceWorkerFactory(object):
    def create_device_worker(self, worker_type):
        classname = worker_type.capitalize()
        return globals()[classname]()