collective.py 8.1 KB
Newer Older
K
kuizhiqing 已提交
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
K
kuizhiqing 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
K
kuizhiqing 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
K
kuizhiqing 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

K
kuizhiqing 已提交
15
from .controller import Controller, ControleMode
16
from ..context.device import DeviceType
K
kuizhiqing 已提交
17 18 19 20 21

import json


class CollectiveController(Controller):
22

K
kuizhiqing 已提交
23 24
    @classmethod
    def enable(cls, ctx):
K
kuizhiqing 已提交
25
        # collective is the default mode
K
kuizhiqing 已提交
26 27
        if ctx:
            ctx.logger.debug("{} enabled".format(cls.__name__))
K
kuizhiqing 已提交
28
            ctx.args.run_mode = ControleMode.COLLECTIVE
K
kuizhiqing 已提交
29 30 31 32 33
            return True
        else:
            return False

    def build_pod(self):
K
kuizhiqing 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
        if self.ctx.args.master is None and self.ctx.args.start_port and self.ctx.args.ips:
            self._build_pod_with_args()
        else:
            self._build_pod_with_master()

    def _build_pod_with_args(self):
        self.pod.replicas = self.pod_replicas()

        start_port = int(self.ctx.args.start_port)
        ips = self.ctx.args.ips.split(',')

        job_endpoints = [
            f"{h}:{p+start_port}" for h in ips for p in range(self.pod.replicas)
        ]

        self.ctx.logger.debug("job endpoints: {}".format(job_endpoints))

        rank_offset = ips.index(
K
kuizhiqing 已提交
52 53
            self.ctx.node.ip
        ) * self.pod.replicas if self.ctx.node.ip in ips else 0
K
kuizhiqing 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66

        self.save_pod_log(job_endpoints)

        selected_dev_key = self.ctx.node.device.get_selected_device_key()
        selected_dev_list = self.ctx.node.device.get_selected_devices(
            self.ctx.args.devices)

        for i in range(self.pod.replicas):
            e = {
                "PADDLE_GLOBAL_SIZE": "{}".format(len(job_endpoints)),
                "PADDLE_LOCAL_SIZE": "{}".format(self.pod.replicas),
                "PADDLE_GLOBAL_RANK": "{}".format(i + rank_offset),
                "PADDLE_LOCAL_RANK": "{}".format(i),
K
kuizhiqing 已提交
67
                "PADDLE_NNODES": "{}".format(len(ips)),
K
kuizhiqing 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
                ## compatible env
                "PADDLE_TRAINER_ENDPOINTS": ",".join(job_endpoints),
                "PADDLE_CURRENT_ENDPOINT": job_endpoints[i + rank_offset],
                "PADDLE_TRAINER_ID": "{}".format(i + rank_offset),
                "PADDLE_TRAINERS_NUM": "{}".format(len(job_endpoints)),
                "PADDLE_RANK_IN_NODE": str(i),
            }
            if len(selected_dev_list) > 0:
                if self.ctx.node.device.dtype == DeviceType.CUSTOM_DEVICE:
                    e.update(self.ctx.node.device.get_custom_device_envs())
                if self.pod.replicas == 1:
                    e.update({selected_dev_key: ",".join(selected_dev_list)})
                else:
                    e.update({selected_dev_key: selected_dev_list[i]})
            else:
                e.update({'PADDLE_DISTRI_BACKEND': 'gloo'})

            log_file = f"workerlog.{i}"
            self.add_container(envs=e, log_file=log_file)

        return True

    def _build_pod_with_master(self):
K
kuizhiqing 已提交
91 92 93
        self.pod.replicas = self.pod_replicas()

        # rank will be reset when restart
K
kuizhiqing 已提交
94
        self.pod.rank = int(self.ctx.args.rank)
K
kuizhiqing 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112

        port = self.ctx.node.get_free_port()

        # compatible
        endpoints = [
            "{}:{}".format(self.ctx.node.ip, p)
            for p in self.ctx.node.get_free_ports(self.pod.replicas)
        ]

        data = json.dumps({
            'name': self.pod.name,
            'rank': self.pod.rank,
            'replicas': self.pod.replicas,
            'dtype': self.ctx.node.device.dtype,
            'candidate': '{}:{}'.format(self.ctx.node.ip, port),
            'endpoints': ",".join(endpoints),
        })

113 114 115 116
        peer_list, rank = self.master.sync_peers('/{}/info'.format(self.job.id),
                                                 self.pod.name, data,
                                                 self.job.replicas,
                                                 self.pod.rank)
