downpour.py 7.6 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

D
dongdaxiang 已提交
14 15
from .node import DownpourServer
from .node import DownpourWorker
D
dongdaxiang 已提交
16 17
from ..backward import append_backward
import ps_pb2 as pslib
D
dongdaxiang 已提交
18
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
19 20
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_inputs
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_outputs
D
dongdaxiang 已提交
21
from google.protobuf import text_format
D
dongdaxiang 已提交
22

23

D
dongdaxiang 已提交
24
class DownpourSGD(object):
25
    r"""
26 27 28 29 30 31 32 33 34 35
    Distributed optimizer of downpour stochastic gradient descent
    Standard implementation of Google's Downpour SGD
    in Large Scale Distributed Deep Networks

    Args:
        learning_rate (float): the learning rate used to update parameters. \
        Can be a float value
    Examples:
        .. code-block:: python
    
36 37 38
             opt = fluid.DistributedOptimizer(sgd_opt)
             opt.minimize()

39 40 41
             downpour_sgd = fluid.distributed.DownpourSGD(learning_rate=0.2)
             downpour_sgd.minimize(cost)
    """
42

D
dongdaxiang 已提交
43
    def __init__(self, learning_rate=0.001, window=1):
44 45
        # todo(guru4elephant): add more optimizers here as argument
        # todo(guru4elephant): make learning_rate as a variable
D
dongdaxiang 已提交
46
        self.learning_rate_ = learning_rate
D
dongdaxiang 已提交
47
        self.window_ = window
48
        self.type = "downpour"
H
heqiaozhi 已提交
49 50 51 52
        self.data_norm_name = [
            ".batch_size", ".batch_square_sum", ".batch_sum",
            ".batch_size@GRAD", ".batch_square_sum@GRAD", ".batch_sum@GRAD"
        ]
53 54

    def minimize(self,
H
heqiaozhi 已提交
55
                 losses,
56 57 58 59 60 61
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None):
        """
        DownpounSGD is a distributed optimizer so
        that user can call minimize to generate backward
T
tianshuo78520a 已提交
62
        operators and optimization operators within minimize function
63 64 65 66 67 68 69 70 71 72 73 74
        Args:
            loss(Variable): loss variable defined by user
            startup_program(Program): startup program that defined by user
            parameter_list(str list): parameter names defined by users
            no_grad_set(set): a set of variables that is defined by users
            so that these variables do not need gradient computation
        Returns:
            [ps_param, worker_skipped_ops]
            ps_param: parameter server protobuf desc
            worker_skipped_ops: operator names that need
            to be skipped during execution
        """
H
heqiaozhi 已提交
75 76 77
        if not isinstance(losses, list):
            raise ValueError('losses is a list, just lick [model.cost]')
        table_name = find_distributed_lookup_table(losses[0].block.program)
78
        prefetch_slots = find_distributed_lookup_table_inputs(
H
heqiaozhi 已提交
79
            losses[0].block.program, table_name)
80
        prefetch_slots_emb = find_distributed_lookup_table_outputs(
H
heqiaozhi 已提交
81 82 83
            losses[0].block.program, table_name)

        ps_param = pslib.PSParameter()
D
dongdaxiang 已提交
84
        server = DownpourServer()
D
dongdaxiang 已提交
85
        worker = DownpourWorker(self.window_)
86 87
        sparse_table_index = 0
        server.add_sparse_table(sparse_table_index, self.learning_rate_,
D
dongdaxiang 已提交
88
                                prefetch_slots, prefetch_slots_emb)
89
        worker.add_sparse_table(sparse_table_index, self.learning_rate_,
D
dongdaxiang 已提交
90
                                prefetch_slots, prefetch_slots_emb)
H
heqiaozhi 已提交
91 92
        dense_table_index = 1
        program_configs = []
93
        param_grads_list = []
H
heqiaozhi 已提交
94 95 96 97 98 99 100 101 102 103
        for loss_index in range(len(losses)):
            program_config = ps_param.trainer_param.program_config.add()
            program_config.program_id = str(
                id(losses[loss_index].block.program))
            program_config.pull_sparse_table_id.extend([sparse_table_index])
            program_config.push_sparse_table_id.extend([sparse_table_index])
            params_grads = sorted(
                append_backward(losses[loss_index], parameter_list,
                                no_grad_set),
                key=lambda x: x[0].name)
104
            param_grads_list.append(params_grads)
H
heqiaozhi 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
            params = []
            grads = []
            data_norm_params = []
            data_norm_grads = []
            for i in params_grads:
                is_data_norm_data = False
                for data_norm_name in self.data_norm_name:
                    if i[0].name.endswith(data_norm_name):
                        is_data_norm_data = True
                        data_norm_params.append(i[0])
                if not is_data_norm_data:
                    params.append(i[0])
            for i in params_grads:
                is_data_norm_data = False
                for data_norm_grad in self.data_norm_name:
                    if i[0].name.endswith(data_norm_grad):
                        is_data_norm_data = True
                        data_norm_grads.append(i[1])
                if not is_data_norm_data:
                    grads.append(i[1])
            server.add_dense_table(dense_table_index, self.learning_rate_,
                                   params, grads)
            worker.add_dense_table(dense_table_index, self.learning_rate_,
                                   params, grads)
            program_config.pull_dense_table_id.extend([dense_table_index])
            program_config.push_dense_table_id.extend([dense_table_index])
            if len(data_norm_params) != 0 and len(data_norm_grads) != 0:
                dense_table_index += 1
                server.add_data_norm_table(dense_table_index,
                                           self.learning_rate_,
                                           data_norm_params, data_norm_grads)
                worker.add_dense_table(dense_table_index, self.learning_rate_,
                                       data_norm_params, data_norm_grads)
                program_config.pull_dense_table_id.extend([dense_table_index])
                program_config.push_dense_table_id.extend([dense_table_index])
            dense_table_index += 1
            program_configs.append(program_config)
D
dongdaxiang 已提交
142
        ps_param.server_param.CopyFrom(server.get_desc())
H
heqiaozhi 已提交
143
        ps_param.trainer_param.CopyFrom(worker.get_desc())
H
heqiaozhi 已提交
144 145
        for program_config in program_configs:
            ps_param.trainer_param.program_config.extend([program_config])
146 147
        # Todo(guru4elephant): figure out how to support more sparse parameters
        # currently only support lookup_table
D
dongdaxiang 已提交
148
        worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
H
heqiaozhi 已提交
149
        ps_param.trainer_param.skip_op.extend(worker_skipped_ops)
D
dongdaxiang 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163

        # all fleet operations should be defined in operators in the future
        # we want to return an object here containing:
        # 1) worker execution strategy
        # 2) pserver execution strategy
        # 3) fleet configurations
        # 4) skipped operators in runtime
        # 5) distributed optimization
        opt_info = {}
        opt_info["trainer"] = "DistMultiTrainer"
        opt_info["device_worker"] = "DownpourSGD"
        opt_info["optimizer"] = "DownpourSGD"
        opt_info["fleet_desc"] = ps_param
        opt_info["worker_skipped_ops"] = worker_skipped_ops
164 165 166 167 168

        for loss in losses:
            loss.block.program._fleet_opt = opt_info

        return None, param_grads_list