downpour.py 3.8 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

D
dongdaxiang 已提交
14 15
from .node import DownpourServer
from .node import DownpourWorker
D
dongdaxiang 已提交
16 17
from ..backward import append_backward
import ps_pb2 as pslib
D
dongdaxiang 已提交
18
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
19 20
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_inputs
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_outputs
D
dongdaxiang 已提交
21
from google.protobuf import text_format
D
dongdaxiang 已提交
22 23

class DownpourSGD(object):
24 25 26 27 28 29 30 31 32 33 34 35 36 37
    """
    Distributed optimizer of downpour stochastic gradient descent
    Standard implementation of Google's Downpour SGD
    in Large Scale Distributed Deep Networks

    Args:
        learning_rate (float): the learning rate used to update parameters. \
        Can be a float value
    Examples:
        .. code-block:: python
    
             downpour_sgd = fluid.distributed.DownpourSGD(learning_rate=0.2)
             downpour_sgd.minimize(cost)
    """
D
dongdaxiang 已提交
38
    def __init__(self, learning_rate=0.001, window=1):
39 40
        # todo(guru4elephant): add more optimizers here as argument
        # todo(guru4elephant): make learning_rate as a variable
D
dongdaxiang 已提交
41
        self.learning_rate_ = learning_rate
D
dongdaxiang 已提交
42
        self.window_ = window
43 44
        self.type = "downpour"
    
D
dongdaxiang 已提交
45
    def minimize(self, loss, startup_program=None,
46
                 parameter_list=None, no_grad_set=None):
47 48
        params_grads = sorted(append_backward(
            loss, parameter_list, no_grad_set), key=lambda x:x[0].name)
D
dongdaxiang 已提交
49
        table_name = find_distributed_lookup_table(loss.block.program)
50 51 52 53
        prefetch_slots = find_distributed_lookup_table_inputs(
            loss.block.program, table_name)
        prefetch_slots_emb = find_distributed_lookup_table_outputs(
            loss.block.program, table_name)
D
dongdaxiang 已提交
54
        server = DownpourServer()
55
        # window is communication strategy
D
dongdaxiang 已提交
56
        worker = DownpourWorker(self.window_)
57 58 59 60 61
        # Todo(guru4elephant): support multiple tables definitions
        # currently support one big sparse table
        sparse_table_index = 0
        # currently merge all dense parameters into one dense table
        dense_table_index = 1
H
heqiaozhi 已提交
62 63 64 65 66 67
        params = []
        grads = []
        for i in params_grads:
            params.append(i[0])
        for i in params_grads:
            grads.append(i[1])
68
        server.add_sparse_table(sparse_table_index, self.learning_rate_,
D
dongdaxiang 已提交
69
                                prefetch_slots, prefetch_slots_emb)
70
        server.add_dense_table(dense_table_index, self.learning_rate_, 
H
heqiaozhi 已提交
71
                               params, grads)
72
        worker.add_sparse_table(sparse_table_index, self.learning_rate_,
D
dongdaxiang 已提交
73
                                prefetch_slots, prefetch_slots_emb)
74
        worker.add_dense_table(dense_table_index, self.learning_rate_, 
H
heqiaozhi 已提交
75
                               params, grads)
D
dongdaxiang 已提交
76 77
        ps_param = pslib.PSParameter()
        ps_param.server_param.CopyFrom(server.get_desc())
H
heqiaozhi 已提交
78
        ps_param.trainer_param.CopyFrom(worker.get_desc())
79 80
        # Todo(guru4elephant): figure out how to support more sparse parameters
        # currently only support lookup_table
D
dongdaxiang 已提交
81
        worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
H
heqiaozhi 已提交
82
        ps_param.trainer_param.skip_op.extend(worker_skipped_ops)
D
dongdaxiang 已提交
83
        ps_param_str = text_format.MessageToString(ps_param)
H
heqiaozhi 已提交
84
        return [ps_param, worker_skipped_ops]