device_worker.py 4.4 KB
Newer Older
1
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
D
dongdaxiang 已提交
14
import sys
15

16 17
__all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD']

18 19 20

class DeviceWorker(object):
    def __init__(self):
D
dongdaxiang 已提交
21 22 23 24 25 26 27
        self.program_ = None

    def set_fleet_desc(self, fleet_desc):
        self.fleet_desc_ = fleet_desc

    def set_program(self, program):
        self.program_ = program
28

29
    def gen_worker_desc(self, trainer_desc):
30 31 32 33 34 35 36
        pass


class Hogwild(DeviceWorker):
    def __init__(self):
        super(Hogwild, self).__init__()

37
    def gen_worker_desc(self, trainer_desc):
38 39 40
        trainer_desc.device_worker_name = "HogwildWorker"


D
dongdaxiang 已提交
41
class DownpourSGD(DeviceWorker):
42
    def __init__(self):
D
dongdaxiang 已提交
43
        super(DownpourSGD, self).__init__()
44

45
    def gen_worker_desc(self, trainer_desc):
X
fix bug  
xjqbest 已提交
46
        dense_table_set = set()
D
dongdaxiang 已提交
47 48 49 50 51 52
        program_id = str(id(self.program_))
        if self.program_ == None:
            print("program of current device worker is not configured")
            sys.exit(-1)
        opt_info = self.program_._fleet_opt
        program_configs = opt_info["program_configs"]
53
        downpour = trainer_desc.downpour_param
D
dongdaxiang 已提交
54

D
dongdaxiang 已提交
55 56
        for pid in program_configs:
            if pid == program_id:
D
dongdaxiang 已提交
57 58 59 60 61 62
                pc = downpour.program_config.add()
                pc.program_id = program_id
                for i in program_configs[program_id]["push_sparse"]:
                    pc.push_sparse_table_id.extend([i])
                for i in program_configs[program_id]["push_dense"]:
                    pc.push_dense_table_id.extend([i])
X
xjqbest 已提交
63
                    dense_table_set.add(i)
D
dongdaxiang 已提交
64 65 66 67
                for i in program_configs[program_id]["pull_sparse"]:
                    pc.pull_sparse_table_id.extend([i])
                for i in program_configs[program_id]["pull_dense"]:
                    pc.pull_dense_table_id.extend([i])
X
fix bug  
xjqbest 已提交
68
                    dense_table_set.add(i)
D
dongdaxiang 已提交
69
                break
70

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        trainer_desc.device_worker_name = "DownpourWorker"
        pull_thread = trainer_desc.pull_dense_param
        pull_thread.device_num = trainer_desc.thread_num
        for i in self.fleet_desc_.trainer_param.dense_table:
            if i.table_id in dense_table_set:
                dense_table = pull_thread.dense_table.add()
                dense_table.dense_value_name.extend(
                    i.dense_variable_name)
                dense_table.table_id = \
                    i.table_id
        sparse_table = downpour.sparse_table.add()
        sparse_table.table_id = \
                    self.fleet_desc_.trainer_param.sparse_table[0].table_id
        sparse_table.sparse_key_name.extend(
            self.fleet_desc_.trainer_param.sparse_table[0].slot_key)
        sparse_table.sparse_value_name.extend(
            self.fleet_desc_.trainer_param.sparse_table[0].slot_value)
        sparse_table.sparse_grad_name.extend(
            self.fleet_desc_.trainer_param.sparse_table[0].slot_gradient)
        sparse_table.emb_dim = \
                    self.fleet_desc_.server_param.downpour_server_param.downpour_table_param[
                        0].accessor.fea_dim - 2
        sparse_table.fea_dim = sparse_table.emb_dim + 2
        # TODO(guru4elephant): hard code here, need to improve
        sparse_table.label_var_name = "click"

X
fix bug  
xjqbest 已提交
97 98 99 100 101 102 103 104 105 106
        for i in self.fleet_desc_.trainer_param.dense_table:
            if i.table_id in dense_table_set:
                dense_table = downpour.dense_table.add()
                dense_table.table_id = i.table_id
                dense_table.dense_value_name.extend(
                    i.dense_variable_name)
                dense_table.dense_grad_name.extend(
                    i.dense_gradient_variable_name)
                downpour.skip_ops.extend(self.fleet_desc_.trainer_param.skip_op)

107 108 109 110 111

class DeviceWorkerFactory(object):
    def create_device_worker(self, worker_type):
        classname = worker_type.capitalize()
        return globals()[classname]()