collective.py 10.5 KB
Newer Older
K
kuizhiqing 已提交
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
K
kuizhiqing 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
K
kuizhiqing 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
K
kuizhiqing 已提交
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
16
import os
K
kuizhiqing 已提交
17

18 19 20
from ..context.device import DeviceType
from .controller import ControleMode, Controller

K
kuizhiqing 已提交
21 22

class CollectiveController(Controller):
23 24 25 26
    def __init__(self, ctx):
        self._tuner_run_mode = None  # 'tuner_only', 'run_only', 'tuner_and_run'
        super().__init__(ctx)

K
kuizhiqing 已提交
27 28
    @classmethod
    def enable(cls, ctx):
K
kuizhiqing 已提交
29
        # collective is the default mode
K
kuizhiqing 已提交
30
        if ctx:
31
            ctx.logger.debug(f"{cls.__name__} enabled")
K
kuizhiqing 已提交
32
            ctx.args.run_mode = ControleMode.COLLECTIVE
K
kuizhiqing 已提交
33 34 35 36 37
            return True
        else:
            return False

    def build_pod(self):
38 39 40
        skip_run = self._build_pod_with_tuner()
        if skip_run:
            return
41 42 43 44 45
        if (
            self.ctx.args.master is None
            and self.ctx.args.start_port
            and self.ctx.args.ips
        ):
C
Chitsing KUI 已提交
46
            return self._build_pod_with_args()
K
kuizhiqing 已提交
47
        else:
C
Chitsing KUI 已提交
48
            return self._build_pod_with_master()
K
kuizhiqing 已提交
49

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
    def _build_pod_with_tuner(self):
        auto_parallel_config = self.ctx.args.auto_parallel_config
        if auto_parallel_config is not None:
            if not os.path.exists(auto_parallel_config):
                self.ctx.logger.warning("auto_parallel_conf not exists!")
            if not auto_parallel_config.endswith(".json"):
                self.ctx.logger.warning(
                    "auto_parallel_config should be a json format file!"
                )

            with open(auto_parallel_config, 'r') as robj:
                auto_parallel_data = json.loads(robj.read())
                self._tuner_run_mode = auto_parallel_data.get(
                    "tuner_run_mode", 'tuner_and_run'
                )

            self.ctx.logger.info(f"tuner_run_mode is: {self._tuner_run_mode}")
            endpoint = f"127.0.0.1:{self.ctx.node.get_free_port()}"
            pod_replicas = self.pod_replicas()
            if self._tuner_run_mode in ['tuner_only', 'tuner_and_run']:
                e = {
                    "PADDLE_AUTO_PARALLEL_CONFIG": self.ctx.args.auto_parallel_config,
                    "PADDLE_TRAINERS_NUM": "1",
                    "PADDLE_TRAINER_ENDPOINTS": endpoint,
                    "PADDLE_TRAINER_ID": "0",
                    "PADDLE_CURRENT_ENDPOINT": endpoint,
                    "FLAGS_selected_gpus": "0",
                    "PADDLE_AUTO_PARALLEL_STAGE": "tuner",
                    "PADDLE_GLOBAL_SIZE": "{}".format(
                        pod_replicas * int(self.ctx.args.nnodes)
                    ),
                    "PADDLE_LOCAL_SIZE": f"{pod_replicas}",
                }
                log_file = "tuner.log"
                self.add_container(envs=e, log_file=log_file, is_init=True)

                if self._tuner_run_mode == 'tuner_only':
                    return True
        return False

K
kuizhiqing 已提交
90 91 92 93 94 95 96 97 98 99
    def _build_pod_with_args(self):
        self.pod.replicas = self.pod_replicas()

        start_port = int(self.ctx.args.start_port)
        ips = self.ctx.args.ips.split(',')

        job_endpoints = [
            f"{h}:{p+start_port}" for h in ips for p in range(self.pod.replicas)
        ]

100
        self.ctx.logger.debug(f"job endpoints: {job_endpoints}")
K
kuizhiqing 已提交
101

102 103 104 105 106
        rank_offset = (
            ips.index(self.ctx.node.ip) * self.pod.replicas
            if self.ctx.node.ip in ips
            else 0
        )
