collective.py 7.9 KB
Newer Older
K
kuizhiqing 已提交
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
K
kuizhiqing 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
K
kuizhiqing 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
K
kuizhiqing 已提交
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json

17 18 19
from ..context.device import DeviceType
from .controller import ControleMode, Controller

K
kuizhiqing 已提交
20 21 22 23

class CollectiveController(Controller):
    @classmethod
    def enable(cls, ctx):
K
kuizhiqing 已提交
24
        # collective is the default mode
K
kuizhiqing 已提交
25
        if ctx:
26
            ctx.logger.debug(f"{cls.__name__} enabled")
K
kuizhiqing 已提交
27
            ctx.args.run_mode = ControleMode.COLLECTIVE
K
kuizhiqing 已提交
28 29 30 31 32
            return True
        else:
            return False

    def build_pod(self):
33 34 35 36 37
        if (
            self.ctx.args.master is None
            and self.ctx.args.start_port
            and self.ctx.args.ips
        ):
C
Chitsing KUI 已提交
38
            return self._build_pod_with_args()
K
kuizhiqing 已提交
39
        else:
C
Chitsing KUI 已提交
40
            return self._build_pod_with_master()
K
kuizhiqing 已提交
41 42 43 44 45 46 47 48 49 50 51

    def _build_pod_with_args(self):
        self.pod.replicas = self.pod_replicas()

        start_port = int(self.ctx.args.start_port)
        ips = self.ctx.args.ips.split(',')

        job_endpoints = [
            f"{h}:{p+start_port}" for h in ips for p in range(self.pod.replicas)
        ]

52
        self.ctx.logger.debug(f"job endpoints: {job_endpoints}")
K
kuizhiqing 已提交
53

54 55 56 57 58
        rank_offset = (
            ips.index(self.ctx.node.ip) * self.pod.replicas
            if self.ctx.node.ip in ips
            else 0
        )
K
kuizhiqing 已提交
59 60 61 62 63

        self.save_pod_log(job_endpoints)

        selected_dev_key = self.ctx.node.device.get_selected_device_key()
        selected_dev_list = self.ctx.node.device.get_selected_devices(
64 65
            self.ctx.args.devices
        )
K
kuizhiqing 已提交
66 67 68

        for i in range(self.pod.replicas):
            e = {
69 70 71 72 73
                "PADDLE_GLOBAL_SIZE": f"{len(job_endpoints)}",
                "PADDLE_LOCAL_SIZE": f"{self.pod.replicas}",
                "PADDLE_GLOBAL_RANK": f"{i + rank_offset}",
                "PADDLE_LOCAL_RANK": f"{i}",
                "PADDLE_NNODES": f"{len(ips)}",
74
                # compatible env
K
kuizhiqing 已提交
75 76
                "PADDLE_TRAINER_ENDPOINTS": ",".join(job_endpoints),
                "PADDLE_CURRENT_ENDPOINT": job_endpoints[i + rank_offset],
77 78
                "PADDLE_TRAINER_ID": f"{i + rank_offset}",
                "PADDLE_TRAINERS_NUM": f"{len(job_endpoints)}",
K
kuizhiqing 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
                "PADDLE_RANK_IN_NODE": str(i),
            }
            if len(selected_dev_list) > 0:
                if self.ctx.node.device.dtype == DeviceType.CUSTOM_DEVICE:
                    e.update(self.ctx.node.device.get_custom_device_envs())
                if self.pod.replicas == 1:
                    e.update({selected_dev_key: ",".join(selected_dev_list)})
                else:
                    e.update({selected_dev_key: selected_dev_list[i]})
            else:
                e.update({'PADDLE_DISTRI_BACKEND': 'gloo'})

            log_file = f"workerlog.{i}"
            self.add_container(envs=e, log_file=log_file)

        return True

    def _build_pod_with_master(self):
K
kuizhiqing 已提交
97 98 99
        self.pod.replicas = self.pod_replicas()

        # rank will be reset when restart
K
kuizhiqing 已提交
100
        self.pod.rank = int(self.ctx.args.rank)
K
kuizhiqing 已提交
101 102 103 104 105

        port = self.ctx.node.get_free_port()

        # compatible
        endpoints = [
106
            f"{self.ctx.node.ip}:{p}"
K
kuizhiqing 已提交
107 108 109
            for p in self.ctx.node.get_free_ports(self.pod.replicas)
        ]

110 111 112 113 114 115
        data = json.dumps(
            {
                'name': self.pod.name,
                'rank': self.pod.rank,
                'replicas': self.pod.replicas,
                'dtype': self.ctx.node.device.dtype,
116
                'candidate': f'{self.ctx.node.ip}:{port}',
117 118 119 120 121
                'endpoints': ",".join(endpoints),
            }
        )

        peer_list, rank = self.master.sync_peers(
122
            f'/{self.job.id}/info',
123 124 125 126 127
            self.pod.name,
            data,
            self.job.replicas,
            self.pod.rank,
        )
