controller.py 9.4 KB
Newer Older
K
kuizhiqing 已提交
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
K
kuizhiqing 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
K
kuizhiqing 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
K
kuizhiqing 已提交
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import signal
17
import sys
K
kuizhiqing 已提交
18

19
from paddle.distributed.launch.job.container import Container
20 21
from paddle.distributed.launch.job.job import Job
from paddle.distributed.launch.job.pod import Pod
K
kuizhiqing 已提交
22 23

from .master import Master
24
from .watcher import Watcher
K
kuizhiqing 已提交
25 26 27 28 29


class ControleMode:
    COLLECTIVE = "collective"
    PS = "ps"
30
    IPU = "ipu"
31
    RPC = "rpc"
K
kuizhiqing 已提交
32 33


34
class ControllerBase:
K
kuizhiqing 已提交
35 36 37 38
    def __init__(self, ctx):
        signal.signal(signal.SIGTERM, self.signal_handler)
        signal.signal(signal.SIGABRT, self.signal_handler)
        signal.signal(signal.SIGINT, self.signal_handler)
39 40 41 42 43 44 45
        if ctx.is_auto_tuner_mode():
            if not ctx.run_best:
                # set per task timeout
                signal.signal(signal.SIGALRM, self.not_exit_signal_handler)
                signal.alarm(ctx.max_time_per_task)
            else:
                signal.alarm(0)
K
kuizhiqing 已提交
46 47 48 49

        self.ctx = ctx
        self.master = Master.factory(self.ctx)

50 51
        self.watcher = Watcher(self.ctx)

52 53 54 55 56
        self.job = Job(
            nnodes=self.ctx.args.nnodes,
            mode=self.ctx.args.run_mode,
            jid=self.ctx.args.job_id,
        )
K
kuizhiqing 已提交
57 58
        self.pod = Pod()

K
kuizhiqing 已提交
59 60
        self.ctx.set_envs({"POD_NAME": self.pod.name})

K
kuizhiqing 已提交
61 62
        self.join_server = None

63
    def deploy_pod(self):
64 65 66
        assert (
            len(self.pod.containers) + len(self.pod.init_containers) > 0
        ), "No container in the pod"
K
kuizhiqing 已提交
67

68
        self.ctx.logger.info(f"Run {self.pod}")
69 70 71 72
        if len(self.pod.init_containers) > 0:
            self.ctx.logger.debug(self.pod.init_containers[0])
        if len(self.pod.containers) > 0:
            self.ctx.logger.debug(self.pod.containers[0])
K
kuizhiqing 已提交
73

74
        self.save_pod_env()
75
        self.ctx.status.run()
K
kuizhiqing 已提交
76 77
        self.pod.deploy()

78 79 80 81 82 83
    def run(self):
        self.build_job()
        self.build_pod()

        self.deploy_pod()

K
kuizhiqing 已提交
84 85 86
        self.watch()

    def watch(self) -> bool:
87 88 89
        '''
        watch self and peer status, return true to exit
        '''
90
        # TODO(kuizhiqing) unify ctx.status and master status
91

92
        self.ctx.logger.info(f"Watching {self.pod}")
93

94 95 96
        while not self.ctx.status.is_done():
            status = self.pod.watch(timeout=2)

97
            # if self.ctx.continous_log():
98 99
            # default to print log
            self.pod.logs()
100 101 102 103 104 105 106

            # completed
            if status == self.ctx.status.COMPLETED:
                self.ctx.status.complete()

                self.master.set_status(status)

107 108 109
                while self.pod.logs():
                    pass

110
                self.ctx.logger.info(f"Pod {status}")
111 112 113 114 115 116 117 118 119 120
                return True

            # self failure
            elif status == self.ctx.status.FAILED:
                self.ctx.status.fail()

                self.master.set_status(status)
                self.master.restart_peer()

                fc = self.pod.failed_container()
121 122
                self.ctx.logger.info(f"Pod {status}")
                self.ctx.logger.error(f"Container failed !!!\n{fc[0]}")
123 124 125
                self.ctx.logger.info(
                    "------------------------- ERROR LOG DETAIL -------------------------"
                )
126 127 128
                fc[0].tail()

                if self.ctx.args.elastic_level <= 0:
K
kuizhiqing 已提交
129
                    self.pod.stop(timeout=3)
130 131
                    return True
                else:
K
kuizhiqing 已提交
132
                    self.pod.stop(timeout=30)
133
                    return False
K
kuizhiqing 已提交
134

135
            # peer failure
136 137 138 139
            if (
                self.ctx.status.is_restarting()
                and self.master.get_status() != self.ctx.status.COMPLETED
            ):
140 141 142 143 144
                # when peer failure, stop peer
                if self.ctx.args.elastic_level == -1:
                    self.pod.stop(timeout=3)
                    return True

K
kuizhiqing 已提交
145
                self.pod.stop(timeout=30)
146
                return False
K
kuizhiqing 已提交
147 148 149

    def stop(self, sigint=None):
        self.ctx.logger.debug("Controller stop")
150 151 152

        self.watcher.stop()

K
kuizhiqing 已提交
153
        self.master.stop()
K
kuizhiqing 已提交
154
        self.pod.stop(timeout=30)
K
kuizhiqing 已提交
155

156
    def finalize(self, exit=True):
K
kuizhiqing 已提交
157 158 159
        self.pod.join()
        self.master.stop()

160
        self.ctx.logger.info(f"Exit code {self.pod.exit_code}")
161 162
        if exit:
            sys.exit(self.pod.exit_code)
K
kuizhiqing 已提交
163 164 165

    def signal_handler(self, sigint, frame):
        if hasattr(self, 'sigint'):