K
kuizhiqing 已提交
107 108 109 110 111

        self.save_pod_log(job_endpoints)

        selected_dev_key = self.ctx.node.device.get_selected_device_key()
        selected_dev_list = self.ctx.node.device.get_selected_devices(
112 113
            self.ctx.args.devices
        )
K
kuizhiqing 已提交
114 115 116

        for i in range(self.pod.replicas):
            e = {
117 118 119 120 121
                "PADDLE_GLOBAL_SIZE": f"{len(job_endpoints)}",
                "PADDLE_LOCAL_SIZE": f"{self.pod.replicas}",
                "PADDLE_GLOBAL_RANK": f"{i + rank_offset}",
                "PADDLE_LOCAL_RANK": f"{i}",
                "PADDLE_NNODES": f"{len(ips)}",
122
                # compatible env
K
kuizhiqing 已提交
123 124
                "PADDLE_TRAINER_ENDPOINTS": ",".join(job_endpoints),
                "PADDLE_CURRENT_ENDPOINT": job_endpoints[i + rank_offset],
125 126
                "PADDLE_TRAINER_ID": f"{i + rank_offset}",
                "PADDLE_TRAINERS_NUM": f"{len(job_endpoints)}",
K
kuizhiqing 已提交
127 128
                "PADDLE_RANK_IN_NODE": str(i),
            }
129 130 131 132 133 134 135
            if self._tuner_run_mode is not None:
                e.update(
                    {
                        "PADDLE_AUTO_PARALLEL_CONFIG": self.ctx.args.auto_parallel_config,
                        "PADDLE_AUTO_PARALLEL_STAGE": "run",
                    }
                )
K
kuizhiqing 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
            if len(selected_dev_list) > 0:
                if self.ctx.node.device.dtype == DeviceType.CUSTOM_DEVICE:
                    e.update(self.ctx.node.device.get_custom_device_envs())
                if self.pod.replicas == 1:
                    e.update({selected_dev_key: ",".join(selected_dev_list)})
                else:
                    e.update({selected_dev_key: selected_dev_list[i]})
            else:
                e.update({'PADDLE_DISTRI_BACKEND': 'gloo'})

            log_file = f"workerlog.{i}"
            self.add_container(envs=e, log_file=log_file)

        return True

    def _build_pod_with_master(self):
K
kuizhiqing 已提交
152 153 154
        self.pod.replicas = self.pod_replicas()

        # rank will be reset when restart
K
kuizhiqing 已提交
155
        self.pod.rank = int(self.ctx.args.rank)
K
kuizhiqing 已提交
156 157 158 159 160

        port = self.ctx.node.get_free_port()

        # compatible
        endpoints = [
161
            f"{self.ctx.node.ip}:{p}"
K
kuizhiqing 已提交
162 163 164
            for p in self.ctx.node.get_free_ports(self.pod.replicas)
        ]

165 166 167 168 169 170
        data = json.dumps(
            {
                'name': self.pod.name,
                'rank': self.pod.rank,
                'replicas': self.pod.replicas,
                'dtype': self.ctx.node.device.dtype,
171
                'candidate': f'{self.ctx.node.ip}:{port}',
172 173 174 175 176
                'endpoints': ",".join(endpoints),
            }
        )

        peer_list, rank = self.master.sync_peers(
177
            f'/{self.job.id}/info',
178 179 180 181 182
            self.pod.name,
            data,
            self.job.replicas,
            self.pod.rank,
        )
K
kuizhiqing 已提交
183 184 185 186 187 188 189
        self.pod.rank = rank

        if len(peer_list) < 1:
            return False

        peer_list = [json.loads(i) for i in peer_list]

190
        self.ctx.logger.debug(f"sync peers done {peer_list}")
K
kuizhiqing 已提交
191 192 193 194 195 196 197 198 199 200 201
        self.save_pod_log(peer_list)

        global_size = sum([i['replicas'] for i in peer_list])
        rank_offset = sum([i['replicas'] for i in peer_list[:rank]])
        '''
        The new designed collective need nothing but a master endpoint
        '''
        collective_master = peer_list[0]['candidate']

        job_endpoints = [i['endpoints'] for i in peer_list]

202
        # self.pod.reset()