K
kuizhiqing 已提交
128 129 130 131 132 133 134
        self.pod.rank = rank

        if len(peer_list) < 1:
            return False

        peer_list = [json.loads(i) for i in peer_list]

135
        self.ctx.logger.debug(f"sync peers done {peer_list}")
K
kuizhiqing 已提交
136 137 138 139 140 141 142 143 144 145 146 147
        self.save_pod_log(peer_list)

        global_size = sum([i['replicas'] for i in peer_list])
        rank_offset = sum([i['replicas'] for i in peer_list[:rank]])
        '''
        The new designed collective need nothing but a master endpoint
        '''
        collective_master = peer_list[0]['candidate']

        job_endpoints = [i['endpoints'] for i in peer_list]

        self.pod.reset()
K
kuizhiqing 已提交
148 149
        selected_dev_key = self.ctx.node.device.get_selected_device_key()
        selected_dev_list = self.ctx.node.device.get_selected_devices(
150 151
            self.ctx.args.devices
        )
K
kuizhiqing 已提交
152 153 154
        for i in range(self.pod.replicas):
            e = {
                "PADDLE_MASTER": collective_master,
155 156 157 158 159
                "PADDLE_GLOBAL_SIZE": f"{global_size}",
                "PADDLE_LOCAL_SIZE": f"{self.pod.replicas}",
                "PADDLE_GLOBAL_RANK": f"{i + rank_offset}",
                "PADDLE_LOCAL_RANK": f"{i}",
                "PADDLE_NNODES": f"{self.job.replicas}",
160
                # compatible env
K
kuizhiqing 已提交
161 162
                "PADDLE_TRAINER_ENDPOINTS": ",".join(job_endpoints),
                "PADDLE_CURRENT_ENDPOINT": endpoints[i],
163 164
                "PADDLE_TRAINER_ID": f"{i + rank_offset}",
                "PADDLE_TRAINERS_NUM": f"{global_size}",
K
kuizhiqing 已提交
165 166
                "PADDLE_RANK_IN_NODE": str(i),
            }
167
            if len(selected_dev_list) > 0:
168 169
                if self.ctx.node.device.dtype == DeviceType.CUSTOM_DEVICE:
                    e.update(self.ctx.node.device.get_custom_device_envs())
170 171 172 173
                if self.pod.replicas == 1:
                    e.update({selected_dev_key: ",".join(selected_dev_list)})
                else:
                    e.update({selected_dev_key: selected_dev_list[i]})
174
            else:
175 176
                e.update({'PADDLE_DISTRI_BACKEND': 'gloo'})

177 178 179
            # log_file = "{}.{}.{}.log".format(self.job.id, self.pod.name, i)
            log_file = f"workerlog.{i}"
            self.add_container(envs=e, log_file=log_file)
K
kuizhiqing 已提交
180 181 182 183 184 185 186 187

        return True


class CollectiveElasticController(CollectiveController):
    @classmethod
    def enable(cls, ctx):
        if ctx.args.master and ctx.args.master.startswith("etcd://"):
188
            ctx.logger.debug(f"{cls.__name__} enabled")
K
kuizhiqing 已提交
189
            ctx.args.run_mode = ControleMode.COLLECTIVE
K
kuizhiqing 已提交
190 191 192 193 194 195 196
            return True
        else:
            return False

    def register(self):
        if self.job.id == 'default':
            self.ctx.logger.warning(
197 198
                'Using default job name may cause conflict, add --job_id in args'
            )
K
kuizhiqing 已提交
199 200 201 202 203

        self.master.register_heartbeat(self.job.id, self.pod.name)

    def run(self):

K
kuizhiqing 已提交
204 205
        timeout = int(self.ctx.args.elastic_timeout)
        timeout = timeout if self.job.elastic else timeout * 10
K
kuizhiqing 已提交
206 207 208 209 210 211
        self.register()

        while self.pod.restart <= self.ctx.args.max_restart:

            self.build_job()

212 213
            self.ctx.logger.info("Waiting peer ready...")

214 215 216
            ok, replicas = self.master.wait_peer_ready(
                self.job.replicas_min, self.job.replicas_max, timeout
            )
K
kuizhiqing 已提交
217 218 219
            if ok:
                self.job.replicas = replicas
            else:
220
                self.ctx.logger.warning(f"peer not ready {self.job}")
K
kuizhiqing 已提交
221 222
                break

223
            self.ctx.logger.debug(f"Run {self.job}")
K
kuizhiqing 已提交
224 225 226 227 228 229

            if not self.build_pod():
                continue

            self.master.set_status(self.ctx.status.RUNNING)

230
            self.deploy_pod()
K
kuizhiqing 已提交
231 232 233 234

            if self.watch():
                break

235
        self.ctx.logger.debug(f"Job done {self.job}")