K
kuizhiqing 已提交
166
            self.ctx.logger.info("Force quit in 10 seconds...")
K
kuizhiqing 已提交
167
            self.pod.stop(timeout=10)
K
kuizhiqing 已提交
168 169
            sys.exit(sigint)

170
        self.ctx.logger.info(f"Terminating with signal {sigint}")
K
kuizhiqing 已提交
171

K
kuizhiqing 已提交
172 173
        self.sigint = sigint
        self.ctx.status.done()
K
kuizhiqing 已提交
174
        self.stop(sigint=sigint)
175
        self.ctx.logger.info(f"Exit with signal {sigint}")
K
kuizhiqing 已提交
176 177
        sys.exit(sigint)

178 179 180 181 182 183 184 185 186 187 188 189
    def not_exit_signal_handler(self, sigint, frame):
        if hasattr(self, 'sigint'):
            self.ctx.logger.info("Force quit in 10 seconds...")
            self.pod.stop(timeout=10)

        self.ctx.logger.info(f"Terminating with signal {sigint}")

        self.sigint = sigint
        self.ctx.status.done()
        self.stop(sigint=sigint)
        self.ctx.logger.info(f"Exit with signal {sigint}")

K
kuizhiqing 已提交
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210

class Controller(ControllerBase):
    '''
    Controller API for customization
    '''

    def build_job(self):
        '''
        build job fill the job info.
        '''
        self.ctx.logger.info(self.job)

    def build_pod(self) -> bool:
        '''
        build pod includes creating containers etc.

        Return True if succeed
        '''
        raise NotImplementedError

    def _get_entrypoint(self):
K
kuizhiqing 已提交
211 212 213 214 215
        if self.ctx.args.training_script.endswith('.py'):
            entrypoint = [sys.executable, "-u", self.ctx.args.training_script]
        else:
            entrypoint = [self.ctx.args.training_script]

K
kuizhiqing 已提交
216 217 218 219 220 221 222 223 224 225
        entrypoint.extend(self.ctx.args.training_script_args)
        return entrypoint

    def _get_out_err_file(self, out=None, err=None):
        if out and self.ctx.args.log_dir != "":
            out = os.path.join(self.ctx.args.log_dir, out)
        if err and self.ctx.args.log_dir != "":
            err = os.path.join(self.ctx.args.log_dir, err)
        return out, (err or out)

226 227 228
    def new_container(
        self, entrypoint=None, envs={}, use_ctx_env=True, out=None, err=None
    ):
K
kuizhiqing 已提交
229 230
        c = Container(
            entrypoint=(entrypoint or self._get_entrypoint()),
231
            env=(self.ctx.get_envs() if use_ctx_env else {}),
232
            overwrite_log=self.ctx.args.log_overwrite,
233
        )
K
kuizhiqing 已提交
234 235 236 237
        c.outfile, c.errfile = self._get_out_err_file(out, err)
        c.update_env(envs)
        return c

238 239 240 241 242 243 244 245
    def add_container(
        self,
        container=None,
        entrypoint=None,
        envs={},
        log_file=None,
        is_init=False,
    ):
K
kuizhiqing 已提交
246
        if not container:
247 248 249
            container = self.new_container(
                entrypoint=entrypoint, envs=envs, out=log_file, err=log_file
            )
K
kuizhiqing 已提交
250 251 252 253 254 255 256 257 258 259 260 261 262

        if is_init:
            self.pod.add_init_container(container)
        else:
            self.pod.add_container(container)

    def pod_replicas(self):
        '''
        how many process/container should be run in pod
        '''

        if self.ctx.args.nproc_per_node:
            return int(self.ctx.args.nproc_per_node)
K
kuizhiqing 已提交
263 264
        elif self.ctx.args.devices:
            return len(self.ctx.args.devices.split(','))
K
kuizhiqing 已提交
265 266 267 268 269 270 271 272 273 274
        else:
            return self.ctx.node.device.count

    def save_pod_log(self, info):
        '''
        save_pod_log append *info* to the log file of pod.name
        '''
        if not self.ctx.args.log_dir:
            return

275 276
        f = os.path.join(
            self.ctx.args.log_dir,
277
            f'{self.job.id}.{self.pod.name}.log',
278
        )
K
kuizhiqing 已提交
279 280 281
        try:
            os.makedirs(os.path.dirname(f), exist_ok=True)
            with open(f, 'a+') as fd:
282 283 284
                if fd.tell() == 0:
                    fd.write(str(os.environ))
                    fd.write("\n")
K
kuizhiqing 已提交
285
                fd.write(str(info))
286
                fd.write("\n")
K
kuizhiqing 已提交
287
        except Exception as e:
288
            self.ctx.logger.error(f"save log failed because {e}")
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312

    def save_pod_env(self):
        assert (
            len(self.pod.containers) + len(self.pod.init_containers) > 0
        ), "No container in the pod"

        if not self.ctx.args.log_dir:
            return

        for c in self.pod.init_containers:
            self._save_container_env(c, is_init=True)

        for c in self.pod.containers:
            self._save_container_env(c)

    def _save_container_env(self, container, is_init=False):
        f = os.path.join(
            self.ctx.args.log_dir,
            f'envlog.init.{container.rank}'
            if is_init
            else f'envlog.{container.rank}',
        )
        try:
            os.makedirs(os.path.dirname(f), exist_ok=True)
313
            with open(f, container.log_mode) as fd:
314 315 316 317
                for k, v in sorted(container.env.items()):
                    fd.write(str(f"{k}={v}\n"))
        except Exception as e:
            self.ctx.logger.error(f"save pod env log failed because {e}")