K
kuizhiqing 已提交
203 204
        selected_dev_key = self.ctx.node.device.get_selected_device_key()
        selected_dev_list = self.ctx.node.device.get_selected_devices(
205 206
            self.ctx.args.devices
        )
K
kuizhiqing 已提交
207 208 209
        for i in range(self.pod.replicas):
            e = {
                "PADDLE_MASTER": collective_master,
210 211 212 213 214
                "PADDLE_GLOBAL_SIZE": f"{global_size}",
                "PADDLE_LOCAL_SIZE": f"{self.pod.replicas}",
                "PADDLE_GLOBAL_RANK": f"{i + rank_offset}",
                "PADDLE_LOCAL_RANK": f"{i}",
                "PADDLE_NNODES": f"{self.job.replicas}",
215
                # compatible env
K
kuizhiqing 已提交
216 217
                "PADDLE_TRAINER_ENDPOINTS": ",".join(job_endpoints),
                "PADDLE_CURRENT_ENDPOINT": endpoints[i],
218 219
                "PADDLE_TRAINER_ID": f"{i + rank_offset}",
                "PADDLE_TRAINERS_NUM": f"{global_size}",
K
kuizhiqing 已提交
220 221
                "PADDLE_RANK_IN_NODE": str(i),
            }
222 223 224 225 226 227 228
            if self._tuner_run_mode is not None:
                e.update(
                    {
                        "PADDLE_AUTO_PARALLEL_CONFIG": self.ctx.args.auto_parallel_config,
                        "PADDLE_AUTO_PARALLEL_STAGE": "run",
                    }
                )
229
            if len(selected_dev_list) > 0:
230 231
                if self.ctx.node.device.dtype == DeviceType.CUSTOM_DEVICE:
                    e.update(self.ctx.node.device.get_custom_device_envs())
232 233 234 235
                if self.pod.replicas == 1:
                    e.update({selected_dev_key: ",".join(selected_dev_list)})
                else:
                    e.update({selected_dev_key: selected_dev_list[i]})
236
            else:
237 238
                e.update({'PADDLE_DISTRI_BACKEND': 'gloo'})

239 240 241
            # log_file = "{}.{}.{}.log".format(self.job.id, self.pod.name, i)
            log_file = f"workerlog.{i}"
            self.add_container(envs=e, log_file=log_file)
K
kuizhiqing 已提交
242 243 244 245 246 247 248 249

        return True


class CollectiveElasticController(CollectiveController):
    @classmethod
    def enable(cls, ctx):
        if ctx.args.master and ctx.args.master.startswith("etcd://"):
250
            ctx.logger.debug(f"{cls.__name__} enabled")
K
kuizhiqing 已提交
251
            ctx.args.run_mode = ControleMode.COLLECTIVE
K
kuizhiqing 已提交
252 253 254 255 256 257 258
            return True
        else:
            return False

    def register(self):
        if self.job.id == 'default':
            self.ctx.logger.warning(
259 260
                'Using default job name may cause conflict, add --job_id in args'
            )
K
kuizhiqing 已提交
261 262 263 264 265

        self.master.register_heartbeat(self.job.id, self.pod.name)

    def run(self):

K
kuizhiqing 已提交
266 267
        timeout = int(self.ctx.args.elastic_timeout)
        timeout = timeout if self.job.elastic else timeout * 10
K
kuizhiqing 已提交
268 269 270 271 272 273
        self.register()

        while self.pod.restart <= self.ctx.args.max_restart:

            self.build_job()

274 275
            self.ctx.logger.info("Waiting peer ready...")

276 277 278
            ok, replicas = self.master.wait_peer_ready(
                self.job.replicas_min, self.job.replicas_max, timeout
            )
K
kuizhiqing 已提交
279 280 281
            if ok:
                self.job.replicas = replicas
            else:
282
                self.ctx.logger.warning(f"peer not ready {self.job}")
K
kuizhiqing 已提交
283 284
                break

285
            self.ctx.logger.debug(f"Run {self.job}")
K
kuizhiqing 已提交
286 287 288 289 290 291

            if not self.build_pod():
                continue

            self.master.set_status(self.ctx.status.RUNNING)

292
            self.deploy_pod()
K
kuizhiqing 已提交
293 294 295 296

            if self.watch():
                break

297
        self.ctx.logger.debug(f"Job done {self.job}")