__init__.py 6.1 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import sys
import os
from ..base.role_maker import MPISymetricRoleMaker
D
dongdaxiang 已提交
17 18 19 20
from .optimizer_factory import *
from google.protobuf import text_format
import paddle.fluid.optimizer as local_optimizer
import paddle.fluid as fluid
D
dongdaxiang 已提交
21 22 23 24 25 26 27 28


class Fleet(object):
    """
    
    """

    def __init__(self):
D
dongdaxiang 已提交
29
        self._opt_info = None  # for fleet only
D
dongdaxiang 已提交
30
        self.role_maker_ = None
31
        self.local_ip_ = 0
32
        self.is_initialized_ = False
D
dongdaxiang 已提交
33 34 35 36 37

    def init(self):
        # TODO(guru4elephant)
        # this is a temporary solution
        # we will support more configurable RoleMaker for users in the future
38 39 40 41 42
        if not self.is_initialized_:
            self.role_maker_ = MPISymetricRoleMaker()
            self.role_maker_.generate_role()
            self._fleet_ptr = fluid.core.Fleet()
            self.is_initialized_ = True
D
dongdaxiang 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60

    def stop(self):
        self.role_maker_.barrier_worker()
        if self.role_maker_.is_first_worker():
            self._fleet_ptr.stop_server()
        self.role_maker_.barrier_worker()
        self.role_maker_.barrier_all()
        self.role_maker_.finalize()

    def init_pserver(self):
        if self._opt_info:
            if "fleet_desc" in self._opt_info:
                self._dist_desc_str = text_format.MessageToString(
                    self._opt_info["fleet_desc"])
                self._dist_desc = self._opt_info["fleet_desc"]
            else:
                print("You should run DistributedOptimizer.minimize() first")
                sys.exit(-1)
D
dongdaxiang 已提交
61 62 63
            self._fleet_ptr.init_server(self._dist_desc_str,
                                        self.role_maker_.get_rank())
            self.local_ip_ = self._fleet_ptr.run_server()
64
            self.role_maker_.barrier_all()
D
dongdaxiang 已提交
65
            self.all_ips_ = self.role_maker_.all_gather(self.local_ip_)
66

D
dongdaxiang 已提交
67 68
            self._fleet_ptr.gather_servers(self.all_ips_,
                                           self.role_maker_.get_size())
69
            # wait all workers start
D
dongdaxiang 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83
            self.role_maker_.barrier_all()
        else:
            print("You should run DistributedOptimizer.minimize() first")
            sys.exit(-1)

    def init_worker(self):
        if self._opt_info:
            if "fleet_desc" in self._opt_info:
                self._dist_desc_str = text_format.MessageToString(
                    self._opt_info["fleet_desc"])
                self._dist_desc = self._opt_info["fleet_desc"]
            else:
                print("You should run DistributedOptimizer.minimize() first")
                sys.exit(-1)
84 85 86
            self.role_maker_.barrier_all()  # wait for server starts
            self.all_ips_ = self.role_maker_.all_gather(self.local_ip_)
            self._fleet_ptr.init_worker(self._dist_desc_str, self.all_ips_,
D
dongdaxiang 已提交
87 88
                                        self.role_maker_.get_size(),
                                        self.role_maker_.get_rank())
89
            self.role_maker_.barrier_all()
D
dongdaxiang 已提交
90 91 92 93 94
            self.role_maker_.barrier_worker()
        else:
            print("You should run DistributedOptimizer.minimize() first")
            sys.exit(-1)

95 96 97 98 99 100
    def get_worker_num(self):
        return self.role_maker_.worker_num()

    def get_server_num(self):
        return self.role_maker_.server_num()

D
dongdaxiang 已提交
101 102 103 104 105 106
    def is_worker(self):
        return self.role_maker_.is_worker()

    def is_server(self):
        return self.role_maker_.is_server()

D
dongdaxiang 已提交
107 108 109 110 111 112 113 114 115 116 117 118
    def init_pserver_model(self):
        if self.role_maker_.is_first_worker():
            self._fleet_ptr.init_model()
        self.role_maker_.barrier_worker()

    def save_pserver_model(self, save_path):
        self._fleet_ptr.save_model(save_path)

    def _set_opt_info(self, opt_info):
        self._opt_info = opt_info


D
dongdaxiang 已提交
119
class DistributedOptimizer(object):
D
dongdaxiang 已提交
120 121 122 123 124 125 126 127 128 129 130
    def __init__(self, optimizer, dist_config={}):
        super(DistributedOptimizer, self).__init__()
        self._optimizer = optimizer
        self._optimizer_name = "Distributed%s" % optimizer.type.capitalize()
        if optimizer.type != "adam":
            print("Currently, distributed optimizer only supports Adam"
                  "Will config built-in adam for you."
                  "We will support more functions in DistributedOptimizer",
                  sys.stderr)
            self._optimizer_name = "DistributedAdam"

D
dongdaxiang 已提交
131
        self._distributed_optimizer = globals()[self._optimizer_name](optimizer)
D
dongdaxiang 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156

    def backward(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None,
                 callbacks=None):
        pass

    def apply_gradients(self, params_grads):
        pass

    def minimize(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None):
        optimize_ops, param_grads, opt_info = \
                      self._distributed_optimizer.minimize(
                          loss,
                          startup_program,
                          parameter_list,
                          no_grad_set)

        fleet_instance._set_opt_info(opt_info)
D
dongdaxiang 已提交
157
        return [optimize_ops, param_grads]
D
dongdaxiang 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172


# this is a temporary solution
# TODO(guru4elephant)
# will make this more flexible for more Parameter Server Archs
fleet_instance = Fleet()

init = fleet_instance.init
stop = fleet_instance.stop
init_pserver = fleet_instance.init_pserver
init_worker = fleet_instance.init_worker
is_worker = fleet_instance.is_worker
is_server = fleet_instance.is_server
init_pserver_model = fleet_instance.init_pserver_model
save_pserver_model = fleet_instance.save_pserver_model
173 174
worker_num = fleet_instance.get_worker_num
server_num = fleet_instance.get_server_num