K
kuizhiqing 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
        self.pod.rank = rank

        if len(peer_list) < 1:
            return False

        peer_list = [json.loads(i) for i in peer_list]

        self.ctx.logger.debug("sync peers done {}".format(peer_list))
        self.save_pod_log(peer_list)

        global_size = sum([i['replicas'] for i in peer_list])
        rank_offset = sum([i['replicas'] for i in peer_list[:rank]])
        '''
        The new designed collective need nothing but a master endpoint
        '''
        collective_master = peer_list[0]['candidate']

        job_endpoints = [i['endpoints'] for i in peer_list]

        self.pod.reset()
K
kuizhiqing 已提交
137 138
        selected_dev_key = self.ctx.node.device.get_selected_device_key()
        selected_dev_list = self.ctx.node.device.get_selected_devices(
139
            self.ctx.args.devices)
K
kuizhiqing 已提交
140 141 142 143 144 145 146
        for i in range(self.pod.replicas):
            e = {
                "PADDLE_MASTER": collective_master,
                "PADDLE_GLOBAL_SIZE": "{}".format(global_size),
                "PADDLE_LOCAL_SIZE": "{}".format(self.pod.replicas),
                "PADDLE_GLOBAL_RANK": "{}".format(i + rank_offset),
                "PADDLE_LOCAL_RANK": "{}".format(i),
K
kuizhiqing 已提交
147
                "PADDLE_NNODES": "{}".format(self.job.replicas),
K
kuizhiqing 已提交
148 149 150 151 152 153 154
                ## compatible env
                "PADDLE_TRAINER_ENDPOINTS": ",".join(job_endpoints),
                "PADDLE_CURRENT_ENDPOINT": endpoints[i],
                "PADDLE_TRAINER_ID": "{}".format(i + rank_offset),
                "PADDLE_TRAINERS_NUM": "{}".format(global_size),
                "PADDLE_RANK_IN_NODE": str(i),
            }
155
            if len(selected_dev_list) > 0:
156 157
                if self.ctx.node.device.dtype == DeviceType.CUSTOM_DEVICE:
                    e.update(self.ctx.node.device.get_custom_device_envs())
158 159 160 161
                if self.pod.replicas == 1:
                    e.update({selected_dev_key: ",".join(selected_dev_list)})
                else:
                    e.update({selected_dev_key: selected_dev_list[i]})
162
            else:
163 164
                e.update({'PADDLE_DISTRI_BACKEND': 'gloo'})

165 166 167
            # log_file = "{}.{}.{}.log".format(self.job.id, self.pod.name, i)
            log_file = f"workerlog.{i}"
            self.add_container(envs=e, log_file=log_file)
K
kuizhiqing 已提交
168 169 170 171 172

        return True


class CollectiveElasticController(CollectiveController):
173

K
kuizhiqing 已提交
174 175 176 177
    @classmethod
    def enable(cls, ctx):
        if ctx.args.master and ctx.args.master.startswith("etcd://"):
            ctx.logger.debug("{} enabled".format(cls.__name__))
K
kuizhiqing 已提交
178
            ctx.args.run_mode = ControleMode.COLLECTIVE
K
kuizhiqing 已提交
179 180 181 182 183 184 185
            return True
        else:
            return False

    def register(self):
        if self.job.id == 'default':
            self.ctx.logger.warning(
186 187
                'Using default job name may cause conflict, add --job_id in args'
            )
K
kuizhiqing 已提交
188 189 190 191 192

        self.master.register_heartbeat(self.job.id, self.pod.name)

    def run(self):

K
kuizhiqing 已提交
193 194
        timeout = int(self.ctx.args.elastic_timeout)
        timeout = timeout if self.job.elastic else timeout * 10
K
kuizhiqing 已提交
195 196 197 198 199 200
        self.register()

        while self.pod.restart <= self.ctx.args.max_restart:

            self.build_job()

201 202
            self.ctx.logger.info("Waiting peer ready...")

203 204 205
            ok, replicas = self.master.wait_peer_ready(self.job.replicas_min,
                                                       self.job.replicas_max,
                                                       timeout)
K
kuizhiqing 已提交
206 207 208
            if ok:
                self.job.replicas = replicas
            else:
K
kuizhiqing 已提交
209
                self.ctx.logger.warning("peer not ready {}".format(self.job))
K
kuizhiqing 已提交
210 211 212 213 214 215 216 217 218
                break

            self.ctx.logger.debug("Run {}".format(self.job))

            if not self.build_pod():
                continue

            self.master.set_status(self.ctx.status.RUNNING)

219
            self.deploy_pod()
K
kuizhiqing 已提交
220 221 222 223 224

            if self.watch():
                break

        self.ctx.logger.debug("Job done {}".format(self.job))